Skip to main content

gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::caching::{CachedContent, CachedContentList, CachedContentUpdate};
3use super::types::request::*;
4use super::types::response::*;
5use super::types::sessions::Session;
6#[cfg(feature = "reqwest")]
7use reqwest::Client;
8use serde_json::{Value, json};
9use std::time::Duration;
10
11const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
12
13/// The main client for interacting with the Gemini API.
14///
15/// Use `Gemini::new` or `Gemini::new_with_timeout` to create an instance.
16/// You can configure various aspects of the request like model, system instructions,
17/// generation config, safety settings, and tools using the provided builder-like methods.
18#[derive(Clone, Default, Debug)]
19pub struct Gemini {
20    #[cfg(feature = "reqwest")]
21    client: Client,
22    api_key: String,
23    model: String,
24    sys_prompt: Option<SystemInstruction>,
25    generation_config: Option<Value>,
26    safety_settings: Option<Vec<SafetySetting>>,
27    tools: Option<Vec<Tool>>,
28    tool_config: Option<ToolConfig>,
29    cached_content: Option<String>,
30}
31
32impl Gemini {
33    /// Creates a new `Gemini` client.
34    ///
35    /// # Arguments
36    /// * `api_key` - Your Gemini API key. Get one from [Google AI studio](https://aistudio.google.com/app/apikey).
37    /// * `model` - The model variation to use (e.g., "gemini-2.5-flash"). See [model variations](https://ai.google.dev/gemini-api/docs/models#model-variations).
38    /// * `sys_prompt` - Optional system instructions. See [system instructions](https://ai.google.dev/gemini-api/docs/text-generation#image-input).
39    #[cfg(feature = "reqwest")]
40    pub fn new(
41        api_key: impl Into<String>,
42        model: impl Into<String>,
43        sys_prompt: Option<SystemInstruction>,
44    ) -> Self {
45        Self {
46            client: Client::builder()
47                .timeout(Duration::from_secs(60))
48                .build()
49                .unwrap(),
50            api_key: api_key.into(),
51            model: model.into(),
52            sys_prompt,
53            generation_config: None,
54            safety_settings: None,
55            tools: None,
56            tool_config: None,
57            cached_content: None,
58        }
59    }
60    /// Creates a new `Gemini` client with a custom API timeout.
61    ///
62    /// # Arguments
63    /// * `api_key` - Your Gemini API key.
64    /// * `model` - The model variation to use.
65    /// * `sys_prompt` - Optional system instructions.
66    /// * `api_timeout` - Custom duration for request timeouts.
67    #[cfg(feature = "reqwest")]
68    pub fn new_with_timeout(
69        api_key: impl Into<String>,
70        model: impl Into<String>,
71        sys_prompt: Option<SystemInstruction>,
72        api_timeout: Duration,
73    ) -> Self {
74        Self {
75            client: Client::builder().timeout(api_timeout).build().unwrap(),
76            api_key: api_key.into(),
77            model: model.into(),
78            sys_prompt,
79            generation_config: None,
80            safety_settings: None,
81            tools: None,
82            tool_config: None,
83            cached_content: None,
84        }
85    }
86    /// Returns a mutable reference to the generation configuration.
87    /// If not already set, initializes it to an empty object.
88    ///
89    /// See [Gemini docs](https://ai.google.dev/api/generate-content#generationconfig) for schema details.
90    pub fn set_generation_config(&mut self) -> &mut Value {
91        if let None = self.generation_config {
92            self.generation_config = Some(json!({}));
93        }
94        self.generation_config.as_mut().unwrap()
95    }
96    pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
97        self.tool_config = Some(config);
98        self
99    }
100    pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
101        if let Value::Object(map) = self.set_generation_config() {
102            if let Ok(thinking_value) = serde_json::to_value(config) {
103                map.insert("thinking_config".to_string(), thinking_value);
104            }
105        }
106        self
107    }
108    pub fn set_model(mut self, model: impl Into<String>) -> Self {
109        self.model = model.into();
110        self
111    }
112    pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
113        self.sys_prompt = sys_prompt;
114        self
115    }
116    pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
117        self.safety_settings = settings;
118        self
119    }
120    pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
121        self.api_key = api_key.into();
122        self
123    }
124    /// Sets the response format to JSON mode with a specific schema.
125    ///
126    /// To use a Rust struct as a schema, decorate it with `#[gemini_schema]` and pass
127    /// `StructName::gemini_schema()`.
128    ///
129    /// # Arguments
130    /// * `schema` - The JSON schema for the response. See [Gemini Schema docs](https://ai.google.dev/api/caching#Schema).
131    pub fn set_json_mode(mut self, schema: Value) -> Self {
132        let config = self.set_generation_config();
133        config["response_mime_type"] = "application/json".into();
134        config["response_schema"] = schema.into();
135        self
136    }
137    pub fn remove_json_mode(mut self) -> Self {
138        if let Some(ref mut generation_config) = self.generation_config {
139            generation_config["response_schema"] = None::<Value>.into();
140            generation_config["response_mime_type"] = None::<Value>.into();
141        }
142        self
143    }
144    /// Sets the tools (functions) available to the model.
145    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
146        self.tools = Some(tools);
147        self
148    }
149    /// Removes all tools.
150    pub fn remove_tools(mut self) -> Self {
151        self.tools = None;
152        self
153    }
154    pub fn set_cached_content(mut self, name: impl Into<String>) -> Self {
155        self.cached_content = Some(name.into());
156        self
157    }
158    pub fn remove_cached_content(mut self) -> Self {
159        self.cached_content = None;
160        self
161    }
162
163    // Cache management methods
164
165    #[cfg(feature = "reqwest")]
166    pub async fn create_cache(
167        &self,
168        cached_content: &CachedContent,
169    ) -> Result<CachedContent, GeminiResponseError> {
170        let req_url = format!(
171            "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
172            self.api_key
173        );
174
175        let response = self
176            .client
177            .post(req_url)
178            .json(cached_content)
179            .send()
180            .await
181            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
182
183        if !response.status().is_success() {
184            let text = response
185                .text()
186                .await
187                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
188            return Err(GeminiResponseError::StatusNotOk(text));
189        }
190
191        let cached_content: CachedContent = response
192            .json()
193            .await
194            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
195        Ok(cached_content)
196    }
197
198    #[cfg(feature = "reqwest")]
199    pub async fn list_caches(&self) -> Result<CachedContentList, GeminiResponseError> {
200        let req_url = format!(
201            "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
202            self.api_key
203        );
204
205        let response = self
206            .client
207            .get(req_url)
208            .send()
209            .await
210            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
211
212        if !response.status().is_success() {
213            let text = response
214                .text()
215                .await
216                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
217            return Err(GeminiResponseError::StatusNotOk(text));
218        }
219
220        let list: CachedContentList = response
221            .json()
222            .await
223            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
224        Ok(list)
225    }
226
227    #[cfg(feature = "reqwest")]
228    pub async fn get_cache(&self, name: &str) -> Result<CachedContent, GeminiResponseError> {
229        let req_url = format!(
230            "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
231            name, self.api_key
232        );
233
234        let response = self
235            .client
236            .get(req_url)
237            .send()
238            .await
239            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
240
241        if !response.status().is_success() {
242            let text = response
243                .text()
244                .await
245                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
246            return Err(GeminiResponseError::StatusNotOk(text));
247        }
248
249        let cached_content: CachedContent = response
250            .json()
251            .await
252            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
253        Ok(cached_content)
254    }
255
256    #[cfg(feature = "reqwest")]
257    pub async fn update_cache(
258        &self,
259        name: &str,
260        update: &CachedContentUpdate,
261    ) -> Result<CachedContent, GeminiResponseError> {
262        let req_url = format!(
263            "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
264            name, self.api_key
265        );
266
267        let response = self
268            .client
269            .patch(req_url)
270            .json(update)
271            .send()
272            .await
273            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
274
275        if !response.status().is_success() {
276            let text = response
277                .text()
278                .await
279                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
280            return Err(GeminiResponseError::StatusNotOk(text));
281        }
282
283        let cached_content: CachedContent = response
284            .json()
285            .await
286            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
287        Ok(cached_content)
288    }
289
290    #[cfg(feature = "reqwest")]
291    pub async fn delete_cache(&self, name: &str) -> Result<(), GeminiResponseError> {
292        let req_url = format!(
293            "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
294            name, self.api_key
295        );
296
297        let response = self
298            .client
299            .delete(req_url)
300            .send()
301            .await
302            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
303
304        if !response.status().is_success() {
305            let text = response
306                .text()
307                .await
308                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
309            return Err(GeminiResponseError::StatusNotOk(text));
310        }
311
312        Ok(())
313    }
314
315    /// Sends a prompt to the model and waits for the full response.
316    ///
317    /// Updates the `session` history with the model's reply.
318    ///
319    /// # Errors
320    /// Returns `GeminiResponseError::NothingToRespond` if the last message in history is from the model.
321    #[cfg(feature = "reqwest")]
322    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
323        if session
324            .get_last_chat()
325            .is_some_and(|chat| *chat.role() == Role::Model)
326        {
327            return Err(GeminiResponseError::NothingToRespond);
328        }
329        let req_url = format!(
330            "{BASE_URL}/{}:generateContent?key={}",
331            self.model, self.api_key
332        );
333
334        let response = self
335            .client
336            .post(req_url)
337            .json(&GeminiRequestBody::new(
338                self.sys_prompt.as_ref(),
339                self.tools.as_deref(),
340                &session.get_history().as_slice(),
341                self.generation_config.as_ref(),
342                self.safety_settings.as_deref(),
343                self.tool_config.as_ref(),
344                self.cached_content.clone(),
345            ))
346            .send()
347            .await
348            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
349
350        if !response.status().is_success() {
351            let text = response
352                .text()
353                .await
354                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
355            return Err(GeminiResponseError::StatusNotOk(text));
356        }
357
358        let reply = GeminiResponse::new(response)
359            .await
360            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
361        session.update(&reply);
362        Ok(reply)
363    }
364    /// # Warning
365    /// You must read the response stream to get reply stored context in `session`.
366    /// `data_extractor` is used to extract data that you get as a stream of futures.
367    /// # Example
368    ///```ignore
369    ///use futures::StreamExt
370    ///let mut response_stream = gemini.ask_as_stream_with_extractor(session,
371    ///     |session, _gemini_response| session.get_last_chat().unwrap().get_text_no_think("\n"))
372    ///    .await.unwrap(); // Use _gemini_response.get_text("") to just get the text received in every chunk
373    ///while let Some(response) = response_stream.next().await {
374    ///    println!("{}", response.unwrap());
375    ///}
376    ///```
377    #[cfg(feature = "reqwest")]
378    pub async fn ask_as_stream_with_extractor<F, StreamType>(
379        &self,
380        session: Session,
381        data_extractor: F,
382    ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
383    where
384        F: FnMut(&Session, GeminiResponse) -> StreamType,
385    {
386        if session
387            .get_last_chat()
388            .is_some_and(|chat| *chat.role() == Role::Model)
389        {
390            return Err((session, GeminiResponseError::NothingToRespond));
391        }
392        let req_url = format!(
393            "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
394            self.model, self.api_key
395        );
396
397        let request = self
398            .client
399            .post(req_url)
400            .json(&GeminiRequestBody::new(
401                self.sys_prompt.as_ref(),
402                self.tools.as_deref(),
403                session.get_history().as_slice(),
404                self.generation_config.as_ref(),
405                self.safety_settings.as_deref(),
406                self.tool_config.as_ref(),
407                self.cached_content.clone(),
408            ))
409            .send()
410            .await;
411        let response = match request {
412            Ok(response) => response,
413            Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
414        };
415
416        if !response.status().is_success() {
417            let text = match response.text().await {
418                Ok(response) => response,
419                Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
420            };
421            return Err((session, GeminiResponseError::StatusNotOk(text.into())));
422        }
423
424        Ok(ResponseStream::new(
425            Box::new(response.bytes_stream()),
426            session,
427            data_extractor,
428        ))
429    }
430    /// Sends a prompt to the model and returns a stream of responses.
431    ///
432    /// # Warning
433    /// You must exhaust the response stream to ensure the `session` history is correctly updated.
434    ///
435    /// # Example
436    /// ```no_run
437    /// use futures::StreamExt;
438    /// # async fn run(gemini: gemini_client_api::gemini::ask::Gemini, session: gemini_client_api::gemini::types::sessions::Session) {
439    /// let mut response_stream = gemini.ask_as_stream(session).await.unwrap();
440    ///
441    /// while let Some(response) = response_stream.next().await {
442    ///     if let Ok(response) = response {
443    ///         println!("{}", response.get_chat().get_text_no_think("\n"));
444    ///     }
445    /// }
446    /// # }
447    /// ```
448    #[cfg(feature = "reqwest")]
449    pub async fn ask_as_stream(
450        &self,
451        session: Session,
452    ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
453        self.ask_as_stream_with_extractor(
454            session,
455            (|_, gemini_response| gemini_response)
456                as fn(&Session, GeminiResponse) -> GeminiResponse,
457        )
458        .await
459    }
460}