Skip to main content

daimon_provider_gemini/
lib.rs

1//! Google Gemini model provider for the [Daimon](https://docs.rs/daimon) agent framework.
2//!
3//! Supports the Generative AI endpoint and Vertex AI, tool use, SSE streaming,
4//! configurable timeouts, retries with exponential backoff, and cached content.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use daimon_provider_gemini::Gemini;
10//! use daimon_core::Model;
11//!
12//! let model = Gemini::new("gemini-2.0-flash");
13//! ```
14
15use std::time::Duration;
16
17use reqwest::Client;
18use serde::{Deserialize, Serialize};
19
20mod embedding;
21
22#[cfg(feature = "pubsub")]
23pub mod pubsub;
24
25pub use embedding::GeminiEmbedding;
26
27#[cfg(feature = "pubsub")]
28pub use pubsub::PubSubBroker;
29
30use daimon_core::{
31    ChatRequest, ChatResponse, DaimonError, Message, Model, ResponseStream, Result, Role,
32    StopReason, StreamEvent, ToolCall, ToolSpec, Usage,
33};
34
35const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
36const DEFAULT_MAX_RETRIES: u32 = 3;
37
38fn build_client(timeout: Option<Duration>) -> Client {
39    let mut builder = Client::builder();
40    if let Some(t) = timeout {
41        builder = builder.timeout(t);
42    }
43    builder.build().expect("failed to build HTTP client")
44}
45
46/// Google Gemini model provider.
47///
48/// Connects to the Gemini REST API. Supports both the public Generative AI
49/// endpoint (default) and Vertex AI via `with_base_url()`. Authentication is
50/// via API key (passed as `?key=` query parameter) or bearer token for Vertex AI.
51#[derive(Debug)]
52pub struct Gemini {
53    client: Client,
54    api_key: String,
55    model_id: String,
56    base_url: String,
57    timeout: Option<Duration>,
58    max_retries: u32,
59    use_bearer_token: bool,
60    cached_content: Option<String>,
61}
62
63impl Gemini {
64    /// Create a new Gemini client, reading `GOOGLE_API_KEY` from the environment.
65    pub fn new(model_id: impl Into<String>) -> Self {
66        let api_key = std::env::var("GOOGLE_API_KEY").unwrap_or_default();
67        Self::with_api_key(model_id, api_key)
68    }
69
70    /// Create a new Gemini client with an explicit API key.
71    pub fn with_api_key(model_id: impl Into<String>, api_key: impl Into<String>) -> Self {
72        Self {
73            client: build_client(None),
74            api_key: api_key.into(),
75            model_id: model_id.into(),
76            base_url: DEFAULT_BASE_URL.to_string(),
77            timeout: None,
78            max_retries: DEFAULT_MAX_RETRIES,
79            use_bearer_token: false,
80            cached_content: None,
81        }
82    }
83
84    /// Set a custom base URL (e.g. for Vertex AI endpoints).
85    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
86        self.base_url = url.into();
87        self
88    }
89
90    /// Set an HTTP timeout for requests.
91    pub fn with_timeout(mut self, timeout: Duration) -> Self {
92        self.timeout = Some(timeout);
93        self.client = build_client(Some(timeout));
94        self
95    }
96
97    /// Set the maximum number of retries for transient errors.
98    pub fn with_max_retries(mut self, retries: u32) -> Self {
99        self.max_retries = retries;
100        self
101    }
102
103    /// Use `Authorization: Bearer <key>` instead of `?key=` query parameter.
104    ///
105    /// Required for Vertex AI endpoints where the key is an OAuth2 access token.
106    pub fn with_bearer_token(mut self) -> Self {
107        self.use_bearer_token = true;
108        self
109    }
110
111    /// Reference a previously-created cached content resource.
112    ///
113    /// The name should be in the format `cachedContents/<id>`, as returned
114    /// by the Gemini Caching API.
115    pub fn with_cached_content(mut self, name: impl Into<String>) -> Self {
116        self.cached_content = Some(name.into());
117        self
118    }
119
120    fn endpoint_url(&self, method: &str) -> String {
121        format!(
122            "{}/models/{}:{}",
123            self.base_url, self.model_id, method
124        )
125    }
126
127    fn apply_auth(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
128        if self.use_bearer_token {
129            req.bearer_auth(&self.api_key)
130        } else {
131            req.query(&[("key", &self.api_key)])
132        }
133    }
134
135    fn build_request_body(&self, request: &ChatRequest) -> GeminiRequest {
136        let mut system_instruction = None;
137        let mut contents = Vec::new();
138
139        for msg in &request.messages {
140            match msg.role {
141                Role::System => {
142                    if let Some(text) = &msg.content {
143                        system_instruction = Some(GeminiContent {
144                            role: "user".to_string(),
145                            parts: vec![GeminiPart::Text {
146                                text: text.clone(),
147                            }],
148                        });
149                    }
150                }
151                Role::User => {
152                    if let Some(text) = &msg.content {
153                        contents.push(GeminiContent {
154                            role: "user".to_string(),
155                            parts: vec![GeminiPart::Text {
156                                text: text.clone(),
157                            }],
158                        });
159                    }
160                }
161                Role::Assistant => {
162                    if !msg.tool_calls.is_empty() {
163                        let parts = msg
164                            .tool_calls
165                            .iter()
166                            .map(|tc| GeminiPart::FunctionCall {
167                                function_call: GeminiFunctionCall {
168                                    name: tc.name.clone(),
169                                    args: tc.arguments.clone(),
170                                },
171                            })
172                            .collect();
173                        contents.push(GeminiContent {
174                            role: "model".to_string(),
175                            parts,
176                        });
177                    } else if let Some(text) = &msg.content {
178                        contents.push(GeminiContent {
179                            role: "model".to_string(),
180                            parts: vec![GeminiPart::Text {
181                                text: text.clone(),
182                            }],
183                        });
184                    }
185                }
186                Role::Tool => {
187                    let name = msg.tool_call_id.clone().unwrap_or_default();
188                    let content = msg.content.clone().unwrap_or_default();
189                    let response_value: serde_json::Value =
190                        serde_json::from_str(&content).unwrap_or_else(|_| {
191                            serde_json::json!({ "result": content })
192                        });
193                    contents.push(GeminiContent {
194                        role: "user".to_string(),
195                        parts: vec![GeminiPart::FunctionResponse {
196                            function_response: GeminiFunctionResponse {
197                                name,
198                                response: response_value,
199                            },
200                        }],
201                    });
202                }
203            }
204        }
205
206        let tools = if request.tools.is_empty() {
207            None
208        } else {
209            let declarations: Vec<GeminiFunctionDeclaration> =
210                request.tools.iter().map(Into::into).collect();
211            Some(vec![GeminiToolConfig {
212                function_declarations: declarations,
213            }])
214        };
215
216        let generation_config = Some(GeminiGenerationConfig {
217            temperature: request.temperature,
218            max_output_tokens: request.max_tokens,
219        });
220
221        GeminiRequest {
222            cached_content: self.cached_content.clone(),
223            system_instruction,
224            contents,
225            tools,
226            generation_config,
227        }
228    }
229}
230
231impl Model for Gemini {
232    #[tracing::instrument(skip_all, fields(model = %self.model_id))]
233    async fn generate(&self, request: &ChatRequest) -> Result<ChatResponse> {
234        let body = self.build_request_body(request);
235        let url = self.endpoint_url("generateContent");
236
237        for attempt in 0..=self.max_retries {
238            let req = self.client.post(&url).json(&body);
239            let req = self.apply_auth(req);
240
241            tracing::debug!(attempt, "sending Gemini generateContent request");
242            let response = req
243                .send()
244                .await
245                .map_err(|e| DaimonError::Model(format!("Gemini HTTP error: {e}")))?;
246            let status = response.status();
247
248            if status.is_success() {
249                let api_resp: GeminiResponse = response
250                    .json()
251                    .await
252                    .map_err(|e| DaimonError::Model(format!("Gemini response parse error: {e}")))?;
253                tracing::debug!("received successful Gemini response");
254                return parse_response(api_resp);
255            }
256
257            let text = response.text().await.unwrap_or_default();
258            let is_retryable = status.as_u16() == 429 || status.is_server_error();
259
260            if is_retryable && attempt < self.max_retries {
261                let delay_ms = 100 * 2u64.pow(attempt);
262                tracing::debug!(status = %status, attempt, delay_ms, "retryable error, backing off");
263                tokio::time::sleep(Duration::from_millis(delay_ms)).await;
264            } else {
265                return Err(DaimonError::Model(format!(
266                    "Gemini API error ({status}): {text}"
267                )));
268            }
269        }
270
271        unreachable!("loop always returns or retries")
272    }
273
274    #[tracing::instrument(skip_all, fields(model = %self.model_id))]
275    async fn generate_stream(&self, request: &ChatRequest) -> Result<ResponseStream> {
276        let body = self.build_request_body(request);
277        let url = self.endpoint_url("streamGenerateContent");
278
279        let req = self
280            .client
281            .post(&url)
282            .query(&[("alt", "sse")])
283            .json(&body);
284        let req = self.apply_auth(req);
285
286        tracing::debug!("sending Gemini streaming request");
287        let response = req
288            .send()
289            .await
290            .map_err(|e| DaimonError::Model(format!("Gemini HTTP error: {e}")))?;
291
292        if !response.status().is_success() {
293            let status = response.status();
294            let text = response.text().await.unwrap_or_default();
295            return Err(DaimonError::Model(format!(
296                "Gemini API error ({status}): {text}"
297            )));
298        }
299
300        tracing::debug!("Gemini stream established");
301        let byte_stream = response.bytes_stream();
302
303        let stream = async_stream::try_stream! {
304            use futures::StreamExt;
305
306            let mut buffer = String::new();
307            let mut stream = Box::pin(byte_stream);
308
309            while let Some(chunk) = stream.next().await {
310                let chunk = chunk.map_err(|e| DaimonError::Model(format!("Gemini stream error: {e}")))?;
311                buffer.push_str(&String::from_utf8_lossy(&chunk));
312
313                while let Some(line_end) = buffer.find('\n') {
314                    let line = buffer[..line_end].trim().to_string();
315                    buffer = buffer[line_end + 1..].to_string();
316
317                    if line.is_empty() {
318                        continue;
319                    }
320
321                    if let Some(data) = line.strip_prefix("data: ") {
322                        if let Ok(chunk_resp) = serde_json::from_str::<GeminiResponse>(data) {
323                            for candidate in &chunk_resp.candidates {
324                                for part in &candidate.content.parts {
325                                    match part {
326                                        GeminiResponsePart::Text { text } => {
327                                            if !text.is_empty() {
328                                                yield StreamEvent::TextDelta(text.clone());
329                                            }
330                                        }
331                                        GeminiResponsePart::FunctionCall { function_call } => {
332                                            let id = format!("gemini_{}", function_call.name);
333                                            yield StreamEvent::ToolCallStart {
334                                                id: id.clone(),
335                                                name: function_call.name.clone(),
336                                            };
337                                            let args = serde_json::to_string(&function_call.args)
338                                                .unwrap_or_default();
339                                            yield StreamEvent::ToolCallDelta {
340                                                id: id.clone(),
341                                                arguments_delta: args,
342                                            };
343                                            yield StreamEvent::ToolCallEnd { id };
344                                        }
345                                    }
346                                }
347                            }
348
349                            let is_done = chunk_resp.candidates.iter().any(|c| {
350                                c.finish_reason.as_deref() == Some("STOP")
351                                    || c.finish_reason.as_deref() == Some("MAX_TOKENS")
352                            });
353                            if is_done {
354                                yield StreamEvent::Done;
355                            }
356                        }
357                    }
358                }
359            }
360        };
361
362        Ok(Box::pin(stream))
363    }
364}
365
366fn parse_response(response: GeminiResponse) -> Result<ChatResponse> {
367    let candidate = response
368        .candidates
369        .into_iter()
370        .next()
371        .ok_or_else(|| DaimonError::Model("no candidates in Gemini response".into()))?;
372
373    let mut text_content = String::new();
374    let mut tool_calls = Vec::new();
375
376    for part in candidate.content.parts {
377        match part {
378            GeminiResponsePart::Text { text } => {
379                text_content.push_str(&text);
380            }
381            GeminiResponsePart::FunctionCall { function_call } => {
382                tool_calls.push(ToolCall {
383                    id: format!("gemini_{}", function_call.name),
384                    name: function_call.name,
385                    arguments: function_call.args,
386                });
387            }
388        }
389    }
390
391    let stop_reason = if !tool_calls.is_empty() {
392        StopReason::ToolUse
393    } else {
394        match candidate.finish_reason.as_deref() {
395            Some("MAX_TOKENS") => StopReason::MaxTokens,
396            _ => StopReason::EndTurn,
397        }
398    };
399
400    let message = if tool_calls.is_empty() {
401        Message::assistant(text_content)
402    } else {
403        Message {
404            role: Role::Assistant,
405            content: if text_content.is_empty() {
406                None
407            } else {
408                Some(text_content)
409            },
410            tool_calls,
411            tool_call_id: None,
412        }
413    };
414
415    Ok(ChatResponse {
416        message,
417        stop_reason,
418        usage: response.usage_metadata.map(|u| Usage {
419            input_tokens: u.prompt_token_count,
420            output_tokens: u.candidates_token_count,
421            cached_tokens: u.cached_content_token_count,
422        }),
423    })
424}
425
426// --- Gemini API types (request) ---
427
428#[derive(Serialize)]
429#[serde(rename_all = "camelCase")]
430struct GeminiRequest {
431    #[serde(skip_serializing_if = "Option::is_none")]
432    cached_content: Option<String>,
433    #[serde(skip_serializing_if = "Option::is_none")]
434    system_instruction: Option<GeminiContent>,
435    contents: Vec<GeminiContent>,
436    #[serde(skip_serializing_if = "Option::is_none")]
437    tools: Option<Vec<GeminiToolConfig>>,
438    #[serde(skip_serializing_if = "Option::is_none")]
439    generation_config: Option<GeminiGenerationConfig>,
440}
441
442#[derive(Serialize)]
443struct GeminiContent {
444    role: String,
445    parts: Vec<GeminiPart>,
446}
447
448#[derive(Serialize)]
449#[serde(untagged)]
450enum GeminiPart {
451    Text {
452        text: String,
453    },
454    FunctionCall {
455        #[serde(rename = "functionCall")]
456        function_call: GeminiFunctionCall,
457    },
458    FunctionResponse {
459        #[serde(rename = "functionResponse")]
460        function_response: GeminiFunctionResponse,
461    },
462}
463
464#[derive(Serialize)]
465struct GeminiFunctionCall {
466    name: String,
467    args: serde_json::Value,
468}
469
470#[derive(Serialize)]
471struct GeminiFunctionResponse {
472    name: String,
473    response: serde_json::Value,
474}
475
476#[derive(Serialize)]
477#[serde(rename_all = "camelCase")]
478struct GeminiToolConfig {
479    function_declarations: Vec<GeminiFunctionDeclaration>,
480}
481
482#[derive(Serialize)]
483struct GeminiFunctionDeclaration {
484    name: String,
485    description: String,
486    parameters: serde_json::Value,
487}
488
489impl From<&ToolSpec> for GeminiFunctionDeclaration {
490    fn from(spec: &ToolSpec) -> Self {
491        Self {
492            name: spec.name.clone(),
493            description: spec.description.clone(),
494            parameters: spec.parameters.clone(),
495        }
496    }
497}
498
499#[derive(Serialize)]
500#[serde(rename_all = "camelCase")]
501struct GeminiGenerationConfig {
502    #[serde(skip_serializing_if = "Option::is_none")]
503    temperature: Option<f32>,
504    #[serde(skip_serializing_if = "Option::is_none")]
505    max_output_tokens: Option<u32>,
506}
507
508// --- Gemini API types (response) ---
509
510#[derive(Deserialize)]
511#[serde(rename_all = "camelCase")]
512struct GeminiResponse {
513    #[serde(default)]
514    candidates: Vec<GeminiCandidate>,
515    usage_metadata: Option<GeminiUsageMetadata>,
516}
517
518#[derive(Deserialize)]
519#[serde(rename_all = "camelCase")]
520struct GeminiCandidate {
521    content: GeminiResponseContent,
522    finish_reason: Option<String>,
523}
524
525#[derive(Deserialize)]
526struct GeminiResponseContent {
527    #[serde(default)]
528    parts: Vec<GeminiResponsePart>,
529}
530
531#[derive(Deserialize)]
532#[serde(untagged)]
533enum GeminiResponsePart {
534    FunctionCall {
535        #[serde(rename = "functionCall")]
536        function_call: GeminiResponseFunctionCall,
537    },
538    Text {
539        text: String,
540    },
541}
542
543#[derive(Deserialize)]
544struct GeminiResponseFunctionCall {
545    name: String,
546    args: serde_json::Value,
547}
548
549#[derive(Deserialize)]
550#[serde(rename_all = "camelCase")]
551struct GeminiUsageMetadata {
552    #[serde(default)]
553    prompt_token_count: u32,
554    #[serde(default)]
555    candidates_token_count: u32,
556    #[serde(default)]
557    cached_content_token_count: u32,
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    #[test]
565    fn test_gemini_new_default() {
566        let model = Gemini::new("gemini-2.0-flash");
567        assert_eq!(model.model_id, "gemini-2.0-flash");
568        assert_eq!(model.base_url, DEFAULT_BASE_URL);
569        assert_eq!(model.max_retries, DEFAULT_MAX_RETRIES);
570        assert!(!model.use_bearer_token);
571    }
572
573    #[test]
574    fn test_with_base_url() {
575        let model = Gemini::new("gemini-pro").with_base_url("https://vertex.example.com");
576        assert_eq!(model.base_url, "https://vertex.example.com");
577    }
578
579    #[test]
580    fn test_with_timeout() {
581        let model = Gemini::new("gemini-pro").with_timeout(Duration::from_secs(30));
582        assert_eq!(model.timeout, Some(Duration::from_secs(30)));
583    }
584
585    #[test]
586    fn test_with_max_retries() {
587        let model = Gemini::new("gemini-pro").with_max_retries(5);
588        assert_eq!(model.max_retries, 5);
589    }
590
591    #[test]
592    fn test_with_bearer_token() {
593        let model = Gemini::new("gemini-pro").with_bearer_token();
594        assert!(model.use_bearer_token);
595    }
596
597    #[test]
598    fn test_endpoint_url() {
599        let model = Gemini::new("gemini-2.0-flash");
600        assert_eq!(
601            model.endpoint_url("generateContent"),
602            "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
603        );
604    }
605
606    #[test]
607    fn test_tool_spec_conversion() {
608        let spec = ToolSpec {
609            name: "search".into(),
610            description: "Web search".into(),
611            parameters: serde_json::json!({"type": "object"}),
612        };
613        let decl: GeminiFunctionDeclaration = (&spec).into();
614        assert_eq!(decl.name, "search");
615        assert_eq!(decl.description, "Web search");
616    }
617
618    #[test]
619    fn test_parse_response_text() {
620        let raw = GeminiResponse {
621            candidates: vec![GeminiCandidate {
622                content: GeminiResponseContent {
623                    parts: vec![GeminiResponsePart::Text {
624                        text: "Hello world".into(),
625                    }],
626                },
627                finish_reason: Some("STOP".into()),
628            }],
629            usage_metadata: Some(GeminiUsageMetadata {
630                prompt_token_count: 10,
631                candidates_token_count: 5,
632                cached_content_token_count: 0,
633            }),
634        };
635        let resp = parse_response(raw).unwrap();
636        assert_eq!(resp.text(), "Hello world");
637        assert_eq!(resp.stop_reason, StopReason::EndTurn);
638        assert!(!resp.has_tool_calls());
639        assert_eq!(resp.usage.unwrap().input_tokens, 10);
640    }
641
642    #[test]
643    fn test_parse_response_function_call() {
644        let raw = GeminiResponse {
645            candidates: vec![GeminiCandidate {
646                content: GeminiResponseContent {
647                    parts: vec![GeminiResponsePart::FunctionCall {
648                        function_call: GeminiResponseFunctionCall {
649                            name: "calculator".into(),
650                            args: serde_json::json!({"expr": "2+2"}),
651                        },
652                    }],
653                },
654                finish_reason: Some("STOP".into()),
655            }],
656            usage_metadata: None,
657        };
658        let resp = parse_response(raw).unwrap();
659        assert!(resp.has_tool_calls());
660        assert_eq!(resp.tool_calls()[0].name, "calculator");
661        assert_eq!(resp.stop_reason, StopReason::ToolUse);
662    }
663
664    #[test]
665    fn test_parse_response_no_candidates() {
666        let raw = GeminiResponse {
667            candidates: vec![],
668            usage_metadata: None,
669        };
670        assert!(parse_response(raw).is_err());
671    }
672
673    #[test]
674    fn test_build_request_with_system_prompt() {
675        let model = Gemini::with_api_key("gemini-pro", "key");
676        let request = ChatRequest {
677            messages: vec![Message::system("Be helpful"), Message::user("Hello")],
678            tools: vec![],
679            temperature: Some(0.7),
680            max_tokens: Some(1024),
681        };
682        let body = model.build_request_body(&request);
683        assert!(body.system_instruction.is_some());
684        assert_eq!(body.contents.len(), 1);
685        assert_eq!(
686            body.generation_config.as_ref().unwrap().temperature,
687            Some(0.7)
688        );
689    }
690
691    #[test]
692    fn test_build_request_with_tools() {
693        let model = Gemini::with_api_key("gemini-pro", "key");
694        let request = ChatRequest {
695            messages: vec![Message::user("hi")],
696            tools: vec![ToolSpec {
697                name: "calc".into(),
698                description: "Calculator".into(),
699                parameters: serde_json::json!({"type": "object"}),
700            }],
701            temperature: None,
702            max_tokens: None,
703        };
704        let body = model.build_request_body(&request);
705        assert!(body.tools.is_some());
706        assert_eq!(body.tools.unwrap()[0].function_declarations.len(), 1);
707    }
708
709    #[test]
710    fn test_build_request_with_tool_results() {
711        let model = Gemini::with_api_key("gemini-pro", "key");
712        let request = ChatRequest {
713            messages: vec![
714                Message::user("calc 2+2"),
715                Message::assistant_with_tool_calls(vec![ToolCall {
716                    id: "gemini_calc".into(),
717                    name: "calc".into(),
718                    arguments: serde_json::json!({"expr": "2+2"}),
719                }]),
720                Message::tool_result("calc", "4"),
721            ],
722            tools: vec![],
723            temperature: None,
724            max_tokens: None,
725        };
726        let body = model.build_request_body(&request);
727        assert_eq!(body.contents.len(), 3);
728    }
729
730    #[test]
731    fn test_builder_chain() {
732        let model = Gemini::with_api_key("gemini-2.0-flash", "key")
733            .with_base_url("https://custom.example.com")
734            .with_timeout(Duration::from_secs(60))
735            .with_max_retries(5)
736            .with_bearer_token();
737
738        assert_eq!(model.model_id, "gemini-2.0-flash");
739        assert_eq!(model.base_url, "https://custom.example.com");
740        assert_eq!(model.timeout, Some(Duration::from_secs(60)));
741        assert_eq!(model.max_retries, 5);
742        assert!(model.use_bearer_token);
743    }
744}