1use async_trait::async_trait;
12use serde_json::Value;
13use std::collections::HashMap;
14
15use crate::error::Error;
16use crate::protocol::v2::capabilities::Capability;
17use crate::protocol::v2::manifest::ApiStyle;
18use crate::protocol::ProtocolError;
19use crate::types::events::StreamingEvent;
20use crate::types::message::{Message, MessageContent, MessageRole};
21
22use super::{DriverRequest, DriverResponse, ProviderDriver, UsageInfo};
23
24#[derive(Debug)]
26pub struct GeminiDriver {
27 provider_id: String,
28 capabilities: Vec<Capability>,
29}
30
31impl GeminiDriver {
32 pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
33 Self {
34 provider_id: provider_id.into(),
35 capabilities,
36 }
37 }
38
39 fn split_messages(messages: &[Message]) -> (Option<Value>, Vec<Value>) {
42 let mut system_parts: Vec<String> = Vec::new();
43 let mut contents: Vec<Value> = Vec::new();
44
45 for m in messages {
46 match m.role {
47 MessageRole::System => {
48 if let MessageContent::Text(ref s) = m.content {
49 system_parts.push(s.clone());
50 }
51 }
52 MessageRole::Tool => {
53 if let (Some(ref id), MessageContent::Text(ref s)) =
56 (&m.tool_call_id, &m.content)
57 {
58 contents.push(serde_json::json!({
59 "role": "user",
60 "parts": [{ "functionResponse": { "name": id, "response": { "result": s } } }],
61 }));
62 }
63 }
64 _ => {
65 let role = match m.role {
66 MessageRole::User => "user",
67 MessageRole::Assistant => "model",
68 MessageRole::System => unreachable!(),
69 MessageRole::Tool => unreachable!(),
70 };
71 let parts = Self::content_to_parts(&m.content);
72 contents.push(serde_json::json!({
73 "role": role,
74 "parts": parts,
75 }));
76 }
77 }
78 }
79
80 let system_instruction = if system_parts.is_empty() {
81 None
82 } else {
83 Some(serde_json::json!({
84 "parts": [{ "text": system_parts.join("\n\n") }]
85 }))
86 };
87
88 (system_instruction, contents)
89 }
90
91 fn content_to_parts(content: &MessageContent) -> Value {
93 match content {
94 MessageContent::Text(s) => {
95 serde_json::json!([{ "text": s }])
96 }
97 MessageContent::Blocks(_) => {
98 serde_json::to_value(content).unwrap_or(Value::Null)
101 }
102 }
103 }
104}
105
106#[async_trait]
107impl ProviderDriver for GeminiDriver {
108 fn provider_id(&self) -> &str {
109 &self.provider_id
110 }
111
112 fn api_style(&self) -> ApiStyle {
113 ApiStyle::GeminiGenerate
114 }
115
116 fn build_request(
117 &self,
118 messages: &[Message],
119 _model: &str,
120 temperature: Option<f64>,
121 max_tokens: Option<u32>,
122 _stream: bool,
123 extra: Option<&Value>,
124 ) -> Result<DriverRequest, Error> {
125 let (system_instruction, contents) = Self::split_messages(messages);
126
127 let mut body = serde_json::json!({
128 "contents": contents,
129 });
130
131 if let Some(sys) = system_instruction {
132 body["system_instruction"] = sys;
133 }
134
135 let mut gen_config = serde_json::json!({});
137 if let Some(t) = temperature {
138 gen_config["temperature"] = serde_json::json!(t);
139 }
140 if let Some(mt) = max_tokens {
141 gen_config["maxOutputTokens"] = serde_json::json!(mt);
142 }
143 if gen_config != serde_json::json!({}) {
144 body["generationConfig"] = gen_config;
145 }
146
147 if let Some(ext) = extra {
148 if let Value::Object(map) = ext {
149 for (k, v) in map {
150 body[k] = v.clone();
151 }
152 }
153 }
154
155 Ok(DriverRequest {
156 url: String::new(), method: "POST".into(),
158 headers: HashMap::new(),
159 body,
160 stream: _stream,
161 })
162 }
163
164 fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
165 let content = body
167 .pointer("/candidates/0/content/parts/0/text")
168 .and_then(|v| v.as_str())
169 .map(String::from);
170
171 let finish_reason = body
172 .pointer("/candidates/0/finishReason")
173 .and_then(|v| v.as_str())
174 .map(|r| match r {
175 "STOP" => "stop".to_string(),
176 "MAX_TOKENS" => "length".to_string(),
177 "SAFETY" => "content_filter".to_string(),
178 "RECITATION" => "content_filter".to_string(),
179 other => other.to_lowercase(),
180 });
181
182 let usage = body.get("usageMetadata").map(|u| UsageInfo {
183 prompt_tokens: u["promptTokenCount"].as_u64().unwrap_or(0),
184 completion_tokens: u["candidatesTokenCount"].as_u64().unwrap_or(0),
185 total_tokens: u["totalTokenCount"].as_u64().unwrap_or(0),
186 });
187
188 let tool_calls: Vec<Value> = body
190 .pointer("/candidates/0/content/parts")
191 .and_then(|p| p.as_array())
192 .map(|parts| {
193 parts
194 .iter()
195 .filter(|p| p.get("functionCall").is_some())
196 .cloned()
197 .collect()
198 })
199 .unwrap_or_default();
200
201 Ok(DriverResponse {
202 content,
203 finish_reason,
204 usage,
205 tool_calls,
206 raw: body.clone(),
207 })
208 }
209
210 fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
211 if data.trim().is_empty() {
212 return Ok(None);
213 }
214
215 let v: Value = serde_json::from_str(data).map_err(|e| {
217 Error::Protocol(ProtocolError::ValidationError(format!(
218 "Failed to parse Gemini stream: {}",
219 e
220 )))
221 })?;
222
223 if let Some(error) = v.get("error") {
225 return Ok(Some(StreamingEvent::StreamError {
226 error: error.clone(),
227 event_id: None,
228 }));
229 }
230
231 if let Some(text) = v.pointer("/candidates/0/content/parts/0/text").and_then(|t| t.as_str())
233 {
234 if !text.is_empty() {
235 return Ok(Some(StreamingEvent::PartialContentDelta {
236 content: text.to_string(),
237 sequence_id: None,
238 }));
239 }
240 }
241
242 if let Some(reason) = v
244 .pointer("/candidates/0/finishReason")
245 .and_then(|r| r.as_str())
246 {
247 if reason != "STOP" || v.pointer("/candidates/0/content/parts/0/text").is_none() {
248 return Ok(Some(StreamingEvent::StreamEnd {
249 finish_reason: Some(match reason {
250 "STOP" => "stop".to_string(),
251 "MAX_TOKENS" => "length".to_string(),
252 other => other.to_lowercase(),
253 }),
254 }));
255 }
256 }
257
258 Ok(None)
259 }
260
261 fn supported_capabilities(&self) -> &[Capability] {
262 &self.capabilities
263 }
264
265 fn is_stream_done(&self, _data: &str) -> bool {
266 false
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_gemini_system_instruction() {
279 let msgs = vec![
280 Message::system("Be concise."),
281 Message::user("Explain Rust."),
282 ];
283 let (sys, contents) = GeminiDriver::split_messages(&msgs);
284 assert!(sys.is_some());
285 assert_eq!(
286 sys.unwrap()["parts"][0]["text"].as_str().unwrap(),
287 "Be concise."
288 );
289 assert_eq!(contents.len(), 1);
290 assert_eq!(contents[0]["role"], "user");
291 }
292
293 #[test]
294 fn test_gemini_role_mapping() {
295 let msgs = vec![
296 Message::user("Hi"),
297 Message::assistant("Hello!"),
298 Message::user("How are you?"),
299 ];
300 let (_, contents) = GeminiDriver::split_messages(&msgs);
301 assert_eq!(contents[0]["role"], "user");
302 assert_eq!(contents[1]["role"], "model");
303 assert_eq!(contents[2]["role"], "user");
304 }
305
306 #[test]
307 fn test_gemini_build_request() {
308 let driver = GeminiDriver::new("google", vec![Capability::Text]);
309 let messages = vec![Message::user("Hello")];
310 let req = driver
311 .build_request(&messages, "gemini-2.0-flash", Some(0.5), Some(2048), false, None)
312 .unwrap();
313 assert_eq!(req.body["generationConfig"]["temperature"], 0.5);
314 assert_eq!(req.body["generationConfig"]["maxOutputTokens"], 2048);
315 }
316
317 #[test]
318 fn test_gemini_parse_response() {
319 let driver = GeminiDriver::new("google", vec![]);
320 let body = serde_json::json!({
321 "candidates": [{
322 "content": { "parts": [{"text": "Hi!"}], "role": "model" },
323 "finishReason": "STOP"
324 }],
325 "usageMetadata": {
326 "promptTokenCount": 5,
327 "candidatesTokenCount": 3,
328 "totalTokenCount": 8
329 }
330 });
331 let resp = driver.parse_response(&body).unwrap();
332 assert_eq!(resp.content.as_deref(), Some("Hi!"));
333 assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
334 assert_eq!(resp.usage.unwrap().total_tokens, 8);
335 }
336
337 #[test]
338 fn test_gemini_parse_stream_delta() {
339 let driver = GeminiDriver::new("google", vec![]);
340 let data = r#"{"candidates":[{"content":{"parts":[{"text":"World"}],"role":"model"}}]}"#;
341 let event = driver.parse_stream_event(data).unwrap();
342 match event {
343 Some(StreamingEvent::PartialContentDelta { content, .. }) => {
344 assert_eq!(content, "World");
345 }
346 _ => panic!("Expected PartialContentDelta"),
347 }
348 }
349
350 #[test]
351 fn test_gemini_finish_reason_normalization() {
352 let driver = GeminiDriver::new("google", vec![]);
353 let body = serde_json::json!({
354 "candidates": [{
355 "content": { "parts": [{"text": ""}], "role": "model" },
356 "finishReason": "SAFETY"
357 }]
358 });
359 let resp = driver.parse_response(&body).unwrap();
360 assert_eq!(resp.finish_reason.as_deref(), Some("content_filter"));
361 }
362}