use crate::error::LlmError;
use crate::stream::{ChatStream, ChatStreamEvent};
use crate::types::{ChatRequest, ResponseMetadata, Usage};
use crate::types::{ChatResponse, FinishReason, MessageContent};
use crate::utils::streaming::{SseEventConverter, StreamFactory};
use eventsource_stream::Event;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
use super::config::GroqConfig;
use super::types::*;
use super::utils::*;
#[derive(Clone)]
pub struct GroqEventConverter {
#[allow(dead_code)]
config: GroqConfig,
stream_started: Arc<Mutex<bool>>,
}
impl GroqEventConverter {
pub fn new(config: GroqConfig) -> Self {
Self {
config,
stream_started: Arc::new(Mutex::new(false)),
}
}
async fn convert_groq_response_async(
&self,
response: GroqChatStreamChunk,
) -> Vec<ChatStreamEvent> {
use crate::utils::streaming::EventBuilder;
let mut builder = EventBuilder::new();
if self.needs_stream_start().await {
let metadata = self.create_stream_start_metadata(&response);
builder = builder.add_stream_start(metadata);
}
if let Some(content) = self.extract_content(&response) {
builder = builder.add_content_delta(content, self.extract_choice_index(&response));
}
if let Some(usage) = self.extract_usage(&response) {
builder = builder.add_usage_update(usage);
}
builder.build()
}
async fn needs_stream_start(&self) -> bool {
let mut started = self.stream_started.lock().await;
if !*started {
*started = true;
true
} else {
false
}
}
fn extract_content(&self, response: &GroqChatStreamChunk) -> Option<String> {
response
.choices
.first()?
.delta
.content
.as_ref()
.filter(|content| !content.is_empty())
.cloned()
}
fn extract_choice_index(&self, response: &GroqChatStreamChunk) -> Option<usize> {
Some(response.choices.first()?.index as usize)
}
fn extract_usage(&self, response: &GroqChatStreamChunk) -> Option<Usage> {
response.usage.as_ref().map(|usage| Usage {
prompt_tokens: usage.prompt_tokens.unwrap_or(0),
completion_tokens: usage.completion_tokens.unwrap_or(0),
total_tokens: usage.total_tokens.unwrap_or(0),
cached_tokens: None,
reasoning_tokens: None,
})
}
fn create_stream_start_metadata(&self, response: &GroqChatStreamChunk) -> ResponseMetadata {
ResponseMetadata {
id: Some(response.id.clone()),
model: Some(response.model.clone()),
created: Some(chrono::Utc::now()),
provider: "groq".to_string(),
request_id: None,
}
}
}
impl SseEventConverter for GroqEventConverter {
fn convert_event(
&self,
event: Event,
) -> Pin<Box<dyn Future<Output = Vec<Result<ChatStreamEvent, LlmError>>> + Send + Sync + '_>>
{
Box::pin(async move {
match serde_json::from_str::<GroqChatStreamChunk>(&event.data) {
Ok(groq_response) => self
.convert_groq_response_async(groq_response)
.await
.into_iter()
.map(Ok)
.collect(),
Err(e) => {
vec![Err(LlmError::ParseError(format!(
"Failed to parse Groq event: {e}"
)))]
}
}
})
}
fn handle_stream_end(&self) -> Option<Result<ChatStreamEvent, LlmError>> {
let response = ChatResponse {
id: None,
model: None,
content: MessageContent::Text("".to_string()),
usage: None,
finish_reason: Some(FinishReason::Stop),
tool_calls: None,
thinking: None,
metadata: std::collections::HashMap::new(),
};
Some(Ok(ChatStreamEvent::StreamEnd { response }))
}
}
#[derive(Clone)]
pub struct GroqStreaming {
config: GroqConfig,
http_client: reqwest::Client,
}
impl GroqStreaming {
pub fn new(config: GroqConfig, http_client: reqwest::Client) -> Self {
Self {
config,
http_client,
}
}
pub async fn create_chat_stream(&self, request: ChatRequest) -> Result<ChatStream, LlmError> {
let url = format!("{}/chat/completions", self.config.base_url);
let chat_capability = super::chat::GroqChatCapability::new(
self.config.api_key.clone(),
self.config.base_url.clone(),
self.http_client.clone(),
self.config.http_config.clone(),
self.config.common_params.clone(),
);
let mut request_body = chat_capability.build_chat_request_body(&request)?;
request_body["stream"] = serde_json::Value::Bool(true);
request_body["stream_options"] = serde_json::json!({
"include_usage": true
});
validate_groq_params(&request_body)?;
let headers = build_headers(&self.config.api_key, &self.config.http_config.headers)?;
let request_builder = self
.http_client
.post(&url)
.headers(headers)
.json(&request_body);
let converter = GroqEventConverter::new(self.config.clone());
StreamFactory::create_eventsource_stream(request_builder, converter).await
}
}