Skip to main content

mermaid_cli/providers/model/
gemini.rs

1//! Gemini provider — wraps `models::adapters::gemini::GeminiAdapter`.
2//!
3//! Google's Gemini family uses a different wire format from OpenAI-
4//! compat (`:streamGenerateContent?alt=sse` + protobuf-ish JSON
5//! shape). The adapter handles all of that; this wrapper just
6//! forwards.
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::domain::ChatRequest;
13use crate::models::adapters::gemini::GeminiAdapter;
14use crate::models::{
15    Model, ModelConfig, ModelError, ReasoningChunk, Result, StreamCallback,
16    StreamEvent as ModelStreamEvent,
17};
18
19use super::super::capabilities::Capabilities;
20use super::super::ctx::{FinalResponse, StreamContext, StreamEvent};
21use super::ModelProvider;
22
23pub struct GeminiProvider {
24    adapter: GeminiAdapter,
25    capabilities: Capabilities,
26}
27
28impl GeminiProvider {
29    pub fn new(api_key: String, model_name: String, base_url: String) -> Result<Self> {
30        let adapter = GeminiAdapter::new(api_key, model_name, base_url)?;
31        let capabilities = Capabilities::from_legacy(adapter.capabilities());
32        Ok(Self {
33            adapter,
34            capabilities,
35        })
36    }
37}
38
39#[async_trait]
40impl ModelProvider for GeminiProvider {
41    fn capabilities(&self) -> &Capabilities {
42        &self.capabilities
43    }
44
45    async fn chat(&self, request: ChatRequest, ctx: StreamContext) -> Result<FinalResponse> {
46        let config = build_model_config(&request);
47        let relay_tx = super::stream_bridge::ordered_relay(ctx.sink.clone());
48        let callback = forward_callback(relay_tx);
49        let chat_fut = self
50            .adapter
51            .chat(&request.messages, &config, Some(callback));
52
53        let response = tokio::select! {
54            biased;
55            _ = ctx.token.cancelled() => {
56                return Err(ModelError::Cancelled);
57            },
58            r = chat_fut => r?,
59        };
60
61        let usage = response.usage.clone();
62        let _ = ctx
63            .sink
64            .send(StreamEvent::Done {
65                usage: usage.clone(),
66                thinking_signature: None,
67            })
68            .await;
69
70        Ok(FinalResponse {
71            usage,
72            thinking_signature: None,
73            tool_calls: response.tool_calls.unwrap_or_default(),
74        })
75    }
76}
77
78fn build_model_config(request: &ChatRequest) -> ModelConfig {
79    ModelConfig {
80        model: request.model_id.clone(),
81        temperature: request.temperature,
82        max_tokens: request.max_tokens,
83        reasoning: request.reasoning,
84        system_prompt: Some(request.system_prompt.clone()),
85        dynamic_system_suffix: request.instructions.clone(),
86        tools: request.tools.iter().map(|t| t.to_openai_json()).collect(),
87        ..Default::default()
88    }
89}
90
91fn forward_callback(sink: tokio::sync::mpsc::UnboundedSender<StreamEvent>) -> StreamCallback {
92    Arc::new(move |event: ModelStreamEvent| {
93        let mapped = match event {
94            ModelStreamEvent::Text(s) => StreamEvent::Text(s),
95            ModelStreamEvent::Reasoning(chunk) => StreamEvent::Reasoning(ReasoningChunk {
96                text: chunk.text,
97                signature: chunk.signature,
98            }),
99            ModelStreamEvent::ToolCall(tc) => StreamEvent::ToolCall(tc),
100            ModelStreamEvent::Done { tokens } => StreamEvent::Done {
101                usage: if tokens > 0 {
102                    Some(crate::models::TokenUsage::provider(0, tokens, tokens))
103                } else {
104                    None
105                },
106                thinking_signature: None,
107            },
108        };
109        let _ = sink.send(mapped);
110    })
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn build_model_config_maps_fields() {
119        let req = ChatRequest {
120            model_id: "gemini/gemini-3.1-pro-preview".to_string(),
121            messages: vec![],
122            system_prompt: "sys".to_string(),
123            instructions: None,
124            reasoning: crate::models::ReasoningLevel::High,
125            temperature: 0.5,
126            max_tokens: 4096,
127            tools: vec![],
128        };
129        let cfg = build_model_config(&req);
130        assert_eq!(cfg.reasoning, crate::models::ReasoningLevel::High);
131        assert_eq!(cfg.temperature, 0.5);
132        assert!(cfg.dynamic_system_suffix.is_none());
133    }
134}