Skip to main content

ai_lib_rust/drivers/
gemini.rs

1//! Gemini Generate API 驱动 — 实现 Google Gemini 特有的请求/响应格式转换
2//!
3//! Google Gemini generateContent API driver. Key differences:
4//! - Uses `contents` instead of `messages`, with `parts` instead of `content`.
5//! - Roles: `user` and `model` (not `assistant`). System uses `system_instruction`.
6//! - `generationConfig` wraps temperature, max_tokens (→ `maxOutputTokens`), etc.
7//! - Response: `candidates[0].content.parts[0].text`.
8//! - Streaming uses NDJSON with the same structure (each line is a full response).
9//! - API key is passed as `?key=` query parameter, not in headers.
10
11use async_trait::async_trait;
12use serde_json::Value;
13use std::collections::HashMap;
14
15use crate::error::Error;
16use crate::protocol::v2::capabilities::Capability;
17use crate::protocol::v2::manifest::ApiStyle;
18use crate::protocol::ProtocolError;
19use crate::types::events::StreamingEvent;
20use crate::types::message::{Message, MessageContent, MessageRole};
21
22use super::{DriverRequest, DriverResponse, ProviderDriver, UsageInfo};
23
24/// Google Gemini generateContent API driver.
25#[derive(Debug)]
26pub struct GeminiDriver {
27    provider_id: String,
28    capabilities: Vec<Capability>,
29}
30
31impl GeminiDriver {
32    pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
33        Self {
34            provider_id: provider_id.into(),
35            capabilities,
36        }
37    }
38
39    /// Separate system instructions from conversation contents.
40    /// Gemini uses `system_instruction` as a top-level field.
41    fn split_messages(messages: &[Message]) -> (Option<Value>, Vec<Value>) {
42        let mut system_parts: Vec<String> = Vec::new();
43        let mut contents: Vec<Value> = Vec::new();
44
45        for m in messages {
46            match m.role {
47                MessageRole::System => {
48                    if let MessageContent::Text(ref s) = m.content {
49                        system_parts.push(s.clone());
50                    }
51                }
52                _ => {
53                    let role = match m.role {
54                        MessageRole::User => "user",
55                        MessageRole::Assistant => "model",
56                        MessageRole::System => unreachable!(),
57                    };
58                    let parts = Self::content_to_parts(&m.content);
59                    contents.push(serde_json::json!({
60                        "role": role,
61                        "parts": parts,
62                    }));
63                }
64            }
65        }
66
67        let system_instruction = if system_parts.is_empty() {
68            None
69        } else {
70            Some(serde_json::json!({
71                "parts": [{ "text": system_parts.join("\n\n") }]
72            }))
73        };
74
75        (system_instruction, contents)
76    }
77
78    /// Convert MessageContent to Gemini `parts` array.
79    fn content_to_parts(content: &MessageContent) -> Value {
80        match content {
81            MessageContent::Text(s) => {
82                serde_json::json!([{ "text": s }])
83            }
84            MessageContent::Blocks(_) => {
85                // For multimodal blocks, delegate to serde (needs further
86                // transformation for Gemini's inline_data format in Sprint 3).
87                serde_json::to_value(content).unwrap_or(Value::Null)
88            }
89        }
90    }
91}
92
93#[async_trait]
94impl ProviderDriver for GeminiDriver {
95    fn provider_id(&self) -> &str {
96        &self.provider_id
97    }
98
99    fn api_style(&self) -> ApiStyle {
100        ApiStyle::GeminiGenerate
101    }
102
103    fn build_request(
104        &self,
105        messages: &[Message],
106        _model: &str,
107        temperature: Option<f64>,
108        max_tokens: Option<u32>,
109        _stream: bool,
110        extra: Option<&Value>,
111    ) -> Result<DriverRequest, Error> {
112        let (system_instruction, contents) = Self::split_messages(messages);
113
114        let mut body = serde_json::json!({
115            "contents": contents,
116        });
117
118        if let Some(sys) = system_instruction {
119            body["system_instruction"] = sys;
120        }
121
122        // Gemini uses `generationConfig` for parameters
123        let mut gen_config = serde_json::json!({});
124        if let Some(t) = temperature {
125            gen_config["temperature"] = serde_json::json!(t);
126        }
127        if let Some(mt) = max_tokens {
128            gen_config["maxOutputTokens"] = serde_json::json!(mt);
129        }
130        if gen_config != serde_json::json!({}) {
131            body["generationConfig"] = gen_config;
132        }
133
134        if let Some(ext) = extra {
135            if let Value::Object(map) = ext {
136                for (k, v) in map {
137                    body[k] = v.clone();
138                }
139            }
140        }
141
142        Ok(DriverRequest {
143            url: String::new(), // URL includes model and :generateContent / :streamGenerateContent
144            method: "POST".into(),
145            headers: HashMap::new(),
146            body,
147            stream: _stream,
148        })
149    }
150
151    fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
152        // Gemini: { candidates: [{ content: { parts: [{text: "..."}] }, finishReason }], usageMetadata }
153        let content = body
154            .pointer("/candidates/0/content/parts/0/text")
155            .and_then(|v| v.as_str())
156            .map(String::from);
157
158        let finish_reason = body
159            .pointer("/candidates/0/finishReason")
160            .and_then(|v| v.as_str())
161            .map(|r| match r {
162                "STOP" => "stop".to_string(),
163                "MAX_TOKENS" => "length".to_string(),
164                "SAFETY" => "content_filter".to_string(),
165                "RECITATION" => "content_filter".to_string(),
166                other => other.to_lowercase(),
167            });
168
169        let usage = body.get("usageMetadata").map(|u| UsageInfo {
170            prompt_tokens: u["promptTokenCount"].as_u64().unwrap_or(0),
171            completion_tokens: u["candidatesTokenCount"].as_u64().unwrap_or(0),
172            total_tokens: u["totalTokenCount"].as_u64().unwrap_or(0),
173        });
174
175        // Gemini tool calls: functionCall parts
176        let tool_calls: Vec<Value> = body
177            .pointer("/candidates/0/content/parts")
178            .and_then(|p| p.as_array())
179            .map(|parts| {
180                parts
181                    .iter()
182                    .filter(|p| p.get("functionCall").is_some())
183                    .cloned()
184                    .collect()
185            })
186            .unwrap_or_default();
187
188        Ok(DriverResponse {
189            content,
190            finish_reason,
191            usage,
192            tool_calls,
193            raw: body.clone(),
194        })
195    }
196
197    fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
198        if data.trim().is_empty() {
199            return Ok(None);
200        }
201
202        // Gemini streaming returns NDJSON — each line is a full generateContent response
203        let v: Value = serde_json::from_str(data).map_err(|e| {
204            Error::Protocol(ProtocolError::ValidationError(format!(
205                "Failed to parse Gemini stream: {}",
206                e
207            )))
208        })?;
209
210        // Check for error
211        if let Some(error) = v.get("error") {
212            return Ok(Some(StreamingEvent::StreamError {
213                error: error.clone(),
214                event_id: None,
215            }));
216        }
217
218        // Content delta
219        if let Some(text) = v.pointer("/candidates/0/content/parts/0/text").and_then(|t| t.as_str())
220        {
221            if !text.is_empty() {
222                return Ok(Some(StreamingEvent::PartialContentDelta {
223                    content: text.to_string(),
224                    sequence_id: None,
225                }));
226            }
227        }
228
229        // Finish reason
230        if let Some(reason) = v
231            .pointer("/candidates/0/finishReason")
232            .and_then(|r| r.as_str())
233        {
234            if reason != "STOP" || v.pointer("/candidates/0/content/parts/0/text").is_none() {
235                return Ok(Some(StreamingEvent::StreamEnd {
236                    finish_reason: Some(match reason {
237                        "STOP" => "stop".to_string(),
238                        "MAX_TOKENS" => "length".to_string(),
239                        other => other.to_lowercase(),
240                    }),
241                }));
242            }
243        }
244
245        Ok(None)
246    }
247
248    fn supported_capabilities(&self) -> &[Capability] {
249        &self.capabilities
250    }
251
252    fn is_stream_done(&self, _data: &str) -> bool {
253        // Gemini uses NDJSON, stream ends when connection closes.
254        // Individual chunks may contain finishReason but the stream itself
255        // has no sentinel like OpenAI's [DONE].
256        false
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn test_gemini_system_instruction() {
266        let msgs = vec![
267            Message::system("Be concise."),
268            Message::user("Explain Rust."),
269        ];
270        let (sys, contents) = GeminiDriver::split_messages(&msgs);
271        assert!(sys.is_some());
272        assert_eq!(
273            sys.unwrap()["parts"][0]["text"].as_str().unwrap(),
274            "Be concise."
275        );
276        assert_eq!(contents.len(), 1);
277        assert_eq!(contents[0]["role"], "user");
278    }
279
280    #[test]
281    fn test_gemini_role_mapping() {
282        let msgs = vec![
283            Message::user("Hi"),
284            Message::assistant("Hello!"),
285            Message::user("How are you?"),
286        ];
287        let (_, contents) = GeminiDriver::split_messages(&msgs);
288        assert_eq!(contents[0]["role"], "user");
289        assert_eq!(contents[1]["role"], "model");
290        assert_eq!(contents[2]["role"], "user");
291    }
292
293    #[test]
294    fn test_gemini_build_request() {
295        let driver = GeminiDriver::new("google", vec![Capability::Text]);
296        let messages = vec![Message::user("Hello")];
297        let req = driver
298            .build_request(&messages, "gemini-2.0-flash", Some(0.5), Some(2048), false, None)
299            .unwrap();
300        assert_eq!(req.body["generationConfig"]["temperature"], 0.5);
301        assert_eq!(req.body["generationConfig"]["maxOutputTokens"], 2048);
302    }
303
304    #[test]
305    fn test_gemini_parse_response() {
306        let driver = GeminiDriver::new("google", vec![]);
307        let body = serde_json::json!({
308            "candidates": [{
309                "content": { "parts": [{"text": "Hi!"}], "role": "model" },
310                "finishReason": "STOP"
311            }],
312            "usageMetadata": {
313                "promptTokenCount": 5,
314                "candidatesTokenCount": 3,
315                "totalTokenCount": 8
316            }
317        });
318        let resp = driver.parse_response(&body).unwrap();
319        assert_eq!(resp.content.as_deref(), Some("Hi!"));
320        assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
321        assert_eq!(resp.usage.unwrap().total_tokens, 8);
322    }
323
324    #[test]
325    fn test_gemini_parse_stream_delta() {
326        let driver = GeminiDriver::new("google", vec![]);
327        let data = r#"{"candidates":[{"content":{"parts":[{"text":"World"}],"role":"model"}}]}"#;
328        let event = driver.parse_stream_event(data).unwrap();
329        match event {
330            Some(StreamingEvent::PartialContentDelta { content, .. }) => {
331                assert_eq!(content, "World");
332            }
333            _ => panic!("Expected PartialContentDelta"),
334        }
335    }
336
337    #[test]
338    fn test_gemini_finish_reason_normalization() {
339        let driver = GeminiDriver::new("google", vec![]);
340        let body = serde_json::json!({
341            "candidates": [{
342                "content": { "parts": [{"text": ""}], "role": "model" },
343                "finishReason": "SAFETY"
344            }]
345        });
346        let resp = driver.parse_response(&body).unwrap();
347        assert_eq!(resp.finish_reason.as_deref(), Some("content_filter"));
348    }
349}