gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5#[cfg(feature = "reqwest")]
6use reqwest::Client;
7use serde_json::{Value, json};
8use std::time::Duration;
9
10const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
11
12#[derive(Clone, Default, Debug)]
13pub struct Gemini {
14 #[cfg(feature = "reqwest")]
15 client: Client,
16 api_key: String,
17 model: String,
18 sys_prompt: Option<SystemInstruction>,
19 generation_config: Option<Value>,
20 safety_settings: Option<Vec<SafetySetting>>,
21 tools: Option<Vec<Tool>>,
22 tool_config: Option<ToolConfig>,
23}
24
25impl Gemini {
26 #[cfg(feature = "reqwest")]
31 pub fn new(
32 api_key: impl Into<String>,
33 model: impl Into<String>,
34 sys_prompt: Option<SystemInstruction>,
35 ) -> Self {
36 Self {
37 client: Client::builder()
38 .timeout(Duration::from_secs(60))
39 .build()
40 .unwrap(),
41 api_key: api_key.into(),
42 model: model.into(),
43 sys_prompt,
44 generation_config: None,
45 safety_settings: None,
46 tools: None,
47 tool_config: None,
48 }
49 }
50 #[cfg(feature = "reqwest")]
52 pub fn new_with_timeout(
53 api_key: impl Into<String>,
54 model: impl Into<String>,
55 sys_prompt: Option<SystemInstruction>,
56 api_timeout: Duration,
57 ) -> Self {
58 Self {
59 client: Client::builder().timeout(api_timeout).build().unwrap(),
60 api_key: api_key.into(),
61 model: model.into(),
62 sys_prompt,
63 generation_config: None,
64 safety_settings: None,
65 tools: None,
66 tool_config: None,
67 }
68 }
69 pub fn set_generation_config(&mut self) -> &mut Value {
71 if let None = self.generation_config {
72 self.generation_config = Some(json!({}));
73 }
74 self.generation_config.as_mut().unwrap()
75 }
76 pub fn set_tool_config(mut self, config: ToolConfig) -> Self {
77 self.tool_config = Some(config);
78 self
79 }
80 pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
81 if let Value::Object(map) = self.set_generation_config() {
82 if let Ok(thinking_value) = serde_json::to_value(config) {
83 map.insert("thinking_config".to_string(), thinking_value);
84 }
85 }
86 self
87 }
88 pub fn set_model(mut self, model: impl Into<String>) -> Self {
89 self.model = model.into();
90 self
91 }
92 pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
93 self.sys_prompt = sys_prompt;
94 self
95 }
96 pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
97 self.safety_settings = settings;
98 self
99 }
100 pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
101 self.api_key = api_key.into();
102 self
103 }
104 pub fn set_json_mode(mut self, schema: Value) -> Self {
113 let config = self.set_generation_config();
114 config["response_mime_type"] = "application/json".into();
115 config["response_schema"] = schema.into();
116 self
117 }
118 pub fn unset_json_mode(mut self) -> Self {
119 if let Some(ref mut generation_config) = self.generation_config {
120 generation_config["response_schema"] = None::<Value>.into();
121 generation_config["response_mime_type"] = None::<Value>.into();
122 }
123 self
124 }
125 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
126 self.tools = Some(tools);
127 self
128 }
129 pub fn unset_tools(mut self) -> Self {
130 self.tools = None;
131 self
132 }
133
134 #[cfg(feature = "reqwest")]
135 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
136 if !session
137 .get_last_chat()
138 .is_some_and(|chat| *chat.role() != Role::Model)
139 {
140 return Err(GeminiResponseError::NothingToRespond);
141 }
142 let req_url = format!(
143 "{BASE_URL}/{}:generateContent?key={}",
144 self.model, self.api_key
145 );
146
147 let response = self
148 .client
149 .post(req_url)
150 .json(&GeminiRequestBody::new(
151 self.sys_prompt.as_ref(),
152 self.tools.as_deref(),
153 &session.get_history().as_slice(),
154 self.generation_config.as_ref(),
155 self.safety_settings.as_deref(),
156 self.tool_config.as_ref(),
157 ))
158 .send()
159 .await
160 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
161
162 if !response.status().is_success() {
163 let text = response
164 .text()
165 .await
166 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
167 return Err(GeminiResponseError::StatusNotOk(text));
168 }
169
170 let reply = GeminiResponse::new(response)
171 .await
172 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
173 session.update(&reply);
174 Ok(reply)
175 }
176 #[cfg(feature = "reqwest")]
193 pub async fn ask_as_stream_with_extractor<F, StreamType>(
194 &self,
195 session: Session,
196 data_extractor: F,
197 ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
198 where
199 F: FnMut(&Session, GeminiResponse) -> StreamType,
200 {
201 if !session
202 .get_last_chat()
203 .is_some_and(|chat| *chat.role() != Role::Model)
204 {
205 return Err((session, GeminiResponseError::NothingToRespond));
206 }
207 let req_url = format!(
208 "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
209 self.model, self.api_key
210 );
211
212 let request = self
213 .client
214 .post(req_url)
215 .json(&GeminiRequestBody::new(
216 self.sys_prompt.as_ref(),
217 self.tools.as_deref(),
218 session.get_history().as_slice(),
219 self.generation_config.as_ref(),
220 self.safety_settings.as_deref(),
221 self.tool_config.as_ref(),
222 ))
223 .send()
224 .await;
225 let response = match request {
226 Ok(response) => response,
227 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
228 };
229
230 if !response.status().is_success() {
231 let text = match response.text().await {
232 Ok(response) => response,
233 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
234 };
235 return Err((session, GeminiResponseError::StatusNotOk(text.into())));
236 }
237
238 Ok(ResponseStream::new(
239 Box::new(response.bytes_stream()),
240 session,
241 data_extractor,
242 ))
243 }
244 #[cfg(feature = "reqwest")]
258 pub async fn ask_as_stream(
259 &self,
260 session: Session,
261 ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
262 self.ask_as_stream_with_extractor(
263 session,
264 (|_, gemini_response| gemini_response)
265 as fn(&Session, GeminiResponse) -> GeminiResponse,
266 )
267 .await
268 }
269}