codetether_agent/session/helper/stream.rs
1//! Stream-completion collection with incremental event forwarding.
2//!
3//! [`collect_stream_completion_with_events`] drains a provider stream,
4//! accumulates text and tool-call deltas into a final
5//! [`CompletionResponse`](crate::provider::CompletionResponse), and (optionally)
6//! forwards incremental [`SessionEvent::TextChunk`] snapshots to a UI layer.
7//!
8//! ## Snapshot truncation
9//!
10//! Each text chunk forwarded over `event_tx` is a **full snapshot** of the
11//! accumulated assistant text so far. For extremely long replies this would be
12//! O(n²) in memory; to bound the worst case the snapshot is capped at
13//! [`MAX_STREAM_SNAPSHOT_BYTES`] with a trailing `" …[truncated]"` marker.
14//! The full text is still returned in the final [`CompletionResponse`]; only
15//! the streamed previews are truncated.
16
17use super::super::SessionEvent;
18use crate::provider::{ContentPart, FinishReason, Message, Role, StreamChunk, Usage};
19use anyhow::Result;
20use futures::StreamExt;
21use futures::stream::BoxStream;
22use std::collections::HashMap;
23
24/// Maximum bytes forwarded per [`SessionEvent::TextChunk`] snapshot.
25///
26/// Bounds worst-case memory for runaway providers to O(n) rather than O(n²)
27/// across the streaming lifetime. The final response is **not** truncated.
28///
29/// # Examples
30///
31/// ```rust
32/// use codetether_agent::session::helper::stream::MAX_STREAM_SNAPSHOT_BYTES;
33/// assert_eq!(MAX_STREAM_SNAPSHOT_BYTES, 256 * 1024);
34/// ```
35pub const MAX_STREAM_SNAPSHOT_BYTES: usize = 256 * 1024;
36
37#[derive(Default)]
38struct ToolAccumulator {
39 id: String,
40 name: String,
41 arguments: String,
42}
43
44/// Collect a streaming completion into a [`CompletionResponse`](crate::provider::CompletionResponse),
45/// optionally forwarding incremental events.
46///
47/// Reads [`StreamChunk`]s from `stream`, accumulates assistant text and
48/// tool-call argument deltas keyed by tool-call id, and tracks the final
49/// [`FinishReason`] and [`Usage`]. When `event_tx` is `Some`, each text delta
50/// triggers a [`SessionEvent::TextChunk`] carrying the full accumulated text
51/// up to that point — truncated to [`MAX_STREAM_SNAPSHOT_BYTES`] with a
52/// `" …[truncated]"` suffix when exceeded.
53///
54/// # Arguments
55///
56/// * `stream` — Boxed async stream of [`StreamChunk`]s from a provider.
57/// * `event_tx` — Optional channel for UI preview events; pass `None` for
58/// headless/non-interactive callers.
59///
60/// # Returns
61///
62/// A fully materialized [`CompletionResponse`](crate::provider::CompletionResponse)
63/// containing the complete assistant text and any accumulated tool calls.
64///
65/// # Errors
66///
67/// Returns [`anyhow::Error`] if the stream yields a terminal error chunk or if
68/// response assembly fails.
69///
70/// # Examples
71///
72/// ```rust,no_run
73/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
74/// use codetether_agent::session::helper::stream::collect_stream_completion_with_events;
75/// use futures::stream;
76///
77/// // In practice the stream comes from a Provider::stream() call.
78/// let s = Box::pin(stream::empty());
79/// let response = collect_stream_completion_with_events(s, None).await.unwrap();
80/// // `response` is a CompletionResponse; inspect it as needed.
81/// let _ = response;
82/// # });
83/// ```
84pub async fn collect_stream_completion_with_events(
85 mut stream: BoxStream<'static, StreamChunk>,
86 event_tx: Option<&tokio::sync::mpsc::Sender<SessionEvent>>,
87) -> Result<crate::provider::CompletionResponse> {
88 let mut text = String::new();
89 let mut tools = Vec::<ToolAccumulator>::new();
90 let mut tool_index_by_id = HashMap::<String, usize>::new();
91 let mut usage = Usage::default();
92
93 while let Some(chunk) = stream.next().await {
94 match chunk {
95 StreamChunk::Text(delta) => {
96 if delta.is_empty() {
97 continue;
98 }
99 text.push_str(&delta);
100 if let Some(tx) = event_tx {
101 let to_send = if text.len() > MAX_STREAM_SNAPSHOT_BYTES {
102 let mut t =
103 crate::util::truncate_bytes_safe(&text, MAX_STREAM_SNAPSHOT_BYTES)
104 .to_string();
105 t.push_str(" …[truncated]");
106 t
107 } else {
108 text.clone()
109 };
110 let _ = tx.send(SessionEvent::TextChunk(to_send)).await;
111 }
112 }
113 StreamChunk::ToolCallStart { id, name } => {
114 let next_idx = tools.len();
115 let idx = *tool_index_by_id.entry(id.clone()).or_insert(next_idx);
116 if idx == next_idx {
117 tools.push(ToolAccumulator {
118 id,
119 name,
120 arguments: String::new(),
121 });
122 } else if tools[idx].name == "tool" {
123 tools[idx].name = name;
124 }
125 }
126 StreamChunk::ToolCallDelta {
127 id,
128 arguments_delta,
129 } => {
130 if let Some(idx) = tool_index_by_id.get(&id).copied() {
131 tools[idx].arguments.push_str(&arguments_delta);
132 } else {
133 let idx = tools.len();
134 tool_index_by_id.insert(id.clone(), idx);
135 tools.push(ToolAccumulator {
136 id,
137 name: "tool".to_string(),
138 arguments: arguments_delta,
139 });
140 }
141 }
142 StreamChunk::ToolCallEnd { .. } => {}
143 StreamChunk::Done { usage: done_usage } => {
144 if let Some(done_usage) = done_usage {
145 usage = done_usage;
146 }
147 }
148 StreamChunk::Error(message) => anyhow::bail!(message),
149 }
150 }
151
152 let mut content = Vec::new();
153 if !text.is_empty() {
154 content.push(ContentPart::Text { text });
155 }
156 for tool in tools {
157 content.push(ContentPart::ToolCall {
158 id: tool.id,
159 name: tool.name,
160 arguments: tool.arguments,
161 thought_signature: None,
162 });
163 }
164
165 let finish_reason = if content
166 .iter()
167 .any(|part| matches!(part, ContentPart::ToolCall { .. }))
168 {
169 FinishReason::ToolCalls
170 } else {
171 FinishReason::Stop
172 };
173
174 Ok(crate::provider::CompletionResponse {
175 message: Message {
176 role: Role::Assistant,
177 content,
178 },
179 usage,
180 finish_reason,
181 })
182}