Skip to main content

synaptic_gemini/
chat_model.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{
6    AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError,
7    TokenUsage, ToolCall, ToolChoice, ToolDefinition,
8};
9use synaptic_models::{ProviderBackend, ProviderRequest, ProviderResponse};
10
11#[derive(Debug, Clone)]
12pub struct GeminiConfig {
13    pub api_key: String,
14    pub model: String,
15    pub base_url: String,
16    pub top_p: Option<f64>,
17    pub stop: Option<Vec<String>>,
18}
19
20impl GeminiConfig {
21    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
22        Self {
23            api_key: api_key.into(),
24            model: model.into(),
25            base_url: "https://generativelanguage.googleapis.com".to_string(),
26            top_p: None,
27            stop: None,
28        }
29    }
30
31    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
32        self.base_url = url.into();
33        self
34    }
35
36    pub fn with_top_p(mut self, top_p: f64) -> Self {
37        self.top_p = Some(top_p);
38        self
39    }
40
41    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
42        self.stop = Some(stop);
43        self
44    }
45}
46
47pub struct GeminiChatModel {
48    config: GeminiConfig,
49    backend: Arc<dyn ProviderBackend>,
50}
51
52impl GeminiChatModel {
53    pub fn new(config: GeminiConfig, backend: Arc<dyn ProviderBackend>) -> Self {
54        Self { config, backend }
55    }
56
57    fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
58        let mut system_text: Option<String> = None;
59        let mut contents: Vec<Value> = Vec::new();
60
61        for msg in &request.messages {
62            match msg {
63                Message::System { content, .. } => {
64                    system_text = Some(content.clone());
65                }
66                Message::Human { content, .. } => {
67                    contents.push(json!({
68                        "role": "user",
69                        "parts": [{"text": content}],
70                    }));
71                }
72                Message::AI {
73                    content,
74                    tool_calls,
75                    ..
76                } => {
77                    let mut parts: Vec<Value> = Vec::new();
78                    if !content.is_empty() {
79                        parts.push(json!({"text": content}));
80                    }
81                    for tc in tool_calls {
82                        parts.push(json!({
83                            "functionCall": {
84                                "name": tc.name,
85                                "args": tc.arguments,
86                            }
87                        }));
88                    }
89                    contents.push(json!({
90                        "role": "model",
91                        "parts": parts,
92                    }));
93                }
94                Message::Tool {
95                    content,
96                    tool_call_id: _,
97                    ..
98                } => {
99                    let result: Value =
100                        serde_json::from_str(content).unwrap_or(json!({"result": content}));
101                    contents.push(json!({
102                        "role": "user",
103                        "parts": [{
104                            "functionResponse": {
105                                "name": "tool",
106                                "response": result,
107                            }
108                        }],
109                    }));
110                }
111                Message::Chat { content, .. } => {
112                    contents.push(json!({
113                        "role": "user",
114                        "parts": [{"text": content}],
115                    }));
116                }
117                Message::Remove { .. } => {}
118            }
119        }
120
121        let mut body = json!({
122            "contents": contents,
123        });
124
125        if let Some(system) = system_text {
126            body["system_instruction"] = json!({
127                "parts": [{"text": system}],
128            });
129        }
130
131        {
132            let mut gen_config = json!({});
133            let mut has_gen_config = false;
134            if let Some(top_p) = self.config.top_p {
135                gen_config["topP"] = json!(top_p);
136                has_gen_config = true;
137            }
138            if let Some(ref stop) = self.config.stop {
139                gen_config["stopSequences"] = json!(stop);
140                has_gen_config = true;
141            }
142            if has_gen_config {
143                body["generationConfig"] = gen_config;
144            }
145        }
146
147        if !request.tools.is_empty() {
148            body["tools"] = json!([{
149                "functionDeclarations": request.tools.iter().map(tool_def_to_gemini).collect::<Vec<_>>(),
150            }]);
151        }
152        if let Some(ref choice) = request.tool_choice {
153            body["tool_config"] = match choice {
154                ToolChoice::Auto => json!({"functionCallingConfig": {"mode": "AUTO"}}),
155                ToolChoice::Required => json!({"functionCallingConfig": {"mode": "ANY"}}),
156                ToolChoice::None => json!({"functionCallingConfig": {"mode": "NONE"}}),
157                ToolChoice::Specific(name) => json!({
158                    "functionCallingConfig": {
159                        "mode": "ANY",
160                        "allowedFunctionNames": [name]
161                    }
162                }),
163            };
164        }
165
166        let method = if stream {
167            "streamGenerateContent"
168        } else {
169            "generateContent"
170        };
171
172        let mut url = format!(
173            "{}/v1beta/models/{}:{}?key={}",
174            self.config.base_url, self.config.model, method, self.config.api_key
175        );
176        if stream {
177            url.push_str("&alt=sse");
178        }
179
180        ProviderRequest {
181            url,
182            headers: vec![("Content-Type".to_string(), "application/json".to_string())],
183            body,
184        }
185    }
186}
187
188fn tool_def_to_gemini(def: &ToolDefinition) -> Value {
189    json!({
190        "name": def.name,
191        "description": def.description,
192        "parameters": def.parameters,
193    })
194}
195
196fn parse_response(resp: &ProviderResponse) -> Result<ChatResponse, SynapticError> {
197    check_error_status(resp)?;
198
199    let parts = resp.body["candidates"][0]["content"]["parts"]
200        .as_array()
201        .cloned()
202        .unwrap_or_default();
203
204    let mut text = String::new();
205    let mut tool_calls = Vec::new();
206
207    for part in &parts {
208        if let Some(t) = part["text"].as_str() {
209            text.push_str(t);
210        }
211        if let Some(fc) = part.get("functionCall") {
212            if let Some(name) = fc["name"].as_str() {
213                tool_calls.push(ToolCall {
214                    id: format!("gemini-{}", tool_calls.len()),
215                    name: name.to_string(),
216                    arguments: fc["args"].clone(),
217                });
218            }
219        }
220    }
221
222    let usage = parse_usage(&resp.body["usageMetadata"]);
223
224    let message = if tool_calls.is_empty() {
225        Message::ai(text)
226    } else {
227        Message::ai_with_tool_calls(text, tool_calls)
228    };
229
230    Ok(ChatResponse { message, usage })
231}
232
233fn check_error_status(resp: &ProviderResponse) -> Result<(), SynapticError> {
234    if resp.status == 429 {
235        let msg = resp.body["error"]["message"]
236            .as_str()
237            .unwrap_or("rate limited")
238            .to_string();
239        return Err(SynapticError::RateLimit(msg));
240    }
241    if resp.status >= 400 {
242        let msg = resp.body["error"]["message"]
243            .as_str()
244            .unwrap_or("unknown API error")
245            .to_string();
246        return Err(SynapticError::Model(format!(
247            "Gemini API error ({}): {}",
248            resp.status, msg
249        )));
250    }
251    Ok(())
252}
253
254fn parse_usage(usage: &Value) -> Option<TokenUsage> {
255    if usage.is_null() {
256        return None;
257    }
258    Some(TokenUsage {
259        input_tokens: usage["promptTokenCount"].as_u64().unwrap_or(0) as u32,
260        output_tokens: usage["candidatesTokenCount"].as_u64().unwrap_or(0) as u32,
261        total_tokens: usage["totalTokenCount"].as_u64().unwrap_or(0) as u32,
262        input_details: None,
263        output_details: None,
264    })
265}
266
267fn parse_stream_chunk(data: &str) -> Option<AIMessageChunk> {
268    let v: Value = serde_json::from_str(data).ok()?;
269    let parts = v["candidates"][0]["content"]["parts"]
270        .as_array()
271        .cloned()
272        .unwrap_or_default();
273
274    let mut content = String::new();
275    let mut tool_calls = Vec::new();
276
277    for part in &parts {
278        if let Some(t) = part["text"].as_str() {
279            content.push_str(t);
280        }
281        if let Some(fc) = part.get("functionCall") {
282            if let Some(name) = fc["name"].as_str() {
283                tool_calls.push(ToolCall {
284                    id: format!("gemini-{}", tool_calls.len()),
285                    name: name.to_string(),
286                    arguments: fc["args"].clone(),
287                });
288            }
289        }
290    }
291
292    let usage = parse_usage(&v["usageMetadata"]);
293
294    Some(AIMessageChunk {
295        content,
296        tool_calls,
297        usage,
298        ..Default::default()
299    })
300}
301
302#[async_trait]
303impl ChatModel for GeminiChatModel {
304    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
305        let provider_req = self.build_request(&request, false);
306        let resp = self.backend.send(provider_req).await?;
307        parse_response(&resp)
308    }
309
310    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
311        Box::pin(async_stream::stream! {
312            let provider_req = self.build_request(&request, true);
313            let byte_stream = self.backend.send_stream(provider_req).await;
314
315            let byte_stream = match byte_stream {
316                Ok(s) => s,
317                Err(e) => {
318                    yield Err(e);
319                    return;
320                }
321            };
322
323            use eventsource_stream::Eventsource;
324            use futures::StreamExt;
325
326            let mut event_stream = byte_stream
327                .map(|result| result.map_err(|e| std::io::Error::other(e.to_string())))
328                .eventsource();
329
330            while let Some(event) = event_stream.next().await {
331                match event {
332                    Ok(ev) => {
333                        if let Some(chunk) = parse_stream_chunk(&ev.data) {
334                            yield Ok(chunk);
335                        }
336                    }
337                    Err(e) => {
338                        yield Err(SynapticError::Model(format!("SSE parse error: {e}")));
339                        break;
340                    }
341                }
342            }
343        })
344    }
345}