1use async_trait::async_trait;
12use serde_json::Value;
13use std::collections::HashMap;
14
15use crate::error::Error;
16use crate::protocol::v2::capabilities::Capability;
17use crate::protocol::v2::manifest::ApiStyle;
18use crate::protocol::ProtocolError;
19use crate::types::events::StreamingEvent;
20use crate::types::message::{Message, MessageContent, MessageRole};
21
22use super::{DriverRequest, DriverResponse, ProviderDriver, UsageInfo};
23
24#[derive(Debug)]
26pub struct GeminiDriver {
27 provider_id: String,
28 capabilities: Vec<Capability>,
29}
30
31impl GeminiDriver {
32 pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
33 Self {
34 provider_id: provider_id.into(),
35 capabilities,
36 }
37 }
38
39 fn split_messages(messages: &[Message]) -> (Option<Value>, Vec<Value>) {
42 let mut system_parts: Vec<String> = Vec::new();
43 let mut contents: Vec<Value> = Vec::new();
44
45 for m in messages {
46 match m.role {
47 MessageRole::System => {
48 if let MessageContent::Text(ref s) = m.content {
49 system_parts.push(s.clone());
50 }
51 }
52 _ => {
53 let role = match m.role {
54 MessageRole::User => "user",
55 MessageRole::Assistant => "model",
56 MessageRole::System => unreachable!(),
57 };
58 let parts = Self::content_to_parts(&m.content);
59 contents.push(serde_json::json!({
60 "role": role,
61 "parts": parts,
62 }));
63 }
64 }
65 }
66
67 let system_instruction = if system_parts.is_empty() {
68 None
69 } else {
70 Some(serde_json::json!({
71 "parts": [{ "text": system_parts.join("\n\n") }]
72 }))
73 };
74
75 (system_instruction, contents)
76 }
77
78 fn content_to_parts(content: &MessageContent) -> Value {
80 match content {
81 MessageContent::Text(s) => {
82 serde_json::json!([{ "text": s }])
83 }
84 MessageContent::Blocks(_) => {
85 serde_json::to_value(content).unwrap_or(Value::Null)
88 }
89 }
90 }
91}
92
93#[async_trait]
94impl ProviderDriver for GeminiDriver {
95 fn provider_id(&self) -> &str {
96 &self.provider_id
97 }
98
99 fn api_style(&self) -> ApiStyle {
100 ApiStyle::GeminiGenerate
101 }
102
103 fn build_request(
104 &self,
105 messages: &[Message],
106 _model: &str,
107 temperature: Option<f64>,
108 max_tokens: Option<u32>,
109 _stream: bool,
110 extra: Option<&Value>,
111 ) -> Result<DriverRequest, Error> {
112 let (system_instruction, contents) = Self::split_messages(messages);
113
114 let mut body = serde_json::json!({
115 "contents": contents,
116 });
117
118 if let Some(sys) = system_instruction {
119 body["system_instruction"] = sys;
120 }
121
122 let mut gen_config = serde_json::json!({});
124 if let Some(t) = temperature {
125 gen_config["temperature"] = serde_json::json!(t);
126 }
127 if let Some(mt) = max_tokens {
128 gen_config["maxOutputTokens"] = serde_json::json!(mt);
129 }
130 if gen_config != serde_json::json!({}) {
131 body["generationConfig"] = gen_config;
132 }
133
134 if let Some(ext) = extra {
135 if let Value::Object(map) = ext {
136 for (k, v) in map {
137 body[k] = v.clone();
138 }
139 }
140 }
141
142 Ok(DriverRequest {
143 url: String::new(), method: "POST".into(),
145 headers: HashMap::new(),
146 body,
147 stream: _stream,
148 })
149 }
150
151 fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
152 let content = body
154 .pointer("/candidates/0/content/parts/0/text")
155 .and_then(|v| v.as_str())
156 .map(String::from);
157
158 let finish_reason = body
159 .pointer("/candidates/0/finishReason")
160 .and_then(|v| v.as_str())
161 .map(|r| match r {
162 "STOP" => "stop".to_string(),
163 "MAX_TOKENS" => "length".to_string(),
164 "SAFETY" => "content_filter".to_string(),
165 "RECITATION" => "content_filter".to_string(),
166 other => other.to_lowercase(),
167 });
168
169 let usage = body.get("usageMetadata").map(|u| UsageInfo {
170 prompt_tokens: u["promptTokenCount"].as_u64().unwrap_or(0),
171 completion_tokens: u["candidatesTokenCount"].as_u64().unwrap_or(0),
172 total_tokens: u["totalTokenCount"].as_u64().unwrap_or(0),
173 });
174
175 let tool_calls: Vec<Value> = body
177 .pointer("/candidates/0/content/parts")
178 .and_then(|p| p.as_array())
179 .map(|parts| {
180 parts
181 .iter()
182 .filter(|p| p.get("functionCall").is_some())
183 .cloned()
184 .collect()
185 })
186 .unwrap_or_default();
187
188 Ok(DriverResponse {
189 content,
190 finish_reason,
191 usage,
192 tool_calls,
193 raw: body.clone(),
194 })
195 }
196
197 fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
198 if data.trim().is_empty() {
199 return Ok(None);
200 }
201
202 let v: Value = serde_json::from_str(data).map_err(|e| {
204 Error::Protocol(ProtocolError::ValidationError(format!(
205 "Failed to parse Gemini stream: {}",
206 e
207 )))
208 })?;
209
210 if let Some(error) = v.get("error") {
212 return Ok(Some(StreamingEvent::StreamError {
213 error: error.clone(),
214 event_id: None,
215 }));
216 }
217
218 if let Some(text) = v.pointer("/candidates/0/content/parts/0/text").and_then(|t| t.as_str())
220 {
221 if !text.is_empty() {
222 return Ok(Some(StreamingEvent::PartialContentDelta {
223 content: text.to_string(),
224 sequence_id: None,
225 }));
226 }
227 }
228
229 if let Some(reason) = v
231 .pointer("/candidates/0/finishReason")
232 .and_then(|r| r.as_str())
233 {
234 if reason != "STOP" || v.pointer("/candidates/0/content/parts/0/text").is_none() {
235 return Ok(Some(StreamingEvent::StreamEnd {
236 finish_reason: Some(match reason {
237 "STOP" => "stop".to_string(),
238 "MAX_TOKENS" => "length".to_string(),
239 other => other.to_lowercase(),
240 }),
241 }));
242 }
243 }
244
245 Ok(None)
246 }
247
248 fn supported_capabilities(&self) -> &[Capability] {
249 &self.capabilities
250 }
251
252 fn is_stream_done(&self, _data: &str) -> bool {
253 false
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_gemini_system_instruction() {
266 let msgs = vec![
267 Message::system("Be concise."),
268 Message::user("Explain Rust."),
269 ];
270 let (sys, contents) = GeminiDriver::split_messages(&msgs);
271 assert!(sys.is_some());
272 assert_eq!(
273 sys.unwrap()["parts"][0]["text"].as_str().unwrap(),
274 "Be concise."
275 );
276 assert_eq!(contents.len(), 1);
277 assert_eq!(contents[0]["role"], "user");
278 }
279
280 #[test]
281 fn test_gemini_role_mapping() {
282 let msgs = vec![
283 Message::user("Hi"),
284 Message::assistant("Hello!"),
285 Message::user("How are you?"),
286 ];
287 let (_, contents) = GeminiDriver::split_messages(&msgs);
288 assert_eq!(contents[0]["role"], "user");
289 assert_eq!(contents[1]["role"], "model");
290 assert_eq!(contents[2]["role"], "user");
291 }
292
293 #[test]
294 fn test_gemini_build_request() {
295 let driver = GeminiDriver::new("google", vec![Capability::Text]);
296 let messages = vec![Message::user("Hello")];
297 let req = driver
298 .build_request(&messages, "gemini-2.0-flash", Some(0.5), Some(2048), false, None)
299 .unwrap();
300 assert_eq!(req.body["generationConfig"]["temperature"], 0.5);
301 assert_eq!(req.body["generationConfig"]["maxOutputTokens"], 2048);
302 }
303
304 #[test]
305 fn test_gemini_parse_response() {
306 let driver = GeminiDriver::new("google", vec![]);
307 let body = serde_json::json!({
308 "candidates": [{
309 "content": { "parts": [{"text": "Hi!"}], "role": "model" },
310 "finishReason": "STOP"
311 }],
312 "usageMetadata": {
313 "promptTokenCount": 5,
314 "candidatesTokenCount": 3,
315 "totalTokenCount": 8
316 }
317 });
318 let resp = driver.parse_response(&body).unwrap();
319 assert_eq!(resp.content.as_deref(), Some("Hi!"));
320 assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
321 assert_eq!(resp.usage.unwrap().total_tokens, 8);
322 }
323
324 #[test]
325 fn test_gemini_parse_stream_delta() {
326 let driver = GeminiDriver::new("google", vec![]);
327 let data = r#"{"candidates":[{"content":{"parts":[{"text":"World"}],"role":"model"}}]}"#;
328 let event = driver.parse_stream_event(data).unwrap();
329 match event {
330 Some(StreamingEvent::PartialContentDelta { content, .. }) => {
331 assert_eq!(content, "World");
332 }
333 _ => panic!("Expected PartialContentDelta"),
334 }
335 }
336
337 #[test]
338 fn test_gemini_finish_reason_normalization() {
339 let driver = GeminiDriver::new("google", vec![]);
340 let body = serde_json::json!({
341 "candidates": [{
342 "content": { "parts": [{"text": ""}], "role": "model" },
343 "finishReason": "SAFETY"
344 }]
345 });
346 let resp = driver.parse_response(&body).unwrap();
347 assert_eq!(resp.finish_reason.as_deref(), Some("content_filter"));
348 }
349}