use std::collections::BTreeMap;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_util::stream::BoxStream;
use futures_util::{Stream, StreamExt};
use crate::error::OpenRouterError;
use crate::types::completion::{
CompletionsResponse, FinishReason, FunctionCall, PartialToolCall, ReasoningDetail,
ResponseUsage, ToolCall,
};
#[derive(Debug)]
pub enum StreamEvent {
ContentDelta(String),
ReasoningDelta(String),
ReasoningDetailsDelta(Vec<ReasoningDetail>),
Done {
tool_calls: Vec<ToolCall>,
finish_reason: Option<FinishReason>,
usage: Option<ResponseUsage>,
id: String,
model: String,
},
Error(OpenRouterError),
}
#[derive(Debug, Clone, Default)]
struct ToolCallAccumulator {
id: Option<String>,
type_: Option<String>,
name: Option<String>,
arguments: String,
}
impl ToolCallAccumulator {
fn merge(&mut self, partial: &PartialToolCall) {
if let Some(id) = &partial.id {
self.id = Some(id.clone());
}
if let Some(type_) = &partial.type_ {
self.type_ = Some(type_.clone());
}
if let Some(func) = &partial.function {
if let Some(name) = &func.name {
self.name = Some(name.clone());
}
if let Some(args) = &func.arguments {
self.arguments.push_str(args);
}
}
}
fn into_tool_call(self) -> Option<ToolCall> {
Some(ToolCall {
id: self.id?,
type_: self.type_.unwrap_or_else(|| "function".to_string()),
function: FunctionCall {
name: self.name?,
arguments: self.arguments,
},
index: None,
})
}
}
pub struct ToolAwareStream {
inner: BoxStream<'static, Result<CompletionsResponse, OpenRouterError>>,
tool_accumulators: BTreeMap<u32, ToolCallAccumulator>,
pending_events: Vec<StreamEvent>,
last_id: String,
last_model: String,
last_usage: Option<ResponseUsage>,
last_finish_reason: Option<FinishReason>,
finished: bool,
}
impl ToolAwareStream {
pub fn new(inner: BoxStream<'static, Result<CompletionsResponse, OpenRouterError>>) -> Self {
Self {
inner,
tool_accumulators: BTreeMap::new(),
pending_events: Vec::new(),
last_id: String::new(),
last_model: String::new(),
last_usage: None,
last_finish_reason: None,
finished: false,
}
}
fn process_chunk(&mut self, response: CompletionsResponse) {
self.last_id.clone_from(&response.id);
self.last_model.clone_from(&response.model);
if response.usage.is_some() {
self.last_usage = response.usage;
}
for choice in &response.choices {
if let Some(reason) = choice.finish_reason() {
self.last_finish_reason = Some(reason.clone());
}
if let Some(content) = choice.content() {
if !content.is_empty() {
self.pending_events
.push(StreamEvent::ContentDelta(content.to_string()));
}
}
if let Some(reasoning) = choice.reasoning() {
if !reasoning.is_empty() {
self.pending_events
.push(StreamEvent::ReasoningDelta(reasoning.to_string()));
}
}
if let Some(details) = choice.reasoning_details() {
if !details.is_empty() {
self.pending_events
.push(StreamEvent::ReasoningDetailsDelta(details.to_vec()));
}
}
if let Some(partial_tool_calls) = choice.partial_tool_calls() {
for partial in partial_tool_calls {
let idx = partial.index.unwrap_or(0);
let acc = self.tool_accumulators.entry(idx).or_default();
acc.merge(partial);
}
}
}
}
fn finalize(&mut self) {
let tool_calls: Vec<ToolCall> = self
.tool_accumulators
.values()
.cloned()
.filter_map(|acc| acc.into_tool_call())
.collect();
self.pending_events.push(StreamEvent::Done {
tool_calls,
finish_reason: self.last_finish_reason.take(),
usage: self.last_usage.take(),
id: std::mem::take(&mut self.last_id),
model: std::mem::take(&mut self.last_model),
});
self.finished = true;
}
}
impl Stream for ToolAwareStream {
type Item = StreamEvent;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if !self.pending_events.is_empty() {
return Poll::Ready(Some(self.pending_events.remove(0)));
}
if self.finished {
return Poll::Ready(None);
}
match self.inner.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(response))) => {
self.process_chunk(response);
if !self.pending_events.is_empty() {
Poll::Ready(Some(self.pending_events.remove(0)))
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(StreamEvent::Error(e))),
Poll::Ready(None) => {
if !self.finished {
self.finalize();
if !self.pending_events.is_empty() {
Poll::Ready(Some(self.pending_events.remove(0)))
} else {
Poll::Ready(None)
}
} else {
Poll::Ready(None)
}
}
Poll::Pending => Poll::Pending,
}
}
}