use std::pin::Pin;
use std::sync::Arc;
use futures::stream::StreamExt;
use tokio::sync::mpsc;
use crate::backends::{ChatChunk, FinishReason};
use crate::cancel_token::CancellationFlag;
use crate::flow_dispatcher::{DispatchError, DispatchCtx};
use crate::flow_execution_event::FlowExecutionEvent;
use crate::stream_effect::BackpressurePolicy;
use crate::stream_runtime::{Stream as PolicyStream, StreamError};
use crate::tool_trait::{ToolChunk, ToolFinishReason, ToolStream};
#[derive(Debug, Clone, Default)]
pub struct ToolStreamSummary {
pub tokens_emitted: u64,
pub output_hash_hex: String,
pub accumulated: String,
pub chunks_pushed: u64,
pub chunks_delivered: u64,
pub chunks_dropped: u64,
pub chunks_degraded: u64,
pub pause_upstream_blocks: u64,
pub fail_overflows: u64,
pub success: bool,
pub terminator_message: Option<String>,
pub cancelled: bool,
}
impl ToolStreamSummary {
pub fn is_clean_stop(&self) -> bool {
self.success && !self.cancelled && self.terminator_message.is_none()
}
}
pub fn chat_chunk_to_tool_chunk(chunk: ChatChunk) -> ToolChunk {
let finish_reason = chunk.finish_reason.map(|fr| match fr {
FinishReason::Stop => ToolFinishReason::Stop,
FinishReason::Length => ToolFinishReason::Error {
message: "length limit exceeded".to_string(),
},
FinishReason::ToolUse => ToolFinishReason::Stop,
FinishReason::SafetyBreach => ToolFinishReason::Error {
message: "safety classifier blocked output".to_string(),
},
FinishReason::Other(s) => ToolFinishReason::Error {
message: format!("unknown finish reason: {s}"),
},
});
ToolChunk {
delta: chunk.delta,
finish_reason,
timestamp_ms: crate::flow_execution_event::now_ms(),
}
}
pub fn unified_stream_from_chunks(chunks: Vec<ToolChunk>) -> ToolStream {
Box::pin(futures::stream::iter(chunks))
}
pub async fn unified_stream_handler(
source: ToolStream,
policy: Option<BackpressurePolicy>,
cancel: &CancellationFlag,
tx: &mpsc::UnboundedSender<FlowExecutionEvent>,
step_name: &str,
) -> Result<ToolStreamSummary, DispatchError> {
if cancel.is_cancelled() {
return Ok(ToolStreamSummary {
success: false,
cancelled: true,
output_hash_hex: sha256_hex(""),
..Default::default()
});
}
if let Some(p) = policy {
unified_drain_with_policy(source, p, cancel, tx, step_name).await
} else {
unified_drain_direct(source, cancel, tx, step_name).await
}
}
async fn unified_drain_direct(
mut source: Pin<Box<dyn futures::Stream<Item = ToolChunk> + Send>>,
cancel: &CancellationFlag,
tx: &mpsc::UnboundedSender<FlowExecutionEvent>,
step_name: &str,
) -> Result<ToolStreamSummary, DispatchError> {
let mut summary = ToolStreamSummary {
success: true,
..Default::default()
};
while let Some(chunk) = source.next().await {
summary.chunks_pushed += 1;
if cancel.is_cancelled() {
summary.cancelled = true;
summary.success = false;
break;
}
summary.chunks_delivered += 1;
if !chunk.delta.is_empty() {
summary.tokens_emitted += 1;
summary.accumulated.push_str(&chunk.delta);
tx.send(FlowExecutionEvent::StepToken {
step_name: step_name.to_string(),
content: chunk.delta.clone(),
token_index: summary.tokens_emitted,
timestamp_ms: crate::flow_execution_event::now_ms(),
})
.map_err(|_| DispatchError::ChannelClosed)?;
}
if let Some(reason) = chunk.finish_reason {
handle_terminator(reason, &mut summary);
break;
}
}
summary.output_hash_hex = sha256_hex(&summary.accumulated);
Ok(summary)
}
async fn unified_drain_with_policy(
source: ToolStream,
policy: BackpressurePolicy,
cancel: &CancellationFlag,
tx: &mpsc::UnboundedSender<FlowExecutionEvent>,
step_name: &str,
) -> Result<ToolStreamSummary, DispatchError> {
use crate::stream_effect::BackpressureAnnotation;
use crate::stream_effect_dispatcher::DEFAULT_STREAM_BUFFER_CAPACITY;
let annotation = BackpressureAnnotation {
policy,
options: Vec::new(),
};
let policy_stream: PolicyStream<ToolChunk> = match policy {
BackpressurePolicy::DegradeQuality => PolicyStream::with_degrader(
DEFAULT_STREAM_BUFFER_CAPACITY,
annotation,
Arc::new(|c| c),
),
BackpressurePolicy::DropOldest
| BackpressurePolicy::PauseUpstream
| BackpressurePolicy::Fail => PolicyStream::new(DEFAULT_STREAM_BUFFER_CAPACITY, annotation),
};
let producer_stream = policy_stream.clone();
let producer_cancel = cancel.clone();
let producer = tokio::spawn(async move {
let mut source = source;
let mut producer_failed: Option<StreamError> = None;
while let Some(chunk) = source.next().await {
if producer_cancel.is_cancelled() {
break;
}
if let Err(e) = producer_stream.push(chunk).await {
producer_failed = Some(e);
break;
}
}
producer_stream.close().await;
producer_failed
});
let mut summary = ToolStreamSummary {
success: true,
..Default::default()
};
while let Some(chunk) = policy_stream.pop().await {
if cancel.is_cancelled() {
summary.cancelled = true;
summary.success = false;
policy_stream.close().await;
break;
}
if !chunk.delta.is_empty() {
summary.tokens_emitted += 1;
summary.accumulated.push_str(&chunk.delta);
tx.send(FlowExecutionEvent::StepToken {
step_name: step_name.to_string(),
content: chunk.delta.clone(),
token_index: summary.tokens_emitted,
timestamp_ms: crate::flow_execution_event::now_ms(),
})
.map_err(|_| DispatchError::ChannelClosed)?;
}
if let Some(reason) = chunk.finish_reason {
handle_terminator(reason, &mut summary);
policy_stream.close().await;
break;
}
}
let producer_failed = producer.await.map_err(|e| DispatchError::BackendError {
name: "unified_stream:producer".to_string(),
message: format!("producer task join failed: {e}"),
})?;
let snap = policy_stream.metrics.as_ref();
use std::sync::atomic::Ordering;
summary.chunks_pushed = snap.items_pushed.load(Ordering::Relaxed);
summary.chunks_delivered = snap.items_delivered.load(Ordering::Relaxed);
summary.chunks_dropped = snap.drop_oldest_hits.load(Ordering::Relaxed);
summary.chunks_degraded = snap.degrade_quality_hits.load(Ordering::Relaxed);
summary.pause_upstream_blocks = snap.pause_upstream_blocks.load(Ordering::Relaxed);
summary.fail_overflows = snap.fail_overflows.load(Ordering::Relaxed);
if let Some(err) = producer_failed {
if let StreamError::Overflow { policy: p, .. } = err {
summary.success = false;
summary.terminator_message = Some(format!(
"stream overflow under policy {p}: producer hit capacity \
({} chunks pushed before overflow)",
summary.chunks_pushed
));
}
}
summary.output_hash_hex = sha256_hex(&summary.accumulated);
Ok(summary)
}
fn handle_terminator(reason: ToolFinishReason, summary: &mut ToolStreamSummary) {
match reason {
ToolFinishReason::Stop => { }
ToolFinishReason::Error { message } => {
summary.success = false;
summary.terminator_message = Some(message);
}
ToolFinishReason::Cancelled => {
summary.success = false;
summary.cancelled = true;
}
}
}
fn sha256_hex(input: &str) -> String {
use sha2::{Digest, Sha256};
use std::fmt::Write as _;
let mut hasher = Sha256::new();
hasher.update(input.as_bytes());
let digest = hasher.finalize();
let mut out = String::with_capacity(digest.len() * 2);
for byte in digest.as_slice() {
let _ = write!(out, "{byte:02x}");
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chat_chunk_to_tool_chunk_preserves_delta() {
let chat = ChatChunk {
delta: "hello world".to_string(),
finish_reason: None,
usage: None,
};
let tool = chat_chunk_to_tool_chunk(chat);
assert_eq!(tool.delta, "hello world");
assert_eq!(tool.finish_reason, None);
}
#[test]
fn chat_chunk_to_tool_chunk_maps_stop_to_stop() {
let chat = ChatChunk {
delta: "".to_string(),
finish_reason: Some(FinishReason::Stop),
usage: None,
};
let tool = chat_chunk_to_tool_chunk(chat);
assert_eq!(tool.finish_reason, Some(ToolFinishReason::Stop));
}
#[test]
fn chat_chunk_to_tool_chunk_maps_length_to_error() {
let chat = ChatChunk {
delta: "".to_string(),
finish_reason: Some(FinishReason::Length),
usage: None,
};
let tool = chat_chunk_to_tool_chunk(chat);
match tool.finish_reason {
Some(ToolFinishReason::Error { message }) => {
assert!(message.contains("length"));
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn chat_chunk_to_tool_chunk_maps_tooluse_to_stop() {
let chat = ChatChunk {
delta: "".to_string(),
finish_reason: Some(FinishReason::ToolUse),
usage: None,
};
let tool = chat_chunk_to_tool_chunk(chat);
assert_eq!(tool.finish_reason, Some(ToolFinishReason::Stop));
}
#[test]
fn chat_chunk_to_tool_chunk_maps_safety_to_error() {
let chat = ChatChunk {
delta: "".to_string(),
finish_reason: Some(FinishReason::SafetyBreach),
usage: None,
};
let tool = chat_chunk_to_tool_chunk(chat);
match tool.finish_reason {
Some(ToolFinishReason::Error { message }) => {
assert!(message.contains("safety"));
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn chat_chunk_to_tool_chunk_maps_other_to_error_carrying_raw() {
let chat = ChatChunk {
delta: "".to_string(),
finish_reason: Some(FinishReason::Other("custom-reason".to_string())),
usage: None,
};
let tool = chat_chunk_to_tool_chunk(chat);
match tool.finish_reason {
Some(ToolFinishReason::Error { message }) => {
assert!(message.contains("custom-reason"));
}
other => panic!("expected Error, got {other:?}"),
}
}
#[test]
fn unified_stream_from_chunks_yields_inputs_in_order() {
let chunks = vec![
ToolChunk::intermediate("a"),
ToolChunk::intermediate("b"),
ToolChunk::terminator("c", ToolFinishReason::Stop),
];
let stream = unified_stream_from_chunks(chunks);
let collected = futures::executor::block_on(async {
let mut s = stream;
let mut out = Vec::new();
while let Some(c) = s.next().await {
out.push(c);
}
out
});
assert_eq!(collected.len(), 3);
assert_eq!(collected[0].delta, "a");
assert_eq!(collected[1].delta, "b");
assert_eq!(collected[2].delta, "c");
assert!(collected[2].is_terminator());
}
#[test]
fn unified_stream_from_chunks_empty_vec_closes_immediately() {
let stream = unified_stream_from_chunks(Vec::new());
let collected = futures::executor::block_on(async {
let mut s = stream;
let mut out = Vec::new();
while let Some(c) = s.next().await {
out.push(c);
}
out
});
assert!(collected.is_empty());
}
#[test]
fn tool_stream_summary_is_clean_stop_truthy_default_plus_success() {
let s = ToolStreamSummary {
success: true,
cancelled: false,
terminator_message: None,
..Default::default()
};
assert!(s.is_clean_stop());
}
#[test]
fn tool_stream_summary_is_clean_stop_false_when_cancelled() {
let s = ToolStreamSummary {
success: false,
cancelled: true,
..Default::default()
};
assert!(!s.is_clean_stop());
}
#[test]
fn tool_stream_summary_is_clean_stop_false_when_terminator_error() {
let s = ToolStreamSummary {
success: false,
terminator_message: Some("upstream failed".to_string()),
..Default::default()
};
assert!(!s.is_clean_stop());
}
#[tokio::test]
async fn unified_handler_direct_drain_emits_step_tokens_in_order() {
let chunks = vec![
ToolChunk::intermediate("hello"),
ToolChunk::intermediate(" "),
ToolChunk::intermediate("world"),
ToolChunk::terminator("", ToolFinishReason::Stop),
];
let source = unified_stream_from_chunks(chunks);
let cancel = CancellationFlag::new();
let (tx, mut rx) = mpsc::unbounded_channel();
let summary = unified_stream_handler(source, None, &cancel, &tx, "TestStep")
.await
.expect("ok");
assert!(summary.success);
assert_eq!(summary.tokens_emitted, 3);
assert_eq!(summary.accumulated, "hello world");
assert_eq!(summary.chunks_dropped, 0);
assert_eq!(summary.chunks_degraded, 0);
let mut events = Vec::new();
while let Ok(ev) = rx.try_recv() {
events.push(ev);
}
let tokens: Vec<String> = events
.iter()
.filter_map(|e| match e {
FlowExecutionEvent::StepToken { content, .. } => Some(content.clone()),
_ => None,
})
.collect();
assert_eq!(tokens, vec!["hello", " ", "world"]);
}
#[tokio::test]
async fn unified_handler_with_drop_oldest_policy_under_burst_drops_chunks() {
let n: u64 = 200;
let chunks: Vec<ToolChunk> = (0..n)
.map(|i| ToolChunk::intermediate(format!("c{i}")))
.chain(std::iter::once(ToolChunk::terminator(
"",
ToolFinishReason::Stop,
)))
.collect();
let source = unified_stream_from_chunks(chunks);
let cancel = CancellationFlag::new();
let (tx, _rx) = mpsc::unbounded_channel();
let summary = unified_stream_handler(
source,
Some(BackpressurePolicy::DropOldest),
&cancel,
&tx,
"DropTest",
)
.await
.expect("ok");
assert!(summary.success, "DropOldest should never fail");
assert!(summary.chunks_pushed >= n);
assert!(summary.chunks_delivered <= summary.chunks_pushed);
}
#[tokio::test]
async fn unified_handler_with_fail_policy_under_overflow_surfaces_error() {
let n: u64 = 10_000;
let chunks: Vec<ToolChunk> = (0..n)
.map(|i| ToolChunk::intermediate(format!("c{i}")))
.collect();
let source = unified_stream_from_chunks(chunks);
let cancel = CancellationFlag::new();
let (tx, rx) = mpsc::unbounded_channel();
drop(rx);
let summary = unified_stream_handler(
source,
Some(BackpressurePolicy::Fail),
&cancel,
&tx,
"FailTest",
)
.await;
match summary {
Ok(s) => {
if s.fail_overflows > 0 {
assert!(!s.success);
assert!(s.terminator_message.is_some());
}
}
Err(DispatchError::ChannelClosed) => {
}
Err(other) => panic!("unexpected error: {other:?}"),
}
}
#[tokio::test]
async fn unified_handler_pre_cancel_short_circuits() {
let chunks = vec![
ToolChunk::intermediate("a"),
ToolChunk::intermediate("b"),
ToolChunk::terminator("", ToolFinishReason::Stop),
];
let source = unified_stream_from_chunks(chunks);
let cancel = CancellationFlag::new();
cancel.cancel();
let (tx, _rx) = mpsc::unbounded_channel();
let summary = unified_stream_handler(source, None, &cancel, &tx, "PreCancel")
.await
.expect("ok");
assert!(summary.cancelled);
assert!(!summary.success);
}
#[tokio::test]
async fn unified_handler_error_terminator_surfaces_in_summary() {
let chunks = vec![
ToolChunk::intermediate("partial"),
ToolChunk::terminator(
"",
ToolFinishReason::Error {
message: "upstream failed".to_string(),
},
),
];
let source = unified_stream_from_chunks(chunks);
let cancel = CancellationFlag::new();
let (tx, _rx) = mpsc::unbounded_channel();
let summary = unified_stream_handler(source, None, &cancel, &tx, "ErrTerm")
.await
.expect("ok");
assert!(!summary.success);
assert!(!summary.cancelled);
assert_eq!(summary.terminator_message.as_deref(), Some("upstream failed"));
assert_eq!(summary.tokens_emitted, 1); }
#[tokio::test]
async fn unified_handler_cancelled_terminator_surfaces_in_summary() {
let chunks = vec![
ToolChunk::intermediate("partial"),
ToolChunk::terminator("", ToolFinishReason::Cancelled),
];
let source = unified_stream_from_chunks(chunks);
let cancel = CancellationFlag::new();
let (tx, _rx) = mpsc::unbounded_channel();
let summary = unified_stream_handler(source, None, &cancel, &tx, "CancTerm")
.await
.expect("ok");
assert!(summary.cancelled);
assert!(!summary.success);
}
#[test]
fn sha256_hex_matches_canonical_for_empty_string() {
assert_eq!(
sha256_hex(""),
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
);
}
#[test]
fn sha256_hex_matches_canonical_for_abc() {
assert_eq!(
sha256_hex("abc"),
"ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
);
}
}