1use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use futures::StreamExt;
7
8use crate::chat::{ChatMessage, ChatResponse, ContentBlock, StopReason, ToolCall, ToolResult};
9use crate::error::LlmError;
10use crate::provider::{ChatParams, DynProvider};
11use crate::stream::{ChatStream, StreamEvent};
12use crate::usage::Usage;
13
14use super::LoopDepth;
15use super::ToolError;
16use super::ToolRegistry;
17use super::approval::approve_calls;
18use super::config::{StopContext, StopDecision, TerminationReason, ToolLoopConfig, ToolLoopEvent};
19use super::execution::execute_with_events;
20use super::loop_detection::{LoopDetectionState, handle_loop_detection};
21use super::loop_sync::emit_event;
22
23#[allow(clippy::needless_pass_by_value)] pub fn tool_loop_stream<Ctx: LoopDepth + Send + Sync + 'static>(
45 provider: Arc<dyn DynProvider>,
46 registry: Arc<ToolRegistry<Ctx>>,
47 params: ChatParams,
48 config: ToolLoopConfig,
49 ctx: Arc<Ctx>,
50) -> ChatStream {
51 let current_depth = ctx.loop_depth();
53 if let Some(max_depth) = config.max_depth {
54 if current_depth >= max_depth {
55 return Box::pin(futures::stream::once(async move {
57 Err(LlmError::MaxDepthExceeded {
58 current: current_depth,
59 limit: max_depth,
60 })
61 }));
62 }
63 }
64
65 let nested_ctx = Arc::new(ctx.with_depth(current_depth + 1));
67
68 let stream = futures::stream::unfold(
69 ToolLoopStreamState::new(provider, registry, params, config, nested_ctx),
70 |mut state| async move {
71 loop {
72 match std::mem::replace(&mut state.phase, StreamPhase::Done) {
73 StreamPhase::Done => return None,
74 StreamPhase::StartIteration => match phase_start_iteration(&mut state).await {
75 PhaseResult::Yield(event, next) => {
76 state.phase = next;
77 return Some((event, state));
78 }
79 PhaseResult::Continue(next) => state.phase = next,
80 },
81 StreamPhase::Streaming(stream) => {
82 match phase_streaming(&mut state, stream).await {
83 PhaseResult::Yield(event, next) => {
84 state.phase = next;
85 return Some((event, state));
86 }
87 PhaseResult::Continue(next) => state.phase = next,
88 }
89 }
90 StreamPhase::ExecutingTools => {
91 state.phase = phase_executing_tools(&mut state).await;
92 }
93 }
94 }
95 },
96 );
97 Box::pin(stream)
98}
99
100enum PhaseResult {
102 Yield(Result<StreamEvent, LlmError>, StreamPhase),
104 Continue(StreamPhase),
106}
107
108async fn phase_start_iteration<Ctx: LoopDepth + Send + Sync + 'static>(
110 state: &mut ToolLoopStreamState<Ctx>,
111) -> PhaseResult {
112 if let Some(limit) = state.timeout_limit {
114 if state.start_time.elapsed() >= limit {
115 let err = LlmError::ToolExecution {
116 tool_name: String::new(),
117 source: Box::new(ToolError::new(format!(
118 "Tool loop exceeded timeout of {limit:?}",
119 ))),
120 };
121 return PhaseResult::Yield(Err(err), StreamPhase::Done);
122 }
123 }
124
125 state.iterations += 1;
126
127 let iterations = state.iterations;
128 let msg_count = state.params.messages.len();
129 emit_event(&state.config, || ToolLoopEvent::IterationStart {
130 iteration: iterations,
131 message_count: msg_count,
132 });
133
134 if state.iterations > state.config.max_iterations {
135 let err = LlmError::ToolExecution {
136 tool_name: String::new(),
137 source: Box::new(ToolError::new(format!(
138 "Tool loop exceeded {} iterations",
139 state.config.max_iterations,
140 ))),
141 };
142 return PhaseResult::Yield(Err(err), StreamPhase::Done);
143 }
144
145 match state.provider.stream_boxed(&state.params).await {
146 Ok(s) => {
147 state.current_tool_calls.clear();
148 state.current_text.clear();
149 PhaseResult::Continue(StreamPhase::Streaming(s))
150 }
151 Err(e) => PhaseResult::Yield(Err(e), StreamPhase::Done),
152 }
153}
154
155async fn phase_streaming<Ctx: LoopDepth + Send + Sync + 'static>(
157 state: &mut ToolLoopStreamState<Ctx>,
158 mut stream: ChatStream,
159) -> PhaseResult {
160 match stream.next().await {
161 Some(Ok(event)) => {
162 if let StreamEvent::TextDelta(ref text) = event {
163 state.current_text.push_str(text);
164 }
165 if let StreamEvent::ToolCallComplete { ref call, .. } = event {
166 state.current_tool_calls.push(call.clone());
167 }
168 if let StreamEvent::Usage(ref u) = event {
169 state.total_usage += u;
170 }
171 if let StreamEvent::Done { stop_reason } = &event {
172 let iterations = state.iterations;
173 let has_tool_calls = !state.current_tool_calls.is_empty();
174 let text_length = state.current_text.len();
175 emit_event(&state.config, || ToolLoopEvent::LlmResponseReceived {
176 iteration: iterations,
177 has_tool_calls,
178 text_length,
179 });
180
181 if let Some(ref stop_fn) = state.config.stop_when {
183 let response = build_response_from_stream_state(state, *stop_reason);
185 let ctx = StopContext {
186 iteration: state.iterations,
187 response: &response,
188 total_usage: &state.total_usage,
189 tool_calls_executed: state.tool_calls_executed,
190 last_tool_results: &state.last_tool_results,
191 };
192 match stop_fn(&ctx) {
193 StopDecision::Continue => {}
194 StopDecision::Stop | StopDecision::StopWithReason(_) => {
195 return PhaseResult::Yield(Ok(event), StreamPhase::Done);
197 }
198 }
199 }
200
201 if *stop_reason == StopReason::ToolUse && !state.current_tool_calls.is_empty() {
202 let response = build_response_from_stream_state(state, *stop_reason);
204 if let Some(result) = handle_loop_detection(
205 &mut state.loop_state,
206 &state.current_tool_calls,
207 state.config.loop_detection.as_ref(),
208 &state.config,
209 &mut state.params.messages,
210 &response,
211 state.iterations,
212 &state.total_usage,
213 ) {
214 let err = match result.termination_reason {
216 TerminationReason::LoopDetected {
217 ref tool_name,
218 count,
219 } => LlmError::ToolExecution {
220 tool_name: tool_name.clone(),
221 source: Box::new(ToolError::new(format!(
222 "Tool loop detected: '{tool_name}' called {count} \
223 consecutive times with identical arguments"
224 ))),
225 },
226 _ => LlmError::ToolExecution {
227 tool_name: String::new(),
228 source: Box::new(ToolError::new("Unexpected termination")),
229 },
230 };
231 return PhaseResult::Yield(Err(err), StreamPhase::Done);
232 }
233 return PhaseResult::Yield(Ok(event), StreamPhase::ExecutingTools);
235 }
236 }
237 PhaseResult::Yield(Ok(event), StreamPhase::Streaming(stream))
238 }
239 Some(Err(e)) => PhaseResult::Yield(Err(e), StreamPhase::Done),
240 None => PhaseResult::Continue(StreamPhase::Done),
242 }
243}
244
245fn build_response_from_stream_state<Ctx: LoopDepth + Send + Sync + 'static>(
247 state: &ToolLoopStreamState<Ctx>,
248 stop_reason: StopReason,
249) -> ChatResponse {
250 let mut content = Vec::new();
251 if !state.current_text.is_empty() {
252 content.push(ContentBlock::Text(state.current_text.clone()));
253 }
254 for call in &state.current_tool_calls {
255 content.push(ContentBlock::ToolCall(call.clone()));
256 }
257
258 ChatResponse {
259 content,
260 usage: state.total_usage.clone(),
261 stop_reason,
262 model: String::new(), metadata: std::collections::HashMap::new(),
264 }
265}
266
267async fn phase_executing_tools<Ctx: LoopDepth + Send + Sync + 'static>(
269 state: &mut ToolLoopStreamState<Ctx>,
270) -> StreamPhase {
271 let (approved, denied) = approve_calls(&state.current_tool_calls, &state.config);
272
273 let results = execute_with_events(
274 &state.registry,
275 &approved,
276 denied,
277 state.config.parallel_tool_execution,
278 &state.config,
279 &state.ctx,
280 )
281 .await;
282
283 state.tool_calls_executed += results.len();
285 state.last_tool_results.clone_from(&results);
286
287 let mut assistant_content: Vec<ContentBlock> = Vec::new();
288 if !state.current_text.is_empty() {
289 assistant_content.push(ContentBlock::Text(std::mem::take(&mut state.current_text)));
290 }
291 assistant_content.extend(
292 state
293 .current_tool_calls
294 .drain(..)
295 .map(ContentBlock::ToolCall),
296 );
297 state.params.messages.push(ChatMessage {
298 role: crate::chat::ChatRole::Assistant,
299 content: assistant_content,
300 });
301 for result in results {
302 state
303 .params
304 .messages
305 .push(ChatMessage::tool_result_full(result));
306 }
307
308 StreamPhase::StartIteration
309}
310
311struct ToolLoopStreamState<Ctx: LoopDepth + Send + Sync + 'static> {
313 provider: Arc<dyn DynProvider>,
314 registry: Arc<ToolRegistry<Ctx>>,
315 params: ChatParams,
316 config: ToolLoopConfig,
317 ctx: Arc<Ctx>,
318 iterations: u32,
319 total_usage: Usage,
320 tool_calls_executed: usize,
321 last_tool_results: Vec<ToolResult>,
322 current_tool_calls: Vec<ToolCall>,
323 current_text: String,
324 phase: StreamPhase,
325 loop_state: LoopDetectionState,
326 start_time: Instant,
328 timeout_limit: Option<Duration>,
330}
331
332enum StreamPhase {
333 StartIteration,
334 Streaming(ChatStream),
335 ExecutingTools,
336 Done,
338}
339
340impl<Ctx: LoopDepth + Send + Sync + 'static> ToolLoopStreamState<Ctx> {
341 fn new(
342 provider: Arc<dyn DynProvider>,
343 registry: Arc<ToolRegistry<Ctx>>,
344 params: ChatParams,
345 config: ToolLoopConfig,
346 ctx: Arc<Ctx>,
347 ) -> Self {
348 let timeout_limit = config.timeout;
349 Self {
350 provider,
351 registry,
352 params,
353 config,
354 ctx,
355 iterations: 0,
356 total_usage: Usage::default(),
357 tool_calls_executed: 0,
358 last_tool_results: Vec::new(),
359 current_tool_calls: Vec::new(),
360 current_text: String::new(),
361 phase: StreamPhase::StartIteration,
362 loop_state: LoopDetectionState::default(),
363 start_time: Instant::now(),
364 timeout_limit,
365 }
366 }
367}