1use async_trait::async_trait;
12use serde_json::Value;
13use std::collections::HashMap;
14
15use crate::error::Error;
16use crate::protocol::v2::capabilities::Capability;
17use crate::protocol::v2::manifest::ApiStyle;
18use crate::protocol::ProtocolError;
19use crate::types::events::StreamingEvent;
20use crate::types::message::{Message, MessageContent, MessageRole};
21
22use super::{DriverRequest, DriverResponse, ProviderDriver, UsageInfo};
23
24#[derive(Debug)]
26pub struct GeminiDriver {
27 provider_id: String,
28 capabilities: Vec<Capability>,
29}
30
31impl GeminiDriver {
32 pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
33 Self {
34 provider_id: provider_id.into(),
35 capabilities,
36 }
37 }
38
39 fn split_messages(messages: &[Message]) -> (Option<Value>, Vec<Value>) {
42 let mut system_parts: Vec<String> = Vec::new();
43 let mut contents: Vec<Value> = Vec::new();
44
45 for m in messages {
46 match m.role {
47 MessageRole::System => {
48 if let MessageContent::Text(ref s) = m.content {
49 system_parts.push(s.clone());
50 }
51 }
52 MessageRole::Tool => {
53 if let (Some(ref id), MessageContent::Text(ref s)) =
56 (&m.tool_call_id, &m.content)
57 {
58 contents.push(serde_json::json!({
59 "role": "user",
60 "parts": [{ "functionResponse": { "name": id, "response": { "result": s } } }],
61 }));
62 }
63 }
64 _ => {
65 let role = match m.role {
66 MessageRole::User => "user",
67 MessageRole::Assistant => "model",
68 MessageRole::System => unreachable!(),
69 MessageRole::Tool => unreachable!(),
70 };
71 let parts = Self::content_to_parts(&m.content);
72 contents.push(serde_json::json!({
73 "role": role,
74 "parts": parts,
75 }));
76 }
77 }
78 }
79
80 let system_instruction = if system_parts.is_empty() {
81 None
82 } else {
83 Some(serde_json::json!({
84 "parts": [{ "text": system_parts.join("\n\n") }]
85 }))
86 };
87
88 (system_instruction, contents)
89 }
90
91 fn content_to_parts(content: &MessageContent) -> Value {
93 match content {
94 MessageContent::Text(s) => {
95 serde_json::json!([{ "text": s }])
96 }
97 MessageContent::Blocks(_) => {
98 serde_json::to_value(content).unwrap_or(Value::Null)
101 }
102 }
103 }
104}
105
106#[async_trait]
107impl ProviderDriver for GeminiDriver {
108 fn provider_id(&self) -> &str {
109 &self.provider_id
110 }
111
112 fn api_style(&self) -> ApiStyle {
113 ApiStyle::GeminiGenerate
114 }
115
116 fn build_request(
117 &self,
118 messages: &[Message],
119 _model: &str,
120 temperature: Option<f64>,
121 max_tokens: Option<u32>,
122 _stream: bool,
123 extra: Option<&Value>,
124 ) -> Result<DriverRequest, Error> {
125 let (system_instruction, contents) = Self::split_messages(messages);
126
127 let mut body = serde_json::json!({
128 "contents": contents,
129 });
130
131 if let Some(sys) = system_instruction {
132 body["system_instruction"] = sys;
133 }
134
135 let mut gen_config = serde_json::json!({});
137 if let Some(t) = temperature {
138 gen_config["temperature"] = serde_json::json!(t);
139 }
140 if let Some(mt) = max_tokens {
141 gen_config["maxOutputTokens"] = serde_json::json!(mt);
142 }
143 if gen_config != serde_json::json!({}) {
144 body["generationConfig"] = gen_config;
145 }
146
147 if let Some(Value::Object(map)) = extra {
148 for (k, v) in map {
149 body[k] = v.clone();
150 }
151 }
152
153 Ok(DriverRequest {
154 url: String::new(), method: "POST".into(),
156 headers: HashMap::new(),
157 body,
158 stream: _stream,
159 })
160 }
161
162 fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
163 let content = body
165 .pointer("/candidates/0/content/parts/0/text")
166 .and_then(|v| v.as_str())
167 .map(String::from);
168
169 let finish_reason = body
170 .pointer("/candidates/0/finishReason")
171 .and_then(|v| v.as_str())
172 .map(|r| match r {
173 "STOP" => "stop".to_string(),
174 "MAX_TOKENS" => "length".to_string(),
175 "SAFETY" => "content_filter".to_string(),
176 "RECITATION" => "content_filter".to_string(),
177 other => other.to_lowercase(),
178 });
179
180 let usage = body.get("usageMetadata").map(|u| UsageInfo {
181 prompt_tokens: u["promptTokenCount"].as_u64().unwrap_or(0),
182 completion_tokens: u["candidatesTokenCount"].as_u64().unwrap_or(0),
183 total_tokens: u["totalTokenCount"].as_u64().unwrap_or(0),
184 reasoning_tokens: None,
185 cache_read_tokens: None,
186 cache_creation_tokens: None,
187 });
188
189 let tool_calls: Vec<Value> = body
191 .pointer("/candidates/0/content/parts")
192 .and_then(|p| p.as_array())
193 .map(|parts| {
194 parts
195 .iter()
196 .filter(|p| p.get("functionCall").is_some())
197 .cloned()
198 .collect()
199 })
200 .unwrap_or_default();
201
202 Ok(DriverResponse {
203 content,
204 finish_reason,
205 usage,
206 tool_calls,
207 raw: body.clone(),
208 })
209 }
210
211 fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
212 if data.trim().is_empty() {
213 return Ok(None);
214 }
215
216 let v: Value = serde_json::from_str(data).map_err(|e| {
218 Error::Protocol(ProtocolError::ValidationError(format!(
219 "Failed to parse Gemini stream: {}",
220 e
221 )))
222 })?;
223
224 if let Some(error) = v.get("error") {
226 return Ok(Some(StreamingEvent::StreamError {
227 error: error.clone(),
228 event_id: None,
229 }));
230 }
231
232 if let Some(text) = v
234 .pointer("/candidates/0/content/parts/0/text")
235 .and_then(|t| t.as_str())
236 {
237 if !text.is_empty() {
238 return Ok(Some(StreamingEvent::PartialContentDelta {
239 content: text.to_string(),
240 sequence_id: None,
241 }));
242 }
243 }
244
245 if let Some(reason) = v
247 .pointer("/candidates/0/finishReason")
248 .and_then(|r| r.as_str())
249 {
250 if reason != "STOP" || v.pointer("/candidates/0/content/parts/0/text").is_none() {
251 return Ok(Some(StreamingEvent::StreamEnd {
252 finish_reason: Some(match reason {
253 "STOP" => "stop".to_string(),
254 "MAX_TOKENS" => "length".to_string(),
255 other => other.to_lowercase(),
256 }),
257 }));
258 }
259 }
260
261 Ok(None)
262 }
263
264 fn supported_capabilities(&self) -> &[Capability] {
265 &self.capabilities
266 }
267
268 fn is_stream_done(&self, _data: &str) -> bool {
269 false
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_gemini_system_instruction() {
282 let msgs = vec![
283 Message::system("Be concise."),
284 Message::user("Explain Rust."),
285 ];
286 let (sys, contents) = GeminiDriver::split_messages(&msgs);
287 assert!(sys.is_some());
288 assert_eq!(
289 sys.unwrap()["parts"][0]["text"].as_str().unwrap(),
290 "Be concise."
291 );
292 assert_eq!(contents.len(), 1);
293 assert_eq!(contents[0]["role"], "user");
294 }
295
296 #[test]
297 fn test_gemini_role_mapping() {
298 let msgs = vec![
299 Message::user("Hi"),
300 Message::assistant("Hello!"),
301 Message::user("How are you?"),
302 ];
303 let (_, contents) = GeminiDriver::split_messages(&msgs);
304 assert_eq!(contents[0]["role"], "user");
305 assert_eq!(contents[1]["role"], "model");
306 assert_eq!(contents[2]["role"], "user");
307 }
308
309 #[test]
310 fn test_gemini_build_request() {
311 let driver = GeminiDriver::new("google", vec![Capability::Text]);
312 let messages = vec![Message::user("Hello")];
313 let req = driver
314 .build_request(
315 &messages,
316 "gemini-2.0-flash",
317 Some(0.5),
318 Some(2048),
319 false,
320 None,
321 )
322 .unwrap();
323 assert_eq!(req.body["generationConfig"]["temperature"], 0.5);
324 assert_eq!(req.body["generationConfig"]["maxOutputTokens"], 2048);
325 }
326
327 #[test]
328 fn test_gemini_parse_response() {
329 let driver = GeminiDriver::new("google", vec![]);
330 let body = serde_json::json!({
331 "candidates": [{
332 "content": { "parts": [{"text": "Hi!"}], "role": "model" },
333 "finishReason": "STOP"
334 }],
335 "usageMetadata": {
336 "promptTokenCount": 5,
337 "candidatesTokenCount": 3,
338 "totalTokenCount": 8
339 }
340 });
341 let resp = driver.parse_response(&body).unwrap();
342 assert_eq!(resp.content.as_deref(), Some("Hi!"));
343 assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
344 assert_eq!(resp.usage.unwrap().total_tokens, 8);
345 }
346
347 #[test]
348 fn test_gemini_parse_stream_delta() {
349 let driver = GeminiDriver::new("google", vec![]);
350 let data = r#"{"candidates":[{"content":{"parts":[{"text":"World"}],"role":"model"}}]}"#;
351 let event = driver.parse_stream_event(data).unwrap();
352 match event {
353 Some(StreamingEvent::PartialContentDelta { content, .. }) => {
354 assert_eq!(content, "World");
355 }
356 _ => panic!("Expected PartialContentDelta"),
357 }
358 }
359
360 #[test]
361 fn test_gemini_finish_reason_normalization() {
362 let driver = GeminiDriver::new("google", vec![]);
363 let body = serde_json::json!({
364 "candidates": [{
365 "content": { "parts": [{"text": ""}], "role": "model" },
366 "finishReason": "SAFETY"
367 }]
368 });
369 let resp = driver.parse_response(&body).unwrap();
370 assert_eq!(resp.finish_reason.as_deref(), Some("content_filter"));
371 }
372}