1use super::traits::*;
40use crate::types::*;
41use async_trait::async_trait;
42use futures::StreamExt;
43use serde::Deserialize;
44use tokio::sync::mpsc;
45use tracing::{debug, warn};
46
47pub struct GoogleProvider;
49
50#[async_trait]
51impl StreamProvider for GoogleProvider {
52 fn provider_id(&self) -> &str {
53 "google"
54 }
55
56 async fn stream(
57 &self,
58 config: StreamConfig, tx: mpsc::UnboundedSender<StreamEvent>, cancel: tokio_util::sync::CancellationToken, ) -> Result<Message, ProviderError> {
62 let model_config = &config.model_config;
63 let api_key = model_config.resolve_api_key().await?;
65
66 let base_url = &model_config.base_url;
67 let url = format!(
69 "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
70 base_url, config.model_config.id, api_key
71 );
72
73 let body = build_request_body(&config);
74 debug!("Google GenAI request: model={}", config.model_config.id);
75
76 let client = reqwest::Client::new();
77 let mut request = client.post(&url).header("content-type", "application/json");
78
79 for (k, v) in &model_config.headers {
80 request = request.header(k, v);
81 }
82
83 let response = request
104 .json(&body)
105 .send()
106 .await
107 .map_err(|e| ProviderError::Network(e.to_string()))?;
108
109 if !response.status().is_success() {
110 let status = response.status();
111 let body = response.text().await.unwrap_or_default();
112 return Err(ProviderError::classify(
113 status.as_u16(),
114 &format!("Google API error {}: {}", status, body),
115 ));
116 }
117
118 let mut content: Vec<Content> = Vec::new();
119 let mut usage = Usage::default();
120 let mut stop_reason = StopReason::Stop;
121
122 let _ = tx.send(StreamEvent::Start);
123
124 let mut stream = response.bytes_stream();
126 let mut buffer = String::new();
127
128 loop {
129 tokio::select! {
130 _ = cancel.cancelled() => {
131 return Err(ProviderError::Cancelled);
132 }
133 chunk = stream.next() => {
134 match chunk {
135 None => break,
136 Some(Err(e)) => {
137 warn!("Google stream error: {}", e);
138 break;
139 }
140 Some(Ok(bytes)) => {
141 buffer.push_str(&String::from_utf8_lossy(&bytes));
142
143 while let Some(pos) = buffer.find("\n\n") {
145 let event_str = buffer[..pos].to_string();
146 buffer = buffer[pos + 2..].to_string();
147
148 let data = event_str
150 .lines()
151 .find(|l| l.starts_with("data: "))
152 .map(|l| &l[6..])
153 .unwrap_or("");
154
155 if data.is_empty() {
156 continue;
157 }
158
159 let chunk: GoogleChunk = match serde_json::from_str(data) {
160 Ok(c) => c,
161 Err(e) => {
162 debug!("Failed to parse Google chunk: {}", e);
163 continue;
164 }
165 };
166
167 for candidate in &chunk.candidates.unwrap_or_default() {
169 if let Some(c) = &candidate.content {
170 for part in &c.parts {
171 if let Some(text) = &part.text {
172 let text_idx = content.iter().position(|c| matches!(c, Content::Text { .. }));
173 let idx = match text_idx {
174 Some(i) => i,
175 None => {
176 content.push(Content::Text { text: String::new() });
177 content.len() - 1
178 }
179 };
180 if let Some(Content::Text { text: t }) = content.get_mut(idx) {
181 t.push_str(text);
182 }
183 let _ = tx.send(StreamEvent::TextDelta {
184 content_index: idx,
185 delta: text.clone(),
186 });
187 }
188 if let Some(fc) = &part.function_call {
189 let id = format!("google-fc-{}", content.len());
190 let args = fc.args.clone().unwrap_or(serde_json::Value::Object(Default::default()));
191 let idx = content.len();
192 content.push(Content::ToolCall {
193 id: id.clone(),
194 name: fc.name.clone(),
195 arguments: args,
196 });
197 let _ = tx.send(StreamEvent::ToolCallStart {
198 content_index: idx,
199 id,
200 name: fc.name.clone(),
201 });
202 let _ = tx.send(StreamEvent::ToolCallEnd { content_index: idx });
203 stop_reason = StopReason::ToolUse;
204 }
205 }
206 }
207 if let Some(reason) = &candidate.finish_reason {
208 stop_reason = match reason.as_str() {
209 "STOP" => StopReason::Stop,
210 "MAX_TOKENS" | "RECITATION" => StopReason::Length,
211 _ => StopReason::Stop,
212 };
213 }
214 }
215
216 if let Some(u) = &chunk.usage_metadata {
218 usage.input = u.prompt_token_count.unwrap_or(0);
219 usage.output = u.candidates_token_count.unwrap_or(0);
220 usage.total_tokens = u.total_token_count.unwrap_or(0);
221 usage.cache_read = u.cached_content_token_count.unwrap_or(0);
222 }
223 }
224 }
225 }
226 }
227 }
228 }
229
230 let message = Message::Assistant {
231 content,
232 stop_reason,
233 model: config.model_config.id.clone(),
234 provider: model_config.provider.clone(),
235 usage,
236 timestamp: now_ms(),
237 error_message: None,
238 };
239
240 let _ = tx.send(StreamEvent::Done {
241 message: message.clone(),
242 });
243 Ok(message)
244 }
245}
246
247fn build_request_body(config: &StreamConfig) -> serde_json::Value {
248 let mut contents: Vec<serde_json::Value> = Vec::new();
249
250 for msg in &config.messages {
251 match msg {
252 Message::User { content, .. } => {
253 let parts = content_to_google_parts(content);
254 contents.push(serde_json::json!({
255 "role": "user",
256 "parts": parts,
257 }));
258 }
259 Message::Assistant { content, .. } => {
260 let parts = content_to_google_parts(content);
261 contents.push(serde_json::json!({
262 "role": "model",
263 "parts": parts,
264 }));
265 }
266 Message::ToolResult {
267 tool_call_id: _,
268 tool_name,
269 content,
270 ..
271 } => {
272 let text = content
273 .iter()
274 .find_map(|c| match c {
275 Content::Text { text } => Some(text.clone()),
276 _ => None,
277 })
278 .unwrap_or_default();
279
280 let mut parts = vec![serde_json::json!({
281 "functionResponse": {
282 "name": tool_name,
283 "response": {"result": text},
284 }
285 })];
286
287 for c in content {
289 if let Content::Image { data, mime_type } = c {
290 parts.push(serde_json::json!({
291 "inlineData": {"mimeType": mime_type, "data": data},
292 }));
293 }
294 }
295
296 contents.push(serde_json::json!({
297 "role": "user",
298 "parts": parts,
299 }));
300 }
301 }
302 }
303
304 let mut body = serde_json::json!({
305 "contents": contents,
306 });
307
308 if !config.system_prompt.is_empty() {
309 body["systemInstruction"] = serde_json::json!({
310 "parts": [{"text": config.system_prompt}],
311 });
312 }
313
314 let mut generation_config = serde_json::json!({});
315 if let Some(max) = config.max_tokens {
316 generation_config["maxOutputTokens"] = serde_json::json!(max);
317 }
318 if let Some(temp) = config.temperature {
319 generation_config["temperature"] = serde_json::json!(temp);
320 }
321 match &config.response_format {
325 ResponseFormat::Text => {} ResponseFormat::JsonObject => {
327 generation_config["responseMimeType"] = serde_json::json!("application/json");
328 }
329 ResponseFormat::JsonSchema { schema, .. } => {
330 generation_config["responseMimeType"] = serde_json::json!("application/json");
331 generation_config["responseSchema"] = schema.clone();
332 }
333 }
334 if generation_config != serde_json::json!({}) {
335 body["generationConfig"] = generation_config;
336 }
337
338 if !config.tools.is_empty() {
339 let declarations: Vec<serde_json::Value> = config
340 .tools
341 .iter()
342 .map(|t| {
343 serde_json::json!({
344 "name": t.name,
345 "description": t.description,
346 "parameters": t.parameters,
347 })
348 })
349 .collect();
350 body["tools"] = serde_json::json!([{
351 "functionDeclarations": declarations,
352 }]);
353 }
354
355 body
356}
357
358fn content_to_google_parts(content: &[Content]) -> Vec<serde_json::Value> {
359 content
360 .iter()
361 .filter_map(|c| match c {
362 Content::Text { text } => Some(serde_json::json!({"text": text})),
363 Content::Image { data, mime_type } => Some(serde_json::json!({
364 "inlineData": {"mimeType": mime_type, "data": data},
365 })),
366 Content::ToolCall {
367 name, arguments, ..
368 } => Some(serde_json::json!({
369 "functionCall": {"name": name, "args": arguments},
370 })),
371 Content::Thinking { .. } => None,
372 })
373 .collect()
374}
375
376#[derive(Deserialize)]
378struct GoogleChunk {
379 #[serde(default)]
380 candidates: Option<Vec<GoogleCandidate>>,
381 #[serde(default, rename = "usageMetadata")]
382 usage_metadata: Option<GoogleUsageMetadata>,
383}
384
385#[derive(Deserialize)]
386struct GoogleCandidate {
387 #[serde(default)]
388 content: Option<GoogleContent>,
389 #[serde(default, rename = "finishReason")]
390 finish_reason: Option<String>,
391}
392
393#[derive(Deserialize)]
394struct GoogleContent {
395 #[serde(default)]
396 parts: Vec<GooglePart>,
397}
398
399#[derive(Deserialize)]
400struct GooglePart {
401 #[serde(default)]
402 text: Option<String>,
403 #[serde(default, rename = "functionCall")]
404 function_call: Option<GoogleFunctionCall>,
405}
406
407#[derive(Deserialize)]
408struct GoogleFunctionCall {
409 name: String,
410 #[serde(default)]
411 args: Option<serde_json::Value>,
412}
413
414#[derive(Deserialize)]
415struct GoogleUsageMetadata {
416 #[serde(default, rename = "promptTokenCount")]
417 prompt_token_count: Option<u64>,
418 #[serde(default, rename = "candidatesTokenCount")]
419 candidates_token_count: Option<u64>,
420 #[serde(default, rename = "totalTokenCount")]
421 total_token_count: Option<u64>,
422 #[serde(default, rename = "cachedContentTokenCount")]
423 cached_content_token_count: Option<u64>,
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429
430 #[test]
431 fn test_build_google_request() {
432 let config = StreamConfig {
433 model_config: crate::provider::ModelConfig::google(
434 "gemini-2.0-flash",
435 "Gemini Flash",
436 "test",
437 ),
438 system_prompt: "Be helpful".into(),
439 messages: vec![Message::user("Hello")],
440 tools: vec![],
441 thinking_level: ThinkingLevel::Off,
442 max_tokens: Some(1024),
443 temperature: Some(0.7),
444 cache_config: CacheConfig::default(),
445 response_format: ResponseFormat::Text,
446 };
447
448 let body = build_request_body(&config);
449 assert!(body["contents"].is_array());
450 assert_eq!(body["contents"][0]["role"], "user");
451 assert!(body["systemInstruction"].is_object());
452 assert_eq!(body["generationConfig"]["maxOutputTokens"], 1024);
453 let temp = body["generationConfig"]["temperature"].as_f64().unwrap();
454 assert!((temp - 0.7).abs() < 0.01);
455 }
456
457 #[test]
458 fn test_content_to_google_parts_text() {
459 let content = vec![Content::Text {
460 text: "hello".into(),
461 }];
462 let parts = content_to_google_parts(&content);
463 assert_eq!(parts.len(), 1);
464 assert_eq!(parts[0]["text"], "hello");
465 }
466
467 #[test]
468 fn test_content_to_google_parts_tool_call() {
469 let content = vec![Content::ToolCall {
470 id: "tc-1".into(),
471 name: "bash".into(),
472 arguments: serde_json::json!({"command": "ls"}),
473 }];
474 let parts = content_to_google_parts(&content);
475 assert_eq!(parts[0]["functionCall"]["name"], "bash");
476 }
477}