gemini_rust/
client.rs

1use crate::{
2    models::{
3        Content, FunctionCallingConfig, FunctionCallingMode, GenerateContentRequest,
4        GenerationConfig, GenerationResponse, Message, Role, ToolConfig,
5    },
6    tools::{FunctionDeclaration, Tool},
7    Error, Result,
8};
9use reqwest::Client;
10use std::sync::Arc;
11use url::Url;
12use futures_util::StreamExt;
13use futures::stream::Stream;
14use std::pin::Pin;
15
16const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
17const DEFAULT_MODEL: &str = "models/gemini-2.0-flash";
18
19/// Builder for content generation requests
20pub struct ContentBuilder {
21    client: Arc<GeminiClient>,
22    pub contents: Vec<Content>,
23    generation_config: Option<GenerationConfig>,
24    tools: Option<Vec<Tool>>,
25    tool_config: Option<ToolConfig>,
26}
27
28impl ContentBuilder {
29    /// Create a new content builder
30    fn new(client: Arc<GeminiClient>) -> Self {
31        Self {
32            client,
33            contents: Vec::new(),
34            generation_config: None,
35            tools: None,
36            tool_config: None,
37        }
38    }
39
40    /// Add a system prompt to the request
41    pub fn with_system_prompt(mut self, text: impl Into<String>) -> Self {
42        let content = Content::text(text).with_role(Role::System);
43        self.contents.push(content);
44        self
45    }
46
47    /// Add a user message to the request
48    pub fn with_user_message(mut self, text: impl Into<String>) -> Self {
49        let content = Content::text(text).with_role(Role::User);
50        self.contents.push(content);
51        self
52    }
53
54    /// Add a model message to the request
55    pub fn with_model_message(mut self, text: impl Into<String>) -> Self {
56        let content = Content::text(text).with_role(Role::Model);
57        self.contents.push(content);
58        self
59    }
60
61    /// Add a function response to the request using a JSON value
62    pub fn with_function_response(
63        mut self,
64        name: impl Into<String>,
65        response: serde_json::Value,
66    ) -> Self {
67        let content = Content::function_response_json(name, response)
68            .with_role(Role::Function);
69        self.contents.push(content);
70        self
71    }
72    
73    /// Add a function response to the request using a JSON string
74    pub fn with_function_response_str(
75        mut self,
76        name: impl Into<String>,
77        response: impl Into<String>,
78    ) -> std::result::Result<Self, serde_json::Error> {
79        let response_str = response.into();
80        let json = serde_json::from_str(&response_str)?;
81        let content = Content::function_response_json(name, json)
82            .with_role(Role::Function);
83        self.contents.push(content);
84        Ok(self)
85    }
86
87    /// Add a message to the request
88    pub fn with_message(mut self, message: Message) -> Self {
89        let content = message.content.clone();
90        match &content.role {
91            Some(role) => {
92                let role_clone = role.clone();
93                self.contents.push(content.with_role(role_clone));
94            }
95            None => {
96                self.contents.push(content.with_role(message.role));
97            }
98        }
99        self
100    }
101
102    /// Add multiple messages to the request
103    pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
104        for message in messages {
105            self = self.with_message(message);
106        }
107        self
108    }
109
110    /// Set the generation config for the request
111    pub fn with_generation_config(mut self, config: GenerationConfig) -> Self {
112        self.generation_config = Some(config);
113        self
114    }
115    
116    /// Set the temperature for the request
117    pub fn with_temperature(mut self, temperature: f32) -> Self {
118        if self.generation_config.is_none() {
119            self.generation_config = Some(GenerationConfig::default());
120        }
121        if let Some(config) = &mut self.generation_config {
122            config.temperature = Some(temperature);
123        }
124        self
125    }
126    
127    /// Set the top-p value for the request
128    pub fn with_top_p(mut self, top_p: f32) -> Self {
129        if self.generation_config.is_none() {
130            self.generation_config = Some(GenerationConfig::default());
131        }
132        if let Some(config) = &mut self.generation_config {
133            config.top_p = Some(top_p);
134        }
135        self
136    }
137    
138    /// Set the top-k value for the request
139    pub fn with_top_k(mut self, top_k: i32) -> Self {
140        if self.generation_config.is_none() {
141            self.generation_config = Some(GenerationConfig::default());
142        }
143        if let Some(config) = &mut self.generation_config {
144            config.top_k = Some(top_k);
145        }
146        self
147    }
148    
149    /// Set the maximum output tokens for the request
150    pub fn with_max_output_tokens(mut self, max_output_tokens: i32) -> Self {
151        if self.generation_config.is_none() {
152            self.generation_config = Some(GenerationConfig::default());
153        }
154        if let Some(config) = &mut self.generation_config {
155            config.max_output_tokens = Some(max_output_tokens);
156        }
157        self
158    }
159    
160    /// Set the candidate count for the request
161    pub fn with_candidate_count(mut self, candidate_count: i32) -> Self {
162        if self.generation_config.is_none() {
163            self.generation_config = Some(GenerationConfig::default());
164        }
165        if let Some(config) = &mut self.generation_config {
166            config.candidate_count = Some(candidate_count);
167        }
168        self
169    }
170    
171    /// Set the stop sequences for the request
172    pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
173        if self.generation_config.is_none() {
174            self.generation_config = Some(GenerationConfig::default());
175        }
176        if let Some(config) = &mut self.generation_config {
177            config.stop_sequences = Some(stop_sequences);
178        }
179        self
180    }
181    
182    /// Set the response mime type for the request
183    pub fn with_response_mime_type(mut self, mime_type: impl Into<String>) -> Self {
184        if self.generation_config.is_none() {
185            self.generation_config = Some(GenerationConfig::default());
186        }
187        if let Some(config) = &mut self.generation_config {
188            config.response_mime_type = Some(mime_type.into());
189        }
190        self
191    }
192    
193    /// Set the response schema for structured output
194    pub fn with_response_schema(mut self, schema: serde_json::Value) -> Self {
195        if self.generation_config.is_none() {
196            self.generation_config = Some(GenerationConfig::default());
197        }
198        if let Some(config) = &mut self.generation_config {
199            config.response_schema = Some(schema);
200        }
201        self
202    }
203
204    /// Add a tool to the request
205    pub fn with_tool(mut self, tool: Tool) -> Self {
206        if self.tools.is_none() {
207            self.tools = Some(Vec::new());
208        }
209        if let Some(tools) = &mut self.tools {
210            tools.push(tool);
211        }
212        self
213    }
214
215    /// Add a function declaration as a tool
216    pub fn with_function(mut self, function: FunctionDeclaration) -> Self {
217        let tool = Tool::new(function);
218        self = self.with_tool(tool);
219        self
220    }
221
222    /// Set the function calling mode for the request
223    pub fn with_function_calling_mode(mut self, mode: FunctionCallingMode) -> Self {
224        if self.tool_config.is_none() {
225            self.tool_config = Some(ToolConfig {
226                function_calling_config: Some(FunctionCallingConfig { mode }),
227            });
228        } else if let Some(tool_config) = &mut self.tool_config {
229            tool_config.function_calling_config = Some(FunctionCallingConfig { mode });
230        }
231        self
232    }
233
234    /// Execute the request
235    pub async fn execute(self) -> Result<GenerationResponse> {
236        let request = GenerateContentRequest {
237            contents: self.contents,
238            generation_config: self.generation_config,
239            safety_settings: None,
240            tools: self.tools,
241            tool_config: self.tool_config,
242        };
243
244        self.client.generate_content_raw(request).await
245    }
246
247    /// Execute the request with streaming
248    pub async fn execute_stream(
249        self,
250    ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
251        let request = GenerateContentRequest {
252            contents: self.contents,
253            generation_config: self.generation_config,
254            safety_settings: None,
255            tools: self.tools,
256            tool_config: self.tool_config,
257        };
258
259        self.client.generate_content_stream(request).await
260    }
261}
262
263/// Internal client for making requests to the Gemini API
264struct GeminiClient {
265    http_client: Client,
266    api_key: String,
267    model: String,
268}
269
270impl GeminiClient {
271    /// Create a new client
272    fn new(api_key: impl Into<String>, model: String) -> Self {
273        Self {
274            http_client: Client::new(),
275            api_key: api_key.into(),
276            model,
277        }
278    }
279
280    /// Generate content
281    async fn generate_content_raw(
282        &self,
283        request: GenerateContentRequest,
284    ) -> Result<GenerationResponse> {
285        let url = self.build_url("generateContent")?;
286
287        let response = self
288            .http_client
289            .post(url)
290            .json(&request)
291            .send()
292            .await?;
293
294        let status = response.status();
295        if !status.is_success() {
296            let error_text = response.text().await?;
297            return Err(Error::ApiError {
298                status_code: status.as_u16(),
299                message: error_text,
300            });
301        }
302
303        let response = response.json().await?;
304        Ok(response)
305    }
306
307    /// Generate content with streaming
308    async fn generate_content_stream(
309        &self,
310        request: GenerateContentRequest,
311    ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
312        let url = self.build_url("streamGenerateContent")?;
313
314        let response = self
315            .http_client
316            .post(url)
317            .json(&request)
318            .send()
319            .await?;
320
321        let status = response.status();
322        if !status.is_success() {
323            let error_text = response.text().await?;
324            return Err(Error::ApiError {
325                status_code: status.as_u16(),
326                message: error_text,
327            });
328        }
329
330        let stream = response
331            .bytes_stream()
332            .map(|result| {
333                match result {
334                    Ok(bytes) => {
335                        let text = String::from_utf8_lossy(&bytes);
336                        // The stream returns each chunk as a separate JSON object
337                        // Each line that starts with "data: " contains a JSON object
338                        let mut responses = Vec::new();
339                        for line in text.lines() {
340                            if let Some(json_str) = line.strip_prefix("data: ") {
341                                if json_str == "[DONE]" {
342                                    continue;
343                                }
344                                match serde_json::from_str::<GenerationResponse>(json_str) {
345                                    Ok(response) => responses.push(Ok(response)),
346                                    Err(e) => responses.push(Err(Error::JsonError(e))),
347                                }
348                            }
349                        }
350                        futures::stream::iter(responses)
351                    }
352                    Err(e) => futures::stream::iter(vec![Err(Error::HttpError(e))]),
353                }
354            })
355            .flatten();
356
357        Ok(Box::pin(stream))
358    }
359
360    /// Build a URL for the API
361    fn build_url(&self, endpoint: &str) -> Result<Url> {
362        // All Gemini API endpoints now use the format with colon:
363        // "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key=$API_KEY"
364        let url_str = format!(
365            "{}{}:{}?key={}",
366            BASE_URL, self.model, endpoint, self.api_key
367        );
368        Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
369    }
370}
371
372/// Client for the Gemini API
373#[derive(Clone)]
374pub struct Gemini {
375    client: Arc<GeminiClient>,
376}
377
378impl Gemini {
379    /// Create a new client with the specified API key
380    pub fn new(api_key: impl Into<String>) -> Self {
381        Self::with_model(api_key, DEFAULT_MODEL.to_string())
382    }
383    
384    /// Create a new client for the Gemini Pro model
385    pub fn pro(api_key: impl Into<String>) -> Self {
386        Self::with_model(api_key, "models/gemini-2.0-pro".to_string())
387    }
388
389    /// Create a new client with the specified API key and model
390    pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
391        let client = GeminiClient::new(api_key, model);
392        Self {
393            client: Arc::new(client),
394        }
395    }
396
397    /// Start building a content generation request
398    pub fn generate_content(&self) -> ContentBuilder {
399        ContentBuilder::new(self.client.clone())
400    }
401}