Skip to main content

codetether_agent/session/helper/
stream.rs

1//! Stream-completion collection with incremental event forwarding.
2//!
3//! [`collect_stream_completion_with_events`] drains a provider stream,
4//! accumulates text and tool-call deltas into a final
5//! [`CompletionResponse`](crate::provider::CompletionResponse), and (optionally)
6//! forwards incremental [`SessionEvent::TextChunk`] snapshots to a UI layer.
7//!
8//! ## Snapshot truncation
9//!
10//! Each text chunk forwarded over `event_tx` is a **full snapshot** of the
11//! accumulated assistant text so far. For extremely long replies this would be
12//! O(n²) in memory; to bound the worst case the snapshot is capped at
13//! [`MAX_STREAM_SNAPSHOT_BYTES`] with a trailing `" …[truncated]"` marker.
14//! The full text is still returned in the final [`CompletionResponse`]; only
15//! the streamed previews are truncated.
16
17use super::super::SessionEvent;
18use crate::provider::{ContentPart, FinishReason, Message, Role, StreamChunk, Usage};
19use anyhow::Result;
20use futures::StreamExt;
21use futures::stream::BoxStream;
22use std::collections::HashMap;
23
24/// Maximum bytes forwarded per [`SessionEvent::TextChunk`] snapshot.
25///
26/// Bounds worst-case memory for runaway providers to O(n) rather than O(n²)
27/// across the streaming lifetime. The final response is **not** truncated.
28///
29/// # Examples
30///
31/// ```rust
32/// use codetether_agent::session::helper::stream::MAX_STREAM_SNAPSHOT_BYTES;
33/// assert_eq!(MAX_STREAM_SNAPSHOT_BYTES, 256 * 1024);
34/// ```
35pub const MAX_STREAM_SNAPSHOT_BYTES: usize = 256 * 1024;
36
37#[derive(Default)]
38struct ToolAccumulator {
39    id: String,
40    name: String,
41    arguments: String,
42}
43
44/// Collect a streaming completion into a [`CompletionResponse`](crate::provider::CompletionResponse),
45/// optionally forwarding incremental events.
46///
47/// Reads [`StreamChunk`]s from `stream`, accumulates assistant text and
48/// tool-call argument deltas keyed by tool-call id, and tracks the final
49/// [`FinishReason`] and [`Usage`]. When `event_tx` is `Some`, each text delta
50/// triggers a [`SessionEvent::TextChunk`] carrying the full accumulated text
51/// up to that point — truncated to [`MAX_STREAM_SNAPSHOT_BYTES`] with a
52/// `" …[truncated]"` suffix when exceeded.
53///
54/// # Arguments
55///
56/// * `stream` — Boxed async stream of [`StreamChunk`]s from a provider.
57/// * `event_tx` — Optional channel for UI preview events; pass `None` for
58///   headless/non-interactive callers.
59///
60/// # Returns
61///
62/// A fully materialized [`CompletionResponse`](crate::provider::CompletionResponse)
63/// containing the complete assistant text and any accumulated tool calls.
64///
65/// # Errors
66///
67/// Returns [`anyhow::Error`] if the stream yields a terminal error chunk or if
68/// response assembly fails.
69///
70/// # Examples
71///
72/// ```rust,no_run
73/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
74/// use codetether_agent::session::helper::stream::collect_stream_completion_with_events;
75/// use futures::stream;
76///
77/// // In practice the stream comes from a Provider::stream() call.
78/// let s = Box::pin(stream::empty());
79/// let response = collect_stream_completion_with_events(s, None).await.unwrap();
80/// // `response` is a CompletionResponse; inspect it as needed.
81/// let _ = response;
82/// # });
83/// ```
84pub async fn collect_stream_completion_with_events(
85    mut stream: BoxStream<'static, StreamChunk>,
86    event_tx: Option<&tokio::sync::mpsc::Sender<SessionEvent>>,
87) -> Result<crate::provider::CompletionResponse> {
88    let mut text = String::new();
89    let mut tools = Vec::<ToolAccumulator>::new();
90    let mut tool_index_by_id = HashMap::<String, usize>::new();
91    let mut usage = Usage::default();
92
93    while let Some(chunk) = stream.next().await {
94        match chunk {
95            StreamChunk::Text(delta) => {
96                if delta.is_empty() {
97                    continue;
98                }
99                text.push_str(&delta);
100                if let Some(tx) = event_tx {
101                    let to_send = if text.len() > MAX_STREAM_SNAPSHOT_BYTES {
102                        let mut t =
103                            crate::util::truncate_bytes_safe(&text, MAX_STREAM_SNAPSHOT_BYTES)
104                                .to_string();
105                        t.push_str(" …[truncated]");
106                        t
107                    } else {
108                        text.clone()
109                    };
110                    let _ = tx.send(SessionEvent::TextChunk(to_send)).await;
111                }
112            }
113            StreamChunk::ToolCallStart { id, name } => {
114                let next_idx = tools.len();
115                let idx = *tool_index_by_id.entry(id.clone()).or_insert(next_idx);
116                if idx == next_idx {
117                    tools.push(ToolAccumulator {
118                        id,
119                        name,
120                        arguments: String::new(),
121                    });
122                } else if tools[idx].name == "tool" {
123                    tools[idx].name = name;
124                }
125            }
126            StreamChunk::ToolCallDelta {
127                id,
128                arguments_delta,
129            } => {
130                if let Some(idx) = tool_index_by_id.get(&id).copied() {
131                    tools[idx].arguments.push_str(&arguments_delta);
132                } else {
133                    let idx = tools.len();
134                    tool_index_by_id.insert(id.clone(), idx);
135                    tools.push(ToolAccumulator {
136                        id,
137                        name: "tool".to_string(),
138                        arguments: arguments_delta,
139                    });
140                }
141            }
142            StreamChunk::ToolCallEnd { .. } => {}
143            StreamChunk::Done { usage: done_usage } => {
144                if let Some(done_usage) = done_usage {
145                    usage = done_usage;
146                }
147            }
148            StreamChunk::Error(message) => anyhow::bail!(message),
149        }
150    }
151
152    let mut content = Vec::new();
153    if !text.is_empty() {
154        content.push(ContentPart::Text { text });
155    }
156    for tool in tools {
157        content.push(ContentPart::ToolCall {
158            id: tool.id,
159            name: tool.name,
160            arguments: tool.arguments,
161            thought_signature: None,
162        });
163    }
164
165    let finish_reason = if content
166        .iter()
167        .any(|part| matches!(part, ContentPart::ToolCall { .. }))
168    {
169        FinishReason::ToolCalls
170    } else {
171        FinishReason::Stop
172    };
173
174    Ok(crate::provider::CompletionResponse {
175        message: Message {
176            role: Role::Assistant,
177            content,
178        },
179        usage,
180        finish_reason,
181    })
182}