Skip to main content

pi_ai/providers/
google.rs

1//! Google Generative AI provider (`google-generative-ai`).
2//!
3//! Targets the v1beta `generativelanguage.googleapis.com` endpoint with the
4//! `streamGenerateContent` method. Emits the unified `AssistantMessageEvent`
5//! protocol like the other providers.
6//!
7//! The Google SSE format is a JSON array of "candidates" chunks rather than
8//! discrete event names, but `eventsource-stream` still works because the
9//! server sends `data:` framed records.
10
11use std::collections::BTreeMap;
12
13use async_stream::stream;
14use async_trait::async_trait;
15use eventsource_stream::Eventsource;
16use futures::StreamExt;
17use serde::Deserialize;
18use serde_json::{json, Value};
19
20use crate::error::{Error, Result};
21use crate::providers::Provider;
22use crate::retry::{classify_status, parse_retry_after, with_retry, Attempt, RetryConfig};
23use crate::stream::AssistantMessageEventStream;
24use crate::types::{
25    now_ms, AssistantMessage, AssistantMessageEvent, Content, Context, Message, Model, StopReason,
26    StreamOptions, Usage,
27};
28
29#[derive(Deserialize, Debug)]
30struct Chunk {
31    #[serde(default)]
32    candidates: Vec<Candidate>,
33    #[serde(default)]
34    usage_metadata: Option<UsageMetadata>,
35    #[serde(default)]
36    model_version: Option<String>,
37}
38
39#[derive(Deserialize, Debug)]
40struct Candidate {
41    #[serde(default)]
42    content: Option<CandidateContent>,
43    #[serde(default)]
44    finish_reason: Option<String>,
45}
46
47#[derive(Deserialize, Debug)]
48struct CandidateContent {
49    #[serde(default)]
50    parts: Vec<Part>,
51}
52
53#[derive(Deserialize, Debug)]
54struct Part {
55    #[serde(default)]
56    text: Option<String>,
57    #[serde(default)]
58    function_call: Option<FunctionCall>,
59}
60
61#[derive(Deserialize, Debug)]
62struct FunctionCall {
63    #[serde(default)]
64    name: String,
65    #[serde(default)]
66    args: Value,
67}
68
69#[derive(Deserialize, Debug, Default)]
70struct UsageMetadata {
71    #[serde(default)]
72    prompt_token_count: u64,
73    #[serde(default)]
74    candidates_token_count: u64,
75    #[serde(default)]
76    total_token_count: u64,
77}
78
79fn convert_messages(messages: &[Message]) -> Vec<Value> {
80    let mut out: Vec<Value> = Vec::new();
81    for m in messages {
82        match m {
83            Message::User { content, .. } => {
84                let parts: Vec<Value> = content
85                    .iter()
86                    .filter_map(|c| c.as_text().map(|t| json!({"text": t})))
87                    .collect();
88                out.push(json!({"role": "user", "parts": parts}));
89            }
90            Message::Assistant(a) => {
91                let mut parts: Vec<Value> = Vec::new();
92                for c in &a.content {
93                    match c {
94                        Content::Text { text } => parts.push(json!({"text": text})),
95                        Content::ToolCall {
96                            name, arguments, ..
97                        } => {
98                            parts.push(json!({
99                                "functionCall": {"name": name, "args": arguments}
100                            }));
101                        }
102                        _ => {}
103                    }
104                }
105                out.push(json!({"role": "model", "parts": parts}));
106            }
107            Message::ToolResult(tr) => {
108                let text = tr
109                    .content
110                    .iter()
111                    .filter_map(|c| c.as_text().map(|s| s.to_string()))
112                    .collect::<Vec<_>>()
113                    .join("");
114                out.push(json!({
115                    "role": "user",
116                    "parts": [{
117                        "functionResponse": {
118                            "name": tr.tool_name,
119                            "response": {"output": text, "is_error": tr.is_error}
120                        }
121                    }]
122                }));
123            }
124        }
125    }
126    out
127}
128
129fn build_body(context: &Context, options: &StreamOptions) -> Value {
130    let mut body = json!({
131        "contents": convert_messages(&context.messages),
132    });
133    if let Some(sp) = &context.system_prompt {
134        body["systemInstruction"] = json!({"role": "system", "parts": [{"text": sp}]});
135    }
136    if let Some(t) = options.temperature {
137        body["generationConfig"] = json!({"temperature": t});
138    }
139    if !context.tools.is_empty() {
140        let decls: Vec<Value> = context
141            .tools
142            .iter()
143            .map(|t| {
144                json!({
145                    "name": t.name,
146                    "description": t.description,
147                    "parameters": t.parameters,
148                })
149            })
150            .collect();
151        body["tools"] = json!([{"functionDeclarations": decls}]);
152    }
153    body
154}
155
156pub struct GoogleProvider {
157    client: reqwest::Client,
158}
159
160impl GoogleProvider {
161    pub fn new() -> Self {
162        Self {
163            client: reqwest::Client::new(),
164        }
165    }
166}
167
168impl Default for GoogleProvider {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174#[async_trait]
175impl Provider for GoogleProvider {
176    async fn stream(
177        &self,
178        model: &Model,
179        context: &Context,
180        options: &StreamOptions,
181    ) -> Result<AssistantMessageEventStream> {
182        let api_key = options
183            .api_key
184            .clone()
185            .or_else(|| std::env::var("GOOGLE_API_KEY").ok())
186            .or_else(|| std::env::var("GEMINI_API_KEY").ok())
187            .ok_or_else(|| Error::MissingApiKey("google".into()))?;
188        let base_url = options
189            .base_url
190            .clone()
191            .unwrap_or_else(|| model.base_url.clone());
192        let url = format!(
193            "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
194            base_url.trim_end_matches('/'),
195            model.id,
196            api_key,
197        );
198        let body = build_body(context, options);
199        let cancel = options.cancel.clone();
200        let extra_headers: BTreeMap<String, String> = options.headers.clone();
201
202        let resp = with_retry(&RetryConfig::default(), cancel.as_ref(), |_| {
203            let client = self.client.clone();
204            let url = url.clone();
205            let body = body.clone();
206            let extra_headers = extra_headers.clone();
207            async move {
208                let mut req = client
209                    .post(&url)
210                    .header("accept", "text/event-stream")
211                    .header("content-type", "application/json");
212                for (k, v) in extra_headers {
213                    req = req.header(k, v);
214                }
215                let r = match req.json(&body).send().await {
216                    Ok(r) => r,
217                    Err(e) => {
218                        return if e.is_timeout() || e.is_connect() {
219                            Attempt::Retry {
220                                error: Error::Http(e),
221                                retry_after: None,
222                            }
223                        } else {
224                            Attempt::Fatal(Error::Http(e))
225                        };
226                    }
227                };
228                let status = r.status();
229                if status.is_success() {
230                    return Attempt::Ok(r);
231                }
232                let retry_after = r
233                    .headers()
234                    .get("retry-after")
235                    .and_then(|v| v.to_str().ok())
236                    .and_then(parse_retry_after);
237                let body_text = r.text().await.unwrap_or_default();
238                let err = Error::ProviderError {
239                    status: status.as_u16(),
240                    body: body_text,
241                };
242                match classify_status(status.as_u16()) {
243                    Some(_) => Attempt::Retry {
244                        error: err,
245                        retry_after,
246                    },
247                    None => Attempt::Fatal(err),
248                }
249            }
250        })
251        .await?;
252
253        let api = model.api.clone();
254        let provider = model.provider.clone();
255        let model_id = model.id.clone();
256        let cancel_for_stream = cancel.clone();
257
258        let s = stream! {
259            yield Ok(AssistantMessageEvent::Start);
260            let mut sse = resp.bytes_stream().eventsource();
261
262            let mut text_buf = String::new();
263            let mut text_started = false;
264            let mut text_index: usize = 0;
265            let mut tool_blocks: Vec<(String, String, Value)> = Vec::new();
266            let mut stop = StopReason::Stop;
267            let mut usage = Usage::default();
268            let mut response_model: Option<String> = None;
269
270            while let Some(ev) = sse.next().await {
271                if let Some(c) = &cancel_for_stream {
272                    if c.is_cancelled() { yield Err(Error::Cancelled); return; }
273                }
274                let ev = match ev {
275                    Ok(e) => e,
276                    Err(e) => { yield Err(Error::InvalidResponse(format!("sse: {e}"))); return; }
277                };
278                if ev.data.is_empty() { continue; }
279                let chunk: Chunk = match serde_json::from_str(&ev.data) {
280                    Ok(c) => c,
281                    Err(_) => continue,
282                };
283                if let Some(m) = chunk.model_version { response_model = Some(m); }
284                if let Some(u) = chunk.usage_metadata {
285                    usage.input = u.prompt_token_count;
286                    usage.output = u.candidates_token_count;
287                    usage.total_tokens = u.total_token_count;
288                }
289                for cand in chunk.candidates {
290                    if let Some(reason) = cand.finish_reason {
291                        stop = match reason.as_str() {
292                            "STOP" => StopReason::Stop,
293                            "MAX_TOKENS" => StopReason::Length,
294                            _ => StopReason::Stop,
295                        };
296                    }
297                    if let Some(content) = cand.content {
298                        for part in content.parts {
299                            if let Some(t) = part.text {
300                                if !t.is_empty() {
301                                    if !text_started {
302                                        text_started = true;
303                                        yield Ok(AssistantMessageEvent::TextStart { content_index: text_index });
304                                    }
305                                    text_buf.push_str(&t);
306                                    yield Ok(AssistantMessageEvent::TextDelta { content_index: text_index, delta: t });
307                                }
308                            }
309                            if let Some(fc) = part.function_call {
310                                let id = format!("call_{}", tool_blocks.len() + 1);
311                                let block_index = text_index + if text_started { 1 } else { 0 } + tool_blocks.len();
312                                yield Ok(AssistantMessageEvent::ToolCallStart {
313                                    content_index: block_index,
314                                    id: id.clone(),
315                                    name: fc.name.clone(),
316                                });
317                                yield Ok(AssistantMessageEvent::ToolCallEnd {
318                                    content_index: block_index,
319                                    id: id.clone(),
320                                    name: fc.name.clone(),
321                                    arguments: fc.args.clone(),
322                                });
323                                if fc.finish_reason_set_to_tool_use() { stop = StopReason::ToolUse; }
324                                tool_blocks.push((id, fc.name, fc.args));
325                            }
326                        }
327                    }
328                }
329            }
330
331            if text_started {
332                yield Ok(AssistantMessageEvent::TextEnd { content_index: text_index, content: text_buf.clone() });
333                text_index += 1;
334            }
335            if !tool_blocks.is_empty() && stop == StopReason::Stop {
336                stop = StopReason::ToolUse;
337            }
338            let mut out_content: Vec<Content> = Vec::new();
339            if text_started {
340                out_content.push(Content::Text { text: text_buf });
341            }
342            for (id, name, args) in tool_blocks {
343                out_content.push(Content::ToolCall { id, name, arguments: args });
344            }
345            let _ = text_index;
346            let message = AssistantMessage {
347                content: out_content,
348                api,
349                provider,
350                model: response_model.unwrap_or(model_id),
351                usage,
352                stop_reason: stop,
353                error_message: None,
354                timestamp: now_ms(),
355            };
356            yield Ok(AssistantMessageEvent::Done { reason: stop, message });
357        };
358
359        Ok(s.boxed())
360    }
361}
362
363// Helper marker — Gemini doesn't signal tool use in finish_reason; treat any
364// function_call as implying ToolUse if no other stop reason is reported.
365impl FunctionCall {
366    fn finish_reason_set_to_tool_use(&self) -> bool {
367        true
368    }
369}