use crate::error::LlmError;
use crate::providers::gemini::types::GeminiConfig;
use crate::stream::{ChatStream, ChatStreamEvent};
use crate::types::{ChatResponse, FinishReason, MessageContent, ResponseMetadata};
use crate::utils::streaming::{SseEventConverter, StreamFactory};
use serde::Deserialize;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone, Deserialize)]
struct GeminiStreamResponse {
candidates: Option<Vec<GeminiCandidate>>,
#[serde(rename = "usageMetadata")]
#[allow(dead_code)]
usage_metadata: Option<GeminiUsageMetadata>,
}
#[derive(Debug, Clone, Deserialize)]
struct GeminiCandidate {
content: Option<GeminiContent>,
#[serde(rename = "finishReason")]
finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct GeminiContent {
parts: Option<Vec<GeminiPart>>,
role: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct GeminiPart {
text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[allow(dead_code)]
thought: Option<bool>,
}
#[derive(Debug, Clone, Deserialize)]
struct GeminiUsageMetadata {
#[serde(rename = "promptTokenCount")]
#[allow(dead_code)]
prompt_token_count: Option<u32>,
#[serde(rename = "candidatesTokenCount")]
#[allow(dead_code)]
candidates_token_count: Option<u32>,
#[serde(rename = "totalTokenCount")]
#[allow(dead_code)]
total_token_count: Option<u32>,
#[serde(rename = "thoughtsTokenCount")]
#[allow(dead_code)]
thoughts_token_count: Option<u32>,
}
#[derive(Clone)]
pub struct GeminiEventConverter {
config: GeminiConfig,
stream_started: Arc<Mutex<bool>>,
}
impl GeminiEventConverter {
pub fn new(config: GeminiConfig) -> Self {
Self {
config,
stream_started: Arc::new(Mutex::new(false)),
}
}
async fn convert_gemini_response_async(
&self,
response: GeminiStreamResponse,
) -> Vec<ChatStreamEvent> {
use crate::utils::streaming::EventBuilder;
let mut builder = EventBuilder::new();
if self.needs_stream_start().await {
builder = builder.add_stream_start(self.create_stream_start_metadata());
}
if let Some(content) = self.extract_content(&response) {
builder = builder.add_content_delta(content, None);
}
if let Some(thinking) = self.extract_thinking(&response) {
builder = builder.add_thinking_delta(thinking);
}
if let Some(end_response) = self.extract_completion(&response) {
builder = builder.add_stream_end(end_response);
}
builder.build()
}
async fn needs_stream_start(&self) -> bool {
let mut started = self.stream_started.lock().await;
if !*started {
*started = true;
true
} else {
false
}
}
fn extract_content(&self, response: &GeminiStreamResponse) -> Option<String> {
response
.candidates
.as_ref()?
.first()?
.content
.as_ref()?
.parts
.as_ref()?
.first()?
.text
.as_ref()
.filter(|text| !text.is_empty())
.cloned()
}
fn extract_thinking(&self, response: &GeminiStreamResponse) -> Option<String> {
response
.candidates
.as_ref()?
.first()?
.content
.as_ref()?
.parts
.as_ref()?
.iter()
.find_map(|part| {
if let Some(text) = &part.text {
if part.thought.unwrap_or(false) {
Some(text.clone())
} else {
None
}
} else {
None
}
})
}
fn extract_completion(&self, response: &GeminiStreamResponse) -> Option<ChatResponse> {
let candidate = response.candidates.as_ref()?.first()?;
if let Some(finish_reason) = &candidate.finish_reason {
let finish_reason = match finish_reason.as_str() {
"STOP" => FinishReason::Stop,
"MAX_TOKENS" => FinishReason::Length,
"SAFETY" => FinishReason::ContentFilter,
"RECITATION" => FinishReason::ContentFilter,
_ => FinishReason::Stop,
};
let response = ChatResponse {
id: None,
model: None,
content: MessageContent::Text("".to_string()),
usage: None,
finish_reason: Some(finish_reason),
tool_calls: None,
thinking: None,
metadata: std::collections::HashMap::new(),
};
Some(response)
} else {
None
}
}
fn create_stream_start_metadata(&self) -> ResponseMetadata {
ResponseMetadata {
id: None, model: Some(self.config.model.clone()), created: Some(chrono::Utc::now()),
provider: "gemini".to_string(),
request_id: None,
}
}
}
impl SseEventConverter for GeminiEventConverter {
fn convert_event(
&self,
event: eventsource_stream::Event,
) -> Pin<Box<dyn Future<Output = Vec<Result<ChatStreamEvent, LlmError>>> + Send + Sync + '_>>
{
Box::pin(async move {
if event.data.trim().is_empty() {
return vec![];
}
match serde_json::from_str::<GeminiStreamResponse>(&event.data) {
Ok(gemini_response) => self
.convert_gemini_response_async(gemini_response)
.await
.into_iter()
.map(Ok)
.collect(),
Err(e) => {
vec![Err(LlmError::ParseError(format!(
"Failed to parse Gemini SSE JSON: {e}"
)))]
}
}
})
}
}
#[derive(Debug, Clone)]
pub struct GeminiStreaming {
config: GeminiConfig,
http_client: reqwest::Client,
}
impl GeminiStreaming {
pub fn new(http_client: reqwest::Client) -> Self {
Self {
config: GeminiConfig::default(),
http_client,
}
}
pub async fn create_chat_stream(
self,
url: String,
api_key: String,
request: crate::providers::gemini::types::GenerateContentRequest,
) -> Result<ChatStream, LlmError> {
let response = self
.http_client
.post(&url)
.header("Content-Type", "application/json")
.header("x-goog-api-key", &api_key)
.json(&request)
.send()
.await
.map_err(|e| LlmError::HttpError(format!("Request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError {
code: status.as_u16(),
message: format!("Gemini API error {status}: {error_text}"),
details: None,
});
}
let mut config = self.config;
config.api_key = api_key.clone();
let converter = GeminiEventConverter::new(config);
StreamFactory::create_eventsource_stream(
self.http_client
.post(&url)
.header("Content-Type", "application/json")
.header("x-goog-api-key", &api_key)
.json(&request),
converter,
)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::gemini::types::GeminiConfig;
fn create_test_config() -> GeminiConfig {
GeminiConfig {
api_key: "test-key".to_string(),
base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
..Default::default()
}
}
#[tokio::test]
async fn test_gemini_streaming_conversion() {
let config = create_test_config();
let converter = GeminiEventConverter::new(config);
let json_data = r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}"#;
let event = eventsource_stream::Event {
event: "".to_string(),
data: json_data.to_string(),
id: "".to_string(),
retry: None,
};
let result = converter.convert_event(event).await;
assert!(!result.is_empty());
let content_event = result
.iter()
.find(|event| matches!(event, Ok(ChatStreamEvent::ContentDelta { .. })));
if let Some(Ok(ChatStreamEvent::ContentDelta { delta, .. })) = content_event {
assert_eq!(delta, "Hello");
} else {
panic!("Expected ContentDelta event in results: {:?}", result);
}
}
#[tokio::test]
async fn test_gemini_finish_reason() {
let config = create_test_config();
let converter = GeminiEventConverter::new(config);
let json_data = r#"{"candidates":[{"finishReason":"STOP"}]}"#;
let event = eventsource_stream::Event {
event: "".to_string(),
data: json_data.to_string(),
id: "".to_string(),
retry: None,
};
let result = converter.convert_event(event).await;
assert!(!result.is_empty());
let stream_end_event = result
.iter()
.find(|event| matches!(event, Ok(ChatStreamEvent::StreamEnd { .. })));
if let Some(Ok(ChatStreamEvent::StreamEnd { response })) = stream_end_event {
assert_eq!(response.finish_reason, Some(FinishReason::Stop));
} else {
panic!("Expected StreamEnd event in results: {:?}", result);
}
}
}