Skip to main content

ai_lib_core/drivers/
gemini.rs

1//! Gemini Generate API 驱动 — 实现 Google Gemini 特有的请求/响应格式转换
2//!
3//! Google Gemini generateContent API driver. Key differences:
4//! - Uses `contents` instead of `messages`, with `parts` instead of `content`.
5//! - Roles: `user` and `model` (not `assistant`). System uses `system_instruction`.
6//! - `generationConfig` wraps temperature, max_tokens (→ `maxOutputTokens`), etc.
7//! - Response: `candidates[0].content.parts[0].text`.
8//! - Streaming uses NDJSON with the same structure (each line is a full response).
9//! - API key is passed as `?key=` query parameter, not in headers.
10
11use async_trait::async_trait;
12use serde_json::Value;
13use std::collections::HashMap;
14
15use crate::error::Error;
16use crate::protocol::v2::capabilities::Capability;
17use crate::protocol::v2::manifest::ApiStyle;
18use crate::protocol::ProtocolError;
19use crate::types::events::StreamingEvent;
20use crate::types::message::{Message, MessageContent, MessageRole};
21
22use super::{DriverRequest, DriverResponse, ProviderDriver, UsageInfo};
23
24/// Google Gemini generateContent API driver.
25#[derive(Debug)]
26pub struct GeminiDriver {
27    provider_id: String,
28    capabilities: Vec<Capability>,
29}
30
31impl GeminiDriver {
32    pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
33        Self {
34            provider_id: provider_id.into(),
35            capabilities,
36        }
37    }
38
39    /// Separate system instructions from conversation contents.
40    /// Gemini uses `system_instruction` as a top-level field.
41    fn split_messages(messages: &[Message]) -> (Option<Value>, Vec<Value>) {
42        let mut system_parts: Vec<String> = Vec::new();
43        let mut contents: Vec<Value> = Vec::new();
44
45        for m in messages {
46            match m.role {
47                MessageRole::System => {
48                    if let MessageContent::Text(ref s) = m.content {
49                        system_parts.push(s.clone());
50                    }
51                }
52                MessageRole::Tool => {
53                    // Gemini: function_response requires name (from original call) and response.
54                    // We use tool_call_id as name hint; response is the content.
55                    if let (Some(ref id), MessageContent::Text(ref s)) =
56                        (&m.tool_call_id, &m.content)
57                    {
58                        contents.push(serde_json::json!({
59                            "role": "user",
60                            "parts": [{ "functionResponse": { "name": id, "response": { "result": s } } }],
61                        }));
62                    }
63                }
64                _ => {
65                    let role = match m.role {
66                        MessageRole::User => "user",
67                        MessageRole::Assistant => "model",
68                        MessageRole::System => unreachable!(),
69                        MessageRole::Tool => unreachable!(),
70                    };
71                    let parts = Self::content_to_parts(&m.content);
72                    contents.push(serde_json::json!({
73                        "role": role,
74                        "parts": parts,
75                    }));
76                }
77            }
78        }
79
80        let system_instruction = if system_parts.is_empty() {
81            None
82        } else {
83            Some(serde_json::json!({
84                "parts": [{ "text": system_parts.join("\n\n") }]
85            }))
86        };
87
88        (system_instruction, contents)
89    }
90
91    /// Convert MessageContent to Gemini `parts` array.
92    fn content_to_parts(content: &MessageContent) -> Value {
93        match content {
94            MessageContent::Text(s) => {
95                serde_json::json!([{ "text": s }])
96            }
97            MessageContent::Blocks(_) => {
98                // For multimodal blocks, delegate to serde (needs further
99                // transformation for Gemini's inline_data format in Sprint 3).
100                serde_json::to_value(content).unwrap_or(Value::Null)
101            }
102        }
103    }
104}
105
106#[async_trait]
107impl ProviderDriver for GeminiDriver {
108    fn provider_id(&self) -> &str {
109        &self.provider_id
110    }
111
112    fn api_style(&self) -> ApiStyle {
113        ApiStyle::GeminiGenerate
114    }
115
116    fn build_request(
117        &self,
118        messages: &[Message],
119        _model: &str,
120        temperature: Option<f64>,
121        max_tokens: Option<u32>,
122        _stream: bool,
123        extra: Option<&Value>,
124    ) -> Result<DriverRequest, Error> {
125        let (system_instruction, contents) = Self::split_messages(messages);
126
127        let mut body = serde_json::json!({
128            "contents": contents,
129        });
130
131        if let Some(sys) = system_instruction {
132            body["system_instruction"] = sys;
133        }
134
135        // Gemini uses `generationConfig` for parameters
136        let mut gen_config = serde_json::json!({});
137        if let Some(t) = temperature {
138            gen_config["temperature"] = serde_json::json!(t);
139        }
140        if let Some(mt) = max_tokens {
141            gen_config["maxOutputTokens"] = serde_json::json!(mt);
142        }
143        if gen_config != serde_json::json!({}) {
144            body["generationConfig"] = gen_config;
145        }
146
147        if let Some(Value::Object(map)) = extra {
148            for (k, v) in map {
149                body[k] = v.clone();
150            }
151        }
152
153        Ok(DriverRequest {
154            url: String::new(), // URL includes model and :generateContent / :streamGenerateContent
155            method: "POST".into(),
156            headers: HashMap::new(),
157            body,
158            stream: _stream,
159        })
160    }
161
162    fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
163        // Gemini: { candidates: [{ content: { parts: [{text: "..."}] }, finishReason }], usageMetadata }
164        let content = body
165            .pointer("/candidates/0/content/parts/0/text")
166            .and_then(|v| v.as_str())
167            .map(String::from);
168
169        let finish_reason = body
170            .pointer("/candidates/0/finishReason")
171            .and_then(|v| v.as_str())
172            .map(|r| match r {
173                "STOP" => "stop".to_string(),
174                "MAX_TOKENS" => "length".to_string(),
175                "SAFETY" => "content_filter".to_string(),
176                "RECITATION" => "content_filter".to_string(),
177                other => other.to_lowercase(),
178            });
179
180        let usage = body.get("usageMetadata").map(|u| UsageInfo {
181            prompt_tokens: u["promptTokenCount"].as_u64().unwrap_or(0),
182            completion_tokens: u["candidatesTokenCount"].as_u64().unwrap_or(0),
183            total_tokens: u["totalTokenCount"].as_u64().unwrap_or(0),
184            reasoning_tokens: None,
185            cache_read_tokens: None,
186            cache_creation_tokens: None,
187        });
188
189        // Gemini tool calls: functionCall parts
190        let tool_calls: Vec<Value> = body
191            .pointer("/candidates/0/content/parts")
192            .and_then(|p| p.as_array())
193            .map(|parts| {
194                parts
195                    .iter()
196                    .filter(|p| p.get("functionCall").is_some())
197                    .cloned()
198                    .collect()
199            })
200            .unwrap_or_default();
201
202        Ok(DriverResponse {
203            content,
204            finish_reason,
205            usage,
206            tool_calls,
207            raw: body.clone(),
208        })
209    }
210
211    fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
212        if data.trim().is_empty() {
213            return Ok(None);
214        }
215
216        // Gemini streaming returns NDJSON — each line is a full generateContent response
217        let v: Value = serde_json::from_str(data).map_err(|e| {
218            Error::Protocol(ProtocolError::ValidationError(format!(
219                "Failed to parse Gemini stream: {}",
220                e
221            )))
222        })?;
223
224        // Check for error
225        if let Some(error) = v.get("error") {
226            return Ok(Some(StreamingEvent::StreamError {
227                error: error.clone(),
228                event_id: None,
229            }));
230        }
231
232        // Content delta
233        if let Some(text) = v
234            .pointer("/candidates/0/content/parts/0/text")
235            .and_then(|t| t.as_str())
236        {
237            if !text.is_empty() {
238                return Ok(Some(StreamingEvent::PartialContentDelta {
239                    content: text.to_string(),
240                    sequence_id: None,
241                }));
242            }
243        }
244
245        // Finish reason
246        if let Some(reason) = v
247            .pointer("/candidates/0/finishReason")
248            .and_then(|r| r.as_str())
249        {
250            if reason != "STOP" || v.pointer("/candidates/0/content/parts/0/text").is_none() {
251                return Ok(Some(StreamingEvent::StreamEnd {
252                    finish_reason: Some(match reason {
253                        "STOP" => "stop".to_string(),
254                        "MAX_TOKENS" => "length".to_string(),
255                        other => other.to_lowercase(),
256                    }),
257                }));
258            }
259        }
260
261        Ok(None)
262    }
263
264    fn supported_capabilities(&self) -> &[Capability] {
265        &self.capabilities
266    }
267
268    fn is_stream_done(&self, _data: &str) -> bool {
269        // Gemini uses NDJSON, stream ends when connection closes.
270        // Individual chunks may contain finishReason but the stream itself
271        // has no sentinel like OpenAI's [DONE].
272        false
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_gemini_system_instruction() {
282        let msgs = vec![
283            Message::system("Be concise."),
284            Message::user("Explain Rust."),
285        ];
286        let (sys, contents) = GeminiDriver::split_messages(&msgs);
287        assert!(sys.is_some());
288        assert_eq!(
289            sys.unwrap()["parts"][0]["text"].as_str().unwrap(),
290            "Be concise."
291        );
292        assert_eq!(contents.len(), 1);
293        assert_eq!(contents[0]["role"], "user");
294    }
295
296    #[test]
297    fn test_gemini_role_mapping() {
298        let msgs = vec![
299            Message::user("Hi"),
300            Message::assistant("Hello!"),
301            Message::user("How are you?"),
302        ];
303        let (_, contents) = GeminiDriver::split_messages(&msgs);
304        assert_eq!(contents[0]["role"], "user");
305        assert_eq!(contents[1]["role"], "model");
306        assert_eq!(contents[2]["role"], "user");
307    }
308
309    #[test]
310    fn test_gemini_build_request() {
311        let driver = GeminiDriver::new("google", vec![Capability::Text]);
312        let messages = vec![Message::user("Hello")];
313        let req = driver
314            .build_request(
315                &messages,
316                "gemini-2.0-flash",
317                Some(0.5),
318                Some(2048),
319                false,
320                None,
321            )
322            .unwrap();
323        assert_eq!(req.body["generationConfig"]["temperature"], 0.5);
324        assert_eq!(req.body["generationConfig"]["maxOutputTokens"], 2048);
325    }
326
327    #[test]
328    fn test_gemini_parse_response() {
329        let driver = GeminiDriver::new("google", vec![]);
330        let body = serde_json::json!({
331            "candidates": [{
332                "content": { "parts": [{"text": "Hi!"}], "role": "model" },
333                "finishReason": "STOP"
334            }],
335            "usageMetadata": {
336                "promptTokenCount": 5,
337                "candidatesTokenCount": 3,
338                "totalTokenCount": 8
339            }
340        });
341        let resp = driver.parse_response(&body).unwrap();
342        assert_eq!(resp.content.as_deref(), Some("Hi!"));
343        assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
344        assert_eq!(resp.usage.unwrap().total_tokens, 8);
345    }
346
347    #[test]
348    fn test_gemini_parse_stream_delta() {
349        let driver = GeminiDriver::new("google", vec![]);
350        let data = r#"{"candidates":[{"content":{"parts":[{"text":"World"}],"role":"model"}}]}"#;
351        let event = driver.parse_stream_event(data).unwrap();
352        match event {
353            Some(StreamingEvent::PartialContentDelta { content, .. }) => {
354                assert_eq!(content, "World");
355            }
356            _ => panic!("Expected PartialContentDelta"),
357        }
358    }
359
360    #[test]
361    fn test_gemini_finish_reason_normalization() {
362        let driver = GeminiDriver::new("google", vec![]);
363        let body = serde_json::json!({
364            "candidates": [{
365                "content": { "parts": [{"text": ""}], "role": "model" },
366                "finishReason": "SAFETY"
367            }]
368        });
369        let resp = driver.parse_response(&body).unwrap();
370        assert_eq!(resp.finish_reason.as_deref(), Some("content_filter"));
371    }
372}