Skip to main content

gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::caching::{CachedContent, CachedContentList, CachedContentUpdate};
3use super::types::request::*;
4use super::types::response::*;
5use super::types::sessions::Session;
6#[cfg(feature = "reqwest")]
7use reqwest::Client;
8use serde_json::{Value, json};
9use std::time::Duration;
10
11const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
12
13/// The main client for interacting with the Gemini API.
14///
15/// Use `Gemini::new` or `Gemini::new_with_timeout` to create an instance.
16/// You can configure various aspects of the request like model, system instructions,
17/// generation config, safety settings, and tools using the provided builder-like methods.
18#[derive(Clone, Default, Debug)]
19pub struct Gemini {
20    #[cfg(feature = "reqwest")]
21    client: Client,
22    api_key: String,
23    model: String,
24    sys_prompt: Option<SystemInstruction>,
25    generation_config: Option<Value>,
26    safety_settings: Option<Vec<SafetySetting>>,
27    tools: Option<Vec<Tool>>,
28    tool_config: Option<ToolConfig>,
29    cached_content: Option<String>,
30}
31
32impl Gemini {
33    /// Creates a new `Gemini` client.
34    ///
35    /// # Arguments
36    /// * `api_key` - Your Gemini API key. Get one from [Google AI studio](https://aistudio.google.com/app/apikey).
37    /// * `model` - The model variation to use (e.g., "gemini-2.5-flash"). See [model variations](https://ai.google.dev/gemini-api/docs/models#model-variations).
38    /// * `sys_prompt` - Optional system instructions. See [system instructions](https://ai.google.dev/gemini-api/docs/text-generation#image-input).
39    #[cfg(feature = "reqwest")]
40    pub fn new(
41        api_key: impl Into<String>,
42        model: impl Into<String>,
43        sys_prompt: Option<SystemInstruction>,
44    ) -> Self {
45        Self {
46            client: Client::builder()
47                .timeout(Duration::from_secs(60))
48                .build()
49                .unwrap(),
50            api_key: api_key.into(),
51            model: model.into(),
52            sys_prompt,
53            generation_config: None,
54            safety_settings: None,
55            tools: None,
56            tool_config: None,
57            cached_content: None,
58        }
59    }
60    /// Creates a new `Gemini` client with a custom API timeout.
61    ///
62    /// # Arguments
63    /// * `api_key` - Your Gemini API key.
64    /// * `model` - The model variation to use.
65    /// * `sys_prompt` - Optional system instructions.
66    /// * `api_timeout` - Custom duration for request timeouts.
67    #[deprecated]
68    #[cfg(feature = "reqwest")]
69    pub fn new_with_timeout(
70        api_key: impl Into<String>,
71        model: impl Into<String>,
72        sys_prompt: Option<SystemInstruction>,
73        api_timeout: Duration,
74    ) -> Self {
75        Self {
76            client: Client::builder().timeout(api_timeout).build().unwrap(),
77            api_key: api_key.into(),
78            model: model.into(),
79            sys_prompt,
80            generation_config: None,
81            safety_settings: None,
82            tools: None,
83            tool_config: None,
84            cached_content: None,
85        }
86    }
87    /// Creates a new `Gemini` client with a custom API reqwest::Client.
88    ///
89    /// # Arguments
90    /// * `api_key` - Your Gemini API key.
91    /// * `model` - The model variation to use.
92    /// * `sys_prompt` - Optional system instructions.
93    /// * `client` - reqwest::Client to request gemini API.
94    #[cfg(feature = "reqwest")]
95    pub fn new_with_client(
96        api_key: impl Into<String>,
97        model: impl Into<String>,
98        sys_prompt: Option<SystemInstruction>,
99        client: Client,
100    ) -> Self {
101        Self {
102            client,
103            api_key: api_key.into(),
104            model: model.into(),
105            sys_prompt,
106            generation_config: None,
107            safety_settings: None,
108            tools: None,
109            tool_config: None,
110            cached_content: None,
111        }
112    }
113    /// Returns a mutable reference to the generation configuration.
114    /// If not already set, initializes it to an empty object.
115    ///
116    /// See [Gemini docs](https://ai.google.dev/api/generate-content#generationconfig) for schema details.
117    pub fn set_generation_config(&mut self) -> &mut Value {
118        if let None = self.generation_config {
119            self.generation_config = Some(json!({}));
120        }
121        self.generation_config.as_mut().unwrap()
122    }
123    pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
124        self.tool_config = Some(config);
125        self
126    }
127    pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
128        if let Value::Object(map) = self.set_generation_config() {
129            if let Ok(thinking_value) = serde_json::to_value(config) {
130                map.insert("thinking_config".to_string(), thinking_value);
131            }
132        }
133        self
134    }
135    pub fn set_model(mut self, model: impl Into<String>) -> Self {
136        self.model = model.into();
137        self
138    }
139    /// # Warning
140    /// Changing sys_prompt in middle of a conversation can confuse the model.
141    pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
142        self.sys_prompt = sys_prompt;
143        self
144    }
145    pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
146        self.safety_settings = settings;
147        self
148    }
149    pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
150        self.api_key = api_key.into();
151        self
152    }
153    /// Sets the response format to JSON mode with a specific schema.
154    ///
155    /// To use a Rust struct as a schema, decorate it with `#[gemini_schema]` and pass
156    /// `StructName::gemini_schema()`.
157    ///
158    /// # Arguments
159    /// * `schema` - The JSON schema for the response. See [Gemini Schema docs](https://ai.google.dev/api/caching#Schema).
160    pub fn set_json_mode(mut self, schema: Value) -> Self {
161        let config = self.set_generation_config();
162        config["response_mime_type"] = "application/json".into();
163        config["response_schema"] = schema.into();
164        self
165    }
166    pub fn remove_json_mode(mut self) -> Self {
167        if let Some(ref mut generation_config) = self.generation_config {
168            generation_config["response_schema"] = None::<Value>.into();
169            generation_config["response_mime_type"] = None::<Value>.into();
170        }
171        self
172    }
173    /// Sets the tools (functions) available to the model.
174    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
175        self.tools = Some(tools);
176        self
177    }
178    /// Removes all tools.
179    pub fn remove_tools(mut self) -> Self {
180        self.tools = None;
181        self
182    }
183    pub fn set_cached_content(mut self, name: impl Into<String>) -> Self {
184        self.cached_content = Some(name.into());
185        self
186    }
187    pub fn remove_cached_content(mut self) -> Self {
188        self.cached_content = None;
189        self
190    }
191
192    // Cache management methods
193
194    #[cfg(feature = "reqwest")]
195    pub async fn create_cache(
196        &self,
197        cached_content: &CachedContent,
198    ) -> Result<CachedContent, GeminiResponseError> {
199        let req_url = format!(
200            "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
201            self.api_key
202        );
203
204        let response = self
205            .client
206            .post(req_url)
207            .json(cached_content)
208            .send()
209            .await
210            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
211
212        if !response.status().is_success() {
213            let error = response
214                .json()
215                .await
216                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
217            return Err(GeminiResponseError::StatusNotOk(error));
218        }
219
220        let cached_content: CachedContent = response
221            .json()
222            .await
223            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
224        Ok(cached_content)
225    }
226
227    #[cfg(feature = "reqwest")]
228    pub async fn list_caches(&self) -> Result<CachedContentList, GeminiResponseError> {
229        let req_url = format!(
230            "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
231            self.api_key
232        );
233
234        let response = self
235            .client
236            .get(req_url)
237            .send()
238            .await
239            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
240
241        if !response.status().is_success() {
242            let error = response
243                .json()
244                .await
245                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
246            return Err(GeminiResponseError::StatusNotOk(error));
247        }
248
249        let list: CachedContentList = response
250            .json()
251            .await
252            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
253        Ok(list)
254    }
255
256    #[cfg(feature = "reqwest")]
257    pub async fn get_cache(&self, name: &str) -> Result<CachedContent, GeminiResponseError> {
258        let req_url = format!(
259            "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
260            name, self.api_key
261        );
262
263        let response = self
264            .client
265            .get(req_url)
266            .send()
267            .await
268            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
269
270        if !response.status().is_success() {
271            let error = response
272                .json()
273                .await
274                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
275            return Err(GeminiResponseError::StatusNotOk(error));
276        }
277
278        let cached_content: CachedContent = response
279            .json()
280            .await
281            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
282        Ok(cached_content)
283    }
284
285    #[cfg(feature = "reqwest")]
286    pub async fn update_cache(
287        &self,
288        name: &str,
289        update: &CachedContentUpdate,
290    ) -> Result<CachedContent, GeminiResponseError> {
291        let req_url = format!(
292            "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
293            name, self.api_key
294        );
295
296        let response = self
297            .client
298            .patch(req_url)
299            .json(update)
300            .send()
301            .await
302            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
303
304        if !response.status().is_success() {
305            let error = response
306                .json()
307                .await
308                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
309            return Err(GeminiResponseError::StatusNotOk(error));
310        }
311
312        let cached_content: CachedContent = response
313            .json()
314            .await
315            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
316        Ok(cached_content)
317    }
318
319    #[cfg(feature = "reqwest")]
320    pub async fn delete_cache(&self, name: &str) -> Result<(), GeminiResponseError> {
321        let req_url = format!(
322            "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
323            name, self.api_key
324        );
325
326        let response = self
327            .client
328            .delete(req_url)
329            .send()
330            .await
331            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
332
333        if !response.status().is_success() {
334            let error = response
335                .json()
336                .await
337                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
338            return Err(GeminiResponseError::StatusNotOk(error));
339        }
340
341        Ok(())
342    }
343
344    /// Sends a prompt to the model and waits for the full response.
345    ///
346    /// Updates the `session` history with the model's reply.
347    ///
348    /// # Errors
349    /// Returns `GeminiResponseError::NothingToRespond` if the last message in history is from the model.
350    #[cfg(feature = "reqwest")]
351    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
352        if session
353            .get_last_chat()
354            .is_some_and(|chat| *chat.role() == Role::Model)
355        {
356            return Err(GeminiResponseError::NothingToRespond);
357        }
358        let req_url = format!(
359            "{BASE_URL}/{}:generateContent?key={}",
360            self.model, self.api_key
361        );
362
363        let response = self
364            .client
365            .post(req_url)
366            .json(&GeminiRequestBody::new(
367                self.sys_prompt.as_ref(),
368                self.tools.as_deref(),
369                &session.get_history().as_slice(),
370                self.generation_config.as_ref(),
371                self.safety_settings.as_deref(),
372                self.tool_config.as_ref(),
373                self.cached_content.clone(),
374            ))
375            .send()
376            .await
377            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
378
379        if !response.status().is_success() {
380            let error = response
381                .json()
382                .await
383                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
384            return Err(GeminiResponseError::StatusNotOk(error));
385        }
386
387        let reply = GeminiResponse::new(response)
388            .await
389            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
390        session.update(&reply);
391        Ok(reply)
392    }
393    /// # Warning
394    /// You must read the response stream to get reply stored context in `session`.
395    /// `data_extractor` is used to extract data that you get as a stream of futures.
396    /// # Example
397    ///```ignore
398    ///use futures::StreamExt
399    ///let mut response_stream = gemini.ask_as_stream_with_extractor(session,
400    ///     |session, _gemini_response| session.get_last_chat().unwrap().get_text_no_think("\n"))
401    ///    .await.unwrap(); // Use _gemini_response.get_text("") to just get the text received in every chunk
402    ///while let Some(response) = response_stream.next().await {
403    ///    println!("{}", response);
404    ///}
405    ///```
406    #[cfg(feature = "reqwest")]
407    pub async fn ask_as_stream_with_extractor<F, StreamType>(
408        &self,
409        session: Session,
410        data_extractor: F,
411    ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
412    where
413        F: FnMut(&Session, GeminiResponse) -> StreamType,
414    {
415        if session
416            .get_last_chat()
417            .is_some_and(|chat| *chat.role() == Role::Model)
418        {
419            return Err((session, GeminiResponseError::NothingToRespond));
420        }
421        let req_url = format!(
422            "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
423            self.model, self.api_key
424        );
425
426        let request = self
427            .client
428            .post(req_url)
429            .json(&GeminiRequestBody::new(
430                self.sys_prompt.as_ref(),
431                self.tools.as_deref(),
432                session.get_history().as_slice(),
433                self.generation_config.as_ref(),
434                self.safety_settings.as_deref(),
435                self.tool_config.as_ref(),
436                self.cached_content.clone(),
437            ))
438            .send()
439            .await;
440        let response = match request {
441            Ok(response) => response,
442            Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
443        };
444
445        if !response.status().is_success() {
446            let error = match response.json().await {
447                Ok(response) => response,
448                Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
449            };
450            return Err((session, GeminiResponseError::StatusNotOk(error)));
451        }
452
453        Ok(ResponseStream::new(
454            Box::new(response.bytes_stream()),
455            session,
456            data_extractor,
457        ))
458    }
459    /// Sends a prompt to the model and returns a stream of responses.
460    ///
461    /// # Warning
462    /// You must exhaust the response stream to ensure the `session` history is correctly updated.
463    ///
464    /// # Example
465    /// ```no_run
466    /// use futures::StreamExt;
467    /// # async fn run(gemini: gemini_client_api::gemini::ask::Gemini, session: gemini_client_api::gemini::types::sessions::Session) {
468    /// let mut response_stream = gemini.ask_as_stream(session).await.unwrap();
469    ///
470    /// while let Some(response) = response_stream.next().await {
471    ///     if let Ok(response) = response {
472    ///         println!("{}", response.get_chat().get_text_no_think("\n"));
473    ///     }
474    /// }
475    /// # }
476    /// ```
477    #[cfg(feature = "reqwest")]
478    pub async fn ask_as_stream(
479        &self,
480        session: Session,
481    ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
482        self.ask_as_stream_with_extractor(
483            session,
484            (|_, gemini_response| gemini_response)
485                as fn(&Session, GeminiResponse) -> GeminiResponse,
486        )
487        .await
488    }
489}