use crate::llm::provider::{LLMError, Result};
use crate::llm::types::LLMChunk;
use bamboo_domain::{FunctionCall, ToolCall};
use serde_json::Value;
#[derive(Default)]
pub struct GeminiStreamState {
next_tool_id: usize,
pub observed_thinking_signal: bool,
pub thinking_parts_count: usize,
pub thinking_text_chars: usize,
}
impl GeminiStreamState {
fn generate_tool_id(&mut self) -> String {
let id = format!("gemini_{}", self.next_tool_id);
self.next_tool_id += 1;
id
}
}
pub fn parse_gemini_sse_event(
state: &mut GeminiStreamState,
_event_type: &str,
data: &str,
) -> Result<Option<LLMChunk>> {
let data = data.trim();
if data.is_empty() {
return Ok(None);
}
if data == "[DONE]" {
return Ok(Some(LLMChunk::Done));
}
let value: Value = serde_json::from_str(data).map_err(|e| {
LLMError::Stream(format!("Failed to parse Gemini SSE data: {}: {}", e, data))
})?;
if let Some(error) = value.get("error") {
let error_msg = error
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown Gemini API error");
return Err(LLMError::Api(error_msg.to_string()));
}
let candidates = value
.get("candidates")
.and_then(|c| c.as_array())
.ok_or_else(|| {
LLMError::Stream(format!("Missing candidates in Gemini response: {}", data))
})?;
if candidates.is_empty() {
return Ok(None);
}
let candidate = &candidates[0];
if let Some(finish_reason) = candidate.get("finishReason").and_then(|f| f.as_str()) {
if finish_reason == "STOP" || finish_reason == "MAX_TOKENS" {
}
}
let content = match candidate.get("content") {
Some(c) => c,
None => return Ok(None),
};
let parts = match content.get("parts").and_then(|p| p.as_array()) {
Some(p) => p,
None => return Ok(None),
};
if parts.is_empty() {
return Ok(None);
}
let part = &parts[0];
let is_thinking_part = part
.get("thought")
.and_then(|value| value.as_bool())
.unwrap_or(false)
|| part.get("thoughtSignature").is_some()
|| part.get("thinking").is_some();
if is_thinking_part {
state.observed_thinking_signal = true;
state.thinking_parts_count = state.thinking_parts_count.saturating_add(1);
let text_len = part
.get("text")
.and_then(|value| value.as_str())
.map(str::len)
.unwrap_or(0);
state.thinking_text_chars = state.thinking_text_chars.saturating_add(text_len);
}
if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
if !text.is_empty() {
if is_thinking_part {
return Ok(Some(LLMChunk::ReasoningToken(text.to_string())));
}
return Ok(Some(LLMChunk::Token(text.to_string())));
}
return Ok(None);
}
if let Some(function_call) = part.get("functionCall") {
let name = function_call
.get("name")
.and_then(|n| n.as_str())
.ok_or_else(|| {
LLMError::Stream(format!(
"Missing function name in Gemini response: {}",
data
))
})?;
let args = function_call
.get("args")
.cloned()
.unwrap_or_else(|| Value::Object(serde_json::Map::new()));
let args_str = serde_json::to_string(&args)
.map_err(|e| LLMError::Stream(format!("Failed to serialize function args: {}", e)))?;
let tool_id = state.generate_tool_id();
return Ok(Some(LLMChunk::ToolCalls(vec![ToolCall {
id: tool_id,
tool_type: "function".to_string(),
function: FunctionCall {
name: name.to_string(),
arguments: args_str,
},
}])));
}
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_text_chunk() {
let mut state = GeminiStreamState::default();
let data = r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"}}]}"#;
let chunk = parse_gemini_sse_event(&mut state, "", data)
.unwrap()
.expect("chunk");
match chunk {
LLMChunk::Token(text) => assert_eq!(text, "Hello"),
other => panic!("expected LLMChunk::Token, got {:?}", other),
}
}
#[test]
fn parse_thought_text_chunk_emits_reasoning_token() {
let mut state = GeminiStreamState::default();
let data = r#"{"candidates":[{"content":{"parts":[{"thought":true,"text":"Thinking..."}],"role":"model"}}]}"#;
let chunk = parse_gemini_sse_event(&mut state, "", data)
.unwrap()
.expect("chunk");
match chunk {
LLMChunk::ReasoningToken(text) => assert_eq!(text, "Thinking..."),
other => panic!("expected LLMChunk::ReasoningToken, got {:?}", other),
}
assert!(state.observed_thinking_signal);
assert_eq!(state.thinking_parts_count, 1);
}
#[test]
fn parse_empty_data_returns_none() {
let mut state = GeminiStreamState::default();
let chunk = parse_gemini_sse_event(&mut state, "", "").unwrap();
assert!(chunk.is_none());
}
#[test]
fn parse_done_signal() {
let mut state = GeminiStreamState::default();
let chunk = parse_gemini_sse_event(&mut state, "", "[DONE]")
.unwrap()
.expect("chunk");
match chunk {
LLMChunk::Done => {}
other => panic!("expected LLMChunk::Done, got {:?}", other),
}
}
#[test]
fn parse_function_call() {
let mut state = GeminiStreamState::default();
let data = r#"{"candidates":[{"content":{"parts":[{"functionCall":{"name":"search","args":{"q":"test"}}}],"role":"model"}}]}"#;
let chunk = parse_gemini_sse_event(&mut state, "", data)
.unwrap()
.expect("chunk");
match chunk {
LLMChunk::ToolCalls(calls) => {
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "search");
assert_eq!(calls[0].function.arguments, r#"{"q":"test"}"#);
assert!(calls[0].id.starts_with("gemini_"));
}
other => panic!("expected LLMChunk::ToolCalls, got {:?}", other),
}
}
#[test]
fn parse_empty_candidates_returns_none() {
let mut state = GeminiStreamState::default();
let data = r#"{"candidates":[]}"#;
let chunk = parse_gemini_sse_event(&mut state, "", data).unwrap();
assert!(chunk.is_none());
}
#[test]
fn parse_missing_content_returns_none() {
let mut state = GeminiStreamState::default();
let data = r#"{"candidates":[{"finishReason":"STOP"}]}"#;
let chunk = parse_gemini_sse_event(&mut state, "", data).unwrap();
assert!(chunk.is_none());
}
#[test]
fn parse_error_response() {
let mut state = GeminiStreamState::default();
let data = r#"{"error":{"message":"API key invalid","code":401}}"#;
let result = parse_gemini_sse_event(&mut state, "", data);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("API key invalid"));
}
#[test]
fn parse_invalid_json() {
let mut state = GeminiStreamState::default();
let data = "{invalid json}";
let result = parse_gemini_sse_event(&mut state, "", data);
assert!(result.is_err());
}
#[test]
fn parse_multipart_text_accumulates() {
let mut state = GeminiStreamState::default();
let data1 = r#"{"candidates":[{"content":{"parts":[{"text":"Hello "}],"role":"model"}}]}"#;
let chunk1 = parse_gemini_sse_event(&mut state, "", data1)
.unwrap()
.expect("chunk1");
match chunk1 {
LLMChunk::Token(text) => assert_eq!(text, "Hello "),
other => panic!("expected LLMChunk::Token, got {:?}", other),
}
let data2 = r#"{"candidates":[{"content":{"parts":[{"text":"world!"}],"role":"model"}}]}"#;
let chunk2 = parse_gemini_sse_event(&mut state, "", data2)
.unwrap()
.expect("chunk2");
match chunk2 {
LLMChunk::Token(text) => assert_eq!(text, "world!"),
other => panic!("expected LLMChunk::Token, got {:?}", other),
}
}
#[test]
fn parse_function_call_with_empty_args() {
let mut state = GeminiStreamState::default();
let data = r#"{"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_time","args":{}}}],"role":"model"}}]}"#;
let chunk = parse_gemini_sse_event(&mut state, "", data)
.unwrap()
.expect("chunk");
match chunk {
LLMChunk::ToolCalls(calls) => {
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_time");
assert_eq!(calls[0].function.arguments, "{}");
}
other => panic!("expected LLMChunk::ToolCalls, got {:?}", other),
}
}
#[test]
fn parse_whitespace_data_is_trimmed() {
let mut state = GeminiStreamState::default();
let data = " [DONE] ";
let chunk = parse_gemini_sse_event(&mut state, "", data)
.unwrap()
.expect("chunk");
match chunk {
LLMChunk::Done => {}
other => panic!("expected LLMChunk::Done, got {:?}", other),
}
}
#[test]
fn state_generates_unique_tool_ids() {
let mut state = GeminiStreamState::default();
let id1 = state.generate_tool_id();
let id2 = state.generate_tool_id();
let id3 = state.generate_tool_id();
assert_ne!(id1, id2);
assert_ne!(id2, id3);
assert!(id1.starts_with("gemini_"));
assert!(id2.starts_with("gemini_"));
assert!(id3.starts_with("gemini_"));
}
#[test]
fn multiple_function_calls_get_unique_ids() {
let mut state = GeminiStreamState::default();
let data1 = r#"{"candidates":[{"content":{"parts":[{"functionCall":{"name":"search","args":{}}}],"role":"model"}}]}"#;
let chunk1 = parse_gemini_sse_event(&mut state, "", data1)
.unwrap()
.expect("chunk1");
let data2 = r#"{"candidates":[{"content":{"parts":[{"functionCall":{"name":"read","args":{}}}],"role":"model"}}]}"#;
let chunk2 = parse_gemini_sse_event(&mut state, "", data2)
.unwrap()
.expect("chunk2");
let id1 = match chunk1 {
LLMChunk::ToolCalls(calls) => calls[0].id.clone(),
other => panic!("expected LLMChunk::ToolCalls, got {:?}", other),
};
let id2 = match chunk2 {
LLMChunk::ToolCalls(calls) => calls[0].id.clone(),
other => panic!("expected LLMChunk::ToolCalls, got {:?}", other),
};
assert_ne!(id1, id2);
}
}