use crate::error::LlmError;
use crate::providers::openai::config::OpenAiConfig;
use crate::stream::{ChatStream, ChatStreamEvent};
use crate::types::{ChatResponse, FinishReason, MessageContent, ResponseMetadata, Usage};
use crate::utils::streaming::{SseEventConverter, StreamFactory};
use eventsource_stream::Event;
use serde::Deserialize;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone, Deserialize)]
struct OpenAiStreamEvent {
id: Option<String>,
model: Option<String>,
choices: Option<Vec<OpenAiStreamChoice>>,
usage: Option<OpenAiStreamUsage>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiStreamChoice {
index: Option<usize>,
delta: Option<OpenAiStreamDelta>,
#[allow(dead_code)]
finish_reason: Option<String>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct OpenAiStreamDelta {
role: Option<String>,
content: Option<String>,
tool_calls: Option<Vec<OpenAiToolCallDelta>>,
thinking: Option<String>,
}
impl<'de> serde::Deserialize<'de> for OpenAiStreamDelta {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value: serde_json::Value = serde_json::Value::deserialize(deserializer)?;
let thinking = extract_thinking_from_multiple_fields(&value);
let role = value.get("role").and_then(|v| v.as_str()).map(String::from);
let content = value
.get("content")
.and_then(|v| v.as_str())
.map(String::from);
let tool_calls = value
.get("tool_calls")
.and_then(|v| serde_json::from_value(v.clone()).ok());
Ok(OpenAiStreamDelta {
role,
content,
tool_calls,
thinking,
})
}
}
pub(crate) fn extract_thinking_from_multiple_fields(value: &serde_json::Value) -> Option<String> {
let field_names = ["reasoning_content", "thinking", "reasoning"];
for field_name in &field_names {
if let Some(thinking_value) = value
.get(field_name)
.and_then(|v| v.as_str())
.filter(|s| !s.trim().is_empty())
{
return Some(thinking_value.to_string());
}
}
None
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiToolCallDelta {
#[allow(dead_code)]
index: Option<usize>,
id: Option<String>,
function: Option<OpenAiFunctionCallDelta>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiFunctionCallDelta {
name: Option<String>,
arguments: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiStreamUsage {
prompt_tokens: Option<u32>,
completion_tokens: Option<u32>,
total_tokens: Option<u32>,
completion_tokens_details: Option<OpenAiCompletionTokensDetails>,
prompt_tokens_details: Option<OpenAiPromptTokensDetails>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiCompletionTokensDetails {
reasoning_tokens: Option<u32>,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAiPromptTokensDetails {
cached_tokens: Option<u32>,
}
#[derive(Clone)]
pub struct OpenAiEventConverter {
#[allow(dead_code)]
config: OpenAiConfig,
stream_started: Arc<Mutex<bool>>,
}
impl OpenAiEventConverter {
pub fn new(config: OpenAiConfig) -> Self {
Self {
config,
stream_started: Arc::new(Mutex::new(false)),
}
}
async fn convert_openai_event_async(&self, event: OpenAiStreamEvent) -> Vec<ChatStreamEvent> {
use crate::utils::streaming::EventBuilder;
let mut builder = EventBuilder::new();
if self.needs_stream_start().await {
let metadata = self.create_stream_start_metadata(&event);
builder = builder.add_stream_start(metadata);
}
if let Some(content) = self.extract_content(&event) {
builder = builder.add_content_delta(content, self.extract_choice_index(&event));
}
if let Some((id, name, args)) = self.extract_tool_call(&event) {
builder =
builder.add_tool_call_delta(id, name, args, self.extract_choice_index(&event));
}
if let Some(thinking) = self.extract_thinking(&event) {
builder = builder.add_thinking_delta(thinking);
}
if let Some(usage) = self.extract_usage(&event) {
builder = builder.add_usage_update(usage);
}
builder.build()
}
async fn needs_stream_start(&self) -> bool {
let mut started = self.stream_started.lock().await;
if !*started {
*started = true;
true
} else {
false
}
}
fn extract_content(&self, event: &OpenAiStreamEvent) -> Option<String> {
event
.choices
.as_ref()?
.first()?
.delta
.as_ref()?
.content
.as_ref()
.filter(|content| !content.is_empty())
.cloned()
}
fn extract_tool_call(
&self,
event: &OpenAiStreamEvent,
) -> Option<(String, Option<String>, Option<String>)> {
let choice = event.choices.as_ref()?.first()?;
let tool_call = choice.delta.as_ref()?.tool_calls.as_ref()?.first()?;
let id = tool_call.id.clone()?;
let function_name = tool_call.function.as_ref()?.name.clone();
let arguments = tool_call.function.as_ref()?.arguments.clone();
Some((id, function_name, arguments))
}
fn extract_thinking(&self, event: &OpenAiStreamEvent) -> Option<String> {
event
.choices
.as_ref()?
.first()?
.delta
.as_ref()?
.thinking
.as_ref()
.filter(|thinking| !thinking.is_empty())
.cloned()
}
fn extract_usage(&self, event: &OpenAiStreamEvent) -> Option<Usage> {
event.usage.as_ref().map(|usage| Usage {
prompt_tokens: usage.prompt_tokens.unwrap_or(0),
completion_tokens: usage.completion_tokens.unwrap_or(0),
total_tokens: usage.total_tokens.unwrap_or(0),
cached_tokens: usage
.prompt_tokens_details
.as_ref()
.and_then(|details| details.cached_tokens),
reasoning_tokens: usage
.completion_tokens_details
.as_ref()
.and_then(|details| details.reasoning_tokens),
})
}
fn extract_choice_index(&self, event: &OpenAiStreamEvent) -> Option<usize> {
event.choices.as_ref()?.first()?.index
}
fn create_stream_start_metadata(&self, event: &OpenAiStreamEvent) -> ResponseMetadata {
ResponseMetadata {
id: event.id.clone(),
model: event.model.clone(),
created: Some(chrono::Utc::now()),
provider: "openai".to_string(),
request_id: None, }
}
}
impl SseEventConverter for OpenAiEventConverter {
fn convert_event(
&self,
event: Event,
) -> Pin<Box<dyn Future<Output = Vec<Result<ChatStreamEvent, LlmError>>> + Send + Sync + '_>>
{
Box::pin(async move {
match serde_json::from_str::<OpenAiStreamEvent>(&event.data) {
Ok(openai_event) => self
.convert_openai_event_async(openai_event)
.await
.into_iter()
.map(Ok)
.collect(),
Err(e) => {
vec![Err(LlmError::ParseError(format!(
"Failed to parse OpenAI event: {e}"
)))]
}
}
})
}
fn handle_stream_end(&self) -> Option<Result<ChatStreamEvent, LlmError>> {
let response = ChatResponse {
id: None,
model: None,
content: MessageContent::Text("".to_string()),
usage: None,
finish_reason: Some(FinishReason::Stop),
tool_calls: None,
thinking: None,
metadata: HashMap::new(),
};
Some(Ok(ChatStreamEvent::StreamEnd { response }))
}
}
#[derive(Clone)]
pub struct OpenAiStreaming {
config: OpenAiConfig,
http_client: reqwest::Client,
}
impl OpenAiStreaming {
pub fn new(config: OpenAiConfig, http_client: reqwest::Client) -> Self {
Self {
config,
http_client,
}
}
pub async fn create_chat_stream(
self,
request: crate::types::ChatRequest,
) -> Result<ChatStream, LlmError> {
let url = format!("{}/chat/completions", self.config.base_url);
let chat_capability = super::chat::OpenAiChatCapability::new(
self.config.api_key.clone(),
self.config.base_url.clone(),
self.http_client.clone(),
self.config.organization.clone(),
self.config.project.clone(),
self.config.http_config.clone(),
self.config.common_params.clone(),
);
let mut request_body = chat_capability.build_chat_request_body(&request)?;
request_body["stream"] = serde_json::Value::Bool(true);
request_body["stream_options"] = serde_json::json!({
"include_usage": true
});
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in self.config.get_headers() {
let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
.map_err(|e| LlmError::HttpError(format!("Invalid header name: {e}")))?;
let header_value = reqwest::header::HeaderValue::from_str(&value)
.map_err(|e| LlmError::HttpError(format!("Invalid header value: {e}")))?;
headers.insert(header_name, header_value);
}
let request_builder = self
.http_client
.post(&url)
.headers(headers)
.json(&request_body);
let converter = OpenAiEventConverter::new(self.config);
StreamFactory::create_eventsource_stream(request_builder, converter).await
}
}