use std::pin::Pin;
use std::task::{Context, Poll};
use futures::Stream;
use futures::StreamExt;
use futures::future::BoxFuture;
use tokio::sync::oneshot;
use crate::codecs::BoxDeltaStream;
use crate::error::{Error, Result};
use crate::ir::{
ContentPart, ModelResponse, ModelWarning, ProviderEchoSnapshot, StopReason, Usage,
};
use crate::rate_limit::RateLimitSnapshot;
use crate::service::ModelStream;
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum StreamDelta {
Start {
id: String,
model: String,
provider_echoes: Vec<ProviderEchoSnapshot>,
},
TextDelta {
text: String,
provider_echoes: Vec<ProviderEchoSnapshot>,
},
ThinkingDelta {
text: String,
provider_echoes: Vec<ProviderEchoSnapshot>,
},
ToolUseStart {
id: String,
name: String,
provider_echoes: Vec<ProviderEchoSnapshot>,
},
ToolUseInputDelta {
partial_json: String,
},
ToolUseStop,
Usage(Usage),
RateLimit(RateLimitSnapshot),
Warning(ModelWarning),
Stop {
stop_reason: StopReason,
},
}
struct PendingTool {
id: String,
name: String,
input_buffer: String,
provider_echoes: Vec<ProviderEchoSnapshot>,
}
#[derive(Default)]
struct PendingThinking {
text: String,
provider_echoes: Vec<ProviderEchoSnapshot>,
}
#[derive(Default)]
struct PendingText {
text: String,
provider_echoes: Vec<ProviderEchoSnapshot>,
}
#[derive(Default)]
pub struct StreamAggregator {
id: String,
model: String,
parts: Vec<ContentPart>,
open_text: Option<PendingText>,
open_thinking: Option<PendingThinking>,
pending_tool: Option<PendingTool>,
usage: Option<Usage>,
rate_limit: Option<RateLimitSnapshot>,
stop_reason: Option<StopReason>,
warnings: Vec<ModelWarning>,
response_echoes: Vec<ProviderEchoSnapshot>,
}
impl StreamAggregator {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, delta: StreamDelta) -> Result<()> {
match delta {
StreamDelta::Start {
id,
model,
provider_echoes,
} => {
if !self.id.is_empty() || !self.model.is_empty() {
return Err(Error::invalid_request(
"StreamAggregator: duplicate Start delta",
));
}
self.id = id;
self.model = model;
self.response_echoes.extend(provider_echoes);
}
StreamDelta::TextDelta {
text,
provider_echoes,
} => {
if self.pending_tool.is_some() {
return Err(Error::invalid_request(
"StreamAggregator: TextDelta during open tool_use block",
));
}
self.flush_thinking();
let pending = self.open_text.get_or_insert_with(PendingText::default);
pending.text.push_str(&text);
pending.provider_echoes.extend(provider_echoes);
}
StreamDelta::ThinkingDelta {
text,
provider_echoes,
} => {
if self.pending_tool.is_some() {
return Err(Error::invalid_request(
"StreamAggregator: ThinkingDelta during open tool_use block",
));
}
self.flush_text();
let pending = self
.open_thinking
.get_or_insert_with(PendingThinking::default);
pending.text.push_str(&text);
pending.provider_echoes.extend(provider_echoes);
}
StreamDelta::ToolUseStart {
id,
name,
provider_echoes,
} => {
if self.pending_tool.is_some() {
return Err(Error::invalid_request(
"StreamAggregator: ToolUseStart while another tool block is open",
));
}
self.flush_text();
self.flush_thinking();
self.pending_tool = Some(PendingTool {
id,
name,
input_buffer: String::new(),
provider_echoes,
});
}
StreamDelta::ToolUseInputDelta { partial_json } => {
let pending = self.pending_tool.as_mut().ok_or_else(|| {
Error::invalid_request(
"StreamAggregator: ToolUseInputDelta with no open tool block",
)
})?;
pending.input_buffer.push_str(&partial_json);
}
StreamDelta::ToolUseStop => self.close_tool_block()?,
StreamDelta::Usage(u) => self.usage = Some(u),
StreamDelta::RateLimit(r) => self.rate_limit = Some(r),
StreamDelta::Warning(w) => self.warnings.push(w),
StreamDelta::Stop { stop_reason } => {
if self.stop_reason.is_some() {
return Err(Error::invalid_request(
"StreamAggregator: duplicate Stop delta — terminal state already set",
));
}
self.stop_reason = Some(stop_reason);
}
}
Ok(())
}
pub const fn is_finished(&self) -> bool {
self.stop_reason.is_some()
}
pub fn finalize(mut self) -> Result<ModelResponse> {
if self.pending_tool.is_some() {
return Err(Error::invalid_request(
"StreamAggregator: stream ended with an open tool block",
));
}
let stop_reason = self.stop_reason.take().ok_or_else(|| {
Error::invalid_request("StreamAggregator: stream ended without Stop delta")
})?;
self.flush_text();
self.flush_thinking();
if self.usage.is_none() {
self.warnings.push(crate::ir::ModelWarning::LossyEncode {
field: "usage".to_owned(),
detail: "streaming response closed without Usage delta — cost will be zero"
.to_owned(),
});
}
Ok(ModelResponse {
id: self.id,
model: self.model,
stop_reason,
content: self.parts,
usage: self.usage.unwrap_or_default(),
rate_limit: self.rate_limit,
warnings: self.warnings,
provider_echoes: self.response_echoes,
})
}
fn close_tool_block(&mut self) -> Result<()> {
let pending = self.pending_tool.take().ok_or_else(|| {
Error::invalid_request("StreamAggregator: ToolUseStop with no open tool block")
})?;
let input: serde_json::Value = if pending.input_buffer.is_empty() {
serde_json::json!({})
} else {
serde_json::from_str(&pending.input_buffer).map_err(|e| {
Error::invalid_request(format!(
"StreamAggregator: ToolUse '{}' (id={}) arguments are not valid JSON: \
{e}; buffered={:?}",
pending.name,
pending.id,
truncate_for_diagnostic(&pending.input_buffer),
))
})?
};
self.parts.push(ContentPart::ToolUse {
id: pending.id,
name: pending.name,
input,
provider_echoes: pending.provider_echoes,
});
Ok(())
}
fn flush_text(&mut self) {
if let Some(pending) = self.open_text.take()
&& !(pending.text.is_empty() && pending.provider_echoes.is_empty())
{
self.parts.push(ContentPart::Text {
text: pending.text,
cache_control: None,
provider_echoes: pending.provider_echoes,
});
}
}
fn flush_thinking(&mut self) {
if let Some(pending) = self.open_thinking.take()
&& !(pending.text.is_empty() && pending.provider_echoes.is_empty())
{
self.parts.push(ContentPart::Thinking {
text: pending.text,
cache_control: None,
provider_echoes: pending.provider_echoes,
});
}
}
}
pub fn tap_aggregator(inner: BoxDeltaStream<'static>) -> ModelStream {
let (tx, rx) = oneshot::channel::<Result<ModelResponse>>();
let tap = AggregatorTap {
inner,
agg: StreamAggregator::new(),
completion: Some(tx),
terminated: false,
};
ModelStream {
stream: Box::pin(tap),
completion: Box::pin(async move {
match rx.await {
Ok(result) => result,
Err(_) => Err(Error::Cancelled),
}
}) as BoxFuture<'static, Result<ModelResponse>>,
}
}
struct AggregatorTap {
inner: BoxDeltaStream<'static>,
agg: StreamAggregator,
completion: Option<oneshot::Sender<Result<ModelResponse>>>,
terminated: bool,
}
impl AggregatorTap {
fn finalize(&mut self, outcome: Result<ModelResponse>) {
if let Some(tx) = self.completion.take() {
let _ = tx.send(outcome);
}
}
}
impl Stream for AggregatorTap {
type Item = Result<StreamDelta>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.terminated {
return Poll::Ready(None);
}
match self.inner.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => {
let agg = std::mem::take(&mut self.agg);
let outcome = agg.finalize();
self.finalize(outcome);
self.terminated = true;
Poll::Ready(None)
}
Poll::Ready(Some(Err(e))) => {
let cloned = clone_error(&e);
self.finalize(Err(cloned));
self.terminated = true;
Poll::Ready(Some(Err(e)))
}
Poll::Ready(Some(Ok(delta))) => {
let is_stop = matches!(delta, StreamDelta::Stop { .. });
if let Err(e) = self.agg.push(delta.clone()) {
let cloned = clone_error(&e);
self.finalize(Err(cloned));
self.terminated = true;
return Poll::Ready(Some(Err(e)));
}
if is_stop {
let agg = std::mem::take(&mut self.agg);
let outcome = agg.finalize();
self.finalize(outcome);
self.terminated = true;
}
Poll::Ready(Some(Ok(delta)))
}
}
}
}
impl Drop for AggregatorTap {
fn drop(&mut self) {
if self.completion.is_some() {
self.finalize(Err(Error::Cancelled));
}
}
}
fn clone_error(e: &Error) -> Error {
use crate::error::ProviderErrorKind;
match e {
Error::InvalidRequest(msg) => Error::invalid_request(msg.clone()),
Error::Config(msg) => Error::config(msg.clone()),
Error::Provider {
kind,
message,
retry_after,
..
} => {
let cloned = match kind {
ProviderErrorKind::Network => Error::provider_network(message.clone()),
ProviderErrorKind::Tls => Error::provider_tls(message.clone()),
ProviderErrorKind::Dns => Error::provider_dns(message.clone()),
ProviderErrorKind::Http(status) => Error::provider_http(*status, message.clone()),
};
match retry_after {
Some(after) => cloned.with_retry_after(*after),
None => cloned,
}
}
Error::Auth(_) => Error::config("authentication failed (cloned for stream completion)"),
Error::Cancelled => Error::Cancelled,
Error::DeadlineExceeded => Error::DeadlineExceeded,
Error::Interrupted { kind, payload } => Error::Interrupted {
kind: kind.clone(),
payload: payload.clone(),
},
Error::Serde(_) => {
Error::invalid_request("output serialisation failed (cloned for stream completion)")
}
Error::UsageLimitExceeded(breach) => Error::UsageLimitExceeded(breach.clone()),
Error::ModelRetry { hint, attempt } => Error::ModelRetry {
hint: hint.clone(),
attempt: *attempt,
},
}
}
fn truncate_for_diagnostic(s: &str) -> String {
const BUDGET: usize = 256;
if s.len() <= BUDGET {
return s.to_owned();
}
let mut cut = BUDGET;
while cut > 0 && !s.is_char_boundary(cut) {
cut -= 1;
}
format!("{}…", &s[..cut])
}