Skip to main content

synaptic_models/
gemini.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{
6    AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapseError,
7    TokenUsage, ToolCall, ToolChoice, ToolDefinition,
8};
9
10use crate::backend::{ProviderBackend, ProviderRequest, ProviderResponse};
11
12#[derive(Debug, Clone)]
13pub struct GeminiConfig {
14    pub api_key: String,
15    pub model: String,
16    pub base_url: String,
17    pub top_p: Option<f64>,
18    pub stop: Option<Vec<String>>,
19}
20
21impl GeminiConfig {
22    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
23        Self {
24            api_key: api_key.into(),
25            model: model.into(),
26            base_url: "https://generativelanguage.googleapis.com".to_string(),
27            top_p: None,
28            stop: None,
29        }
30    }
31
32    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
33        self.base_url = url.into();
34        self
35    }
36
37    pub fn with_top_p(mut self, top_p: f64) -> Self {
38        self.top_p = Some(top_p);
39        self
40    }
41
42    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
43        self.stop = Some(stop);
44        self
45    }
46}
47
48pub struct GeminiChatModel {
49    config: GeminiConfig,
50    backend: Arc<dyn ProviderBackend>,
51}
52
53impl GeminiChatModel {
54    pub fn new(config: GeminiConfig, backend: Arc<dyn ProviderBackend>) -> Self {
55        Self { config, backend }
56    }
57
58    fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
59        let mut system_text: Option<String> = None;
60        let mut contents: Vec<Value> = Vec::new();
61
62        for msg in &request.messages {
63            match msg {
64                Message::System { content, .. } => {
65                    system_text = Some(content.clone());
66                }
67                Message::Human { content, .. } => {
68                    contents.push(json!({
69                        "role": "user",
70                        "parts": [{"text": content}],
71                    }));
72                }
73                Message::AI {
74                    content,
75                    tool_calls,
76                    ..
77                } => {
78                    let mut parts: Vec<Value> = Vec::new();
79                    if !content.is_empty() {
80                        parts.push(json!({"text": content}));
81                    }
82                    for tc in tool_calls {
83                        parts.push(json!({
84                            "functionCall": {
85                                "name": tc.name,
86                                "args": tc.arguments,
87                            }
88                        }));
89                    }
90                    contents.push(json!({
91                        "role": "model",
92                        "parts": parts,
93                    }));
94                }
95                Message::Tool {
96                    content,
97                    tool_call_id: _,
98                    ..
99                } => {
100                    // Gemini uses functionResponse in parts
101                    let result: Value =
102                        serde_json::from_str(content).unwrap_or(json!({"result": content}));
103                    contents.push(json!({
104                        "role": "user",
105                        "parts": [{
106                            "functionResponse": {
107                                "name": "tool",
108                                "response": result,
109                            }
110                        }],
111                    }));
112                }
113                Message::Chat { content, .. } => {
114                    contents.push(json!({
115                        "role": "user",
116                        "parts": [{"text": content}],
117                    }));
118                }
119                Message::Remove { .. } => { /* skip Remove messages */ }
120            }
121        }
122
123        let mut body = json!({
124            "contents": contents,
125        });
126
127        if let Some(system) = system_text {
128            body["system_instruction"] = json!({
129                "parts": [{"text": system}],
130            });
131        }
132
133        {
134            let mut gen_config = json!({});
135            let mut has_gen_config = false;
136            if let Some(top_p) = self.config.top_p {
137                gen_config["topP"] = json!(top_p);
138                has_gen_config = true;
139            }
140            if let Some(ref stop) = self.config.stop {
141                gen_config["stopSequences"] = json!(stop);
142                has_gen_config = true;
143            }
144            if has_gen_config {
145                body["generationConfig"] = gen_config;
146            }
147        }
148
149        if !request.tools.is_empty() {
150            body["tools"] = json!([{
151                "functionDeclarations": request.tools.iter().map(tool_def_to_gemini).collect::<Vec<_>>(),
152            }]);
153        }
154        if let Some(ref choice) = request.tool_choice {
155            body["tool_config"] = match choice {
156                ToolChoice::Auto => json!({"functionCallingConfig": {"mode": "AUTO"}}),
157                ToolChoice::Required => json!({"functionCallingConfig": {"mode": "ANY"}}),
158                ToolChoice::None => json!({"functionCallingConfig": {"mode": "NONE"}}),
159                ToolChoice::Specific(name) => json!({
160                    "functionCallingConfig": {
161                        "mode": "ANY",
162                        "allowedFunctionNames": [name]
163                    }
164                }),
165            };
166        }
167
168        let method = if stream {
169            "streamGenerateContent"
170        } else {
171            "generateContent"
172        };
173
174        let mut url = format!(
175            "{}/v1beta/models/{}:{}?key={}",
176            self.config.base_url, self.config.model, method, self.config.api_key
177        );
178        if stream {
179            url.push_str("&alt=sse");
180        }
181
182        ProviderRequest {
183            url,
184            headers: vec![("Content-Type".to_string(), "application/json".to_string())],
185            body,
186        }
187    }
188}
189
190fn tool_def_to_gemini(def: &ToolDefinition) -> Value {
191    json!({
192        "name": def.name,
193        "description": def.description,
194        "parameters": def.parameters,
195    })
196}
197
198fn parse_response(resp: &ProviderResponse) -> Result<ChatResponse, SynapseError> {
199    check_error_status(resp)?;
200
201    let parts = resp.body["candidates"][0]["content"]["parts"]
202        .as_array()
203        .cloned()
204        .unwrap_or_default();
205
206    let mut text = String::new();
207    let mut tool_calls = Vec::new();
208
209    for part in &parts {
210        if let Some(t) = part["text"].as_str() {
211            text.push_str(t);
212        }
213        if let Some(fc) = part.get("functionCall") {
214            if let Some(name) = fc["name"].as_str() {
215                tool_calls.push(ToolCall {
216                    id: format!("gemini-{}", tool_calls.len()),
217                    name: name.to_string(),
218                    arguments: fc["args"].clone(),
219                });
220            }
221        }
222    }
223
224    let usage = parse_usage(&resp.body["usageMetadata"]);
225
226    let message = if tool_calls.is_empty() {
227        Message::ai(text)
228    } else {
229        Message::ai_with_tool_calls(text, tool_calls)
230    };
231
232    Ok(ChatResponse { message, usage })
233}
234
235fn check_error_status(resp: &ProviderResponse) -> Result<(), SynapseError> {
236    if resp.status == 429 {
237        let msg = resp.body["error"]["message"]
238            .as_str()
239            .unwrap_or("rate limited")
240            .to_string();
241        return Err(SynapseError::RateLimit(msg));
242    }
243    if resp.status >= 400 {
244        let msg = resp.body["error"]["message"]
245            .as_str()
246            .unwrap_or("unknown API error")
247            .to_string();
248        return Err(SynapseError::Model(format!(
249            "Gemini API error ({}): {}",
250            resp.status, msg
251        )));
252    }
253    Ok(())
254}
255
256fn parse_usage(usage: &Value) -> Option<TokenUsage> {
257    if usage.is_null() {
258        return None;
259    }
260    Some(TokenUsage {
261        input_tokens: usage["promptTokenCount"].as_u64().unwrap_or(0) as u32,
262        output_tokens: usage["candidatesTokenCount"].as_u64().unwrap_or(0) as u32,
263        total_tokens: usage["totalTokenCount"].as_u64().unwrap_or(0) as u32,
264        input_details: None,
265        output_details: None,
266    })
267}
268
269fn parse_stream_chunk(data: &str) -> Option<AIMessageChunk> {
270    let v: Value = serde_json::from_str(data).ok()?;
271    let parts = v["candidates"][0]["content"]["parts"]
272        .as_array()
273        .cloned()
274        .unwrap_or_default();
275
276    let mut content = String::new();
277    let mut tool_calls = Vec::new();
278
279    for part in &parts {
280        if let Some(t) = part["text"].as_str() {
281            content.push_str(t);
282        }
283        if let Some(fc) = part.get("functionCall") {
284            if let Some(name) = fc["name"].as_str() {
285                tool_calls.push(ToolCall {
286                    id: format!("gemini-{}", tool_calls.len()),
287                    name: name.to_string(),
288                    arguments: fc["args"].clone(),
289                });
290            }
291        }
292    }
293
294    let usage = parse_usage(&v["usageMetadata"]);
295
296    Some(AIMessageChunk {
297        content,
298        tool_calls,
299        usage,
300        ..Default::default()
301    })
302}
303
304#[async_trait]
305impl ChatModel for GeminiChatModel {
306    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapseError> {
307        let provider_req = self.build_request(&request, false);
308        let resp = self.backend.send(provider_req).await?;
309        parse_response(&resp)
310    }
311
312    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
313        Box::pin(async_stream::stream! {
314            let provider_req = self.build_request(&request, true);
315            let byte_stream = self.backend.send_stream(provider_req).await;
316
317            let byte_stream = match byte_stream {
318                Ok(s) => s,
319                Err(e) => {
320                    yield Err(e);
321                    return;
322                }
323            };
324
325            use eventsource_stream::Eventsource;
326            use futures::StreamExt;
327
328            let mut event_stream = byte_stream
329                .map(|result| result.map_err(|e| std::io::Error::other(e.to_string())))
330                .eventsource();
331
332            while let Some(event) = event_stream.next().await {
333                match event {
334                    Ok(ev) => {
335                        if let Some(chunk) = parse_stream_chunk(&ev.data) {
336                            yield Ok(chunk);
337                        }
338                    }
339                    Err(e) => {
340                        yield Err(SynapseError::Model(format!("SSE parse error: {e}")));
341                        break;
342                    }
343                }
344            }
345        })
346    }
347}