Skip to main content

llm_stack/tool/
loop_stream.rs

1//! Streaming tool loop implementation.
2//!
3//! Returns a [`LoopStream`] — a unified stream of [`LoopEvent`]s that
4//! includes both LLM streaming events (text deltas, tool call fragments)
5//! and loop-level lifecycle events (iteration boundaries, tool execution
6//! progress). Terminates with [`LoopEvent::Done`] carrying the final
7//! [`ToolLoopResult`].
8//!
9//! Between iterations (when executing tools), the stream emits
10//! `ToolExecutionStart` and `ToolExecutionEnd` events. No LLM deltas
11//! are emitted during this phase.
12
13use std::collections::VecDeque;
14use std::sync::Arc;
15
16use futures::StreamExt;
17
18use crate::chat::{ChatResponse, ContentBlock, StopReason, ToolCall};
19use crate::error::LlmError;
20use crate::provider::{ChatParams, DynProvider};
21use crate::stream::{ChatStream, StreamEvent};
22use crate::usage::Usage;
23
24use super::LoopDepth;
25use super::ToolRegistry;
26use super::config::{LoopEvent, LoopStream, ToolLoopConfig};
27use super::loop_core::{IterationOutcome, LoopCore, StartOutcome};
28
29/// Streaming variant of [`tool_loop`](super::tool_loop).
30///
31/// Yields [`LoopEvent`]s from each iteration. LLM streaming events
32/// (text deltas, tool call fragments) are interleaved with loop-level
33/// events (iteration start, tool execution start/end). The stream
34/// terminates with [`LoopEvent::Done`] carrying the final
35/// [`ToolLoopResult`](super::ToolLoopResult).
36///
37/// # Depth Tracking
38///
39/// If `Ctx` implements [`LoopDepth`], nested calls are tracked automatically.
40/// When `config.max_depth` is set and the context's depth exceeds the limit,
41/// yields `Err(LlmError::MaxDepthExceeded)` immediately.
42///
43/// Uses `Arc` for provider, registry, and context since they must outlive
44/// the returned stream.
45#[allow(clippy::needless_pass_by_value)] // ctx Arc is consumed into LoopCore
46pub fn tool_loop_stream<Ctx: LoopDepth + Send + Sync + 'static>(
47    provider: Arc<dyn DynProvider>,
48    registry: Arc<ToolRegistry<Ctx>>,
49    params: ChatParams,
50    config: ToolLoopConfig,
51    ctx: Arc<Ctx>,
52) -> LoopStream {
53    let core = LoopCore::new(params, config, &*ctx);
54
55    let state = UnfoldState {
56        core,
57        provider,
58        registry,
59        phase: StreamPhase::StartIteration,
60        current_text: String::new(),
61        current_tool_calls: Vec::new(),
62        current_usage: Usage::default(),
63        pending_events: VecDeque::new(),
64    };
65
66    let stream = futures::stream::unfold(state, |mut state| async move {
67        loop {
68            // First, drain any pending events (from LoopCore's event buffer)
69            if let Some(event) = state.pending_events.pop_front() {
70                return Some((event, state));
71            }
72
73            match std::mem::replace(&mut state.phase, StreamPhase::Done) {
74                StreamPhase::Done => return None,
75
76                StreamPhase::StartIteration => {
77                    match state.core.start_iteration(&*state.provider).await {
78                        StartOutcome::Stream(s) => {
79                            state.current_text.clear();
80                            state.current_tool_calls.clear();
81                            state.current_usage = Usage::default();
82                            // Drain IterationStart event from core
83                            state.load_core_events();
84                            state.phase = StreamPhase::Streaming(s);
85                        }
86                        StartOutcome::Terminal(outcome) => {
87                            // Drain any events (e.g., Done from finish())
88                            state.load_core_events();
89                            if let Some(event) = outcome_to_error(*outcome) {
90                                state.phase = StreamPhase::Done;
91                                // Push error, then let pending_events drain
92                                state.pending_events.push_back(event);
93                            }
94                            // Continue loop to drain pending_events
95                        }
96                    }
97                }
98
99                StreamPhase::Streaming(mut stream) => match stream.next().await {
100                    Some(Ok(event)) => {
101                        // Accumulate for finish_iteration
102                        if let StreamEvent::TextDelta(ref t) = event {
103                            state.current_text.push_str(t);
104                        }
105                        if let StreamEvent::ToolCallComplete { ref call, .. } = event {
106                            state.current_tool_calls.push(call.clone());
107                        }
108                        if let StreamEvent::Usage(ref u) = event {
109                            state.current_usage += u;
110                        }
111
112                        let is_done = matches!(&event, StreamEvent::Done { .. });
113                        let loop_event = translate_stream_event(event);
114
115                        if is_done {
116                            // Provider stream done — move to tool execution
117                            state.phase = StreamPhase::ExecutingTools;
118                        } else {
119                            state.phase = StreamPhase::Streaming(stream);
120                        }
121
122                        // Don't forward provider-level Done — it's not the loop being done
123                        if let Some(le) = loop_event {
124                            return Some((Ok(le), state));
125                        }
126                        // If we filtered out Done, continue loop
127                    }
128                    Some(Err(e)) => {
129                        state.phase = StreamPhase::Done;
130                        return Some((Err(e), state));
131                    }
132                    None => {
133                        // Stream exhausted without Done — clean end
134                        return None;
135                    }
136                },
137
138                StreamPhase::ExecutingTools => {
139                    let response = build_response(
140                        &state.current_text,
141                        &state.current_tool_calls,
142                        std::mem::take(&mut state.current_usage),
143                    );
144                    let outcome = state.core.finish_iteration(response, &state.registry).await;
145
146                    // Drain tool execution events + possible Done from core
147                    state.load_core_events();
148
149                    match outcome {
150                        IterationOutcome::ToolsExecuted { .. } => {
151                            state.phase = StreamPhase::StartIteration;
152                        }
153                        IterationOutcome::Completed(_) => {
154                            // Done event already in pending_events from finish()
155                            state.phase = StreamPhase::Done;
156                        }
157                        IterationOutcome::Error(data) => {
158                            state.phase = StreamPhase::Done;
159                            state.pending_events.push_back(Err(data.error));
160                        }
161                    }
162                    // Continue loop to drain pending_events
163                }
164            }
165        }
166    });
167
168    Box::pin(stream)
169}
170
171/// Phases of the streaming state machine.
172enum StreamPhase {
173    StartIteration,
174    Streaming(ChatStream),
175    ExecutingTools,
176    Done,
177}
178
179/// State carried through the unfold.
180struct UnfoldState<Ctx: LoopDepth + Send + Sync + 'static> {
181    core: LoopCore<Ctx>,
182    provider: Arc<dyn DynProvider>,
183    registry: Arc<ToolRegistry<Ctx>>,
184    phase: StreamPhase,
185    current_text: String,
186    current_tool_calls: Vec<ToolCall>,
187    current_usage: Usage,
188    /// Events waiting to be yielded (FIFO).
189    pending_events: VecDeque<Result<LoopEvent, LlmError>>,
190}
191
192impl<Ctx: LoopDepth + Send + Sync + 'static> UnfoldState<Ctx> {
193    /// Drain events from `LoopCore`'s buffer into our pending queue (FIFO).
194    fn load_core_events(&mut self) {
195        for event in self.core.drain_events() {
196            self.pending_events.push_back(Ok(event));
197        }
198    }
199}
200
201/// Translate a provider `StreamEvent` into a `LoopEvent`.
202///
203/// Returns `None` for `StreamEvent::Done` — the provider's "done" is not
204/// the loop's "done". The loop continues with tool execution.
205fn translate_stream_event(event: StreamEvent) -> Option<LoopEvent> {
206    match event {
207        StreamEvent::TextDelta(t) => Some(LoopEvent::TextDelta(t)),
208        StreamEvent::ReasoningDelta(t) => Some(LoopEvent::ReasoningDelta(t)),
209        StreamEvent::ToolCallStart { index, id, name } => {
210            Some(LoopEvent::ToolCallStart { index, id, name })
211        }
212        StreamEvent::ToolCallDelta { index, json_chunk } => {
213            Some(LoopEvent::ToolCallDelta { index, json_chunk })
214        }
215        StreamEvent::ToolCallComplete { index, call } => {
216            Some(LoopEvent::ToolCallComplete { index, call })
217        }
218        StreamEvent::Usage(u) => Some(LoopEvent::Usage(u)),
219        StreamEvent::Done { .. } => None, // Filtered — not the loop's done
220    }
221}
222
223/// Build a `ChatResponse` from accumulated stream data.
224fn build_response(text: &str, tool_calls: &[ToolCall], usage: Usage) -> ChatResponse {
225    let mut content = Vec::new();
226    if !text.is_empty() {
227        content.push(ContentBlock::Text(text.to_owned()));
228    }
229    for call in tool_calls {
230        content.push(ContentBlock::ToolCall(call.clone()));
231    }
232
233    let stop_reason = if tool_calls.is_empty() {
234        StopReason::EndTurn
235    } else {
236        StopReason::ToolUse
237    };
238
239    ChatResponse {
240        content,
241        usage,
242        stop_reason,
243        model: String::new(),
244        metadata: std::collections::HashMap::new(),
245    }
246}
247
248/// Convert a terminal `IterationOutcome` into an error event.
249///
250/// `Completed` outcomes are NOT converted to errors — they produce
251/// `LoopEvent::Done` via the core's event buffer. Only `Error` outcomes
252/// become `Err` items in the stream.
253fn outcome_to_error(outcome: IterationOutcome) -> Option<Result<LoopEvent, LlmError>> {
254    match outcome {
255        IterationOutcome::Error(data) => Some(Err(data.error)),
256        // Completed outcomes push Done into the core's event buffer
257        IterationOutcome::Completed(_) | IterationOutcome::ToolsExecuted { .. } => None,
258    }
259}