use crate::adapter::AdapterKind;
use crate::adapter::adapters::support::{StreamerCapturedData, StreamerOptions};
use crate::adapter::inter_stream::{InterStreamEnd, InterStreamEvent};
use crate::adapter::openai::OpenAIAdapter;
use crate::chat::{ChatOptionsSet, ToolCall};
use crate::webc::{Event, EventSourceStream};
use crate::{Error, ModelIden, Result};
use serde_json::Value;
use std::pin::Pin;
use std::task::{Context, Poll};
use value_ext::JsonValueExt;
pub struct OpenAIStreamer {
inner: EventSourceStream,
options: StreamerOptions,
done: bool,
captured_data: StreamerCapturedData,
}
impl OpenAIStreamer {
pub fn new(inner: EventSourceStream, model_iden: ModelIden, options_set: ChatOptionsSet<'_, '_>) -> Self {
Self {
inner,
done: false,
options: StreamerOptions::new(model_iden, options_set),
captured_data: Default::default(),
}
}
fn capture_tool_call(&mut self, index: usize, call_id: String, fn_name: String, arguments: String) -> ToolCall {
let tool_call = ToolCall {
call_id: call_id.clone(),
fn_name: fn_name.clone(),
fn_arguments: Value::String(arguments.clone()),
thought_signatures: None,
};
if !self.options.capture_tool_calls {
return tool_call;
}
let calls = self.captured_data.tool_calls.get_or_insert_with(Vec::new);
if let Some(existing_call) = calls.get_mut(index) {
if let Some(existing_args) = existing_call.fn_arguments.as_str() {
let accumulated = format!("{existing_args}{arguments}");
existing_call.fn_arguments = Value::String(accumulated);
}
if !fn_name.is_empty() {
existing_call.call_id = call_id;
existing_call.fn_name = fn_name;
}
existing_call.clone()
} else {
calls.resize(index + 1, tool_call.clone());
tool_call
}
}
}
impl futures::Stream for OpenAIStreamer {
type Item = Result<InterStreamEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}
while let Poll::Ready(event) = Pin::new(&mut self.inner).poll_next(cx) {
match event {
Some(Ok(Event::Open)) => return Poll::Ready(Some(Ok(InterStreamEvent::Start))),
Some(Ok(Event::Message(message))) => {
if message.data == "[DONE]" {
self.done = true;
let captured_usage = if self.options.capture_usage {
self.captured_data.usage.take()
} else {
None
};
let captured_tool_calls = if let Some(tools_calls) = self.captured_data.tool_calls.take() {
let tools_calls: Vec<ToolCall> = tools_calls
.into_iter()
.map(|tool_call| {
let ToolCall {
call_id,
fn_name,
fn_arguments,
..
} = tool_call;
let fn_arguments = match fn_arguments {
Value::String(fn_arguments_string) => {
match serde_json::from_str::<Value>(&fn_arguments_string) {
Ok(fn_arguments) => fn_arguments,
Err(_) => Value::String(fn_arguments_string),
}
}
_ => fn_arguments,
};
ToolCall {
call_id,
fn_name,
fn_arguments,
thought_signatures: None,
}
})
.collect();
Some(tools_calls)
} else {
None
};
let inter_stream_end = InterStreamEnd {
captured_usage,
captured_text_content: self.captured_data.content.take(),
captured_reasoning_content: self.captured_data.reasoning_content.take(),
captured_tool_calls,
captured_thought_signatures: None,
};
return Poll::Ready(Some(Ok(InterStreamEvent::End(inter_stream_end))));
}
let mut message_data: Value =
serde_json::from_str(&message.data).map_err(|serde_error| Error::StreamParse {
model_iden: self.options.model_iden.clone(),
serde_error,
})?;
let first_choice: Option<Value> = message_data.x_take("/choices/0").ok();
let adapter_kind = self.options.model_iden.adapter_kind;
if let Some(mut first_choice) = first_choice {
if let Ok(_finish_reason) = first_choice.x_take::<String>("finish_reason") {
if let Ok(delta_tool_calls) = first_choice.x_take::<Value>("/delta/tool_calls")
&& delta_tool_calls != Value::Null
{
if let Some(delta_tool_calls) = delta_tool_calls.as_array() {
for tool_call_obj_val in delta_tool_calls {
let mut tool_call_obj = tool_call_obj_val.clone();
if let (Ok(index), Ok(mut function)) = (
tool_call_obj.x_take::<u32>("index"),
tool_call_obj.x_take::<Value>("function"),
) {
let call_id = tool_call_obj
.x_take::<String>("id")
.unwrap_or_else(|_| format!("call_{index}"));
let fn_name = function.x_take::<String>("name").unwrap_or_default();
let arguments = function.x_take::<String>("arguments").unwrap_or_default();
self.capture_tool_call(index as usize, call_id, fn_name, arguments);
}
}
}
}
if self.options.capture_usage {
match adapter_kind {
AdapterKind::Groq => {
let usage = message_data
.x_take("/x_groq/usage")
.map(|v| OpenAIAdapter::into_usage(adapter_kind, v))
.unwrap_or_default(); self.captured_data.usage = Some(usage)
}
AdapterKind::DeepSeek
| AdapterKind::Zai
| AdapterKind::Fireworks
| AdapterKind::Together => {
let usage = message_data
.x_take("usage")
.map(|v| OpenAIAdapter::into_usage(adapter_kind, v))
.unwrap_or_default();
self.captured_data.usage = Some(usage)
}
_ => (), }
}
continue;
}
else if let Ok(delta_tool_calls) = first_choice.x_take::<Value>("/delta/tool_calls")
&& delta_tool_calls != Value::Null
{
if let Some(delta_tool_calls) = delta_tool_calls.as_array()
&& let Some(tool_call_obj_val) = delta_tool_calls.get(0)
{
let mut tool_call_obj = tool_call_obj_val.clone();
if let (Ok(index), Ok(mut function)) = (
tool_call_obj.x_take::<u32>("index"),
tool_call_obj.x_take::<Value>("function"),
) {
let call_id = tool_call_obj
.x_take::<String>("id")
.unwrap_or_else(|_| format!("call_{index}"));
let fn_name = function.x_take::<String>("name").unwrap_or_default();
let arguments = function.x_take::<String>("arguments").unwrap_or_default();
let tool_call = self.capture_tool_call(index as usize, call_id, fn_name, arguments);
return Poll::Ready(Some(Ok(InterStreamEvent::ToolCallChunk(tool_call))));
}
}
continue;
}
else {
let content = first_choice.x_take::<Option<String>>("/delta/content").ok().flatten();
let reasoning_content = first_choice
.x_take::<Option<String>>("/delta/reasoning_content")
.ok()
.flatten()
.or_else(|| first_choice.x_take::<Option<String>>("/delta/reasoning").ok().flatten());
if let Some(content) = content
&& !content.is_empty()
{
if self.options.capture_content {
match self.captured_data.content {
Some(ref mut c) => c.push_str(&content),
None => self.captured_data.content = Some(content.clone()),
}
}
return Poll::Ready(Some(Ok(InterStreamEvent::Chunk(content))));
} else if let Some(reasoning_content) = reasoning_content
&& !reasoning_content.is_empty()
{
if self.options.capture_reasoning_content {
match self.captured_data.reasoning_content {
Some(ref mut c) => c.push_str(&reasoning_content),
None => self.captured_data.reasoning_content = Some(reasoning_content.clone()),
}
}
return Poll::Ready(Some(Ok(InterStreamEvent::ReasoningChunk(reasoning_content))));
}
tracing::warn!("EMPTY CHOICE CONTENT");
}
}
else {
if !matches!(adapter_kind, AdapterKind::Groq)
&& !matches!(adapter_kind, AdapterKind::DeepSeek)
&& self.captured_data.usage.is_none() && self.options.capture_usage
{
let usage = message_data
.x_take("usage")
.map(|v| OpenAIAdapter::into_usage(adapter_kind, v))
.unwrap_or_default();
self.captured_data.usage = Some(usage);
}
}
}
Some(Err(err)) => {
tracing::error!("Error: {}", err);
return Poll::Ready(Some(Err(Error::WebStream {
model_iden: self.options.model_iden.clone(),
cause: err.to_string(),
error: err,
})));
}
None => {
return Poll::Ready(None);
}
}
}
Poll::Pending
}
}