use super::traits::*;
use crate::types::*;
use async_trait::async_trait;
use futures::StreamExt;
use serde::Deserialize;
use tokio::sync::mpsc;
use tracing::{debug, warn};
pub struct GoogleProvider;
#[async_trait]
impl StreamProvider for GoogleProvider {
async fn stream(
&self,
config: StreamConfig,
tx: mpsc::UnboundedSender<StreamEvent>,
cancel: tokio_util::sync::CancellationToken,
) -> Result<Message, ProviderError> {
let model_config = config
.model_config
.as_ref()
.ok_or_else(|| ProviderError::Other("ModelConfig required".into()))?;
let base_url = &model_config.base_url;
let url = format!(
"{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
base_url, config.model, config.api_key
);
let body = build_request_body(&config);
debug!("Google GenAI request: model={}", config.model);
let client = reqwest::Client::new();
let mut request = client.post(&url).header("content-type", "application/json");
for (k, v) in &model_config.headers {
request = request.header(k, v);
}
let response = request
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(ProviderError::classify(
status.as_u16(),
&format!("Google API error {}: {}", status, body),
));
}
let mut content: Vec<Content> = Vec::new();
let mut usage = Usage::default();
let mut stop_reason = StopReason::Stop;
let _ = tx.send(StreamEvent::Start);
let mut stream = response.bytes_stream();
let mut buffer = String::new();
loop {
tokio::select! {
_ = cancel.cancelled() => {
return Err(ProviderError::Cancelled);
}
chunk = stream.next() => {
match chunk {
None => break,
Some(Err(e)) => {
warn!("Google stream error: {}", e);
break;
}
Some(Ok(bytes)) => {
buffer.push_str(&String::from_utf8_lossy(&bytes));
while let Some(pos) = buffer.find("\n\n").or_else(|| buffer.find("\r\n\r\n")) {
let sep_len = if buffer[pos..].starts_with("\r\n\r\n") { 4 } else { 2 };
let event_str = buffer[..pos].to_string();
buffer = buffer[pos + sep_len..].to_string();
let data = event_str
.lines()
.map(|l| l.trim_end_matches('\r'))
.find(|l| l.starts_with("data: "))
.map(|l| &l[6..])
.unwrap_or("");
if data.is_empty() {
continue;
}
let chunk: GoogleChunk = match serde_json::from_str(data) {
Ok(c) => c,
Err(e) => {
warn!("Failed to parse Google chunk: {}", e);
continue;
}
};
for candidate in &chunk.candidates.unwrap_or_default() {
if let Some(c) = &candidate.content {
for part in &c.parts {
if let Some(text) = &part.text {
if text.is_empty() {
continue;
}
let text_idx = content.iter().position(|c| matches!(c, Content::Text { .. }));
let idx = match text_idx {
Some(i) => i,
None => {
content.push(Content::Text { text: String::new() });
content.len() - 1
}
};
if let Some(Content::Text { text: t }) = content.get_mut(idx) {
t.push_str(text);
}
let _ = tx.send(StreamEvent::TextDelta {
content_index: idx,
delta: text.clone(),
});
}
if let Some(fc) = &part.function_call {
let id = fc.id.clone().unwrap_or_else(|| format!("google-fc-{}", content.len()));
let args = fc.args.clone().unwrap_or(serde_json::Value::Object(Default::default()));
let metadata = part.thought_signature.as_ref().map(|sig| {
serde_json::json!({"thought_signature": sig})
});
let idx = content.len();
content.push(Content::ToolCall {
id: id.clone(),
name: fc.name.clone(),
arguments: args,
provider_metadata: metadata,
});
let _ = tx.send(StreamEvent::ToolCallStart {
content_index: idx,
id,
name: fc.name.clone(),
});
let _ = tx.send(StreamEvent::ToolCallEnd { content_index: idx });
stop_reason = StopReason::ToolUse;
}
}
}
if let Some(reason) = &candidate.finish_reason {
if stop_reason != StopReason::ToolUse {
stop_reason = match reason.as_str() {
"STOP" => StopReason::Stop,
"MAX_TOKENS" | "RECITATION" => StopReason::Length,
_ => StopReason::Stop,
};
}
}
}
if let Some(u) = &chunk.usage_metadata {
usage.input = u.prompt_token_count.unwrap_or(0);
usage.output = u.candidates_token_count.unwrap_or(0);
usage.total_tokens = u.total_token_count.unwrap_or(0);
usage.cache_read = u.cached_content_token_count.unwrap_or(0);
}
}
}
}
}
}
}
let message = Message::Assistant {
content,
stop_reason,
model: config.model.clone(),
provider: model_config.provider.clone(),
usage,
timestamp: now_ms(),
error_message: None,
};
let _ = tx.send(StreamEvent::Done {
message: message.clone(),
});
Ok(message)
}
}
fn build_request_body(config: &StreamConfig) -> serde_json::Value {
let mut contents: Vec<serde_json::Value> = Vec::new();
for msg in &config.messages {
match msg {
Message::User { content, .. } => {
let parts = content_to_google_parts(content);
contents.push(serde_json::json!({
"role": "user",
"parts": parts,
}));
}
Message::Assistant { content, .. } => {
let parts = content_to_google_parts(content);
contents.push(serde_json::json!({
"role": "model",
"parts": parts,
}));
}
Message::ToolResult {
tool_call_id,
tool_name,
content,
..
} => {
let text = content
.iter()
.find_map(|c| match c {
Content::Text { text } => Some(text.clone()),
_ => None,
})
.unwrap_or_default();
let mut fr = serde_json::json!({
"name": tool_name,
"response": {"result": text},
});
if !tool_call_id.is_empty() && !tool_call_id.starts_with("google-fc-") {
fr["id"] = serde_json::json!(tool_call_id);
}
let mut parts = vec![serde_json::json!({"functionResponse": fr})];
for c in content {
if let Content::Image { data, mime_type } = c {
parts.push(serde_json::json!({
"inlineData": {"mimeType": mime_type, "data": data},
}));
}
}
contents.push(serde_json::json!({
"role": "user",
"parts": parts,
}));
}
}
}
let mut body = serde_json::json!({
"contents": contents,
});
if !config.system_prompt.is_empty() {
body["systemInstruction"] = serde_json::json!({
"parts": [{"text": config.system_prompt}],
});
}
let mut generation_config = serde_json::json!({});
if let Some(max) = config.max_tokens {
generation_config["maxOutputTokens"] = serde_json::json!(max);
}
if let Some(temp) = config.temperature {
generation_config["temperature"] = serde_json::json!(temp);
}
if generation_config != serde_json::json!({}) {
body["generationConfig"] = generation_config;
}
if !config.tools.is_empty() {
let declarations: Vec<serde_json::Value> = config
.tools
.iter()
.map(|t| {
serde_json::json!({
"name": t.name,
"description": t.description,
"parameters": t.parameters,
})
})
.collect();
body["tools"] = serde_json::json!([{
"functionDeclarations": declarations,
}]);
}
body
}
fn content_to_google_parts(content: &[Content]) -> Vec<serde_json::Value> {
content
.iter()
.filter(|c| !matches!(c, Content::Text { text } if text.is_empty()))
.filter_map(|c| match c {
Content::Text { text } => Some(serde_json::json!({"text": text})),
Content::Image { data, mime_type } => Some(serde_json::json!({
"inlineData": {"mimeType": mime_type, "data": data},
})),
Content::ToolCall {
id,
name,
arguments,
provider_metadata,
} => {
let mut fc = serde_json::json!({"name": name, "args": arguments});
if !id.is_empty() && !id.starts_with("google-fc-") {
fc["id"] = serde_json::json!(id);
}
let mut part = serde_json::json!({"functionCall": fc});
if let Some(sig) = provider_metadata
.as_ref()
.and_then(|m| m.get("thought_signature"))
.and_then(|v| v.as_str())
{
part["thoughtSignature"] = serde_json::json!(sig);
}
Some(part)
}
Content::Thinking { .. } => None,
})
.collect()
}
#[derive(Deserialize)]
struct GoogleChunk {
#[serde(default)]
candidates: Option<Vec<GoogleCandidate>>,
#[serde(default, rename = "usageMetadata")]
usage_metadata: Option<GoogleUsageMetadata>,
}
#[derive(Deserialize)]
struct GoogleCandidate {
#[serde(default)]
content: Option<GoogleContent>,
#[serde(default, rename = "finishReason")]
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct GoogleContent {
#[serde(default)]
parts: Vec<GooglePart>,
}
#[derive(Deserialize)]
struct GooglePart {
#[serde(default)]
text: Option<String>,
#[serde(default, rename = "functionCall")]
function_call: Option<GoogleFunctionCall>,
#[serde(default, rename = "thoughtSignature")]
thought_signature: Option<String>,
}
#[derive(Deserialize)]
struct GoogleFunctionCall {
name: String,
#[serde(default)]
args: Option<serde_json::Value>,
#[serde(default)]
id: Option<String>,
}
#[derive(Deserialize)]
struct GoogleUsageMetadata {
#[serde(default, rename = "promptTokenCount")]
prompt_token_count: Option<u64>,
#[serde(default, rename = "candidatesTokenCount")]
candidates_token_count: Option<u64>,
#[serde(default, rename = "totalTokenCount")]
total_token_count: Option<u64>,
#[serde(default, rename = "cachedContentTokenCount")]
cached_content_token_count: Option<u64>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_google_request() {
let config = StreamConfig {
model: "gemini-2.0-flash".into(),
system_prompt: "Be helpful".into(),
messages: vec![Message::user("Hello")],
tools: vec![],
thinking_level: ThinkingLevel::Off,
api_key: "test".into(),
max_tokens: Some(1024),
temperature: Some(0.7),
model_config: None,
cache_config: CacheConfig::default(),
};
let body = build_request_body(&config);
assert!(body["contents"].is_array());
assert_eq!(body["contents"][0]["role"], "user");
assert!(body["systemInstruction"].is_object());
assert_eq!(body["generationConfig"]["maxOutputTokens"], 1024);
let temp = body["generationConfig"]["temperature"].as_f64().unwrap();
assert!((temp - 0.7).abs() < 0.01);
}
#[test]
fn test_content_to_google_parts_text() {
let content = vec![Content::Text {
text: "hello".into(),
}];
let parts = content_to_google_parts(&content);
assert_eq!(parts.len(), 1);
assert_eq!(parts[0]["text"], "hello");
}
#[test]
fn test_content_to_google_parts_filters_empty_text() {
let content = vec![
Content::Text { text: "".into() },
Content::Text {
text: "hello".into(),
},
Content::Text { text: "".into() },
];
let parts = content_to_google_parts(&content);
assert_eq!(parts.len(), 1);
assert_eq!(parts[0]["text"], "hello");
}
#[test]
fn test_content_to_google_parts_tool_call() {
let content = vec![Content::ToolCall {
id: "tc-1".into(),
name: "bash".into(),
arguments: serde_json::json!({"command": "ls"}),
provider_metadata: None,
}];
let parts = content_to_google_parts(&content);
assert_eq!(parts[0]["functionCall"]["name"], "bash");
}
#[test]
fn test_parse_chunk_with_function_call_and_thought_signature() {
let data = r#"{"candidates": [{"content": {"parts": [{"functionCall": {"name": "bash", "args": {"command": "echo hi"}, "id": "abc123"}, "thoughtSignature": "SIG_DATA"}], "role": "model"}, "finishReason": "STOP", "index": 0}], "usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 5, "totalTokenCount": 15}}"#;
let chunk: GoogleChunk = serde_json::from_str(data).unwrap();
let candidates = chunk.candidates.unwrap();
assert_eq!(candidates.len(), 1);
let parts = &candidates[0].content.as_ref().unwrap().parts;
assert_eq!(parts.len(), 1);
let fc = parts[0].function_call.as_ref().unwrap();
assert_eq!(fc.name, "bash");
assert_eq!(fc.id.as_deref(), Some("abc123"));
assert_eq!(fc.args.as_ref().unwrap()["command"], "echo hi");
assert_eq!(parts[0].thought_signature.as_deref(), Some("SIG_DATA"));
}
#[test]
fn test_parse_chunk_with_empty_text() {
let data = r#"{"candidates": [{"content": {"parts": [{"text": ""}], "role": "model"}, "index": 0}]}"#;
let chunk: GoogleChunk = serde_json::from_str(data).unwrap();
let candidates = chunk.candidates.unwrap();
let parts = &candidates[0].content.as_ref().unwrap().parts;
assert_eq!(parts[0].text.as_deref(), Some(""));
}
#[test]
fn test_parse_chunk_with_crlf_sse() {
let raw = "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"Blue\"}], \"role\": \"model\"}, \"finishReason\": \"STOP\", \"index\": 0}]}\r\n\r\n";
let pos = raw.find("\n\n").or_else(|| raw.find("\r\n\r\n"));
assert!(pos.is_some(), "Should find separator");
let sep_start = pos.unwrap();
let is_crlf = raw[sep_start..].starts_with("\r\n\r\n");
assert!(is_crlf, "Should detect \\r\\n\\r\\n separator");
let event_str = &raw[..sep_start];
let data = event_str
.lines()
.map(|l| l.trim_end_matches('\r'))
.find(|l| l.starts_with("data: "))
.map(|l| &l[6..])
.unwrap();
let chunk: GoogleChunk = serde_json::from_str(data).unwrap();
let candidates = chunk.candidates.unwrap();
let text = &candidates[0].content.as_ref().unwrap().parts[0].text;
assert_eq!(text.as_deref(), Some("Blue"));
}
#[test]
fn test_thought_signature_round_trip() {
let content = vec![Content::ToolCall {
id: "abc123".into(),
name: "bash".into(),
arguments: serde_json::json!({"command": "echo hi"}),
provider_metadata: Some(serde_json::json!({"thought_signature": "SIG_DATA"})),
}];
let parts = content_to_google_parts(&content);
assert_eq!(parts.len(), 1);
assert_eq!(parts[0]["functionCall"]["name"], "bash");
assert_eq!(parts[0]["functionCall"]["id"], "abc123");
assert_eq!(parts[0]["functionCall"]["args"]["command"], "echo hi");
assert_eq!(parts[0]["thoughtSignature"], "SIG_DATA");
}
#[test]
fn test_tool_call_without_thought_signature() {
let content = vec![Content::ToolCall {
id: "google-fc-0".into(),
name: "bash".into(),
arguments: serde_json::json!({"command": "ls"}),
provider_metadata: None,
}];
let parts = content_to_google_parts(&content);
assert!(parts[0]["functionCall"].get("id").is_none());
assert!(parts[0].get("thoughtSignature").is_none());
}
#[test]
fn test_function_response_includes_id() {
let config = StreamConfig {
model: "gemini-2.5-flash".into(),
system_prompt: "".into(),
messages: vec![
Message::Assistant {
content: vec![Content::ToolCall {
id: "abc123".into(),
name: "bash".into(),
arguments: serde_json::json!({"command": "echo hi"}),
provider_metadata: None,
}],
stop_reason: StopReason::ToolUse,
model: "test".into(),
provider: "test".into(),
usage: Usage::default(),
timestamp: 0,
error_message: None,
},
Message::ToolResult {
tool_call_id: "abc123".into(),
tool_name: "bash".into(),
content: vec![Content::Text { text: "hi".into() }],
is_error: false,
timestamp: 0,
},
],
tools: vec![],
thinking_level: ThinkingLevel::Off,
api_key: "test".into(),
max_tokens: None,
temperature: None,
model_config: None,
cache_config: CacheConfig::default(),
};
let body = build_request_body(&config);
let msgs = body["contents"].as_array().unwrap();
let tool_result = &msgs[1]["parts"][0]["functionResponse"];
assert_eq!(tool_result["name"], "bash");
assert_eq!(tool_result["id"], "abc123");
assert_eq!(tool_result["response"]["result"], "hi");
}
#[test]
fn test_function_response_synthetic_id_omitted() {
let config = StreamConfig {
model: "gemini-2.5-flash".into(),
system_prompt: "".into(),
messages: vec![
Message::Assistant {
content: vec![Content::ToolCall {
id: "google-fc-0".into(),
name: "bash".into(),
arguments: serde_json::json!({"command": "ls"}),
provider_metadata: None,
}],
stop_reason: StopReason::ToolUse,
model: "test".into(),
provider: "test".into(),
usage: Usage::default(),
timestamp: 0,
error_message: None,
},
Message::ToolResult {
tool_call_id: "google-fc-0".into(),
tool_name: "bash".into(),
content: vec![Content::Text {
text: "output".into(),
}],
is_error: false,
timestamp: 0,
},
],
tools: vec![],
thinking_level: ThinkingLevel::Off,
api_key: "test".into(),
max_tokens: None,
temperature: None,
model_config: None,
cache_config: CacheConfig::default(),
};
let body = build_request_body(&config);
let msgs = body["contents"].as_array().unwrap();
let tool_result = &msgs[1]["parts"][0]["functionResponse"];
assert!(
tool_result.get("id").is_none(),
"Synthetic ID should not be included"
);
}
}