gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::caching::{CachedContent, CachedContentList, CachedContentUpdate};
3use super::types::request::*;
4use super::types::response::*;
5use super::types::sessions::Session;
6#[cfg(feature = "reqwest")]
7use reqwest::Client;
8use serde_json::{Value, json};
9use std::time::Duration;
10
11const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
12
13#[derive(Clone, Default, Debug)]
19pub struct Gemini {
20 #[cfg(feature = "reqwest")]
21 client: Client,
22 api_key: String,
23 model: String,
24 sys_prompt: Option<SystemInstruction>,
25 generation_config: Option<Value>,
26 safety_settings: Option<Vec<SafetySetting>>,
27 tools: Option<Vec<Tool>>,
28 tool_config: Option<ToolConfig>,
29 cached_content: Option<String>,
30}
31
32impl Gemini {
33 #[cfg(feature = "reqwest")]
40 pub fn new(
41 api_key: impl Into<String>,
42 model: impl Into<String>,
43 sys_prompt: Option<SystemInstruction>,
44 ) -> Self {
45 Self {
46 client: Client::builder()
47 .timeout(Duration::from_secs(60))
48 .build()
49 .unwrap(),
50 api_key: api_key.into(),
51 model: model.into(),
52 sys_prompt,
53 generation_config: None,
54 safety_settings: None,
55 tools: None,
56 tool_config: None,
57 cached_content: None,
58 }
59 }
60 #[deprecated]
68 #[cfg(feature = "reqwest")]
69 pub fn new_with_timeout(
70 api_key: impl Into<String>,
71 model: impl Into<String>,
72 sys_prompt: Option<SystemInstruction>,
73 api_timeout: Duration,
74 ) -> Self {
75 Self {
76 client: Client::builder().timeout(api_timeout).build().unwrap(),
77 api_key: api_key.into(),
78 model: model.into(),
79 sys_prompt,
80 generation_config: None,
81 safety_settings: None,
82 tools: None,
83 tool_config: None,
84 cached_content: None,
85 }
86 }
87 #[cfg(feature = "reqwest")]
95 pub fn new_with_client(
96 api_key: impl Into<String>,
97 model: impl Into<String>,
98 sys_prompt: Option<SystemInstruction>,
99 client: Client,
100 ) -> Self {
101 Self {
102 client,
103 api_key: api_key.into(),
104 model: model.into(),
105 sys_prompt,
106 generation_config: None,
107 safety_settings: None,
108 tools: None,
109 tool_config: None,
110 cached_content: None,
111 }
112 }
113 pub fn set_generation_config(&mut self) -> &mut Value {
118 if let None = self.generation_config {
119 self.generation_config = Some(json!({}));
120 }
121 self.generation_config.as_mut().unwrap()
122 }
123 pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
124 self.tool_config = Some(config);
125 self
126 }
127 pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
128 if let Value::Object(map) = self.set_generation_config() {
129 if let Ok(thinking_value) = serde_json::to_value(config) {
130 map.insert("thinking_config".to_string(), thinking_value);
131 }
132 }
133 self
134 }
135 pub fn set_model(mut self, model: impl Into<String>) -> Self {
136 self.model = model.into();
137 self
138 }
139 pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
142 self.sys_prompt = sys_prompt;
143 self
144 }
145 pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
146 self.safety_settings = settings;
147 self
148 }
149 pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
150 self.api_key = api_key.into();
151 self
152 }
153 pub fn set_json_mode(mut self, schema: Value) -> Self {
161 let config = self.set_generation_config();
162 config["response_mime_type"] = "application/json".into();
163 config["response_schema"] = schema.into();
164 self
165 }
166 pub fn remove_json_mode(mut self) -> Self {
167 if let Some(ref mut generation_config) = self.generation_config {
168 generation_config["response_schema"] = None::<Value>.into();
169 generation_config["response_mime_type"] = None::<Value>.into();
170 }
171 self
172 }
173 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
175 self.tools = Some(tools);
176 self
177 }
178 pub fn remove_tools(mut self) -> Self {
180 self.tools = None;
181 self
182 }
183 pub fn set_cached_content(mut self, name: impl Into<String>) -> Self {
184 self.cached_content = Some(name.into());
185 self
186 }
187 pub fn remove_cached_content(mut self) -> Self {
188 self.cached_content = None;
189 self
190 }
191
192 #[cfg(feature = "reqwest")]
195 pub async fn create_cache(
196 &self,
197 cached_content: &CachedContent,
198 ) -> Result<CachedContent, GeminiResponseError> {
199 let req_url = format!(
200 "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
201 self.api_key
202 );
203
204 let response = self
205 .client
206 .post(req_url)
207 .json(cached_content)
208 .send()
209 .await
210 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
211
212 if !response.status().is_success() {
213 let error = response
214 .json()
215 .await
216 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
217 return Err(GeminiResponseError::StatusNotOk(error));
218 }
219
220 let cached_content: CachedContent = response
221 .json()
222 .await
223 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
224 Ok(cached_content)
225 }
226
227 #[cfg(feature = "reqwest")]
228 pub async fn list_caches(&self) -> Result<CachedContentList, GeminiResponseError> {
229 let req_url = format!(
230 "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
231 self.api_key
232 );
233
234 let response = self
235 .client
236 .get(req_url)
237 .send()
238 .await
239 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
240
241 if !response.status().is_success() {
242 let error = response
243 .json()
244 .await
245 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
246 return Err(GeminiResponseError::StatusNotOk(error));
247 }
248
249 let list: CachedContentList = response
250 .json()
251 .await
252 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
253 Ok(list)
254 }
255
256 #[cfg(feature = "reqwest")]
257 pub async fn get_cache(&self, name: &str) -> Result<CachedContent, GeminiResponseError> {
258 let req_url = format!(
259 "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
260 name, self.api_key
261 );
262
263 let response = self
264 .client
265 .get(req_url)
266 .send()
267 .await
268 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
269
270 if !response.status().is_success() {
271 let error = response
272 .json()
273 .await
274 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
275 return Err(GeminiResponseError::StatusNotOk(error));
276 }
277
278 let cached_content: CachedContent = response
279 .json()
280 .await
281 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
282 Ok(cached_content)
283 }
284
285 #[cfg(feature = "reqwest")]
286 pub async fn update_cache(
287 &self,
288 name: &str,
289 update: &CachedContentUpdate,
290 ) -> Result<CachedContent, GeminiResponseError> {
291 let req_url = format!(
292 "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
293 name, self.api_key
294 );
295
296 let response = self
297 .client
298 .patch(req_url)
299 .json(update)
300 .send()
301 .await
302 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
303
304 if !response.status().is_success() {
305 let error = response
306 .json()
307 .await
308 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
309 return Err(GeminiResponseError::StatusNotOk(error));
310 }
311
312 let cached_content: CachedContent = response
313 .json()
314 .await
315 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
316 Ok(cached_content)
317 }
318
319 #[cfg(feature = "reqwest")]
320 pub async fn delete_cache(&self, name: &str) -> Result<(), GeminiResponseError> {
321 let req_url = format!(
322 "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
323 name, self.api_key
324 );
325
326 let response = self
327 .client
328 .delete(req_url)
329 .send()
330 .await
331 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
332
333 if !response.status().is_success() {
334 let error = response
335 .json()
336 .await
337 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
338 return Err(GeminiResponseError::StatusNotOk(error));
339 }
340
341 Ok(())
342 }
343
344 #[cfg(feature = "reqwest")]
351 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
352 if session
353 .get_last_chat()
354 .is_some_and(|chat| *chat.role() == Role::Model)
355 {
356 return Err(GeminiResponseError::NothingToRespond);
357 }
358 let req_url = format!(
359 "{BASE_URL}/{}:generateContent?key={}",
360 self.model, self.api_key
361 );
362
363 let response = self
364 .client
365 .post(req_url)
366 .json(&GeminiRequestBody::new(
367 self.sys_prompt.as_ref(),
368 self.tools.as_deref(),
369 &session.get_history().as_slice(),
370 self.generation_config.as_ref(),
371 self.safety_settings.as_deref(),
372 self.tool_config.as_ref(),
373 self.cached_content.clone(),
374 ))
375 .send()
376 .await
377 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
378
379 if !response.status().is_success() {
380 let error = response
381 .json()
382 .await
383 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
384 return Err(GeminiResponseError::StatusNotOk(error));
385 }
386
387 let reply = GeminiResponse::new(response)
388 .await
389 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
390 session.update(&reply);
391 Ok(reply)
392 }
393 #[cfg(feature = "reqwest")]
407 pub async fn ask_as_stream_with_extractor<F, StreamType>(
408 &self,
409 session: Session,
410 data_extractor: F,
411 ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
412 where
413 F: FnMut(&Session, GeminiResponse) -> StreamType,
414 {
415 if session
416 .get_last_chat()
417 .is_some_and(|chat| *chat.role() == Role::Model)
418 {
419 return Err((session, GeminiResponseError::NothingToRespond));
420 }
421 let req_url = format!(
422 "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
423 self.model, self.api_key
424 );
425
426 let request = self
427 .client
428 .post(req_url)
429 .json(&GeminiRequestBody::new(
430 self.sys_prompt.as_ref(),
431 self.tools.as_deref(),
432 session.get_history().as_slice(),
433 self.generation_config.as_ref(),
434 self.safety_settings.as_deref(),
435 self.tool_config.as_ref(),
436 self.cached_content.clone(),
437 ))
438 .send()
439 .await;
440 let response = match request {
441 Ok(response) => response,
442 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
443 };
444
445 if !response.status().is_success() {
446 let error = match response.json().await {
447 Ok(response) => response,
448 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
449 };
450 return Err((session, GeminiResponseError::StatusNotOk(error)));
451 }
452
453 Ok(ResponseStream::new(
454 Box::new(response.bytes_stream()),
455 session,
456 data_extractor,
457 ))
458 }
459 #[cfg(feature = "reqwest")]
478 pub async fn ask_as_stream(
479 &self,
480 session: Session,
481 ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
482 self.ask_as_stream_with_extractor(
483 session,
484 (|_, gemini_response| gemini_response)
485 as fn(&Session, GeminiResponse) -> GeminiResponse,
486 )
487 .await
488 }
489}