gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::caching::{CachedContent, CachedContentList, CachedContentUpdate};
3use super::types::request::*;
4use super::types::response::*;
5use super::types::sessions::Session;
6use reqwest::Client;
7use serde_json::{Value, json};
8use std::time::Duration;
9
10const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
11
12#[derive(Clone, Default, Debug)]
14pub struct Gemini {
15 client: Client,
16 api_key: String,
17 model: String,
18 sys_prompt: Option<SystemInstruction>,
19 generation_config: Option<Value>,
20 safety_settings: Option<Vec<SafetySetting>>,
21 tools: Option<Vec<Tool>>,
22 tool_config: Option<ToolConfig>,
23 cached_content: Option<String>,
24}
25
26impl Gemini {
27 pub fn new(
34 api_key: impl Into<String>,
35 model: impl Into<String>,
36 sys_prompt: Option<SystemInstruction>,
37 ) -> Self {
38 Self {
39 client: Client::builder()
40 .timeout(Duration::from_secs(60))
41 .build()
42 .unwrap(),
43 api_key: api_key.into(),
44 model: model.into(),
45 sys_prompt,
46 generation_config: None,
47 safety_settings: None,
48 tools: None,
49 tool_config: None,
50 cached_content: None,
51 }
52 }
53 #[deprecated]
61 pub fn new_with_timeout(
62 api_key: impl Into<String>,
63 model: impl Into<String>,
64 sys_prompt: Option<SystemInstruction>,
65 api_timeout: Duration,
66 ) -> Self {
67 Self {
68 client: Client::builder().timeout(api_timeout).build().unwrap(),
69 api_key: api_key.into(),
70 model: model.into(),
71 sys_prompt,
72 generation_config: None,
73 safety_settings: None,
74 tools: None,
75 tool_config: None,
76 cached_content: None,
77 }
78 }
79 pub fn new_with_client(
87 api_key: impl Into<String>,
88 model: impl Into<String>,
89 sys_prompt: Option<SystemInstruction>,
90 client: Client,
91 ) -> Self {
92 Self {
93 client,
94 api_key: api_key.into(),
95 model: model.into(),
96 sys_prompt,
97 generation_config: None,
98 safety_settings: None,
99 tools: None,
100 tool_config: None,
101 cached_content: None,
102 }
103 }
104 pub fn set_generation_config(&mut self) -> &mut Value {
109 if let None = self.generation_config {
110 self.generation_config = Some(json!({}));
111 }
112 self.generation_config.as_mut().unwrap()
113 }
114 pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
115 self.tool_config = Some(config);
116 self
117 }
118 pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
119 if let Value::Object(map) = self.set_generation_config() {
120 if let Ok(thinking_value) = serde_json::to_value(config) {
121 map.insert("thinking_config".to_string(), thinking_value);
122 }
123 }
124 self
125 }
126 pub fn set_model(mut self, model: impl Into<String>) -> Self {
127 self.model = model.into();
128 self
129 }
130 pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
133 self.sys_prompt = sys_prompt;
134 self
135 }
136 pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
137 self.safety_settings = settings;
138 self
139 }
140 pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
141 self.api_key = api_key.into();
142 self
143 }
144 pub fn set_json_mode(mut self, schema: Value) -> Self {
152 let config = self.set_generation_config();
153 config["response_mime_type"] = "application/json".into();
154 config["response_schema"] = schema.into();
155 self
156 }
157 pub fn remove_json_mode(mut self) -> Self {
158 if let Some(ref mut generation_config) = self.generation_config {
159 generation_config["response_schema"] = None::<Value>.into();
160 generation_config["response_mime_type"] = None::<Value>.into();
161 }
162 self
163 }
164 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
166 self.tools = Some(tools);
167 self
168 }
169 pub fn remove_tools(mut self) -> Self {
171 self.tools = None;
172 self
173 }
174 pub fn set_cached_content(mut self, name: impl Into<String>) -> Self {
175 self.cached_content = Some(name.into());
176 self
177 }
178 pub fn remove_cached_content(mut self) -> Self {
179 self.cached_content = None;
180 self
181 }
182
183 pub async fn create_cache(
186 &self,
187 cached_content: &CachedContent,
188 ) -> Result<CachedContent, GeminiResponseError> {
189 let req_url = format!(
190 "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
191 self.api_key
192 );
193
194 let response = self
195 .client
196 .post(req_url)
197 .json(cached_content)
198 .send()
199 .await
200 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
201
202 if !response.status().is_success() {
203 let error = response
204 .json()
205 .await
206 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
207 return Err(GeminiResponseError::StatusNotOk(error));
208 }
209
210 let cached_content: CachedContent = response
211 .json()
212 .await
213 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
214 Ok(cached_content)
215 }
216
217 pub async fn list_caches(&self) -> Result<CachedContentList, GeminiResponseError> {
218 let req_url = format!(
219 "https://generativelanguage.googleapis.com/v1beta/cachedContents?key={}",
220 self.api_key
221 );
222
223 let response = self
224 .client
225 .get(req_url)
226 .send()
227 .await
228 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
229
230 if !response.status().is_success() {
231 let error = response
232 .json()
233 .await
234 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
235 return Err(GeminiResponseError::StatusNotOk(error));
236 }
237
238 let list: CachedContentList = response
239 .json()
240 .await
241 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
242 Ok(list)
243 }
244
245 pub async fn get_cache(&self, name: &str) -> Result<CachedContent, GeminiResponseError> {
246 let req_url = format!(
247 "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
248 name, self.api_key
249 );
250
251 let response = self
252 .client
253 .get(req_url)
254 .send()
255 .await
256 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
257
258 if !response.status().is_success() {
259 let error = response
260 .json()
261 .await
262 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
263 return Err(GeminiResponseError::StatusNotOk(error));
264 }
265
266 let cached_content: CachedContent = response
267 .json()
268 .await
269 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
270 Ok(cached_content)
271 }
272
273 pub async fn update_cache(
274 &self,
275 name: &str,
276 update: &CachedContentUpdate,
277 ) -> Result<CachedContent, GeminiResponseError> {
278 let req_url = format!(
279 "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
280 name, self.api_key
281 );
282
283 let response = self
284 .client
285 .patch(req_url)
286 .json(update)
287 .send()
288 .await
289 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
290
291 if !response.status().is_success() {
292 let error = response
293 .json()
294 .await
295 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
296 return Err(GeminiResponseError::StatusNotOk(error));
297 }
298
299 let cached_content: CachedContent = response
300 .json()
301 .await
302 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
303 Ok(cached_content)
304 }
305
306 pub async fn delete_cache(&self, name: &str) -> Result<(), GeminiResponseError> {
307 let req_url = format!(
308 "https://generativelanguage.googleapis.com/v1beta/{}?key={}",
309 name, self.api_key
310 );
311
312 let response = self
313 .client
314 .delete(req_url)
315 .send()
316 .await
317 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
318
319 if !response.status().is_success() {
320 let error = response
321 .json()
322 .await
323 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
324 return Err(GeminiResponseError::StatusNotOk(error));
325 }
326
327 Ok(())
328 }
329
330 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
337 if session
338 .get_last_chat()
339 .is_some_and(|chat| *chat.role() == Role::Model)
340 {
341 return Err(GeminiResponseError::NothingToRespond);
342 }
343 let req_url = format!(
344 "{BASE_URL}/{}:generateContent?key={}",
345 self.model, self.api_key
346 );
347
348 let response = self
349 .client
350 .post(req_url)
351 .json(&GeminiRequestBody::new(
352 self.sys_prompt.as_ref(),
353 self.tools.as_deref(),
354 &session.get_history().as_slice(),
355 self.generation_config.as_ref(),
356 self.safety_settings.as_deref(),
357 self.tool_config.as_ref(),
358 self.cached_content.clone(),
359 ))
360 .send()
361 .await
362 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
363
364 if !response.status().is_success() {
365 let error = response
366 .json()
367 .await
368 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
369 return Err(GeminiResponseError::StatusNotOk(error));
370 }
371
372 let reply = GeminiResponse::new(response)
373 .await
374 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
375 session.update(&reply);
376 Ok(reply)
377 }
378 pub async fn ask_as_stream_with_extractor<F, StreamType>(
392 &self,
393 session: Session,
394 data_extractor: F,
395 ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
396 where
397 F: FnMut(&Session, GeminiResponse) -> StreamType,
398 {
399 if session
400 .get_last_chat()
401 .is_some_and(|chat| *chat.role() == Role::Model)
402 {
403 return Err((session, GeminiResponseError::NothingToRespond));
404 }
405 let req_url = format!(
406 "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
407 self.model, self.api_key
408 );
409
410 let request = self
411 .client
412 .post(req_url)
413 .json(&GeminiRequestBody::new(
414 self.sys_prompt.as_ref(),
415 self.tools.as_deref(),
416 session.get_history().as_slice(),
417 self.generation_config.as_ref(),
418 self.safety_settings.as_deref(),
419 self.tool_config.as_ref(),
420 self.cached_content.clone(),
421 ))
422 .send()
423 .await;
424 let response = match request {
425 Ok(response) => response,
426 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
427 };
428
429 if !response.status().is_success() {
430 let error = match response.json().await {
431 Ok(response) => response,
432 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
433 };
434 return Err((session, GeminiResponseError::StatusNotOk(error)));
435 }
436
437 Ok(ResponseStream::new(
438 Box::new(response.bytes_stream()),
439 session,
440 data_extractor,
441 ))
442 }
443 pub async fn ask_as_stream(
462 &self,
463 session: Session,
464 ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
465 self.ask_as_stream_with_extractor(
466 session,
467 (|_, gemini_response| gemini_response)
468 as fn(&Session, GeminiResponse) -> GeminiResponse,
469 )
470 .await
471 }
472}