gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13 client: Client,
14 api_key: String,
15 model: String,
16 sys_prompt: Option<SystemInstruction>,
17 generation_config: Option<Value>,
18 tools: Option<Vec<Tool>>,
19}
20impl Gemini {
21 pub fn new(
23 api_key: impl Into<String>,
24 model: impl Into<String>,
25 sys_prompt: Option<SystemInstruction>,
26 ) -> Self {
27 Self {
28 client: Client::builder()
29 .timeout(Duration::from_secs(60))
30 .build()
31 .unwrap(),
32 api_key: api_key.into(),
33 model: model.into(),
34 sys_prompt,
35 generation_config: None,
36 tools: None,
37 }
38 }
39 pub fn new_with_timeout(
41 api_key: impl Into<String>,
42 model: impl Into<String>,
43 sys_prompt: Option<SystemInstruction>,
44 api_timeout: Duration,
45 ) -> Self {
46 Self {
47 client: Client::builder().timeout(api_timeout).build().unwrap(),
48 api_key: api_key.into(),
49 model: model.into(),
50 sys_prompt,
51 generation_config: None,
52 tools: None,
53 }
54 }
55 pub fn set_generation_config(&mut self, generation_config: Value) -> &mut Self {
57 self.generation_config = Some(generation_config);
58 self
59 }
60 pub fn set_model(&mut self, model: impl Into<String>) -> &mut Self {
61 self.model = model.into();
62 self
63 }
64 pub fn set_api_key(&mut self, api_key: impl Into<String>) -> &mut Self {
65 self.api_key = api_key.into();
66 self
67 }
68 pub fn set_json_mode(&mut self, schema: Value) -> &mut Self {
70 if let None = self.generation_config {
71 self.generation_config = Some(json!({
72 "response_mime_type": "application/json",
73 "response_schema":schema
74 }))
75 } else if let Some(config) = self.generation_config.as_mut() {
76 config["response_mime_type"] = "application/json".into();
77 config["response_schema"] = schema.into();
78 }
79 self
80 }
81 pub fn unset_json_mode(&mut self) -> &mut Self {
82 if let Some(ref mut generation_config) = self.generation_config {
83 generation_config["response_schema"] = None::<Value>.into();
84 generation_config["response_mime_type"] = None::<Value>.into();
85 }
86 self
87 }
88 pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) -> &mut Self {
91 self.tools = tools;
92 self
93 }
94 pub fn unset_code_execution_mode(&mut self) -> &mut Self {
95 self.tools.take();
96 self
97 }
98
99 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
100 let req_url = format!(
101 "{BASE_URL}/{}:generateContent?key={}",
102 self.model, self.api_key
103 );
104
105 let response = self
106 .client
107 .post(req_url)
108 .json(&GeminiRequestBody::new(
109 self.sys_prompt.as_ref(),
110 self.tools.as_deref(),
111 &session.get_history().as_slice(),
112 self.generation_config.as_ref(),
113 ))
114 .send()
115 .await?;
116
117 if !response.status().is_success() {
118 let text = response.text().await?;
119 return Err(text.into());
120 }
121
122 let reply = GeminiResponse::new(response).await?;
123 session.update(&reply);
124 Ok(reply)
125 }
126 pub async fn ask_as_stream<F, StreamType>(
143 &self,
144 session: Session,
145 data_extractor: F,
146 ) -> Result<GeminiResponseStream<F, StreamType>, GeminiResponseError>
147 where
148 F: FnMut(&Session, GeminiResponse) -> StreamType,
149 {
150 let req_url = format!(
151 "{BASE_URL}/{}:streamGenerateContent?key={}",
152 self.model, self.api_key
153 );
154
155 let response = self
156 .client
157 .post(req_url)
158 .json(&GeminiRequestBody::new(
159 self.sys_prompt.as_ref(),
160 self.tools.as_deref(),
161 session.get_history().as_slice(),
162 self.generation_config.as_ref(),
163 ))
164 .send()
165 .await?;
166
167 if !response.status().is_success() {
168 let text = response.text().await?;
169 return Err(text.into());
170 }
171
172 Ok(GeminiResponseStream::new(
173 Box::new(response.bytes_stream()),
174 session,
175 data_extractor,
176 ))
177 }
178}