mermaid_cli/providers/model/
gemini.rs1use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::domain::ChatRequest;
13use crate::models::adapters::gemini::GeminiAdapter;
14use crate::models::{
15 Model, ModelConfig, ModelError, ReasoningChunk, Result, StreamCallback,
16 StreamEvent as ModelStreamEvent,
17};
18
19use super::super::capabilities::Capabilities;
20use super::super::ctx::{FinalResponse, StreamContext, StreamEvent};
21use super::ModelProvider;
22
23pub struct GeminiProvider {
24 adapter: GeminiAdapter,
25 capabilities: Capabilities,
26}
27
28impl GeminiProvider {
29 pub fn new(api_key: String, model_name: String, base_url: String) -> Result<Self> {
30 let adapter = GeminiAdapter::new(api_key, model_name, base_url)?;
31 let capabilities = Capabilities::from_legacy(adapter.capabilities());
32 Ok(Self {
33 adapter,
34 capabilities,
35 })
36 }
37}
38
39#[async_trait]
40impl ModelProvider for GeminiProvider {
41 fn capabilities(&self) -> &Capabilities {
42 &self.capabilities
43 }
44
45 async fn chat(&self, request: ChatRequest, ctx: StreamContext) -> Result<FinalResponse> {
46 let config = build_model_config(&request);
47 let relay_tx = super::stream_bridge::ordered_relay(ctx.sink.clone());
48 let callback = forward_callback(relay_tx);
49 let chat_fut = self
50 .adapter
51 .chat(&request.messages, &config, Some(callback));
52
53 let response = tokio::select! {
54 biased;
55 _ = ctx.token.cancelled() => {
56 return Err(ModelError::Cancelled);
57 },
58 r = chat_fut => r?,
59 };
60
61 let usage = response.usage.clone();
62 let _ = ctx
63 .sink
64 .send(StreamEvent::Done {
65 usage: usage.clone(),
66 thinking_signature: None,
67 })
68 .await;
69
70 Ok(FinalResponse {
71 usage,
72 thinking_signature: None,
73 tool_calls: response.tool_calls.unwrap_or_default(),
74 })
75 }
76}
77
78fn build_model_config(request: &ChatRequest) -> ModelConfig {
79 ModelConfig {
80 model: request.model_id.clone(),
81 temperature: request.temperature,
82 max_tokens: request.max_tokens,
83 reasoning: request.reasoning,
84 system_prompt: Some(request.system_prompt.clone()),
85 dynamic_system_suffix: request.instructions.clone(),
86 tools: request.tools.iter().map(|t| t.to_openai_json()).collect(),
87 ..Default::default()
88 }
89}
90
91fn forward_callback(sink: tokio::sync::mpsc::UnboundedSender<StreamEvent>) -> StreamCallback {
92 Arc::new(move |event: ModelStreamEvent| {
93 let mapped = match event {
94 ModelStreamEvent::Text(s) => StreamEvent::Text(s),
95 ModelStreamEvent::Reasoning(chunk) => StreamEvent::Reasoning(ReasoningChunk {
96 text: chunk.text,
97 signature: chunk.signature,
98 }),
99 ModelStreamEvent::ToolCall(tc) => StreamEvent::ToolCall(tc),
100 ModelStreamEvent::Done { tokens } => StreamEvent::Done {
101 usage: if tokens > 0 {
102 Some(crate::models::TokenUsage::provider(0, tokens, tokens))
103 } else {
104 None
105 },
106 thinking_signature: None,
107 },
108 };
109 let _ = sink.send(mapped);
110 })
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn build_model_config_maps_fields() {
119 let req = ChatRequest {
120 model_id: "gemini/gemini-3.1-pro-preview".to_string(),
121 messages: vec![],
122 system_prompt: "sys".to_string(),
123 instructions: None,
124 reasoning: crate::models::ReasoningLevel::High,
125 temperature: 0.5,
126 max_tokens: 4096,
127 tools: vec![],
128 };
129 let cfg = build_model_config(&req);
130 assert_eq!(cfg.reasoning, crate::models::ReasoningLevel::High);
131 assert_eq!(cfg.temperature, 0.5);
132 assert!(cfg.dynamic_system_suffix.is_none());
133 }
134}