gemini_rust/
client.rs

1use crate::{
2    models::{
3        Content, FunctionCallingConfig, FunctionCallingMode, GenerateContentRequest,
4        GenerationConfig, GenerationResponse, Message, Role, ToolConfig,
5    },
6    tools::{FunctionDeclaration, Tool},
7    Error, Result,
8};
9use futures::stream::Stream;
10use futures_util::StreamExt;
11use reqwest::Client;
12use std::pin::Pin;
13use std::sync::Arc;
14use url::Url;
15
16const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
17const DEFAULT_MODEL: &str = "models/gemini-2.0-flash";
18
19/// Builder for content generation requests
20pub struct ContentBuilder {
21    client: Arc<GeminiClient>,
22    pub contents: Vec<Content>,
23    generation_config: Option<GenerationConfig>,
24    tools: Option<Vec<Tool>>,
25    tool_config: Option<ToolConfig>,
26    system_instruction: Option<Content>,
27}
28
29impl ContentBuilder {
30    /// Create a new content builder
31    fn new(client: Arc<GeminiClient>) -> Self {
32        Self {
33            client,
34            contents: Vec::new(),
35            generation_config: None,
36            tools: None,
37            tool_config: None,
38            system_instruction: None,
39        }
40    }
41
42    /// Add a system prompt to the request
43    pub fn with_system_prompt(self, text: impl Into<String>) -> Self {
44        // Create a Content with text parts specifically for system_instruction field
45        self.with_system_instruction(text)
46    }
47
48    /// Set the system instruction directly (matching the API format in the curl example)
49    pub fn with_system_instruction(mut self, text: impl Into<String>) -> Self {
50        // Create a Content with text parts specifically for system_instruction field
51        let content = Content::text(text);
52        self.system_instruction = Some(content);
53        self
54    }
55
56    /// Add a user message to the request
57    pub fn with_user_message(mut self, text: impl Into<String>) -> Self {
58        let message = Message::user(text);
59        let content = message.content;
60        self.contents.push(content);
61        self
62    }
63
64    /// Add a model message to the request
65    pub fn with_model_message(mut self, text: impl Into<String>) -> Self {
66        let message = Message::model(text);
67        let content = message.content;
68        self.contents.push(content);
69        self
70    }
71
72    /// Add a function response to the request using a JSON value
73    pub fn with_function_response(
74        mut self,
75        name: impl Into<String>,
76        response: serde_json::Value,
77    ) -> Self {
78        let content = Content::function_response_json(name, response).with_role(Role::User);
79        self.contents.push(content);
80        self
81    }
82
83    /// Add a function response to the request using a JSON string
84    pub fn with_function_response_str(
85        mut self,
86        name: impl Into<String>,
87        response: impl Into<String>,
88    ) -> std::result::Result<Self, serde_json::Error> {
89        let response_str = response.into();
90        let json = serde_json::from_str(&response_str)?;
91        let content = Content::function_response_json(name, json).with_role(Role::User);
92        self.contents.push(content);
93        Ok(self)
94    }
95
96    /// Add a message to the request
97    pub fn with_message(mut self, message: Message) -> Self {
98        let content = message.content.clone();
99        match &content.role {
100            Some(role) => {
101                let role_clone = role.clone();
102                self.contents.push(content.with_role(role_clone));
103            }
104            None => {
105                self.contents.push(content.with_role(message.role));
106            }
107        }
108        self
109    }
110
111    /// Add multiple messages to the request
112    pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
113        for message in messages {
114            self = self.with_message(message);
115        }
116        self
117    }
118
119    /// Set the generation config for the request
120    pub fn with_generation_config(mut self, config: GenerationConfig) -> Self {
121        self.generation_config = Some(config);
122        self
123    }
124
125    /// Set the temperature for the request
126    pub fn with_temperature(mut self, temperature: f32) -> Self {
127        if self.generation_config.is_none() {
128            self.generation_config = Some(GenerationConfig::default());
129        }
130        if let Some(config) = &mut self.generation_config {
131            config.temperature = Some(temperature);
132        }
133        self
134    }
135
136    /// Set the top-p value for the request
137    pub fn with_top_p(mut self, top_p: f32) -> Self {
138        if self.generation_config.is_none() {
139            self.generation_config = Some(GenerationConfig::default());
140        }
141        if let Some(config) = &mut self.generation_config {
142            config.top_p = Some(top_p);
143        }
144        self
145    }
146
147    /// Set the top-k value for the request
148    pub fn with_top_k(mut self, top_k: i32) -> Self {
149        if self.generation_config.is_none() {
150            self.generation_config = Some(GenerationConfig::default());
151        }
152        if let Some(config) = &mut self.generation_config {
153            config.top_k = Some(top_k);
154        }
155        self
156    }
157
158    /// Set the maximum output tokens for the request
159    pub fn with_max_output_tokens(mut self, max_output_tokens: i32) -> Self {
160        if self.generation_config.is_none() {
161            self.generation_config = Some(GenerationConfig::default());
162        }
163        if let Some(config) = &mut self.generation_config {
164            config.max_output_tokens = Some(max_output_tokens);
165        }
166        self
167    }
168
169    /// Set the candidate count for the request
170    pub fn with_candidate_count(mut self, candidate_count: i32) -> Self {
171        if self.generation_config.is_none() {
172            self.generation_config = Some(GenerationConfig::default());
173        }
174        if let Some(config) = &mut self.generation_config {
175            config.candidate_count = Some(candidate_count);
176        }
177        self
178    }
179
180    /// Set the stop sequences for the request
181    pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
182        if self.generation_config.is_none() {
183            self.generation_config = Some(GenerationConfig::default());
184        }
185        if let Some(config) = &mut self.generation_config {
186            config.stop_sequences = Some(stop_sequences);
187        }
188        self
189    }
190
191    /// Set the response mime type for the request
192    pub fn with_response_mime_type(mut self, mime_type: impl Into<String>) -> Self {
193        if self.generation_config.is_none() {
194            self.generation_config = Some(GenerationConfig::default());
195        }
196        if let Some(config) = &mut self.generation_config {
197            config.response_mime_type = Some(mime_type.into());
198        }
199        self
200    }
201
202    /// Set the response schema for structured output
203    pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
204        if self.generation_config.is_none() {
205            self.generation_config = Some(GenerationConfig::default());
206        }
207        if let Some(config) = &mut self.generation_config {
208            config.response_schema = Some(schema);
209        }
210        self
211    }
212
213    /// Add a tool to the request
214    pub fn with_tool(mut self, tool: Tool) -> Self {
215        if self.tools.is_none() {
216            self.tools = Some(Vec::new());
217        }
218        if let Some(tools) = &mut self.tools {
219            tools.push(tool);
220        }
221        self
222    }
223
224    /// Add a function declaration as a tool
225    pub fn with_function(mut self, function: FunctionDeclaration) -> Self {
226        let tool = Tool::new(function);
227        self = self.with_tool(tool);
228        self
229    }
230
231    /// Set the function calling mode for the request
232    pub fn with_function_calling_mode(mut self, mode: FunctionCallingMode) -> Self {
233        if self.tool_config.is_none() {
234            self.tool_config = Some(ToolConfig {
235                function_calling_config: Some(FunctionCallingConfig { mode }),
236            });
237        } else if let Some(tool_config) = &mut self.tool_config {
238            tool_config.function_calling_config = Some(FunctionCallingConfig { mode });
239        }
240        self
241    }
242
243    /// Execute the request
244    pub async fn execute(self) -> Result<GenerationResponse> {
245        let request = GenerateContentRequest {
246            contents: self.contents,
247            generation_config: self.generation_config,
248            safety_settings: None,
249            tools: self.tools,
250            tool_config: self.tool_config,
251            system_instruction: self.system_instruction,
252        };
253
254        self.client.generate_content_raw(request).await
255    }
256
257    /// Execute the request with streaming
258    pub async fn execute_stream(
259        self,
260    ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
261        let request = GenerateContentRequest {
262            contents: self.contents,
263            generation_config: self.generation_config,
264            safety_settings: None,
265            tools: self.tools,
266            tool_config: self.tool_config,
267            system_instruction: self.system_instruction,
268        };
269
270        self.client.generate_content_stream(request).await
271    }
272}
273
274/// Internal client for making requests to the Gemini API
275struct GeminiClient {
276    http_client: Client,
277    api_key: String,
278    model: String,
279}
280
281impl GeminiClient {
282    /// Create a new client
283    fn new(api_key: impl Into<String>, model: String) -> Self {
284        Self {
285            http_client: Client::new(),
286            api_key: api_key.into(),
287            model,
288        }
289    }
290
291    /// Generate content
292    async fn generate_content_raw(
293        &self,
294        request: GenerateContentRequest,
295    ) -> Result<GenerationResponse> {
296        let url = self.build_url("generateContent")?;
297
298        let response = self.http_client.post(url).json(&request).send().await?;
299
300        let status = response.status();
301        if !status.is_success() {
302            let error_text = response.text().await?;
303            return Err(Error::ApiError {
304                status_code: status.as_u16(),
305                message: error_text,
306            });
307        }
308
309        let response = response.json().await?;
310        Ok(response)
311    }
312
313    /// Generate content with streaming
314    async fn generate_content_stream(
315        &self,
316        request: GenerateContentRequest,
317    ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
318        let url = self.build_url("streamGenerateContent")?;
319
320        let response = self.http_client.post(url).json(&request).send().await?;
321
322        let status = response.status();
323        if !status.is_success() {
324            let error_text = response.text().await?;
325            return Err(Error::ApiError {
326                status_code: status.as_u16(),
327                message: error_text,
328            });
329        }
330
331        let stream = response
332            .bytes_stream()
333            .map(|result| {
334                match result {
335                    Ok(bytes) => {
336                        let text = String::from_utf8_lossy(&bytes);
337                        // The stream returns each chunk as a separate JSON object
338                        // Each line that starts with "data: " contains a JSON object
339                        let mut responses = Vec::new();
340                        for line in text.lines() {
341                            if let Some(json_str) = line.strip_prefix("data: ") {
342                                if json_str == "[DONE]" {
343                                    continue;
344                                }
345                                match serde_json::from_str::<GenerationResponse>(json_str) {
346                                    Ok(response) => responses.push(Ok(response)),
347                                    Err(e) => responses.push(Err(Error::JsonError(e))),
348                                }
349                            }
350                        }
351                        futures::stream::iter(responses)
352                    }
353                    Err(e) => futures::stream::iter(vec![Err(Error::HttpError(e))]),
354                }
355            })
356            .flatten();
357
358        Ok(Box::pin(stream))
359    }
360
361    /// Build a URL for the API
362    fn build_url(&self, endpoint: &str) -> Result<Url> {
363        // All Gemini API endpoints now use the format with colon:
364        // "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key=$API_KEY"
365        let url_str = format!(
366            "{}{}:{}?key={}",
367            BASE_URL, self.model, endpoint, self.api_key
368        );
369        Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
370    }
371}
372
373/// Client for the Gemini API
374#[derive(Clone)]
375pub struct Gemini {
376    client: Arc<GeminiClient>,
377}
378
379impl Gemini {
380    /// Create a new client with the specified API key
381    pub fn new(api_key: impl Into<String>) -> Self {
382        Self::with_model(api_key, DEFAULT_MODEL.to_string())
383    }
384
385    /// Create a new client for the Gemini Pro model
386    pub fn pro(api_key: impl Into<String>) -> Self {
387        Self::with_model(api_key, "models/gemini-2.0-pro-exp-02-05".to_string())
388    }
389
390    /// Create a new client with the specified API key and model
391    pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
392        let client = GeminiClient::new(api_key, model);
393        Self {
394            client: Arc::new(client),
395        }
396    }
397
398    /// Start building a content generation request
399    pub fn generate_content(&self) -> ContentBuilder {
400        ContentBuilder::new(self.client.clone())
401    }
402}