Skip to main content

phi_core/provider/
google.rs

1//! Google Generative AI (Gemini) provider.
2//!
3//! Uses the `streamGenerateContent` endpoint with SSE streaming.
4//! API key is passed as a query parameter.
5/*
6ARCHITECTURE: GoogleProvider — Gemini's unique API shape
7
8Google's Generative AI API differs from OpenAI and Anthropic in several ways:
9
10URL & Auth:
11  Endpoint: POST .../v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}
12  Auth: API key in query parameter `key=`, NOT as a header.
13  `?alt=sse` enables SSE streaming (otherwise returns JSON in a single response).
14
15Request body:
16  - `contents` (not `messages`) — array of `{role, parts}` objects
17  - `role`: "user" | "model" (NOT "assistant")
18  - Tool results use `functionResponse` parts (not a separate message role)
19  - System instructions are separate: `systemInstruction: {parts: [{text: "..."}]}`
20  - Tools: `tools: [{functionDeclarations: [...]}]` (nested differently)
21
22Response events:
23  Google wraps each SSE data in a full response snapshot:
24    {"candidates": [{"content": {"parts": [{"text": "..."}], "role": "model"}}]}
25  Each chunk is a PARTIAL response — we accumulate text from all chunks.
26  No explicit "start/delta/stop" protocol; each chunk's text is appended.
27
28Error handling:
29  Google returns non-streaming errors as HTTP 4xx with JSON body.
30  We check status before starting SSE parsing (see `if !response.status().is_success()`).
31
32RUST QUIRK: `.send().await` vs `EventSource::new(request)`
33  Most providers use `EventSource::new(request)` which handles the HTTP request
34  internally. Google's provider calls `.send().await` first to check the HTTP status
35  code before creating the SSE stream. This allows early error detection without
36  starting the SSE event loop.
37*/
38
39use super::traits::*;
40use crate::types::*;
41use async_trait::async_trait;
42use futures::StreamExt;
43use serde::Deserialize;
44use tokio::sync::mpsc;
45use tracing::{debug, warn};
46
47/// Unit struct — no state. All logic in the `StreamProvider` impl.
48pub struct GoogleProvider;
49
50#[async_trait]
51impl StreamProvider for GoogleProvider {
52    fn provider_id(&self) -> &str {
53        "google"
54    }
55
56    async fn stream(
57        &self,
58        config: StreamConfig, // REQUEST — api_key sent as `?key=` query param; uses `contents[]` not `messages[]`
59        tx: mpsc::UnboundedSender<StreamEvent>, // OBSERVER — two-phase: send() checks status first, then streams SSE
60        cancel: tokio_util::sync::CancellationToken, // ABORT — races against SSE stream
61    ) -> Result<Message, ProviderError> {
62        let model_config = &config.model_config;
63        // Resolve via CredentialProvider when set, else use the static `api_key`.
64        let api_key = model_config.resolve_api_key().await?;
65
66        let base_url = &model_config.base_url;
67        // Google embeds the API key as a query parameter (not a header like other providers)
68        let url = format!(
69            "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
70            base_url, config.model_config.id, api_key
71        );
72
73        let body = build_request_body(&config);
74        debug!("Google GenAI request: model={}", config.model_config.id);
75
76        let client = reqwest::Client::new();
77        let mut request = client.post(&url).header("content-type", "application/json");
78
79        for (k, v) in &model_config.headers {
80            request = request.header(k, v);
81        }
82
83        /*
84        ARCHITECTURE: Two-phase request for early error detection
85
86        Unlike other providers (which pass the request to EventSource immediately),
87        we call `.send().await` first to check the HTTP status code before SSE parsing.
88
89        Why? Google returns non-streaming errors (invalid API key, bad model name,
90        context overflow) as plain HTTP 4xx responses with a JSON error body.
91        If we passed the request directly to EventSource, we'd get an SSE parse
92        error instead of a proper error classification.
93
94        Phase 1: send() → check status → error if 4xx/5xx
95        Phase 2: convert response to SSE stream (reqwest's `bytes_stream`)
96
97        RUST QUIRK: `.map_err(|e| ProviderError::Network(e.to_string()))?`
98          Converts `reqwest::Error` → `ProviderError::Network(String)` then propagates.
99          The `?` after `.await` chains error propagation through the `async fn`.
100        */
101        // Google streams JSON chunks separated by newlines, not SSE.
102        // With alt=sse, it does use SSE format.
103        let response = request
104            .json(&body)
105            .send()
106            .await
107            .map_err(|e| ProviderError::Network(e.to_string()))?;
108
109        if !response.status().is_success() {
110            let status = response.status();
111            let body = response.text().await.unwrap_or_default();
112            return Err(ProviderError::classify(
113                status.as_u16(),
114                &format!("Google API error {}: {}", status, body),
115            ));
116        }
117
118        let mut content: Vec<Content> = Vec::new();
119        let mut usage = Usage::default();
120        let mut stop_reason = StopReason::Stop;
121
122        let _ = tx.send(StreamEvent::Start);
123
124        // Parse SSE stream
125        let mut stream = response.bytes_stream();
126        let mut buffer = String::new();
127
128        loop {
129            tokio::select! {
130                _ = cancel.cancelled() => {
131                    return Err(ProviderError::Cancelled);
132                }
133                chunk = stream.next() => {
134                    match chunk {
135                        None => break,
136                        Some(Err(e)) => {
137                            warn!("Google stream error: {}", e);
138                            break;
139                        }
140                        Some(Ok(bytes)) => {
141                            buffer.push_str(&String::from_utf8_lossy(&bytes));
142
143                            // Process complete SSE events
144                            while let Some(pos) = buffer.find("\n\n") {
145                                let event_str = buffer[..pos].to_string();
146                                buffer = buffer[pos + 2..].to_string();
147
148                                // Parse SSE data line
149                                let data = event_str
150                                    .lines()
151                                    .find(|l| l.starts_with("data: "))
152                                    .map(|l| &l[6..])
153                                    .unwrap_or("");
154
155                                if data.is_empty() {
156                                    continue;
157                                }
158
159                                let chunk: GoogleChunk = match serde_json::from_str(data) {
160                                    Ok(c) => c,
161                                    Err(e) => {
162                                        debug!("Failed to parse Google chunk: {}", e);
163                                        continue;
164                                    }
165                                };
166
167                                // Process candidates
168                                for candidate in &chunk.candidates.unwrap_or_default() {
169                                    if let Some(c) = &candidate.content {
170                                        for part in &c.parts {
171                                            if let Some(text) = &part.text {
172                                                let text_idx = content.iter().position(|c| matches!(c, Content::Text { .. }));
173                                                let idx = match text_idx {
174                                                    Some(i) => i,
175                                                    None => {
176                                                        content.push(Content::Text { text: String::new() });
177                                                        content.len() - 1
178                                                    }
179                                                };
180                                                if let Some(Content::Text { text: t }) = content.get_mut(idx) {
181                                                    t.push_str(text);
182                                                }
183                                                let _ = tx.send(StreamEvent::TextDelta {
184                                                    content_index: idx,
185                                                    delta: text.clone(),
186                                                });
187                                            }
188                                            if let Some(fc) = &part.function_call {
189                                                let id = format!("google-fc-{}", content.len());
190                                                let args = fc.args.clone().unwrap_or(serde_json::Value::Object(Default::default()));
191                                                let idx = content.len();
192                                                content.push(Content::ToolCall {
193                                                    id: id.clone(),
194                                                    name: fc.name.clone(),
195                                                    arguments: args,
196                                                });
197                                                let _ = tx.send(StreamEvent::ToolCallStart {
198                                                    content_index: idx,
199                                                    id,
200                                                    name: fc.name.clone(),
201                                                });
202                                                let _ = tx.send(StreamEvent::ToolCallEnd { content_index: idx });
203                                                stop_reason = StopReason::ToolUse;
204                                            }
205                                        }
206                                    }
207                                    if let Some(reason) = &candidate.finish_reason {
208                                        stop_reason = match reason.as_str() {
209                                            "STOP" => StopReason::Stop,
210                                            "MAX_TOKENS" | "RECITATION" => StopReason::Length,
211                                            _ => StopReason::Stop,
212                                        };
213                                    }
214                                }
215
216                                // Process usage
217                                if let Some(u) = &chunk.usage_metadata {
218                                    usage.input = u.prompt_token_count.unwrap_or(0);
219                                    usage.output = u.candidates_token_count.unwrap_or(0);
220                                    usage.total_tokens = u.total_token_count.unwrap_or(0);
221                                    usage.cache_read = u.cached_content_token_count.unwrap_or(0);
222                                }
223                            }
224                        }
225                    }
226                }
227            }
228        }
229
230        let message = Message::Assistant {
231            content,
232            stop_reason,
233            model: config.model_config.id.clone(),
234            provider: model_config.provider.clone(),
235            usage,
236            timestamp: now_ms(),
237            error_message: None,
238        };
239
240        let _ = tx.send(StreamEvent::Done {
241            message: message.clone(),
242        });
243        Ok(message)
244    }
245}
246
247fn build_request_body(config: &StreamConfig) -> serde_json::Value {
248    let mut contents: Vec<serde_json::Value> = Vec::new();
249
250    for msg in &config.messages {
251        match msg {
252            Message::User { content, .. } => {
253                let parts = content_to_google_parts(content);
254                contents.push(serde_json::json!({
255                    "role": "user",
256                    "parts": parts,
257                }));
258            }
259            Message::Assistant { content, .. } => {
260                let parts = content_to_google_parts(content);
261                contents.push(serde_json::json!({
262                    "role": "model",
263                    "parts": parts,
264                }));
265            }
266            Message::ToolResult {
267                tool_call_id: _,
268                tool_name,
269                content,
270                ..
271            } => {
272                let text = content
273                    .iter()
274                    .find_map(|c| match c {
275                        Content::Text { text } => Some(text.clone()),
276                        _ => None,
277                    })
278                    .unwrap_or_default();
279
280                let mut parts = vec![serde_json::json!({
281                    "functionResponse": {
282                        "name": tool_name,
283                        "response": {"result": text},
284                    }
285                })];
286
287                // Append image parts if present
288                for c in content {
289                    if let Content::Image { data, mime_type } = c {
290                        parts.push(serde_json::json!({
291                            "inlineData": {"mimeType": mime_type, "data": data},
292                        }));
293                    }
294                }
295
296                contents.push(serde_json::json!({
297                    "role": "user",
298                    "parts": parts,
299                }));
300            }
301        }
302    }
303
304    let mut body = serde_json::json!({
305        "contents": contents,
306    });
307
308    if !config.system_prompt.is_empty() {
309        body["systemInstruction"] = serde_json::json!({
310            "parts": [{"text": config.system_prompt}],
311        });
312    }
313
314    let mut generation_config = serde_json::json!({});
315    if let Some(max) = config.max_tokens {
316        generation_config["maxOutputTokens"] = serde_json::json!(max);
317    }
318    if let Some(temp) = config.temperature {
319        generation_config["temperature"] = serde_json::json!(temp);
320    }
321    // Structured-output wiring. Gemini uses `responseMimeType` + `responseSchema`
322    // inside `generationConfig`, NOT a top-level field. JsonSchema is forwarded
323    // verbatim; the API supports Draft 2020-12 with Gemini-specific extensions.
324    match &config.response_format {
325        ResponseFormat::Text => {} // default; omit
326        ResponseFormat::JsonObject => {
327            generation_config["responseMimeType"] = serde_json::json!("application/json");
328        }
329        ResponseFormat::JsonSchema { schema, .. } => {
330            generation_config["responseMimeType"] = serde_json::json!("application/json");
331            generation_config["responseSchema"] = schema.clone();
332        }
333    }
334    if generation_config != serde_json::json!({}) {
335        body["generationConfig"] = generation_config;
336    }
337
338    if !config.tools.is_empty() {
339        let declarations: Vec<serde_json::Value> = config
340            .tools
341            .iter()
342            .map(|t| {
343                serde_json::json!({
344                    "name": t.name,
345                    "description": t.description,
346                    "parameters": t.parameters,
347                })
348            })
349            .collect();
350        body["tools"] = serde_json::json!([{
351            "functionDeclarations": declarations,
352        }]);
353    }
354
355    body
356}
357
358fn content_to_google_parts(content: &[Content]) -> Vec<serde_json::Value> {
359    content
360        .iter()
361        .filter_map(|c| match c {
362            Content::Text { text } => Some(serde_json::json!({"text": text})),
363            Content::Image { data, mime_type } => Some(serde_json::json!({
364                "inlineData": {"mimeType": mime_type, "data": data},
365            })),
366            Content::ToolCall {
367                name, arguments, ..
368            } => Some(serde_json::json!({
369                "functionCall": {"name": name, "args": arguments},
370            })),
371            Content::Thinking { .. } => None,
372        })
373        .collect()
374}
375
376// Google API response types
377#[derive(Deserialize)]
378struct GoogleChunk {
379    #[serde(default)]
380    candidates: Option<Vec<GoogleCandidate>>,
381    #[serde(default, rename = "usageMetadata")]
382    usage_metadata: Option<GoogleUsageMetadata>,
383}
384
385#[derive(Deserialize)]
386struct GoogleCandidate {
387    #[serde(default)]
388    content: Option<GoogleContent>,
389    #[serde(default, rename = "finishReason")]
390    finish_reason: Option<String>,
391}
392
393#[derive(Deserialize)]
394struct GoogleContent {
395    #[serde(default)]
396    parts: Vec<GooglePart>,
397}
398
399#[derive(Deserialize)]
400struct GooglePart {
401    #[serde(default)]
402    text: Option<String>,
403    #[serde(default, rename = "functionCall")]
404    function_call: Option<GoogleFunctionCall>,
405}
406
407#[derive(Deserialize)]
408struct GoogleFunctionCall {
409    name: String,
410    #[serde(default)]
411    args: Option<serde_json::Value>,
412}
413
414#[derive(Deserialize)]
415struct GoogleUsageMetadata {
416    #[serde(default, rename = "promptTokenCount")]
417    prompt_token_count: Option<u64>,
418    #[serde(default, rename = "candidatesTokenCount")]
419    candidates_token_count: Option<u64>,
420    #[serde(default, rename = "totalTokenCount")]
421    total_token_count: Option<u64>,
422    #[serde(default, rename = "cachedContentTokenCount")]
423    cached_content_token_count: Option<u64>,
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429
430    #[test]
431    fn test_build_google_request() {
432        let config = StreamConfig {
433            model_config: crate::provider::ModelConfig::google(
434                "gemini-2.0-flash",
435                "Gemini Flash",
436                "test",
437            ),
438            system_prompt: "Be helpful".into(),
439            messages: vec![Message::user("Hello")],
440            tools: vec![],
441            thinking_level: ThinkingLevel::Off,
442            max_tokens: Some(1024),
443            temperature: Some(0.7),
444            cache_config: CacheConfig::default(),
445            response_format: ResponseFormat::Text,
446        };
447
448        let body = build_request_body(&config);
449        assert!(body["contents"].is_array());
450        assert_eq!(body["contents"][0]["role"], "user");
451        assert!(body["systemInstruction"].is_object());
452        assert_eq!(body["generationConfig"]["maxOutputTokens"], 1024);
453        let temp = body["generationConfig"]["temperature"].as_f64().unwrap();
454        assert!((temp - 0.7).abs() < 0.01);
455    }
456
457    #[test]
458    fn test_content_to_google_parts_text() {
459        let content = vec![Content::Text {
460            text: "hello".into(),
461        }];
462        let parts = content_to_google_parts(&content);
463        assert_eq!(parts.len(), 1);
464        assert_eq!(parts[0]["text"], "hello");
465    }
466
467    #[test]
468    fn test_content_to_google_parts_tool_call() {
469        let content = vec![Content::ToolCall {
470            id: "tc-1".into(),
471            name: "bash".into(),
472            arguments: serde_json::json!({"command": "ls"}),
473        }];
474        let parts = content_to_google_parts(&content);
475        assert_eq!(parts[0]["functionCall"]["name"], "bash");
476    }
477}