Skip to main content

oxi_ai/providers/
google.rs

1//! Google Generative AI provider (Gemini API)
2
3use async_trait::async_trait;
4use futures::stream::StreamExt;
5use futures::Stream;
6use reqwest::Client;
7use std::pin::Pin;
8
9use super::google_shared::{
10    build_request_body, convert_messages, convert_tools, create_error_message, parse_google_events,
11};
12use super::openai::split_complete_lines;
13use super::shared_client;
14use super::{Provider, ProviderError, ProviderEvent, StreamOptions};
15use crate::{Api, Context, Model, StopReason};
16
17/// Google Generative AI provider
18#[derive(Clone)]
19pub struct GoogleProvider {
20    client: &'static Client,
21    api_key: Option<String>,
22}
23
24impl GoogleProvider {
25    /// Create a new Google provider without an API key.
26    ///
27    /// API keys are resolved at request time via auth.json or StreamOptions.
28    pub fn new() -> Self {
29        Self {
30            client: shared_client(),
31            api_key: None,
32        }
33    }
34}
35
36impl Default for GoogleProvider {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42#[async_trait]
43impl Provider for GoogleProvider {
44    async fn stream(
45        &self,
46        model: &Model,
47        context: &Context,
48        options: Option<StreamOptions>,
49    ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
50        let options = options.unwrap_or_default();
51
52        // Get API key
53        let api_key = options
54            .api_key
55            .as_ref()
56            .or(self.api_key.as_ref())
57            .ok_or_else(|| ProviderError::MissingApiKey)?;
58
59        // Build the request URL (without key - uses header instead for security)
60        let model_id = &model.id;
61        let url = format!(
62            "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse",
63            model_id
64        );
65
66        // Build contents using shared conversion
67        let contents = convert_messages(context)?;
68
69        // Build tools using shared conversion
70        let tools_json = convert_tools(&context.tools, false);
71
72        // Build request body using shared helper
73        let mut body = build_request_body(
74            &contents,
75            context.system_prompt.as_deref(),
76            tools_json.as_ref(),
77            options.temperature,
78            options.max_tokens,
79        );
80
81        // ── Google thinking config (via ProviderOptions) ────────────────
82        // When the model supports reasoning, apply thinkingConfig from
83        // provider_options.google. Mirrors opencode's Gemini thinking support.
84        if model.reasoning {
85            let google_opts = options
86                .provider_options
87                .as_ref()
88                .and_then(|po| po.google.as_ref());
89
90            let mut thinking_config = serde_json::json!({});
91
92            // Include thoughts (always true for reasoning models)
93            thinking_config["includeThoughts"] = serde_json::json!(true);
94
95            if let Some(opts) = google_opts {
96                if let Some(ref level) = opts.thinking_level {
97                    thinking_config["thinkingLevel"] = serde_json::json!(level);
98                }
99                if let Some(budget) = opts.thinking_budget {
100                    thinking_config["thinkingBudget"] = serde_json::json!(budget);
101                }
102            } else if let Some(ref level) = options.thinking_level {
103                // Fallback: derive from thinking_level
104                if let Some(effort) = level.as_str() {
105                    thinking_config["thinkingLevel"] = serde_json::json!(effort);
106                }
107            }
108
109            // Merge into generationConfig
110            if let Some(gc) = body.get_mut("generationConfig") {
111                if let serde_json::Value::Object(map) = gc {
112                    map.insert("thinkingConfig".to_string(), thinking_config);
113                }
114            } else {
115                body["generationConfig"] = serde_json::json!({
116                    "thinkingConfig": thinking_config,
117                });
118            }
119        }
120
121        // Make request with API key in header (not URL query param)
122        let response = self
123            .client
124            .post(&url)
125            .header("x-goog-api-key", api_key)
126            .header("Content-Type", "application/json")
127            .json(&body)
128            .send()
129            .await
130            .map_err(ProviderError::RequestFailed)?;
131
132        if !response.status().is_success() {
133            let status = response.status();
134            let body: String = response.text().await.unwrap_or_default();
135            return Err(ProviderError::HttpError(status.as_u16(), body));
136        }
137
138        // Create event stream — use split_complete_lines (like OpenAI provider)
139        // to handle UTF-8 boundaries safely.  Google SSE lines can be split
140        // across HTTP chunks at arbitrary byte boundaries.
141        let model_name = model.id.clone();
142
143        let stream = response
144            .bytes_stream()
145            .scan(
146                Vec::new(), // pending_bytes
147                move |pending_bytes, chunk: Result<bytes::Bytes, reqwest::Error>| {
148                    let events = match chunk {
149                        Ok(bytes) => {
150                            let mut combined =
151                                Vec::with_capacity(pending_bytes.len() + bytes.len());
152                            combined.extend_from_slice(pending_bytes);
153                            combined.extend_from_slice(&bytes);
154                            let (text, trailing) = split_complete_lines(&combined);
155                            *pending_bytes = trailing;
156                            parse_google_events(
157                                &text,
158                                Api::GoogleGenerativeAi,
159                                "google",
160                                &model_name,
161                            )
162                        }
163                        Err(e) => vec![ProviderEvent::Error {
164                            reason: StopReason::Error,
165                            error: create_error_message(
166                                Api::GoogleGenerativeAi,
167                                "google",
168                                &e.to_string(),
169                            ),
170                        }],
171                    };
172                    async move { Some(futures::stream::iter(events)) }
173                },
174            )
175            .flatten();
176
177        Ok(Box::pin(stream))
178    }
179
180    fn name(&self) -> &str {
181        "google"
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::{Context, Message};
189
190    #[test]
191    fn test_google_provider_name() {
192        let provider = GoogleProvider::new();
193        assert_eq!(provider.name(), "google");
194    }
195
196    #[test]
197    fn test_build_google_contents_with_text() {
198        let mut ctx = Context::new();
199        ctx.add_message(Message::user("Hello, world!"));
200
201        let contents = convert_messages(&ctx).unwrap();
202        assert_eq!(contents.len(), 1);
203        assert_eq!(contents[0]["role"], "user");
204        assert_eq!(contents[0]["parts"][0]["text"], "Hello, world!");
205    }
206
207    #[test]
208    fn test_build_google_tools() {
209        let tools = vec![crate::Tool::new(
210            "get_weather",
211            "Get weather for a location",
212            serde_json::json!({
213                "type": "object",
214                "properties": {
215                    "location": {
216                        "type": "string",
217                        "description": "The city name"
218                    }
219                },
220                "required": ["location"]
221            }),
222        )];
223
224        let tools_json = convert_tools(&tools, false).unwrap();
225        let declarations = tools_json[0]["functionDeclarations"].as_array().unwrap();
226        assert_eq!(declarations.len(), 1);
227        assert_eq!(declarations[0]["name"], "get_weather");
228    }
229
230    #[test]
231    fn test_parse_google_events_basic_text() {
232        let sse_data = r#"data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}"#;
233        let events = parse_google_events(
234            sse_data,
235            Api::GoogleGenerativeAi,
236            "google",
237            "gemini-1.5-pro",
238        );
239        assert!(!events.is_empty());
240    }
241
242    #[test]
243    fn test_create_error_message() {
244        let msg = create_error_message(Api::GoogleGenerativeAi, "google", "Something went wrong");
245        assert_eq!(msg.provider, "google");
246        assert_eq!(msg.api, Api::GoogleGenerativeAi);
247        assert_eq!(msg.stop_reason, StopReason::Error);
248    }
249}