use crate::error::LlmError;
use crate::stream::{ChatStream, ChatStreamEvent};
use crate::types::{ChatRequest, ChatResponse, FinishReason, MessageContent, ResponseMetadata};
use crate::utils::streaming::{SseEventConverter, StreamFactory};
use eventsource_stream::Event;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
use super::config::XaiConfig;
use super::types::*;
use super::utils::*;
#[derive(Clone)]
pub struct XaiEventConverter {
#[allow(dead_code)]
config: XaiConfig,
stream_started: Arc<Mutex<bool>>,
}
impl XaiEventConverter {
pub fn new(config: XaiConfig) -> Self {
Self {
config,
stream_started: Arc::new(Mutex::new(false)),
}
}
async fn convert_xai_event_async(&self, event: XaiStreamChunk) -> Vec<ChatStreamEvent> {
use crate::utils::streaming::EventBuilder;
let mut builder = EventBuilder::new();
if self.needs_stream_start().await {
let metadata = self.create_stream_start_metadata(&event);
builder = builder.add_stream_start(metadata);
}
if let Some(content) = self.extract_content(&event) {
builder = builder.add_content_delta(content, self.extract_choice_index(&event));
}
if let Some(thinking) = self.extract_thinking(&event) {
builder = builder.add_thinking_delta(thinking);
}
builder.build()
}
async fn needs_stream_start(&self) -> bool {
let mut started = self.stream_started.lock().await;
if !*started {
*started = true;
true
} else {
false
}
}
fn extract_content(&self, event: &XaiStreamChunk) -> Option<String> {
event
.choices
.first()?
.delta
.content
.as_ref()
.filter(|content| !content.is_empty())
.cloned()
}
fn extract_thinking(&self, event: &XaiStreamChunk) -> Option<String> {
event
.choices
.first()?
.delta
.reasoning_content
.as_ref()
.filter(|thinking| !thinking.is_empty())
.cloned()
}
fn extract_choice_index(&self, event: &XaiStreamChunk) -> Option<usize> {
Some(event.choices.first()?.index as usize)
}
fn create_stream_start_metadata(&self, event: &XaiStreamChunk) -> ResponseMetadata {
ResponseMetadata {
id: Some(event.id.clone()),
model: Some(event.model.clone()),
created: Some(
chrono::DateTime::from_timestamp(event.created as i64, 0)
.unwrap_or_else(chrono::Utc::now),
),
provider: "xai".to_string(),
request_id: None,
}
}
}
impl SseEventConverter for XaiEventConverter {
fn convert_event(
&self,
event: Event,
) -> Pin<Box<dyn Future<Output = Vec<Result<ChatStreamEvent, LlmError>>> + Send + Sync + '_>>
{
Box::pin(async move {
match serde_json::from_str::<XaiStreamChunk>(&event.data) {
Ok(xai_event) => self
.convert_xai_event_async(xai_event)
.await
.into_iter()
.map(Ok)
.collect(),
Err(e) => {
vec![Err(LlmError::ParseError(format!(
"Failed to parse xAI event: {e}"
)))]
}
}
})
}
fn handle_stream_end(&self) -> Option<Result<ChatStreamEvent, LlmError>> {
let response = ChatResponse {
id: None,
model: None,
content: MessageContent::Text("".to_string()),
usage: None,
finish_reason: Some(FinishReason::Stop),
tool_calls: None,
thinking: None,
metadata: HashMap::new(),
};
Some(Ok(ChatStreamEvent::StreamEnd { response }))
}
}
#[derive(Clone)]
pub struct XaiStreaming {
config: XaiConfig,
http_client: reqwest::Client,
}
impl XaiStreaming {
pub const fn new(config: XaiConfig, http_client: reqwest::Client) -> Self {
Self {
config,
http_client,
}
}
pub async fn create_chat_stream(self, request: ChatRequest) -> Result<ChatStream, LlmError> {
let url = format!("{}/chat/completions", self.config.base_url);
let chat_capability = super::chat::XaiChatCapability::new(
self.config.api_key.clone(),
self.config.base_url.clone(),
self.http_client.clone(),
self.config.http_config.clone(),
self.config.common_params.clone(),
);
let mut request_body = chat_capability.build_chat_request_body(&request)?;
request_body["stream"] = serde_json::Value::Bool(true);
let headers = build_headers(&self.config.api_key, &self.config.http_config.headers)?;
let request_builder = self
.http_client
.post(&url)
.headers(headers)
.json(&request_body);
let converter = XaiEventConverter::new(self.config);
StreamFactory::create_eventsource_stream(request_builder, converter).await
}
}