gemini_rust/
client.rs

1use crate::{
2    models::{
3        Content, FunctionCallingConfig, FunctionCallingMode, GenerateContentRequest,
4        GenerationConfig, GenerationResponse, Message, Role, ToolConfig,
5    },
6    tools::{FunctionDeclaration, Tool},
7    Error, Result,
8};
9use futures::stream::Stream;
10use futures_util::StreamExt;
11use reqwest::Client;
12use std::pin::Pin;
13use std::sync::Arc;
14use url::Url;
15
16const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
17const DEFAULT_MODEL: &str = "models/gemini-2.0-flash";
18
19/// Builder for content generation requests
20pub struct ContentBuilder {
21    client: Arc<GeminiClient>,
22    pub contents: Vec<Content>,
23    generation_config: Option<GenerationConfig>,
24    tools: Option<Vec<Tool>>,
25    tool_config: Option<ToolConfig>,
26    system_instruction: Option<Content>,
27}
28
29impl ContentBuilder {
30    /// Create a new content builder
31    fn new(client: Arc<GeminiClient>) -> Self {
32        Self {
33            client,
34            contents: Vec::new(),
35            generation_config: None,
36            tools: None,
37            tool_config: None,
38            system_instruction: None,
39        }
40    }
41
42    /// Add a system prompt to the request
43    pub fn with_system_prompt(self, text: impl Into<String>) -> Self {
44        // Create a Content with text parts specifically for system_instruction field
45        self.with_system_instruction(text)
46    }
47
48    /// Set the system instruction directly (matching the API format in the curl example)
49    pub fn with_system_instruction(mut self, text: impl Into<String>) -> Self {
50        // Create a Content with text parts specifically for system_instruction field
51        let content = Content::text(text);
52        self.system_instruction = Some(content);
53        self
54    }
55
56    /// Add a user message to the request
57    pub fn with_user_message(mut self, text: impl Into<String>) -> Self {
58        let content = Content::text(text).with_role(Role::User);
59        self.contents.push(content);
60        self
61    }
62
63    /// Add a model message to the request
64    pub fn with_model_message(mut self, text: impl Into<String>) -> Self {
65        let content = Content::text(text).with_role(Role::Model);
66        self.contents.push(content);
67        self
68    }
69
70    /// Add a function response to the request using a JSON value
71    pub fn with_function_response(
72        mut self,
73        name: impl Into<String>,
74        response: serde_json::Value,
75    ) -> Self {
76        let content = Content::function_response_json(name, response).with_role(Role::Function);
77        self.contents.push(content);
78        self
79    }
80
81    /// Add a function response to the request using a JSON string
82    pub fn with_function_response_str(
83        mut self,
84        name: impl Into<String>,
85        response: impl Into<String>,
86    ) -> std::result::Result<Self, serde_json::Error> {
87        let response_str = response.into();
88        let json = serde_json::from_str(&response_str)?;
89        let content = Content::function_response_json(name, json).with_role(Role::Function);
90        self.contents.push(content);
91        Ok(self)
92    }
93
94    /// Add a message to the request
95    pub fn with_message(mut self, message: Message) -> Self {
96        let content = message.content.clone();
97        match &content.role {
98            Some(role) => {
99                let role_clone = role.clone();
100                self.contents.push(content.with_role(role_clone));
101            }
102            None => {
103                self.contents.push(content.with_role(message.role));
104            }
105        }
106        self
107    }
108
109    /// Add multiple messages to the request
110    pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
111        for message in messages {
112            self = self.with_message(message);
113        }
114        self
115    }
116
117    /// Set the generation config for the request
118    pub fn with_generation_config(mut self, config: GenerationConfig) -> Self {
119        self.generation_config = Some(config);
120        self
121    }
122
123    /// Set the temperature for the request
124    pub fn with_temperature(mut self, temperature: f32) -> Self {
125        if self.generation_config.is_none() {
126            self.generation_config = Some(GenerationConfig::default());
127        }
128        if let Some(config) = &mut self.generation_config {
129            config.temperature = Some(temperature);
130        }
131        self
132    }
133
134    /// Set the top-p value for the request
135    pub fn with_top_p(mut self, top_p: f32) -> Self {
136        if self.generation_config.is_none() {
137            self.generation_config = Some(GenerationConfig::default());
138        }
139        if let Some(config) = &mut self.generation_config {
140            config.top_p = Some(top_p);
141        }
142        self
143    }
144
145    /// Set the top-k value for the request
146    pub fn with_top_k(mut self, top_k: i32) -> Self {
147        if self.generation_config.is_none() {
148            self.generation_config = Some(GenerationConfig::default());
149        }
150        if let Some(config) = &mut self.generation_config {
151            config.top_k = Some(top_k);
152        }
153        self
154    }
155
156    /// Set the maximum output tokens for the request
157    pub fn with_max_output_tokens(mut self, max_output_tokens: i32) -> Self {
158        if self.generation_config.is_none() {
159            self.generation_config = Some(GenerationConfig::default());
160        }
161        if let Some(config) = &mut self.generation_config {
162            config.max_output_tokens = Some(max_output_tokens);
163        }
164        self
165    }
166
167    /// Set the candidate count for the request
168    pub fn with_candidate_count(mut self, candidate_count: i32) -> Self {
169        if self.generation_config.is_none() {
170            self.generation_config = Some(GenerationConfig::default());
171        }
172        if let Some(config) = &mut self.generation_config {
173            config.candidate_count = Some(candidate_count);
174        }
175        self
176    }
177
178    /// Set the stop sequences for the request
179    pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
180        if self.generation_config.is_none() {
181            self.generation_config = Some(GenerationConfig::default());
182        }
183        if let Some(config) = &mut self.generation_config {
184            config.stop_sequences = Some(stop_sequences);
185        }
186        self
187    }
188
189    /// Set the response mime type for the request
190    pub fn with_response_mime_type(mut self, mime_type: impl Into<String>) -> Self {
191        if self.generation_config.is_none() {
192            self.generation_config = Some(GenerationConfig::default());
193        }
194        if let Some(config) = &mut self.generation_config {
195            config.response_mime_type = Some(mime_type.into());
196        }
197        self
198    }
199
200    /// Set the response schema for structured output
201    pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
202        if self.generation_config.is_none() {
203            self.generation_config = Some(GenerationConfig::default());
204        }
205        if let Some(config) = &mut self.generation_config {
206            config.response_schema = Some(schema);
207        }
208        self
209    }
210
211    /// Add a tool to the request
212    pub fn with_tool(mut self, tool: Tool) -> Self {
213        if self.tools.is_none() {
214            self.tools = Some(Vec::new());
215        }
216        if let Some(tools) = &mut self.tools {
217            tools.push(tool);
218        }
219        self
220    }
221
222    /// Add a function declaration as a tool
223    pub fn with_function(mut self, function: FunctionDeclaration) -> Self {
224        let tool = Tool::new(function);
225        self = self.with_tool(tool);
226        self
227    }
228
229    /// Set the function calling mode for the request
230    pub fn with_function_calling_mode(mut self, mode: FunctionCallingMode) -> Self {
231        if self.tool_config.is_none() {
232            self.tool_config = Some(ToolConfig {
233                function_calling_config: Some(FunctionCallingConfig { mode }),
234            });
235        } else if let Some(tool_config) = &mut self.tool_config {
236            tool_config.function_calling_config = Some(FunctionCallingConfig { mode });
237        }
238        self
239    }
240
241    /// Execute the request
242    pub async fn execute(self) -> Result<GenerationResponse> {
243        let request = GenerateContentRequest {
244            contents: self.contents,
245            generation_config: self.generation_config,
246            safety_settings: None,
247            tools: self.tools,
248            tool_config: self.tool_config,
249            system_instruction: self.system_instruction,
250        };
251
252        self.client.generate_content_raw(request).await
253    }
254
255    /// Execute the request with streaming
256    pub async fn execute_stream(
257        self,
258    ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
259        let request = GenerateContentRequest {
260            contents: self.contents,
261            generation_config: self.generation_config,
262            safety_settings: None,
263            tools: self.tools,
264            tool_config: self.tool_config,
265            system_instruction: self.system_instruction,
266        };
267
268        self.client.generate_content_stream(request).await
269    }
270}
271
272/// Internal client for making requests to the Gemini API
273struct GeminiClient {
274    http_client: Client,
275    api_key: String,
276    model: String,
277}
278
279impl GeminiClient {
280    /// Create a new client
281    fn new(api_key: impl Into<String>, model: String) -> Self {
282        Self {
283            http_client: Client::new(),
284            api_key: api_key.into(),
285            model,
286        }
287    }
288
289    /// Generate content
290    async fn generate_content_raw(
291        &self,
292        request: GenerateContentRequest,
293    ) -> Result<GenerationResponse> {
294        let url = self.build_url("generateContent")?;
295
296        let response = self.http_client.post(url).json(&request).send().await?;
297
298        let status = response.status();
299        if !status.is_success() {
300            let error_text = response.text().await?;
301            return Err(Error::ApiError {
302                status_code: status.as_u16(),
303                message: error_text,
304            });
305        }
306
307        let response = response.json().await?;
308        Ok(response)
309    }
310
311    /// Generate content with streaming
312    async fn generate_content_stream(
313        &self,
314        request: GenerateContentRequest,
315    ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
316        let url = self.build_url("streamGenerateContent")?;
317
318        let response = self.http_client.post(url).json(&request).send().await?;
319
320        let status = response.status();
321        if !status.is_success() {
322            let error_text = response.text().await?;
323            return Err(Error::ApiError {
324                status_code: status.as_u16(),
325                message: error_text,
326            });
327        }
328
329        let stream = response
330            .bytes_stream()
331            .map(|result| {
332                match result {
333                    Ok(bytes) => {
334                        let text = String::from_utf8_lossy(&bytes);
335                        // The stream returns each chunk as a separate JSON object
336                        // Each line that starts with "data: " contains a JSON object
337                        let mut responses = Vec::new();
338                        for line in text.lines() {
339                            if let Some(json_str) = line.strip_prefix("data: ") {
340                                if json_str == "[DONE]" {
341                                    continue;
342                                }
343                                match serde_json::from_str::<GenerationResponse>(json_str) {
344                                    Ok(response) => responses.push(Ok(response)),
345                                    Err(e) => responses.push(Err(Error::JsonError(e))),
346                                }
347                            }
348                        }
349                        futures::stream::iter(responses)
350                    }
351                    Err(e) => futures::stream::iter(vec![Err(Error::HttpError(e))]),
352                }
353            })
354            .flatten();
355
356        Ok(Box::pin(stream))
357    }
358
359    /// Build a URL for the API
360    fn build_url(&self, endpoint: &str) -> Result<Url> {
361        // All Gemini API endpoints now use the format with colon:
362        // "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key=$API_KEY"
363        let url_str = format!(
364            "{}{}:{}?key={}",
365            BASE_URL, self.model, endpoint, self.api_key
366        );
367        Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
368    }
369}
370
371/// Client for the Gemini API
372#[derive(Clone)]
373pub struct Gemini {
374    client: Arc<GeminiClient>,
375}
376
377impl Gemini {
378    /// Create a new client with the specified API key
379    pub fn new(api_key: impl Into<String>) -> Self {
380        Self::with_model(api_key, DEFAULT_MODEL.to_string())
381    }
382
383    /// Create a new client for the Gemini Pro model
384    pub fn pro(api_key: impl Into<String>) -> Self {
385        Self::with_model(api_key, "models/gemini-2.0-pro-exp-02-05".to_string())
386    }
387
388    /// Create a new client with the specified API key and model
389    pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
390        let client = GeminiClient::new(api_key, model);
391        Self {
392            client: Arc::new(client),
393        }
394    }
395
396    /// Start building a content generation request
397    pub fn generate_content(&self) -> ContentBuilder {
398        ContentBuilder::new(self.client.clone())
399    }
400}