1use crate::{
2 models::{
3 Content, FunctionCallingConfig, FunctionCallingMode, GenerateContentRequest,
4 GenerationConfig, GenerationResponse, Message, Role, ToolConfig,
5 },
6 tools::{FunctionDeclaration, Tool},
7 Error, Result,
8};
9use reqwest::Client;
10use std::sync::Arc;
11use url::Url;
12use futures_util::StreamExt;
13use futures::stream::Stream;
14use std::pin::Pin;
15
16const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
17const DEFAULT_MODEL: &str = "models/gemini-2.0-flash";
18
19pub struct ContentBuilder {
21 client: Arc<GeminiClient>,
22 pub contents: Vec<Content>,
23 generation_config: Option<GenerationConfig>,
24 tools: Option<Vec<Tool>>,
25 tool_config: Option<ToolConfig>,
26}
27
28impl ContentBuilder {
29 fn new(client: Arc<GeminiClient>) -> Self {
31 Self {
32 client,
33 contents: Vec::new(),
34 generation_config: None,
35 tools: None,
36 tool_config: None,
37 }
38 }
39
40 pub fn with_system_prompt(mut self, text: impl Into<String>) -> Self {
42 let content = Content::text(text).with_role(Role::System);
43 self.contents.push(content);
44 self
45 }
46
47 pub fn with_user_message(mut self, text: impl Into<String>) -> Self {
49 let content = Content::text(text).with_role(Role::User);
50 self.contents.push(content);
51 self
52 }
53
54 pub fn with_model_message(mut self, text: impl Into<String>) -> Self {
56 let content = Content::text(text).with_role(Role::Model);
57 self.contents.push(content);
58 self
59 }
60
61 pub fn with_function_response(
63 mut self,
64 name: impl Into<String>,
65 response: serde_json::Value,
66 ) -> Self {
67 let content = Content::function_response_json(name, response)
68 .with_role(Role::Function);
69 self.contents.push(content);
70 self
71 }
72
73 pub fn with_function_response_str(
75 mut self,
76 name: impl Into<String>,
77 response: impl Into<String>,
78 ) -> std::result::Result<Self, serde_json::Error> {
79 let response_str = response.into();
80 let json = serde_json::from_str(&response_str)?;
81 let content = Content::function_response_json(name, json)
82 .with_role(Role::Function);
83 self.contents.push(content);
84 Ok(self)
85 }
86
87 pub fn with_message(mut self, message: Message) -> Self {
89 let content = message.content.clone();
90 match &content.role {
91 Some(role) => {
92 let role_clone = role.clone();
93 self.contents.push(content.with_role(role_clone));
94 }
95 None => {
96 self.contents.push(content.with_role(message.role));
97 }
98 }
99 self
100 }
101
102 pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
104 for message in messages {
105 self = self.with_message(message);
106 }
107 self
108 }
109
110 pub fn with_generation_config(mut self, config: GenerationConfig) -> Self {
112 self.generation_config = Some(config);
113 self
114 }
115
116 pub fn with_temperature(mut self, temperature: f32) -> Self {
118 if self.generation_config.is_none() {
119 self.generation_config = Some(GenerationConfig::default());
120 }
121 if let Some(config) = &mut self.generation_config {
122 config.temperature = Some(temperature);
123 }
124 self
125 }
126
127 pub fn with_top_p(mut self, top_p: f32) -> Self {
129 if self.generation_config.is_none() {
130 self.generation_config = Some(GenerationConfig::default());
131 }
132 if let Some(config) = &mut self.generation_config {
133 config.top_p = Some(top_p);
134 }
135 self
136 }
137
138 pub fn with_top_k(mut self, top_k: i32) -> Self {
140 if self.generation_config.is_none() {
141 self.generation_config = Some(GenerationConfig::default());
142 }
143 if let Some(config) = &mut self.generation_config {
144 config.top_k = Some(top_k);
145 }
146 self
147 }
148
149 pub fn with_max_output_tokens(mut self, max_output_tokens: i32) -> Self {
151 if self.generation_config.is_none() {
152 self.generation_config = Some(GenerationConfig::default());
153 }
154 if let Some(config) = &mut self.generation_config {
155 config.max_output_tokens = Some(max_output_tokens);
156 }
157 self
158 }
159
160 pub fn with_candidate_count(mut self, candidate_count: i32) -> Self {
162 if self.generation_config.is_none() {
163 self.generation_config = Some(GenerationConfig::default());
164 }
165 if let Some(config) = &mut self.generation_config {
166 config.candidate_count = Some(candidate_count);
167 }
168 self
169 }
170
171 pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
173 if self.generation_config.is_none() {
174 self.generation_config = Some(GenerationConfig::default());
175 }
176 if let Some(config) = &mut self.generation_config {
177 config.stop_sequences = Some(stop_sequences);
178 }
179 self
180 }
181
182 pub fn with_response_mime_type(mut self, mime_type: impl Into<String>) -> Self {
184 if self.generation_config.is_none() {
185 self.generation_config = Some(GenerationConfig::default());
186 }
187 if let Some(config) = &mut self.generation_config {
188 config.response_mime_type = Some(mime_type.into());
189 }
190 self
191 }
192
193 pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
195 if self.generation_config.is_none() {
196 self.generation_config = Some(GenerationConfig::default());
197 }
198 if let Some(config) = &mut self.generation_config {
199 config.response_schema = Some(schema);
200 }
201 self
202 }
203
204 pub fn with_tool(mut self, tool: Tool) -> Self {
206 if self.tools.is_none() {
207 self.tools = Some(Vec::new());
208 }
209 if let Some(tools) = &mut self.tools {
210 tools.push(tool);
211 }
212 self
213 }
214
215 pub fn with_function(mut self, function: FunctionDeclaration) -> Self {
217 let tool = Tool::new(function);
218 self = self.with_tool(tool);
219 self
220 }
221
222 pub fn with_function_calling_mode(mut self, mode: FunctionCallingMode) -> Self {
224 if self.tool_config.is_none() {
225 self.tool_config = Some(ToolConfig {
226 function_calling_config: Some(FunctionCallingConfig { mode }),
227 });
228 } else if let Some(tool_config) = &mut self.tool_config {
229 tool_config.function_calling_config = Some(FunctionCallingConfig { mode });
230 }
231 self
232 }
233
234 pub async fn execute(self) -> Result<GenerationResponse> {
236 let request = GenerateContentRequest {
237 contents: self.contents,
238 generation_config: self.generation_config,
239 safety_settings: None,
240 tools: self.tools,
241 tool_config: self.tool_config,
242 };
243
244 self.client.generate_content_raw(request).await
245 }
246
247 pub async fn execute_stream(
249 self,
250 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
251 let request = GenerateContentRequest {
252 contents: self.contents,
253 generation_config: self.generation_config,
254 safety_settings: None,
255 tools: self.tools,
256 tool_config: self.tool_config,
257 };
258
259 self.client.generate_content_stream(request).await
260 }
261}
262
263struct GeminiClient {
265 http_client: Client,
266 api_key: String,
267 model: String,
268}
269
270impl GeminiClient {
271 fn new(api_key: impl Into<String>, model: String) -> Self {
273 Self {
274 http_client: Client::new(),
275 api_key: api_key.into(),
276 model,
277 }
278 }
279
280 async fn generate_content_raw(
282 &self,
283 request: GenerateContentRequest,
284 ) -> Result<GenerationResponse> {
285 let url = self.build_url("generateContent")?;
286
287 let response = self
288 .http_client
289 .post(url)
290 .json(&request)
291 .send()
292 .await?;
293
294 let status = response.status();
295 if !status.is_success() {
296 let error_text = response.text().await?;
297 return Err(Error::ApiError {
298 status_code: status.as_u16(),
299 message: error_text,
300 });
301 }
302
303 let response = response.json().await?;
304 Ok(response)
305 }
306
307 async fn generate_content_stream(
309 &self,
310 request: GenerateContentRequest,
311 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
312 let url = self.build_url("streamGenerateContent")?;
313
314 let response = self
315 .http_client
316 .post(url)
317 .json(&request)
318 .send()
319 .await?;
320
321 let status = response.status();
322 if !status.is_success() {
323 let error_text = response.text().await?;
324 return Err(Error::ApiError {
325 status_code: status.as_u16(),
326 message: error_text,
327 });
328 }
329
330 let stream = response
331 .bytes_stream()
332 .map(|result| {
333 match result {
334 Ok(bytes) => {
335 let text = String::from_utf8_lossy(&bytes);
336 let mut responses = Vec::new();
339 for line in text.lines() {
340 if let Some(json_str) = line.strip_prefix("data: ") {
341 if json_str == "[DONE]" {
342 continue;
343 }
344 match serde_json::from_str::<GenerationResponse>(json_str) {
345 Ok(response) => responses.push(Ok(response)),
346 Err(e) => responses.push(Err(Error::JsonError(e))),
347 }
348 }
349 }
350 futures::stream::iter(responses)
351 }
352 Err(e) => futures::stream::iter(vec![Err(Error::HttpError(e))]),
353 }
354 })
355 .flatten();
356
357 Ok(Box::pin(stream))
358 }
359
360 fn build_url(&self, endpoint: &str) -> Result<Url> {
362 let url_str = format!(
365 "{}{}:{}?key={}",
366 BASE_URL, self.model, endpoint, self.api_key
367 );
368 Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
369 }
370}
371
372#[derive(Clone)]
374pub struct Gemini {
375 client: Arc<GeminiClient>,
376}
377
378impl Gemini {
379 pub fn new(api_key: impl Into<String>) -> Self {
381 Self::with_model(api_key, DEFAULT_MODEL.to_string())
382 }
383
384 pub fn pro(api_key: impl Into<String>) -> Self {
386 Self::with_model(api_key, "models/gemini-2.0-pro".to_string())
387 }
388
389 pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
391 let client = GeminiClient::new(api_key, model);
392 Self {
393 client: Arc::new(client),
394 }
395 }
396
397 pub fn generate_content(&self) -> ContentBuilder {
399 ContentBuilder::new(self.client.clone())
400 }
401}