gemini_client_api/gemini/
ask.rs1use super::error::GeminiResponseError;
2use super::types::request::*;
3use super::types::response::*;
4use super::types::sessions::Session;
5use reqwest::Client;
6use serde_json::{Value, json};
7use std::time::Duration;
8
9const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
10
11#[derive(Clone, Default, Debug)]
12pub struct Gemini {
13 client: Client,
14 api_key: String,
15 model: String,
16 sys_prompt: Option<SystemInstruction>,
17 generation_config: Option<Value>,
18 safety_settings: Option<Vec<SafetySetting>>,
19 tools: Option<Vec<Tool>>,
20}
21impl Gemini {
22 pub fn new(
27 api_key: impl Into<String>,
28 model: impl Into<String>,
29 sys_prompt: Option<SystemInstruction>,
30 ) -> Self {
31 Self {
32 client: Client::builder()
33 .timeout(Duration::from_secs(60))
34 .build()
35 .unwrap(),
36 api_key: api_key.into(),
37 model: model.into(),
38 sys_prompt,
39 generation_config: None,
40 safety_settings: None,
41 tools: None,
42 }
43 }
44 pub fn new_with_timeout(
46 api_key: impl Into<String>,
47 model: impl Into<String>,
48 sys_prompt: Option<SystemInstruction>,
49 api_timeout: Duration,
50 ) -> Self {
51 Self {
52 client: Client::builder().timeout(api_timeout).build().unwrap(),
53 api_key: api_key.into(),
54 model: model.into(),
55 sys_prompt,
56 generation_config: None,
57 safety_settings: None,
58 tools: None,
59 }
60 }
61 pub fn set_generation_config(&mut self) -> &mut Value {
63 if let None = self.generation_config {
64 self.generation_config = Some(json!({}));
65 }
66 self.generation_config.as_mut().unwrap()
67 }
68 pub fn set_thinking_config(mut self, config: ThinkingConfig) -> Self {
69 if let Value::Object(map) = self.set_generation_config() {
70 if let Ok(thinking_value) = serde_json::to_value(config) {
71 map.insert("thinking_config".to_string(), thinking_value);
72 }
73 }
74 self
75 }
76 pub fn set_model(mut self, model: impl Into<String>) -> Self {
77 self.model = model.into();
78 self
79 }
80 pub fn set_sys_prompt(mut self, sys_prompt: Option<SystemInstruction>) -> Self {
81 self.sys_prompt = sys_prompt;
82 self
83 }
84 pub fn set_safety_settings(mut self, settings: Option<Vec<SafetySetting>>) -> Self {
85 self.safety_settings = settings;
86 self
87 }
88 pub fn set_api_key(mut self, api_key: impl Into<String>) -> Self {
89 self.api_key = api_key.into();
90 self
91 }
92 pub fn set_json_mode(mut self, schema: Value) -> Self {
98 let config = self.set_generation_config();
99 config["response_mime_type"] = "application/json".into();
100 config["response_schema"] = schema.into();
101 self
102 }
103 pub fn unset_json_mode(mut self) -> Self {
104 if let Some(ref mut generation_config) = self.generation_config {
105 generation_config["response_schema"] = None::<Value>.into();
106 generation_config["response_mime_type"] = None::<Value>.into();
107 }
108 self
109 }
110 pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
111 self.tools = Some(tools);
112 self
113 }
114 pub fn unset_tools(mut self) -> Self {
115 self.tools = None;
116 self
117 }
118
119 pub async fn ask(&self, session: &mut Session) -> Result<GeminiResponse, GeminiResponseError> {
120 let req_url = format!(
121 "{BASE_URL}/{}:generateContent?key={}",
122 self.model, self.api_key
123 );
124
125 let response = self
126 .client
127 .post(req_url)
128 .json(&GeminiRequestBody::new(
129 self.sys_prompt.as_ref(),
130 self.tools.as_deref(),
131 &session.get_history().as_slice(),
132 self.generation_config.as_ref(),
133 self.safety_settings.as_deref(),
134 ))
135 .send()
136 .await
137 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
138
139 if !response.status().is_success() {
140 let text = response
141 .text()
142 .await
143 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
144 return Err(GeminiResponseError::StatusNotOk(text));
145 }
146
147 let reply = GeminiResponse::new(response)
148 .await
149 .map_err(|e| GeminiResponseError::ReqwestError(e))?;
150 session.update(&reply);
151 Ok(reply)
152 }
153 pub async fn ask_as_stream_with_extractor<F, StreamType>(
170 &self,
171 session: Session,
172 data_extractor: F,
173 ) -> Result<ResponseStream<F, StreamType>, (Session, GeminiResponseError)>
174 where
175 F: FnMut(&Session, GeminiResponse) -> StreamType,
176 {
177 let req_url = format!(
178 "{BASE_URL}/{}:streamGenerateContent?alt=sse&key={}",
179 self.model, self.api_key
180 );
181
182 let request = self
183 .client
184 .post(req_url)
185 .json(&GeminiRequestBody::new(
186 self.sys_prompt.as_ref(),
187 self.tools.as_deref(),
188 session.get_history().as_slice(),
189 self.generation_config.as_ref(),
190 self.safety_settings.as_deref(),
191 ))
192 .send()
193 .await;
194 let response = match request {
195 Ok(response) => response,
196 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
197 };
198
199 if !response.status().is_success() {
200 let text = match response.text().await {
201 Ok(response) => response,
202 Err(e) => return Err((session, GeminiResponseError::ReqwestError(e))),
203 };
204 return Err((session, GeminiResponseError::StatusNotOk(text.into())));
205 }
206
207 Ok(ResponseStream::new(
208 Box::new(response.bytes_stream()),
209 session,
210 data_extractor,
211 ))
212 }
213 pub async fn ask_as_stream(
227 &self,
228 session: Session,
229 ) -> Result<GeminiResponseStream, (Session, GeminiResponseError)> {
230 self.ask_as_stream_with_extractor(
231 session,
232 (|_, gemini_response| gemini_response)
233 as fn(&Session, GeminiResponse) -> GeminiResponse,
234 )
235 .await
236 }
237}