Skip to main content

gemini_client_api/gemini/
ask.rs

1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13    client: Client,
14    api_key: String,
15    model: String,
16    sys_prompt: Option<SystemInstruction>,
17    generation_config: Option<Value>,
18    safety_settings: Option<Vec<SafetySetting>>,
19    tools: Option<Vec<Tool>>,
20    tool_config: Option<ToolConfig>,
21}
22impl Gemini {
23    /// # Arguments
24    /// `api_key` get one from [Google AI studio](https://aistudio.google.com/app/apikey)
25    /// `model` should be of those mentioned [here](https://ai.google.dev/gemini-api/docs/models#model-variations) in bold black color
26    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
27    pub fn new(
28        api_key: impl Into<String>,
29        model: impl Into<String>,
30        sys_prompt: Option<SystemInstruction>,
31    ) -> Self {
32        Self {
33            client: Client::builder()
34                .timeout(Duration::from_secs(60))
35                .build()
36                .unwrap(),
37            api_key: api_key.into(),
38            model: model.into(),
39            sys_prompt,
40            generation_config: None,
41            safety_settings: None,
42            tools: None,
43            tool_config: None,
44        }
45    }
46    /// `sys_prompt` should follow [gemini doc](https://ai.google.dev/gemini-api/docs/text-generation#image-input)
47    pub fn new_with_timeout(
48        api_key: impl Into<String>,
49        model: impl Into<String>,
50        sys_prompt: Option<SystemInstruction>,
51        api_timeout: Duration,
52    ) -> Self {
53        Self {
54            client: Client::builder().timeout(api_timeout).build().unwrap(),
55            api_key: api_key.into(),
56            model: model.into(),
57            sys_prompt,
58            generation_config: None,
59            safety_settings: None,
60            tools: None,
61            tool_config: None,
62        }
63    }
64    /// The generation config Schema should follow [Gemini docs](https://ai.google.dev/api/generate-content#generationconfig)
65    pub fn set_generation_config(&mut self) -> &mut Value {
66        if let None = self.generation_config {
67            self.generation_config = Some(json!({}));
68        }
69        self.generation_config.as_mut().unwrap()
70    }
71    pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
72        self.tool_config = Some(config);
73        self
74    }
75    pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
76        if let Value::Object(map) = self.set_generation_config() {
77            if let Ok(thinking_value) = serde_json::to_value(config) {
78                map.insert("thinking_config".to_string(), thinking_value);
79            }
80        }
81        self
82    }
83    pub fn set_model(mut self, model: impl Into<String>) -> Self {
84        self.model = model.into();
85        self
86    }
87    pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
88        self.sys_prompt = sys_prompt;
89        self
90    }
91    pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
92        self.safety_settings = settings;
93        self
94    }
95    pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
96        self.api_key = api_key.into();
97        self
98    }
99    /// `schema` should follow [Schema of gemini](https://ai.google.dev/api/caching#Schema)
100    /// To verify your schema visit [here](https://aistudio.google.com/prompts/new_chat):
101    /// - Under tools, toggle on Structured output
102    /// - Click Edit
103    /// - Here you can create schema with `Visual Editor` or `Code Editor` with error detection
104    pub fn set_json_mode(mut self, schema: Value) -> Self {
105        let config = self.set_generation_config();
106        config["response_mime_type"] = "application/json".into();
107        config["response_schema"] = schema.into();
108        self
109    }
110    pub fn unset_json_mode(mut self) -> Self {
111        if let Some(ref mut generation_config) = self.generation_config {
112            generation_config["response_schema"] = None::<Value>.into();
113            generation_config["response_mime_type"] = None::<Value>.into();
114        }
115        self
116    }
117    pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
118        self.tools = Some(tools);
119        self
120    }
121    pub fn unset_tools(mut self) -> Self {
122        self.tools = None;
123        self
124    }
125
126    pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
127        let req_url = format!(
128            "{BASE_URL}/{}:generateContent?key={}",
129            self.model, self.api_key
130        );
131
132        let response = self
133            .client
134            .post(req_url)
135            .json(&GeminiRequestBody::new(
136                self.sys_prompt.as_ref(),
137                self.tools.as_deref(),
138                &session.get_history().as_slice(),
139                self.generation_config.as_ref(),
140                self.safety_settings.as_deref(),
141                self.tool_config.as_ref(),
142            ))
143            .send()
144            .await
145            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
146
147        if !response.status().is_success() {
148            let text = response
149                .text()
150                .await
151                .map_err(|e| GeminiResponseError::ReqwestError(e))?;
152            return Err(GeminiResponseError::StatusNotOk(text));
153        }
154
155        let reply = GeminiResponse::new(response)
156            .await
157            .map_err(|e| GeminiResponseError::ReqwestError(e))?;
158        session.update(&reply);
159        Ok(reply)
160    }
161    /// # Warning
162    /// You must read the response stream to get reply stored context in `session`.
163    /// `data_extractor` is used to extract data that you get as a stream of futures.
164    /// # Example
165    ///```ignore
166    ///use futures::StreamExt
167    ///let mut response_stream = gemini.ask_as_stream_with_extractor(session,
168    ///|session, _gemini_response| session.get_last_message_text("").unwrap())
169    ///.await.unwrap(); // Use _gemini_response.get_text("") to just get the text received in every chunk
170    ///
171    ///while let Some(response) = response_stream.next().await {
172    ///    if let Ok(response) = response {
173    ///        println!("{}", response);
174    ///    }
175    ///}
176    ///```
177    pub async fn ask_as_stream_with_extractor<F, StreamType>(
178        &self,
179        session: Session,
180        data_extractor: F,
181    ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
182    where
183        F: FnMut(&Session, GeminiResponse) -> StreamType,
184    {
185        let req_url = format!(
186            "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
187            self.model, self.api_key
188        );
189
190        let request = self
191            .client
192            .post(req_url)
193            .json(&GeminiRequestBody::new(
194                self.sys_prompt.as_ref(),
195                self.tools.as_deref(),
196                session.get_history().as_slice(),
197                self.generation_config.as_ref(),
198                self.safety_settings.as_deref(),
199                self.tool_config.as_ref(),
200            ))
201            .send()
202            .await;
203        let response = match request {
204            Ok(response) => response,
205            Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
206        };
207
208        if !response.status().is_success() {
209            let text = match response.text().await {
210                Ok(response) => response,
211                Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
212            };
213            return Err((session, GeminiResponseError::StatusNotOk(text.into())));
214        }
215
216        Ok(ResponseStream::new(
217            Box::new(response.bytes_stream()),
218            session,
219            data_extractor,
220        ))
221    }
222    /// # Warning
223    /// You must read the response stream to get reply stored context in `session`.
224    /// # Example
225    ///```ignore
226    ///use futures::StreamExt
227    ///let mut response_stream = gemini.ask_as_stream(session).await.unwrap();
228    ///
229    ///while let Some(response) = response_stream.next().await {
230    ///    if let Ok(response) = response {
231    ///        println!("{}", response.get_text(""));
232    ///    }
233    ///}
234    ///```
235    pub async fn ask_as_stream(
236        &self,
237        session: Session,
238    ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
239        self.ask_as_stream_with_extractor(
240            session,
241            (|_, gemini_response| gemini_response)
242                as fn(&Session, GeminiResponse) -> GeminiResponse,
243        )
244        .await
245    }
246}