1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{
6 AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError,
7 TokenUsage, ToolCall, ToolChoice, ToolDefinition,
8};
9use synaptic_models::{ProviderBackend, ProviderRequest, ProviderResponse};
10
11#[derive(Debug, Clone)]
12pub struct GeminiConfig {
13 pub api_key: String,
14 pub model: String,
15 pub base_url: String,
16 pub top_p: Option<f64>,
17 pub stop: Option<Vec<String>>,
18}
19
20impl GeminiConfig {
21 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
22 Self {
23 api_key: api_key.into(),
24 model: model.into(),
25 base_url: "https://generativelanguage.googleapis.com".to_string(),
26 top_p: None,
27 stop: None,
28 }
29 }
30
31 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
32 self.base_url = url.into();
33 self
34 }
35
36 pub fn with_top_p(mut self, top_p: f64) -> Self {
37 self.top_p = Some(top_p);
38 self
39 }
40
41 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
42 self.stop = Some(stop);
43 self
44 }
45}
46
47pub struct GeminiChatModel {
48 config: GeminiConfig,
49 backend: Arc<dyn ProviderBackend>,
50}
51
52impl GeminiChatModel {
53 pub fn new(config: GeminiConfig, backend: Arc<dyn ProviderBackend>) -> Self {
54 Self { config, backend }
55 }
56
57 fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
58 let mut system_text: Option<String> = None;
59 let mut contents: Vec<Value> = Vec::new();
60
61 for msg in &request.messages {
62 match msg {
63 Message::System { content, .. } => {
64 system_text = Some(content.clone());
65 }
66 Message::Human { content, .. } => {
67 contents.push(json!({
68 "role": "user",
69 "parts": [{"text": content}],
70 }));
71 }
72 Message::AI {
73 content,
74 tool_calls,
75 ..
76 } => {
77 let mut parts: Vec<Value> = Vec::new();
78 if !content.is_empty() {
79 parts.push(json!({"text": content}));
80 }
81 for tc in tool_calls {
82 parts.push(json!({
83 "functionCall": {
84 "name": tc.name,
85 "args": tc.arguments,
86 }
87 }));
88 }
89 contents.push(json!({
90 "role": "model",
91 "parts": parts,
92 }));
93 }
94 Message::Tool {
95 content,
96 tool_call_id: _,
97 ..
98 } => {
99 let result: Value =
100 serde_json::from_str(content).unwrap_or(json!({"result": content}));
101 contents.push(json!({
102 "role": "user",
103 "parts": [{
104 "functionResponse": {
105 "name": "tool",
106 "response": result,
107 }
108 }],
109 }));
110 }
111 Message::Chat { content, .. } => {
112 contents.push(json!({
113 "role": "user",
114 "parts": [{"text": content}],
115 }));
116 }
117 Message::Remove { .. } => {}
118 }
119 }
120
121 let mut body = json!({
122 "contents": contents,
123 });
124
125 if let Some(system) = system_text {
126 body["system_instruction"] = json!({
127 "parts": [{"text": system}],
128 });
129 }
130
131 {
132 let mut gen_config = json!({});
133 let mut has_gen_config = false;
134 if let Some(top_p) = self.config.top_p {
135 gen_config["topP"] = json!(top_p);
136 has_gen_config = true;
137 }
138 if let Some(ref stop) = self.config.stop {
139 gen_config["stopSequences"] = json!(stop);
140 has_gen_config = true;
141 }
142 if has_gen_config {
143 body["generationConfig"] = gen_config;
144 }
145 }
146
147 if !request.tools.is_empty() {
148 body["tools"] = json!([{
149 "functionDeclarations": request.tools.iter().map(tool_def_to_gemini).collect::<Vec<_>>(),
150 }]);
151 }
152 if let Some(ref choice) = request.tool_choice {
153 body["tool_config"] = match choice {
154 ToolChoice::Auto => json!({"functionCallingConfig": {"mode": "AUTO"}}),
155 ToolChoice::Required => json!({"functionCallingConfig": {"mode": "ANY"}}),
156 ToolChoice::None => json!({"functionCallingConfig": {"mode": "NONE"}}),
157 ToolChoice::Specific(name) => json!({
158 "functionCallingConfig": {
159 "mode": "ANY",
160 "allowedFunctionNames": [name]
161 }
162 }),
163 };
164 }
165
166 let method = if stream {
167 "streamGenerateContent"
168 } else {
169 "generateContent"
170 };
171
172 let mut url = format!(
173 "{}/v1beta/models/{}:{}?key={}",
174 self.config.base_url, self.config.model, method, self.config.api_key
175 );
176 if stream {
177 url.push_str("&alt=sse");
178 }
179
180 ProviderRequest {
181 url,
182 headers: vec![("Content-Type".to_string(), "application/json".to_string())],
183 body,
184 }
185 }
186}
187
188fn tool_def_to_gemini(def: &ToolDefinition) -> Value {
189 json!({
190 "name": def.name,
191 "description": def.description,
192 "parameters": def.parameters,
193 })
194}
195
196fn parse_response(resp: &ProviderResponse) -> Result<ChatResponse, SynapticError> {
197 check_error_status(resp)?;
198
199 let parts = resp.body["candidates"][0]["content"]["parts"]
200 .as_array()
201 .cloned()
202 .unwrap_or_default();
203
204 let mut text = String::new();
205 let mut tool_calls = Vec::new();
206
207 for part in &parts {
208 if let Some(t) = part["text"].as_str() {
209 text.push_str(t);
210 }
211 if let Some(fc) = part.get("functionCall") {
212 if let Some(name) = fc["name"].as_str() {
213 tool_calls.push(ToolCall {
214 id: format!("gemini-{}", tool_calls.len()),
215 name: name.to_string(),
216 arguments: fc["args"].clone(),
217 });
218 }
219 }
220 }
221
222 let usage = parse_usage(&resp.body["usageMetadata"]);
223
224 let message = if tool_calls.is_empty() {
225 Message::ai(text)
226 } else {
227 Message::ai_with_tool_calls(text, tool_calls)
228 };
229
230 Ok(ChatResponse { message, usage })
231}
232
233fn check_error_status(resp: &ProviderResponse) -> Result<(), SynapticError> {
234 if resp.status == 429 {
235 let msg = resp.body["error"]["message"]
236 .as_str()
237 .unwrap_or("rate limited")
238 .to_string();
239 return Err(SynapticError::RateLimit(msg));
240 }
241 if resp.status >= 400 {
242 let msg = resp.body["error"]["message"]
243 .as_str()
244 .unwrap_or("unknown API error")
245 .to_string();
246 return Err(SynapticError::Model(format!(
247 "Gemini API error ({}): {}",
248 resp.status, msg
249 )));
250 }
251 Ok(())
252}
253
254fn parse_usage(usage: &Value) -> Option<TokenUsage> {
255 if usage.is_null() {
256 return None;
257 }
258 Some(TokenUsage {
259 input_tokens: usage["promptTokenCount"].as_u64().unwrap_or(0) as u32,
260 output_tokens: usage["candidatesTokenCount"].as_u64().unwrap_or(0) as u32,
261 total_tokens: usage["totalTokenCount"].as_u64().unwrap_or(0) as u32,
262 input_details: None,
263 output_details: None,
264 })
265}
266
267fn parse_stream_chunk(data: &str) -> Option<AIMessageChunk> {
268 let v: Value = serde_json::from_str(data).ok()?;
269 let parts = v["candidates"][0]["content"]["parts"]
270 .as_array()
271 .cloned()
272 .unwrap_or_default();
273
274 let mut content = String::new();
275 let mut tool_calls = Vec::new();
276
277 for part in &parts {
278 if let Some(t) = part["text"].as_str() {
279 content.push_str(t);
280 }
281 if let Some(fc) = part.get("functionCall") {
282 if let Some(name) = fc["name"].as_str() {
283 tool_calls.push(ToolCall {
284 id: format!("gemini-{}", tool_calls.len()),
285 name: name.to_string(),
286 arguments: fc["args"].clone(),
287 });
288 }
289 }
290 }
291
292 let usage = parse_usage(&v["usageMetadata"]);
293
294 Some(AIMessageChunk {
295 content,
296 tool_calls,
297 usage,
298 ..Default::default()
299 })
300}
301
302#[async_trait]
303impl ChatModel for GeminiChatModel {
304 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
305 let provider_req = self.build_request(&request, false);
306 let resp = self.backend.send(provider_req).await?;
307 parse_response(&resp)
308 }
309
310 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
311 Box::pin(async_stream::stream! {
312 let provider_req = self.build_request(&request, true);
313 let byte_stream = self.backend.send_stream(provider_req).await;
314
315 let byte_stream = match byte_stream {
316 Ok(s) => s,
317 Err(e) => {
318 yield Err(e);
319 return;
320 }
321 };
322
323 use eventsource_stream::Eventsource;
324 use futures::StreamExt;
325
326 let mut event_stream = byte_stream
327 .map(|result| result.map_err(|e| std::io::Error::other(e.to_string())))
328 .eventsource();
329
330 while let Some(event) = event_stream.next().await {
331 match event {
332 Ok(ev) => {
333 if let Some(chunk) = parse_stream_chunk(&ev.data) {
334 yield Ok(chunk);
335 }
336 }
337 Err(e) => {
338 yield Err(SynapticError::Model(format!("SSE parse error: {e}")));
339 break;
340 }
341 }
342 }
343 })
344 }
345}