1use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use futures::StreamExt;
7
8use crate::chat::{ChatMessage, ChatResponse, ContentBlock, StopReason, ToolCall, ToolResult};
9use crate::error::LlmError;
10use crate::provider::{ChatParams, DynProvider};
11use crate::stream::{ChatStream, StreamEvent};
12use crate::usage::Usage;
13
14use super::LoopDepth;
15use super::ToolError;
16use super::ToolRegistry;
17use super::approval::approve_calls;
18use super::config::{StopContext, StopDecision, TerminationReason, ToolLoopConfig, ToolLoopEvent};
19use super::execution::execute_with_events;
20use super::loop_detection::{IterationSnapshot, LoopDetectionState, handle_loop_detection};
21use super::loop_sync::emit_event;
22
23#[allow(clippy::needless_pass_by_value)] pub fn tool_loop_stream<Ctx: LoopDepth + Send + Sync + 'static>(
45 provider: Arc<dyn DynProvider>,
46 registry: Arc<ToolRegistry<Ctx>>,
47 params: ChatParams,
48 config: ToolLoopConfig,
49 ctx: Arc<Ctx>,
50) -> ChatStream {
51 let current_depth = ctx.loop_depth();
53 if let Some(max_depth) = config.max_depth {
54 if current_depth >= max_depth {
55 return Box::pin(futures::stream::once(async move {
57 Err(LlmError::MaxDepthExceeded {
58 current: current_depth,
59 limit: max_depth,
60 })
61 }));
62 }
63 }
64
65 let nested_ctx = Arc::new(ctx.with_depth(current_depth + 1));
67
68 let stream = futures::stream::unfold(
69 ToolLoopStreamState::new(provider, registry, params, config, nested_ctx),
70 |mut state| async move {
71 loop {
72 match std::mem::replace(&mut state.phase, StreamPhase::Done) {
73 StreamPhase::Done => return None,
74 StreamPhase::StartIteration => match phase_start_iteration(&mut state).await {
75 PhaseResult::Yield(event, next) => {
76 state.phase = next;
77 return Some((event, state));
78 }
79 PhaseResult::Continue(next) => state.phase = next,
80 },
81 StreamPhase::Streaming(stream) => {
82 match phase_streaming(&mut state, stream).await {
83 PhaseResult::Yield(event, next) => {
84 state.phase = next;
85 return Some((event, state));
86 }
87 PhaseResult::Continue(next) => state.phase = next,
88 }
89 }
90 StreamPhase::ExecutingTools => {
91 state.phase = phase_executing_tools(&mut state).await;
92 }
93 }
94 }
95 },
96 );
97 Box::pin(stream)
98}
99
100enum PhaseResult {
102 Yield(Result<StreamEvent, LlmError>, StreamPhase),
104 Continue(StreamPhase),
106}
107
108async fn phase_start_iteration<Ctx: LoopDepth + Send + Sync + 'static>(
110 state: &mut ToolLoopStreamState<Ctx>,
111) -> PhaseResult {
112 if let Some(limit) = state.timeout_limit {
114 if state.start_time.elapsed() >= limit {
115 let err = LlmError::ToolExecution {
116 tool_name: String::new(),
117 source: Box::new(ToolError::new(format!(
118 "Tool loop exceeded timeout of {limit:?}",
119 ))),
120 };
121 return PhaseResult::Yield(Err(err), StreamPhase::Done);
122 }
123 }
124
125 state.iterations += 1;
126
127 let iterations = state.iterations;
128 let msg_count = state.params.messages.len();
129 emit_event(&state.config, || ToolLoopEvent::IterationStart {
130 iteration: iterations,
131 message_count: msg_count,
132 });
133
134 if state.iterations > state.config.max_iterations {
135 let err = LlmError::ToolExecution {
136 tool_name: String::new(),
137 source: Box::new(ToolError::new(format!(
138 "Tool loop exceeded {} iterations",
139 state.config.max_iterations,
140 ))),
141 };
142 return PhaseResult::Yield(Err(err), StreamPhase::Done);
143 }
144
145 match state.provider.stream_boxed(&state.params).await {
146 Ok(s) => {
147 state.current_tool_calls.clear();
148 state.current_text.clear();
149 PhaseResult::Continue(StreamPhase::Streaming(s))
150 }
151 Err(e) => PhaseResult::Yield(Err(e), StreamPhase::Done),
152 }
153}
154
155async fn phase_streaming<Ctx: LoopDepth + Send + Sync + 'static>(
157 state: &mut ToolLoopStreamState<Ctx>,
158 mut stream: ChatStream,
159) -> PhaseResult {
160 match stream.next().await {
161 Some(Ok(event)) => {
162 if let StreamEvent::TextDelta(ref text) = event {
163 state.current_text.push_str(text);
164 }
165 if let StreamEvent::ToolCallComplete { ref call, .. } = event {
166 state.current_tool_calls.push(call.clone());
167 }
168 if let StreamEvent::Usage(ref u) = event {
169 state.total_usage += u;
170 }
171 if let StreamEvent::Done { stop_reason } = &event {
172 let iterations = state.iterations;
173 let has_tool_calls = !state.current_tool_calls.is_empty();
174 let text_length = state.current_text.len();
175 emit_event(&state.config, || ToolLoopEvent::LlmResponseReceived {
176 iteration: iterations,
177 has_tool_calls,
178 text_length,
179 });
180
181 if let Some(ref stop_fn) = state.config.stop_when {
183 let response = build_response_from_stream_state(state, *stop_reason);
185 let ctx = StopContext {
186 iteration: state.iterations,
187 response: &response,
188 total_usage: &state.total_usage,
189 tool_calls_executed: state.tool_calls_executed,
190 last_tool_results: &state.last_tool_results,
191 };
192 match stop_fn(&ctx) {
193 StopDecision::Continue => {}
194 StopDecision::Stop | StopDecision::StopWithReason(_) => {
195 return PhaseResult::Yield(Ok(event), StreamPhase::Done);
197 }
198 }
199 }
200
201 if *stop_reason == StopReason::ToolUse && !state.current_tool_calls.is_empty() {
202 let response = build_response_from_stream_state(state, *stop_reason);
204 let call_refs: Vec<&ToolCall> = state.current_tool_calls.iter().collect();
205 let snap = IterationSnapshot {
206 response: &response,
207 call_refs: &call_refs,
208 iterations: state.iterations,
209 total_usage: &state.total_usage,
210 tool_calls_executed: state.tool_calls_executed,
211 last_tool_results: &state.last_tool_results,
212 config: &state.config,
213 };
214 if let Some(result) = handle_loop_detection(
215 &mut state.loop_state,
216 &snap,
217 &mut state.params.messages,
218 ) {
219 let err = match result.termination_reason {
221 TerminationReason::LoopDetected {
222 ref tool_name,
223 count,
224 } => LlmError::ToolExecution {
225 tool_name: tool_name.clone(),
226 source: Box::new(ToolError::new(format!(
227 "Tool loop detected: '{tool_name}' called {count} \
228 consecutive times with identical arguments"
229 ))),
230 },
231 _ => LlmError::ToolExecution {
232 tool_name: String::new(),
233 source: Box::new(ToolError::new("Unexpected termination")),
234 },
235 };
236 return PhaseResult::Yield(Err(err), StreamPhase::Done);
237 }
238 return PhaseResult::Yield(Ok(event), StreamPhase::ExecutingTools);
240 }
241 }
242 PhaseResult::Yield(Ok(event), StreamPhase::Streaming(stream))
243 }
244 Some(Err(e)) => PhaseResult::Yield(Err(e), StreamPhase::Done),
245 None => PhaseResult::Continue(StreamPhase::Done),
247 }
248}
249
250fn build_response_from_stream_state<Ctx: LoopDepth + Send + Sync + 'static>(
252 state: &ToolLoopStreamState<Ctx>,
253 stop_reason: StopReason,
254) -> ChatResponse {
255 let mut content = Vec::new();
256 if !state.current_text.is_empty() {
257 content.push(ContentBlock::Text(state.current_text.clone()));
258 }
259 for call in &state.current_tool_calls {
260 content.push(ContentBlock::ToolCall(call.clone()));
261 }
262
263 ChatResponse {
264 content,
265 usage: state.total_usage.clone(),
266 stop_reason,
267 model: String::new(), metadata: std::collections::HashMap::new(),
269 }
270}
271
272async fn phase_executing_tools<Ctx: LoopDepth + Send + Sync + 'static>(
274 state: &mut ToolLoopStreamState<Ctx>,
275) -> StreamPhase {
276 let calls = std::mem::take(&mut state.current_tool_calls);
278 let assistant_calls: Vec<ContentBlock> =
279 calls.iter().cloned().map(ContentBlock::ToolCall).collect();
280 let (approved, denied) = approve_calls(calls, &state.config);
281
282 let results = execute_with_events(
283 &state.registry,
284 approved,
285 denied,
286 state.config.parallel_tool_execution,
287 &state.config,
288 &state.ctx,
289 )
290 .await;
291
292 state.tool_calls_executed += results.len();
294 state.last_tool_results.clone_from(&results);
295
296 let mut assistant_content: Vec<ContentBlock> = Vec::new();
297 if !state.current_text.is_empty() {
298 assistant_content.push(ContentBlock::Text(std::mem::take(&mut state.current_text)));
299 }
300 assistant_content.extend(assistant_calls);
301 state.params.messages.push(ChatMessage {
302 role: crate::chat::ChatRole::Assistant,
303 content: assistant_content,
304 });
305 for result in results {
306 state
307 .params
308 .messages
309 .push(ChatMessage::tool_result_full(result));
310 }
311
312 StreamPhase::StartIteration
313}
314
315struct ToolLoopStreamState<Ctx: LoopDepth + Send + Sync + 'static> {
317 provider: Arc<dyn DynProvider>,
318 registry: Arc<ToolRegistry<Ctx>>,
319 params: ChatParams,
320 config: ToolLoopConfig,
321 ctx: Arc<Ctx>,
322 iterations: u32,
323 total_usage: Usage,
324 tool_calls_executed: usize,
325 last_tool_results: Vec<ToolResult>,
326 current_tool_calls: Vec<ToolCall>,
327 current_text: String,
328 phase: StreamPhase,
329 loop_state: LoopDetectionState,
330 start_time: Instant,
332 timeout_limit: Option<Duration>,
334}
335
336enum StreamPhase {
337 StartIteration,
338 Streaming(ChatStream),
339 ExecutingTools,
340 Done,
342}
343
344impl<Ctx: LoopDepth + Send + Sync + 'static> ToolLoopStreamState<Ctx> {
345 fn new(
346 provider: Arc<dyn DynProvider>,
347 registry: Arc<ToolRegistry<Ctx>>,
348 params: ChatParams,
349 config: ToolLoopConfig,
350 ctx: Arc<Ctx>,
351 ) -> Self {
352 let timeout_limit = config.timeout;
353 Self {
354 provider,
355 registry,
356 params,
357 config,
358 ctx,
359 iterations: 0,
360 total_usage: Usage::default(),
361 tool_calls_executed: 0,
362 last_tool_results: Vec::new(),
363 current_tool_calls: Vec::new(),
364 current_text: String::new(),
365 phase: StreamPhase::StartIteration,
366 loop_state: LoopDetectionState::default(),
367 start_time: Instant::now(),
368 timeout_limit,
369 }
370 }
371}