gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5#[cfg(feature = "reqwest")]
6use reqwest::Client;
7use serde_json::{Value, json};
8use std::time::Duration;
9
10const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
11
12#[derive(Clone, Default, Debug)]
18pub struct Gemini {
19 #[cfg(feature = "reqwest")]
20 client: Client,
21 api_key: String,
22 model: String,
23 sys_prompt: Option<SystemInstruction>,
24 generation_config: Option<Value>,
25 safety_settings: Option<Vec<SafetySetting>>,
26 tools: Option<Vec<Tool>>,
27 tool_config: Option<ToolConfig>,
28}
29
30impl Gemini {
31 #[cfg(feature = "reqwest")]
38 pub fn new(
39 api_key: impl Into<String>,
40 model: impl Into<String>,
41 sys_prompt: Option<SystemInstruction>,
42 ) -> Self {
43 Self {
44 client: Client::builder()
45 .timeout(Duration::from_secs(60))
46 .build()
47 .unwrap(),
48 api_key: api_key.into(),
49 model: model.into(),
50 sys_prompt,
51 generation_config: None,
52 safety_settings: None,
53 tools: None,
54 tool_config: None,
55 }
56 }
57 #[cfg(feature = "reqwest")]
65 pub fn new_with_timeout(
66 api_key: impl Into<String>,
67 model: impl Into<String>,
68 sys_prompt: Option<SystemInstruction>,
69 api_timeout: Duration,
70 ) -> Self {
71 Self {
72 client: Client::builder().timeout(api_timeout).build().unwrap(),
73 api_key: api_key.into(),
74 model: model.into(),
75 sys_prompt,
76 generation_config: None,
77 safety_settings: None,
78 tools: None,
79 tool_config: None,
80 }
81 }
82 pub fn set_generation_config(&mut self) -> &mut Value {
87 if let None = self.generation_config {
88 self.generation_config = Some(json!({}));
89 }
90 self.generation_config.as_mut().unwrap()
91 }
92 pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
93 self.tool_config = Some(config);
94 self
95 }
96 pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
97 if let Value::Object(map) = self.set_generation_config() {
98 if let Ok(thinking_value) = serde_json::to_value(config) {
99 map.insert("thinking_config".to_string(), thinking_value);
100 }
101 }
102 self
103 }
104 pub fn set_model(mut self, model: impl Into<String>) -> Self {
105 self.model = model.into();
106 self
107 }
108 pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
109 self.sys_prompt = sys_prompt;
110 self
111 }
112 pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
113 self.safety_settings = settings;
114 self
115 }
116 pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
117 self.api_key = api_key.into();
118 self
119 }
120 pub fn set_json_mode(mut self, schema: Value) -> Self {
128 let config = self.set_generation_config();
129 config["response_mime_type"] = "application/json".into();
130 config["response_schema"] = schema.into();
131 self
132 }
133 pub fn unset_json_mode(mut self) -> Self {
134 if let Some(ref mut generation_config) = self.generation_config {
135 generation_config["response_schema"] = None::<Value>.into();
136 generation_config["response_mime_type"] = None::<Value>.into();
137 }
138 self
139 }
140 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
142 self.tools = Some(tools);
143 self
144 }
145 pub fn unset_tools(mut self) -> Self {
147 self.tools = None;
148 self
149 }
150
151 #[cfg(feature = "reqwest")]
158 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
159 if session
160 .get_last_chat()
161 .is_some_and(|chat| *chat.role() == Role::Model)
162 {
163 return Err(GeminiResponseError::NothingToRespond);
164 }
165 let req_url = format!(
166 "{BASE_URL}/{}:generateContent?key={}",
167 self.model, self.api_key
168 );
169
170 let response = self
171 .client
172 .post(req_url)
173 .json(&GeminiRequestBody::new(
174 self.sys_prompt.as_ref(),
175 self.tools.as_deref(),
176 &session.get_history().as_slice(),
177 self.generation_config.as_ref(),
178 self.safety_settings.as_deref(),
179 self.tool_config.as_ref(),
180 ))
181 .send()
182 .await
183 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
184
185 if !response.status().is_success() {
186 let text = response
187 .text()
188 .await
189 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
190 return Err(GeminiResponseError::StatusNotOk(text));
191 }
192
193 let reply = GeminiResponse::new(response)
194 .await
195 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
196 session.update(&reply);
197 Ok(reply)
198 }
199 #[cfg(feature = "reqwest")]
216 pub async fn ask_as_stream_with_extractor<F, StreamType>(
217 &self,
218 session: Session,
219 data_extractor: F,
220 ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
221 where
222 F: FnMut(&Session, GeminiResponse) -> StreamType,
223 {
224 if session
225 .get_last_chat()
226 .is_some_and(|chat| *chat.role() == Role::Model)
227 {
228 return Err((session, GeminiResponseError::NothingToRespond));
229 }
230 let req_url = format!(
231 "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
232 self.model, self.api_key
233 );
234
235 let request = self
236 .client
237 .post(req_url)
238 .json(&GeminiRequestBody::new(
239 self.sys_prompt.as_ref(),
240 self.tools.as_deref(),
241 session.get_history().as_slice(),
242 self.generation_config.as_ref(),
243 self.safety_settings.as_deref(),
244 self.tool_config.as_ref(),
245 ))
246 .send()
247 .await;
248 let response = match request {
249 Ok(response) => response,
250 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
251 };
252
253 if !response.status().is_success() {
254 let text = match response.text().await {
255 Ok(response) => response,
256 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
257 };
258 return Err((session, GeminiResponseError::StatusNotOk(text.into())));
259 }
260
261 Ok(ResponseStream::new(
262 Box::new(response.bytes_stream()),
263 session,
264 data_extractor,
265 ))
266 }
267 #[cfg(feature = "reqwest")]
286 pub async fn ask_as_stream(
287 &self,
288 session: Session,
289 ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
290 self.ask_as_stream_with_extractor(
291 session,
292 (|_, gemini_response| gemini_response)
293 as fn(&Session, GeminiResponse) -> GeminiResponse,
294 )
295 .await
296 }
297}