gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13    client: Client,
14    api_key: String,
15    model: String,
16    sys_prompt: Option<SystemInstruction>,
17    generation_config: Option<Value>,
18    tools: Option<Vec<Tool>>,
19}
20impl Gemini {
21    /// # Arguments
22    /// `api_key` get one from [Google AI studio](https://aistudio.google.com/app/apikey)
23    /// `model` should be of those mentioned [here](https://ai.google.dev/gemini-api/docs/models#model-variations) in bold black color
24    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
25    pub fn new(
26        api_key: impl Into<String>,
27        model: impl Into<String>,
28        sys_prompt: Option<SystemInstruction>,
29    ) -> Self {
30        Self {
31            client: Client::builder()
32                .timeout(Duration::from_secs(60))
33                .build()
34                .unwrap(),
35            api_key: api_key.into(),
36            model: model.into(),
37            sys_prompt,
38            generation_config: None,
39            tools: None,
40        }
41    }
42    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
43    pub fn new_with_timeout(
44        api_key: impl Into<String>,
45        model: impl Into<String>,
46        sys_prompt: Option<SystemInstruction>,
47        api_timeout: Duration,
48    ) -> Self {
49        Self {
50            client: Client::builder().timeout(api_timeout).build().unwrap(),
51            api_key: api_key.into(),
52            model: model.into(),
53            sys_prompt,
54            generation_config: None,
55            tools: None,
56        }
57    }
58    /// The generation config Schema should follow [Gemini docs](https://ai.google.dev/api/generate-content#generationconfig)
59    pub fn set_generation_config(&mut self) -> &mut Value {
60        if let None = self.generation_config {
61            self.generation_config = Some(json!({}));
62        }
63        self.generation_config.as_mut().unwrap()
64    }
65    pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
66        if let Value::Object(map) = self.set_generation_config() {
67            if let Ok(thinking_value) = serde_json::to_value(config) {
68                map.insert("thinking_config".to_string(), thinking_value);
69            }
70        }
71        self
72    }
73    pub fn set_thoughts_enabled(mut self, enabled: bool) -> Self {
74        let thinking = serde_json::to_value(ThinkingConfig::new(enabled, -1)).unwrap();
75        if let Value::Object(map) = self.set_generation_config() {
76            map.insert("thinking_config".to_string(), thinking);
77        }
78        self
79    }
80    pub fn set_model(mut self, model: impl Into<String>) -> Self {
81        self.model = model.into();
82        self
83    }
84    pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
85        self.sys_prompt = sys_prompt;
86        self
87    }
88    pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
89        self.api_key = api_key.into();
90        self
91    }
92    /// `schema` should follow [Schema of gemini](https://ai.google.dev/api/caching#Schema)
93    /// To verify your schema visit [here](https://aistudio.google.com/prompts/new_chat):
94    /// - Under tools, toggle on Structured output
95    /// - Click Edit
96    /// - Here you can create schema with `Visual Editor` or `Code Editor` with error detection
97    pub fn set_json_mode(mut self, schema: Value) -> Self {
98        let config = self.set_generation_config();
99        config["response_mime_type"] = "application/json".into();
100        config["response_schema"] = schema.into();
101        self
102    }
103    pub fn unset_json_mode(mut self) -> Self {
104        if let Some(ref mut generation_config) = self.generation_config {
105            generation_config["response_schema"] = None::<Value>.into();
106            generation_config["response_mime_type"] = None::<Value>.into();
107        }
108        self
109    }
110    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
111        self.tools = Some(tools);
112        self
113    }
114    pub fn unset_tools(mut self) -> Self {
115        self.tools = None;
116        self
117    }
118
119    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
120        let req_url = format!(
121            "{BASE_URL}/{}:generateContent?key={}",
122            self.model, self.api_key
123        );
124
125        let response = self
126            .client
127            .post(req_url)
128            .json(&GeminiRequestBody::new(
129                self.sys_prompt.as_ref(),
130                self.tools.as_deref(),
131                &session.get_history().as_slice(),
132                self.generation_config.as_ref(),
133            ))
134            .send()
135            .await
136            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
137
138        if !response.status().is_success() {
139            let text = response
140                .text()
141                .await
142                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
143            return Err(GeminiResponseError::StatusNotOk(text));
144        }
145
146        let reply = GeminiResponse::new(response)
147            .await
148            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
149        session.update(&reply);
150        Ok(reply)
151    }
152    /// # Warning
153    /// You must read the response stream to get reply stored context in `session`.
154    /// `data_extractor` is used to extract data that you get as a stream of futures.
155    /// # Example
156    ///```ignore
157    ///use futures::StreamExt
158    ///let mut response_stream = gemini.ask_as_stream_with_extractor(session,
159    ///|session, _gemini_response| session.get_last_message_text("").unwrap())
160    ///.await.unwrap(); // Use _gemini_response.get_text("") to just get the text received in every chunk
161    ///
162    ///while let Some(response) = response_stream.next().await {
163    ///    if let Ok(response) = response {
164    ///        println!("{}", response);
165    ///    }
166    ///}
167    ///```
168    pub async fn ask_as_stream_with_extractor<F, StreamType>(
169        &self,
170        session: Session,
171        data_extractor: F,
172    ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
173    where
174        F: FnMut(&Session, GeminiResponse) -> StreamType,
175    {
176        let req_url = format!(
177            "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
178            self.model, self.api_key
179        );
180
181        let request = self
182            .client
183            .post(req_url)
184            .json(&GeminiRequestBody::new(
185                self.sys_prompt.as_ref(),
186                self.tools.as_deref(),
187                session.get_history().as_slice(),
188                self.generation_config.as_ref(),
189            ))
190            .send()
191            .await;
192        let response = match request {
193            Ok(response) => response,
194            Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
195        };
196
197        if !response.status().is_success() {
198            let text = match response.text().await {
199                Ok(response) => response,
200                Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
201            };
202            return Err((session, GeminiResponseError::StatusNotOk(text.into())));
203        }
204
205        Ok(ResponseStream::new(
206            Box::new(response.bytes_stream()),
207            session,
208            data_extractor,
209        ))
210    }
211    /// # Warning
212    /// You must read the response stream to get reply stored context in `session`.
213    /// # Example
214    ///```ignore
215    ///use futures::StreamExt
216    ///let mut response_stream = gemini.ask_as_stream(session).await.unwrap();
217    ///
218    ///while let Some(response) = response_stream.next().await {
219    ///    if let Ok(response) = response {
220    ///        println!("{}", response.get_text(""));
221    ///    }
222    ///}
223    ///```
224    pub async fn ask_as_stream(
225        &self,
226        session: Session,
227    ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
228        self.ask_as_stream_with_extractor(
229            session,
230            (|_, gemini_response| gemini_response)
231                as fn(&Session, GeminiResponse) -> GeminiResponse,
232        )
233        .await
234    }
235}