oxi_ai/providers/
google.rs1use async_trait::async_trait;
4use futures::stream::StreamExt;
5use futures::Stream;
6use reqwest::Client;
7use std::pin::Pin;
8
9use super::google_shared::{
10 build_request_body, convert_messages, convert_tools, create_error_message, parse_google_events,
11};
12use super::openai::split_complete_lines;
13use super::shared_client;
14use super::{Provider, ProviderError, ProviderEvent, StreamOptions};
15use crate::{Api, Context, Model, StopReason};
16
17#[derive(Clone)]
19pub struct GoogleProvider {
20 client: &'static Client,
21 api_key: Option<String>,
22}
23
24impl GoogleProvider {
25 pub fn new() -> Self {
29 Self {
30 client: shared_client(),
31 api_key: None,
32 }
33 }
34}
35
36impl Default for GoogleProvider {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42#[async_trait]
43impl Provider for GoogleProvider {
44 async fn stream(
45 &self,
46 model: &Model,
47 context: &Context,
48 options: Option<StreamOptions>,
49 ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
50 let options = options.unwrap_or_default();
51
52 let api_key = options
54 .api_key
55 .as_ref()
56 .or(self.api_key.as_ref())
57 .ok_or_else(|| ProviderError::MissingApiKey)?;
58
59 let model_id = &model.id;
61 let url = format!(
62 "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse",
63 model_id
64 );
65
66 let contents = convert_messages(context)?;
68
69 let tools_json = convert_tools(&context.tools, false);
71
72 let mut body = build_request_body(
74 &contents,
75 context.system_prompt.as_deref(),
76 tools_json.as_ref(),
77 options.temperature,
78 options.max_tokens,
79 );
80
81 if model.reasoning {
85 let google_opts = options
86 .provider_options
87 .as_ref()
88 .and_then(|po| po.google.as_ref());
89
90 let mut thinking_config = serde_json::json!({});
91
92 thinking_config["includeThoughts"] = serde_json::json!(true);
94
95 if let Some(opts) = google_opts {
96 if let Some(ref level) = opts.thinking_level {
97 thinking_config["thinkingLevel"] = serde_json::json!(level);
98 }
99 if let Some(budget) = opts.thinking_budget {
100 thinking_config["thinkingBudget"] = serde_json::json!(budget);
101 }
102 } else if let Some(ref level) = options.thinking_level {
103 if let Some(effort) = level.as_str() {
105 thinking_config["thinkingLevel"] = serde_json::json!(effort);
106 }
107 }
108
109 if let Some(gc) = body.get_mut("generationConfig") {
111 if let serde_json::Value::Object(map) = gc {
112 map.insert("thinkingConfig".to_string(), thinking_config);
113 }
114 } else {
115 body["generationConfig"] = serde_json::json!({
116 "thinkingConfig": thinking_config,
117 });
118 }
119 }
120
121 let response = self
123 .client
124 .post(&url)
125 .header("x-goog-api-key", api_key)
126 .header("Content-Type", "application/json")
127 .json(&body)
128 .send()
129 .await
130 .map_err(ProviderError::RequestFailed)?;
131
132 if !response.status().is_success() {
133 let status = response.status();
134 let body: String = response.text().await.unwrap_or_default();
135 return Err(ProviderError::HttpError(status.as_u16(), body));
136 }
137
138 let model_name = model.id.clone();
142
143 let stream = response
144 .bytes_stream()
145 .scan(
146 Vec::new(), move |pending_bytes, chunk: Result<bytes::Bytes, reqwest::Error>| {
148 let events = match chunk {
149 Ok(bytes) => {
150 let mut combined =
151 Vec::with_capacity(pending_bytes.len() + bytes.len());
152 combined.extend_from_slice(pending_bytes);
153 combined.extend_from_slice(&bytes);
154 let (text, trailing) = split_complete_lines(&combined);
155 *pending_bytes = trailing;
156 parse_google_events(
157 &text,
158 Api::GoogleGenerativeAi,
159 "google",
160 &model_name,
161 )
162 }
163 Err(e) => vec![ProviderEvent::Error {
164 reason: StopReason::Error,
165 error: create_error_message(
166 Api::GoogleGenerativeAi,
167 "google",
168 &e.to_string(),
169 ),
170 }],
171 };
172 async move { Some(futures::stream::iter(events)) }
173 },
174 )
175 .flatten();
176
177 Ok(Box::pin(stream))
178 }
179
180 fn name(&self) -> &str {
181 "google"
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use crate::{Context, Message};
189
190 #[test]
191 fn test_google_provider_name() {
192 let provider = GoogleProvider::new();
193 assert_eq!(provider.name(), "google");
194 }
195
196 #[test]
197 fn test_build_google_contents_with_text() {
198 let mut ctx = Context::new();
199 ctx.add_message(Message::user("Hello, world!"));
200
201 let contents = convert_messages(&ctx).unwrap();
202 assert_eq!(contents.len(), 1);
203 assert_eq!(contents[0]["role"], "user");
204 assert_eq!(contents[0]["parts"][0]["text"], "Hello, world!");
205 }
206
207 #[test]
208 fn test_build_google_tools() {
209 let tools = vec![crate::Tool::new(
210 "get_weather",
211 "Get weather for a location",
212 serde_json::json!({
213 "type": "object",
214 "properties": {
215 "location": {
216 "type": "string",
217 "description": "The city name"
218 }
219 },
220 "required": ["location"]
221 }),
222 )];
223
224 let tools_json = convert_tools(&tools, false).unwrap();
225 let declarations = tools_json[0]["functionDeclarations"].as_array().unwrap();
226 assert_eq!(declarations.len(), 1);
227 assert_eq!(declarations[0]["name"], "get_weather");
228 }
229
230 #[test]
231 fn test_parse_google_events_basic_text() {
232 let sse_data = r#"data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}"#;
233 let events = parse_google_events(
234 sse_data,
235 Api::GoogleGenerativeAi,
236 "google",
237 "gemini-1.5-pro",
238 );
239 assert!(!events.is_empty());
240 }
241
242 #[test]
243 fn test_create_error_message() {
244 let msg = create_error_message(Api::GoogleGenerativeAi, "google", "Something went wrong");
245 assert_eq!(msg.provider, "google");
246 assert_eq!(msg.api, Api::GoogleGenerativeAi);
247 assert_eq!(msg.stop_reason, StopReason::Error);
248 }
249}