gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13    client: Client,
14    api_key: String,
15    model: String,
16    sys_prompt: Option<SystemInstruction>,
17    generation_config: Option<Value>,
18    tools: Option<Vec<Tool>>,
19}
20impl Gemini {
21    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
22    pub fn new(
23        api_key: impl Into<String>,
24        model: impl Into<String>,
25        sys_prompt: Option<SystemInstruction>,
26    ) -> Self {
27        Self {
28            client: Client::builder()
29                .timeout(Duration::from_secs(60))
30                .build()
31                .unwrap(),
32            api_key: api_key.into(),
33            model: model.into(),
34            sys_prompt,
35            generation_config: None,
36            tools: None,
37        }
38    }
39    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
40    pub fn new_with_timeout(
41        api_key: impl Into<String>,
42        model: impl Into<String>,
43        sys_prompt: Option<SystemInstruction>,
44        api_timeout: Duration,
45    ) -> Self {
46        Self {
47            client: Client::builder().timeout(api_timeout).build().unwrap(),
48            api_key: api_key.into(),
49            model: model.into(),
50            sys_prompt,
51            generation_config: None,
52            tools: None,
53        }
54    }
55    /// The generation config Schema should follow [Gemini docs](https://ai.google.dev/gemini-api/docs/text-generation#configuration-parameters)
56    pub fn set_generation_config(&mut self, generation_config: Value) -> &mut Self {
57        self.generation_config = Some(generation_config);
58        self
59    }
60    pub fn set_model(&mut self, model: impl Into<String>) -> &mut Self {
61        self.model = model.into();
62        self
63    }
64    pub fn set_api_key(&mut self, api_key: impl Into<String>) -> &mut Self {
65        self.api_key = api_key.into();
66        self
67    }
68    /// `schema` should follow [Schema of gemini](https://ai.google.dev/api/caching#Schema)
69    pub fn set_json_mode(&mut self, schema: Value) -> &mut Self {
70        if let None = self.generation_config {
71            self.generation_config = Some(json!({
72                "response_mime_type": "application/json",
73                "response_schema":schema
74            }))
75        } else if let Some(config) = self.generation_config.as_mut() {
76            config["response_mime_type"] = "application/json".into();
77            config["response_schema"] = schema.into();
78        }
79        self
80    }
81    pub fn unset_json_mode(&mut self) -> &mut Self {
82        if let Some(ref mut generation_config) = self.generation_config {
83            generation_config["response_schema"] = None::<Value>.into();
84            generation_config["response_mime_type"] = None::<Value>.into();
85        }
86        self
87    }
88    ///- `tools` can be None to unset tools from using.  
89    ///- Or Vec tools to be allowed
90    pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) -> &mut Self {
91        self.tools = tools;
92        self
93    }
94    pub fn unset_code_execution_mode(&mut self) -> &mut Self {
95        self.tools.take();
96        self
97    }
98
99    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
100        let req_url = format!(
101            "{BASE_URL}/{}:generateContent?key={}",
102            self.model, self.api_key
103        );
104
105        let response = self
106            .client
107            .post(req_url)
108            .json(&GeminiRequestBody::new(
109                self.sys_prompt.as_ref(),
110                self.tools.as_deref(),
111                &session.get_history().as_slice(),
112                self.generation_config.as_ref(),
113            ))
114            .send()
115            .await?;
116
117        if !response.status().is_success() {
118            let text = response.text().await?;
119            return Err(text.into());
120        }
121
122        let reply = GeminiResponse::new(response).await?;
123        session.update(&reply);
124        Ok(reply)
125    }
126    /// # Warning
127    /// You must read the response stream to get reply stored context in sessions.
128    /// `data_extractor` is used to extract data that you get as a stream of futures.
129    /// # Example
130    ///```ignore
131    ///use futures::StreamExt
132    ///let mut response_stream = gemini.ask_as_stream(session,
133    ///|session, _gemini_response| session.get_last_message_text("").unwrap())
134    ///.await.unwrap(); // Use _gemini_response.get_text("") for text received in every chunk
135    ///
136    ///while let Some(response) = response_stream.next().await {
137    ///    if let Ok(response) = response {
138    ///        println!("{}", response.get_text(""));
139    ///    }
140    ///}
141    ///```
142    pub async fn ask_as_stream<F, StreamType>(
143        &self,
144        session: Session,
145        data_extractor: F,
146    ) -> Result<GeminiResponseStream<F, StreamType>, GeminiResponseError>
147    where
148        F: FnMut(&Session, GeminiResponse) -> StreamType,
149    {
150        let req_url = format!(
151            "{BASE_URL}/{}:streamGenerateContent?key={}",
152            self.model, self.api_key
153        );
154
155        let response = self
156            .client
157            .post(req_url)
158            .json(&GeminiRequestBody::new(
159                self.sys_prompt.as_ref(),
160                self.tools.as_deref(),
161                session.get_history().as_slice(),
162                self.generation_config.as_ref(),
163            ))
164            .send()
165            .await?;
166
167        if !response.status().is_success() {
168            let text = response.text().await?;
169            return Err(text.into());
170        }
171
172        Ok(GeminiResponseStream::new(
173            Box::new(response.bytes_stream()),
174            session,
175            data_extractor,
176        ))
177    }
178}