gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::caching::{CachedContent, CachedContentList, CachedContentUpdate};
3use super::types::request::*;
4use super::types::response::*;
5use super::types::sessions::Session;
6#[cfg(feature = "reqwest")]
7use reqwest::Client;
8use serde_json::{Value, json};
9use std::time::Duration;
10
11const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
12
13#[derive(Clone, Default, Debug)]
19pub struct Gemini {
20 #[cfg(feature = "reqwest")]
21 client: Client,
22 api_key: String,
23 model: String,
24 sys_prompt: Option<SystemInstruction>,
25 generation_config: Option<Value>,
26 safety_settings: Option<Vec<SafetySetting>>,
27 tools: Option<Vec<Tool>>,
28 tool_config: Option<ToolConfig>,
29 cached_content: Option<String>,
30}
31
32impl Gemini {
33 #[cfg(feature = "reqwest")]
40 pub fn new(
41 api_key: impl Into<String>,
42 model: impl Into<String>,
43 sys_prompt: Option<SystemInstruction>,
44 ) -> Self {
45 Self {
46 client: Client::builder()
47 .timeout(Duration::from_secs(60))
48 .build()
49 .unwrap(),
50 api_key: api_key.into(),
51 model: model.into(),
52 sys_prompt,
53 generation_config: None,
54 safety_settings: None,
55 tools: None,
56 tool_config: None,
57 cached_content: None,
58 }
59 }
60 #[cfg(feature = "reqwest")]
68 pub fn new_with_timeout(
69 api_key: impl Into<String>,
70 model: impl Into<String>,
71 sys_prompt: Option<SystemInstruction>,
72 api_timeout: Duration,
73 ) -> Self {
74 Self {
75 client: Client::builder().timeout(api_timeout).build().unwrap(),
76 api_key: api_key.into(),
77 model: model.into(),
78 sys_prompt,
79 generation_config: None,
80 safety_settings: None,
81 tools: None,
82 tool_config: None,
83 cached_content: None,
84 }
85 }
86 pub fn set_generation_config(&mut self) -> &mut Value {
91 if let None = self.generation_config {
92 self.generation_config = Some(json!({}));
93 }
94 self.generation_config.as_mut().unwrap()
95 }
96 pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
97 self.tool_config = Some(config);
98 self
99 }
100 pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
101 if let Value::Object(map) = self.set_generation_config() {
102 if let Ok(thinking_value) = serde_json::to_value(config) {
103 map.insert("thinking_config".to_string(), thinking_value);
104 }
105 }
106 self
107 }
108 pub fn set_model(mut self, model: impl Into<String>) -> Self {
109 self.model = model.into();
110 self
111 }
112 pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
113 self.sys_prompt = sys_prompt;
114 self
115 }
116 pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
117 self.safety_settings = settings;
118 self
119 }
120 pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
121 self.api_key = api_key.into();
122 self
123 }
124 pub fn set_json_mode(mut self, schema: Value) -> Self {
132 let config = self.set_generation_config();
133 config["response_mime_type"] = "application/json".into();
134 config["response_schema"] = schema.into();
135 self
136 }
137 pub fn remove_json_mode(mut self) -> Self {
138 if let Some(ref mut generation_config) = self.generation_config {
139 generation_config["response_schema"] = None::<Value>.into();
140 generation_config["response_mime_type"] = None::<Value>.into();
141 }
142 self
143 }
144 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
146 self.tools = Some(tools);
147 self
148 }
149 pub fn remove_tools(mut self) -> Self {
151 self.tools = None;
152 self
153 }
154 pub fn set_cached_content(mut self, name: impl Into<String>) -> Self {
155 self.cached_content = Some(name.into());
156 self
157 }
158 pub fn remove_cached_content(mut self) -> Self {
159 self.cached_content = None;
160 self
161 }
162
163 #[cfg(feature = "reqwest")]
166 pub async fn create_cache(
167 &self,
168 cached_content: &CachedContent,
169 ) -> Result<CachedContent, GeminiResponseError> {
170 let req_url = format!(
171 "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
172 self.api_key
173 );
174
175 let response = self
176 .client
177 .post(req_url)
178 .json(cached_content)
179 .send()
180 .await
181 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
182
183 if !response.status().is_success() {
184 let text = response
185 .text()
186 .await
187 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
188 return Err(GeminiResponseError::StatusNotOk(text));
189 }
190
191 let cached_content: CachedContent = response
192 .json()
193 .await
194 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
195 Ok(cached_content)
196 }
197
198 #[cfg(feature = "reqwest")]
199 pub async fn list_caches(&self) -> Result<CachedContentList, GeminiResponseError> {
200 let req_url = format!(
201 "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
202 self.api_key
203 );
204
205 let response = self
206 .client
207 .get(req_url)
208 .send()
209 .await
210 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
211
212 if !response.status().is_success() {
213 let text = response
214 .text()
215 .await
216 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
217 return Err(GeminiResponseError::StatusNotOk(text));
218 }
219
220 let list: CachedContentList = response
221 .json()
222 .await
223 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
224 Ok(list)
225 }
226
227 #[cfg(feature = "reqwest")]
228 pub async fn get_cache(&self, name: &str) -> Result<CachedContent, GeminiResponseError> {
229 let req_url = format!(
230 "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
231 name, self.api_key
232 );
233
234 let response = self
235 .client
236 .get(req_url)
237 .send()
238 .await
239 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
240
241 if !response.status().is_success() {
242 let text = response
243 .text()
244 .await
245 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
246 return Err(GeminiResponseError::StatusNotOk(text));
247 }
248
249 let cached_content: CachedContent = response
250 .json()
251 .await
252 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
253 Ok(cached_content)
254 }
255
256 #[cfg(feature = "reqwest")]
257 pub async fn update_cache(
258 &self,
259 name: &str,
260 update: &CachedContentUpdate,
261 ) -> Result<CachedContent, GeminiResponseError> {
262 let req_url = format!(
263 "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
264 name, self.api_key
265 );
266
267 let response = self
268 .client
269 .patch(req_url)
270 .json(update)
271 .send()
272 .await
273 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
274
275 if !response.status().is_success() {
276 let text = response
277 .text()
278 .await
279 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
280 return Err(GeminiResponseError::StatusNotOk(text));
281 }
282
283 let cached_content: CachedContent = response
284 .json()
285 .await
286 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
287 Ok(cached_content)
288 }
289
290 #[cfg(feature = "reqwest")]
291 pub async fn delete_cache(&self, name: &str) -> Result<(), GeminiResponseError> {
292 let req_url = format!(
293 "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
294 name, self.api_key
295 );
296
297 let response = self
298 .client
299 .delete(req_url)
300 .send()
301 .await
302 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
303
304 if !response.status().is_success() {
305 let text = response
306 .text()
307 .await
308 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
309 return Err(GeminiResponseError::StatusNotOk(text));
310 }
311
312 Ok(())
313 }
314
315 #[cfg(feature = "reqwest")]
322 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
323 if session
324 .get_last_chat()
325 .is_some_and(|chat| *chat.role() == Role::Model)
326 {
327 return Err(GeminiResponseError::NothingToRespond);
328 }
329 let req_url = format!(
330 "{BASE_URL}/{}:generateContent?key={}",
331 self.model, self.api_key
332 );
333
334 let response = self
335 .client
336 .post(req_url)
337 .json(&GeminiRequestBody::new(
338 self.sys_prompt.as_ref(),
339 self.tools.as_deref(),
340 &session.get_history().as_slice(),
341 self.generation_config.as_ref(),
342 self.safety_settings.as_deref(),
343 self.tool_config.as_ref(),
344 self.cached_content.clone(),
345 ))
346 .send()
347 .await
348 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
349
350 if !response.status().is_success() {
351 let text = response
352 .text()
353 .await
354 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
355 return Err(GeminiResponseError::StatusNotOk(text));
356 }
357
358 let reply = GeminiResponse::new(response)
359 .await
360 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
361 session.update(&reply);
362 Ok(reply)
363 }
364 #[cfg(feature = "reqwest")]
378 pub async fn ask_as_stream_with_extractor<F, StreamType>(
379 &self,
380 session: Session,
381 data_extractor: F,
382 ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
383 where
384 F: FnMut(&Session, GeminiResponse) -> StreamType,
385 {
386 if session
387 .get_last_chat()
388 .is_some_and(|chat| *chat.role() == Role::Model)
389 {
390 return Err((session, GeminiResponseError::NothingToRespond));
391 }
392 let req_url = format!(
393 "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
394 self.model, self.api_key
395 );
396
397 let request = self
398 .client
399 .post(req_url)
400 .json(&GeminiRequestBody::new(
401 self.sys_prompt.as_ref(),
402 self.tools.as_deref(),
403 session.get_history().as_slice(),
404 self.generation_config.as_ref(),
405 self.safety_settings.as_deref(),
406 self.tool_config.as_ref(),
407 self.cached_content.clone(),
408 ))
409 .send()
410 .await;
411 let response = match request {
412 Ok(response) => response,
413 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
414 };
415
416 if !response.status().is_success() {
417 let text = match response.text().await {
418 Ok(response) => response,
419 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
420 };
421 return Err((session, GeminiResponseError::StatusNotOk(text.into())));
422 }
423
424 Ok(ResponseStream::new(
425 Box::new(response.bytes_stream()),
426 session,
427 data_extractor,
428 ))
429 }
430 #[cfg(feature = "reqwest")]
449 pub async fn ask_as_stream(
450 &self,
451 session: Session,
452 ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
453 self.ask_as_stream_with_extractor(
454 session,
455 (|_, gemini_response| gemini_response)
456 as fn(&Session, GeminiResponse) -> GeminiResponse,
457 )
458 .await
459 }
460}