Skip to main content

llm_stack/tool/
loop_stream.rs

1//! Streaming tool loop implementation.
2
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use futures::StreamExt;
7
8use crate::chat::{ChatMessage, ChatResponse, ContentBlock, StopReason, ToolCall, ToolResult};
9use crate::error::LlmError;
10use crate::provider::{ChatParams, DynProvider};
11use crate::stream::{ChatStream, StreamEvent};
12use crate::usage::Usage;
13
14use super::LoopDepth;
15use super::ToolError;
16use super::ToolRegistry;
17use super::approval::approve_calls;
18use super::config::{StopContext, StopDecision, TerminationReason, ToolLoopConfig, ToolLoopEvent};
19use super::execution::execute_with_events;
20use super::loop_detection::{LoopDetectionState, handle_loop_detection};
21use super::loop_sync::emit_event;
22
23/// Streaming variant of [`tool_loop`](super::tool_loop).
24///
25/// Yields [`StreamEvent`]s from each iteration. Between iterations
26/// (when executing tools), no events are emitted. The final
27/// [`StreamEvent::Done`] carries the stop reason from the last
28/// iteration.
29///
30/// # Depth Tracking
31///
32/// If `Ctx` implements [`LoopDepth`], nested calls are tracked automatically.
33/// When `config.max_depth` is set and the context's depth exceeds the limit,
34/// yields `Err(LlmError::MaxDepthExceeded)` immediately.
35///
36/// # Events
37///
38/// If `config.on_event` is set, the callback will be invoked with
39/// [`ToolLoopEvent`]s at key points during execution, same as [`tool_loop`](super::tool_loop).
40///
41/// Uses `Arc` for provider, registry, and context since they must outlive
42/// the returned stream.
43#[allow(clippy::needless_pass_by_value)] // ctx is consumed to create nested_ctx
44pub fn tool_loop_stream<Ctx: LoopDepth + Send + Sync + 'static>(
45    provider: Arc<dyn DynProvider>,
46    registry: Arc<ToolRegistry<Ctx>>,
47    params: ChatParams,
48    config: ToolLoopConfig,
49    ctx: Arc<Ctx>,
50) -> ChatStream {
51    // Check depth limit at entry
52    let current_depth = ctx.loop_depth();
53    if let Some(max_depth) = config.max_depth {
54        if current_depth >= max_depth {
55            // Return a stream that immediately yields the depth error
56            return Box::pin(futures::stream::once(async move {
57                Err(LlmError::MaxDepthExceeded {
58                    current: current_depth,
59                    limit: max_depth,
60                })
61            }));
62        }
63    }
64
65    // Create nested context with incremented depth
66    let nested_ctx = Arc::new(ctx.with_depth(current_depth + 1));
67
68    let stream = futures::stream::unfold(
69        ToolLoopStreamState::new(provider, registry, params, config, nested_ctx),
70        |mut state| async move {
71            loop {
72                match std::mem::replace(&mut state.phase, StreamPhase::Done) {
73                    StreamPhase::Done => return None,
74                    StreamPhase::StartIteration => match phase_start_iteration(&mut state).await {
75                        PhaseResult::Yield(event, next) => {
76                            state.phase = next;
77                            return Some((event, state));
78                        }
79                        PhaseResult::Continue(next) => state.phase = next,
80                    },
81                    StreamPhase::Streaming(stream) => {
82                        match phase_streaming(&mut state, stream).await {
83                            PhaseResult::Yield(event, next) => {
84                                state.phase = next;
85                                return Some((event, state));
86                            }
87                            PhaseResult::Continue(next) => state.phase = next,
88                        }
89                    }
90                    StreamPhase::ExecutingTools => {
91                        state.phase = phase_executing_tools(&mut state).await;
92                    }
93                }
94            }
95        },
96    );
97    Box::pin(stream)
98}
99
100/// Result of processing a stream phase.
101enum PhaseResult {
102    /// Yield an event and transition to the next phase.
103    Yield(Result<StreamEvent, LlmError>, StreamPhase),
104    /// Transition to the next phase without yielding.
105    Continue(StreamPhase),
106}
107
108/// Handle the `StartIteration` phase: emit event, check limits, start LLM stream.
109async fn phase_start_iteration<Ctx: LoopDepth + Send + Sync + 'static>(
110    state: &mut ToolLoopStreamState<Ctx>,
111) -> PhaseResult {
112    // Check timeout at start of each iteration
113    if let Some(limit) = state.timeout_limit {
114        if state.start_time.elapsed() >= limit {
115            let err = LlmError::ToolExecution {
116                tool_name: String::new(),
117                source: Box::new(ToolError::new(format!(
118                    "Tool loop exceeded timeout of {limit:?}",
119                ))),
120            };
121            return PhaseResult::Yield(Err(err), StreamPhase::Done);
122        }
123    }
124
125    state.iterations += 1;
126
127    let iterations = state.iterations;
128    let msg_count = state.params.messages.len();
129    emit_event(&state.config, || ToolLoopEvent::IterationStart {
130        iteration: iterations,
131        message_count: msg_count,
132    });
133
134    if state.iterations > state.config.max_iterations {
135        let err = LlmError::ToolExecution {
136            tool_name: String::new(),
137            source: Box::new(ToolError::new(format!(
138                "Tool loop exceeded {} iterations",
139                state.config.max_iterations,
140            ))),
141        };
142        return PhaseResult::Yield(Err(err), StreamPhase::Done);
143    }
144
145    match state.provider.stream_boxed(&state.params).await {
146        Ok(s) => {
147            state.current_tool_calls.clear();
148            state.current_text.clear();
149            PhaseResult::Continue(StreamPhase::Streaming(s))
150        }
151        Err(e) => PhaseResult::Yield(Err(e), StreamPhase::Done),
152    }
153}
154
155/// Handle the `Streaming` phase: pull events from the LLM stream.
156async fn phase_streaming<Ctx: LoopDepth + Send + Sync + 'static>(
157    state: &mut ToolLoopStreamState<Ctx>,
158    mut stream: ChatStream,
159) -> PhaseResult {
160    match stream.next().await {
161        Some(Ok(event)) => {
162            if let StreamEvent::TextDelta(ref text) = event {
163                state.current_text.push_str(text);
164            }
165            if let StreamEvent::ToolCallComplete { ref call, .. } = event {
166                state.current_tool_calls.push(call.clone());
167            }
168            if let StreamEvent::Usage(ref u) = event {
169                state.total_usage += u;
170            }
171            if let StreamEvent::Done { stop_reason } = &event {
172                let iterations = state.iterations;
173                let has_tool_calls = !state.current_tool_calls.is_empty();
174                let text_length = state.current_text.len();
175                emit_event(&state.config, || ToolLoopEvent::LlmResponseReceived {
176                    iteration: iterations,
177                    has_tool_calls,
178                    text_length,
179                });
180
181                // Check stop condition before deciding next phase
182                if let Some(ref stop_fn) = state.config.stop_when {
183                    // Construct a ChatResponse from accumulated state for the stop condition
184                    let response = build_response_from_stream_state(state, *stop_reason);
185                    let ctx = StopContext {
186                        iteration: state.iterations,
187                        response: &response,
188                        total_usage: &state.total_usage,
189                        tool_calls_executed: state.tool_calls_executed,
190                        last_tool_results: &state.last_tool_results,
191                    };
192                    match stop_fn(&ctx) {
193                        StopDecision::Continue => {}
194                        StopDecision::Stop | StopDecision::StopWithReason(_) => {
195                            // Stop early - yield Done and terminate
196                            return PhaseResult::Yield(Ok(event), StreamPhase::Done);
197                        }
198                    }
199                }
200
201                if *stop_reason == StopReason::ToolUse && !state.current_tool_calls.is_empty() {
202                    // Check for loop detection before executing tools
203                    let response = build_response_from_stream_state(state, *stop_reason);
204                    if let Some(result) = handle_loop_detection(
205                        &mut state.loop_state,
206                        &state.current_tool_calls,
207                        state.config.loop_detection.as_ref(),
208                        &state.config,
209                        &mut state.params.messages,
210                        &response,
211                        state.iterations,
212                        &state.total_usage,
213                    ) {
214                        // Convert termination reason to error for streaming
215                        let err = match result.termination_reason {
216                            TerminationReason::LoopDetected {
217                                ref tool_name,
218                                count,
219                            } => LlmError::ToolExecution {
220                                tool_name: tool_name.clone(),
221                                source: Box::new(ToolError::new(format!(
222                                    "Tool loop detected: '{tool_name}' called {count} \
223                                         consecutive times with identical arguments"
224                                ))),
225                            },
226                            _ => LlmError::ToolExecution {
227                                tool_name: String::new(),
228                                source: Box::new(ToolError::new("Unexpected termination")),
229                            },
230                        };
231                        return PhaseResult::Yield(Err(err), StreamPhase::Done);
232                    }
233                    // Yield the Done event, then transition to ExecutingTools
234                    return PhaseResult::Yield(Ok(event), StreamPhase::ExecutingTools);
235                }
236            }
237            PhaseResult::Yield(Ok(event), StreamPhase::Streaming(stream))
238        }
239        Some(Err(e)) => PhaseResult::Yield(Err(e), StreamPhase::Done),
240        // Stream exhausted — this is the clean termination path after Done event
241        None => PhaseResult::Continue(StreamPhase::Done),
242    }
243}
244
245/// Build a `ChatResponse` from accumulated stream state (for stop condition checks).
246fn build_response_from_stream_state<Ctx: LoopDepth + Send + Sync + 'static>(
247    state: &ToolLoopStreamState<Ctx>,
248    stop_reason: StopReason,
249) -> ChatResponse {
250    let mut content = Vec::new();
251    if !state.current_text.is_empty() {
252        content.push(ContentBlock::Text(state.current_text.clone()));
253    }
254    for call in &state.current_tool_calls {
255        content.push(ContentBlock::ToolCall(call.clone()));
256    }
257
258    ChatResponse {
259        content,
260        usage: state.total_usage.clone(),
261        stop_reason,
262        model: String::new(), // Not available in stream state
263        metadata: std::collections::HashMap::new(),
264    }
265}
266
267/// Handle the `ExecutingTools` phase: run tools, update messages, return next phase.
268async fn phase_executing_tools<Ctx: LoopDepth + Send + Sync + 'static>(
269    state: &mut ToolLoopStreamState<Ctx>,
270) -> StreamPhase {
271    let (approved, denied) = approve_calls(&state.current_tool_calls, &state.config);
272
273    let results = execute_with_events(
274        &state.registry,
275        &approved,
276        denied,
277        state.config.parallel_tool_execution,
278        &state.config,
279        &state.ctx,
280    )
281    .await;
282
283    // Track executed tool calls for stop condition
284    state.tool_calls_executed += results.len();
285    state.last_tool_results.clone_from(&results);
286
287    let mut assistant_content: Vec<ContentBlock> = Vec::new();
288    if !state.current_text.is_empty() {
289        assistant_content.push(ContentBlock::Text(std::mem::take(&mut state.current_text)));
290    }
291    assistant_content.extend(
292        state
293            .current_tool_calls
294            .drain(..)
295            .map(ContentBlock::ToolCall),
296    );
297    state.params.messages.push(ChatMessage {
298        role: crate::chat::ChatRole::Assistant,
299        content: assistant_content,
300    });
301    for result in results {
302        state
303            .params
304            .messages
305            .push(ChatMessage::tool_result_full(result));
306    }
307
308    StreamPhase::StartIteration
309}
310
311/// Internal state for the streaming tool loop.
312struct ToolLoopStreamState<Ctx: LoopDepth + Send + Sync + 'static> {
313    provider: Arc<dyn DynProvider>,
314    registry: Arc<ToolRegistry<Ctx>>,
315    params: ChatParams,
316    config: ToolLoopConfig,
317    ctx: Arc<Ctx>,
318    iterations: u32,
319    total_usage: Usage,
320    tool_calls_executed: usize,
321    last_tool_results: Vec<ToolResult>,
322    current_tool_calls: Vec<ToolCall>,
323    current_text: String,
324    phase: StreamPhase,
325    loop_state: LoopDetectionState,
326    /// Start time for timeout tracking.
327    start_time: Instant,
328    /// Cached timeout limit from config.
329    timeout_limit: Option<Duration>,
330}
331
332enum StreamPhase {
333    StartIteration,
334    Streaming(ChatStream),
335    ExecutingTools,
336    /// Terminal state — unfold returns `None` on next poll.
337    Done,
338}
339
340impl<Ctx: LoopDepth + Send + Sync + 'static> ToolLoopStreamState<Ctx> {
341    fn new(
342        provider: Arc<dyn DynProvider>,
343        registry: Arc<ToolRegistry<Ctx>>,
344        params: ChatParams,
345        config: ToolLoopConfig,
346        ctx: Arc<Ctx>,
347    ) -> Self {
348        let timeout_limit = config.timeout;
349        Self {
350            provider,
351            registry,
352            params,
353            config,
354            ctx,
355            iterations: 0,
356            total_usage: Usage::default(),
357            tool_calls_executed: 0,
358            last_tool_results: Vec::new(),
359            current_tool_calls: Vec::new(),
360            current_text: String::new(),
361            phase: StreamPhase::StartIteration,
362            loop_state: LoopDetectionState::default(),
363            start_time: Instant::now(),
364            timeout_limit,
365        }
366    }
367}