1use std::collections::HashMap;
24use std::time::Instant;
25
26use futures::StreamExt;
27
28use crate::chat::{ChatMessage, ChatResponse, ContentBlock, StopReason, ToolCall, ToolResult};
29use crate::error::LlmError;
30use crate::provider::{ChatParams, DynProvider};
31use crate::stream::{ChatStream, StreamEvent};
32use crate::usage::Usage;
33
34use super::LoopDepth;
35use super::ToolRegistry;
36use super::approval::approve_calls;
37use super::config::{
38 LoopEvent, StopContext, StopDecision, TerminationReason, ToolLoopConfig, ToolLoopResult,
39};
40use super::execution::execute_with_events;
41use super::loop_detection::{IterationSnapshot, LoopDetectionState, handle_loop_detection};
42use super::loop_resumable::LoopCommand;
43
44pub(crate) enum IterationOutcome {
50 ToolsExecuted {
51 tool_calls: Vec<ToolCall>,
52 results: Vec<ToolResult>,
53 assistant_content: Vec<ContentBlock>,
54 iteration: u32,
55 total_usage: Usage,
56 },
57 Completed(CompletedData),
58 Error(ErrorData),
59}
60
61pub(crate) struct CompletedData {
63 pub response: ChatResponse,
64 pub termination_reason: TerminationReason,
65 pub iterations: u32,
66 pub total_usage: Usage,
67}
68
69pub(crate) struct ErrorData {
71 pub error: LlmError,
72 pub iterations: u32,
73 pub total_usage: Usage,
74}
75
76pub(crate) enum StartOutcome {
83 Stream(ChatStream),
86 Terminal(Box<IterationOutcome>),
89}
90
91pub(crate) struct LoopCore<Ctx: LoopDepth + Send + Sync + 'static> {
110 pub(crate) params: ChatParams,
111 config: ToolLoopConfig,
112 nested_ctx: Ctx,
113 total_usage: Usage,
114 iterations: u32,
115 tool_calls_executed: usize,
116 last_tool_results: Vec<ToolResult>,
117 loop_state: LoopDetectionState,
118 start_time: Instant,
119 finished: bool,
120 pending_command: Option<LoopCommand>,
121 final_result: Option<ToolLoopResult>,
122 depth_error: Option<LlmError>,
123 events: Vec<LoopEvent>,
124}
125
126impl<Ctx: LoopDepth + Send + Sync + 'static> LoopCore<Ctx> {
127 pub(crate) fn new(params: ChatParams, config: ToolLoopConfig, ctx: &Ctx) -> Self {
132 let current_depth = ctx.loop_depth();
133 let depth_error = config.max_depth.and_then(|max_depth| {
134 if current_depth >= max_depth {
135 Some(LlmError::MaxDepthExceeded {
136 current: current_depth,
137 limit: max_depth,
138 })
139 } else {
140 None
141 }
142 });
143
144 let nested_ctx = ctx.with_depth(current_depth + 1);
145
146 Self {
147 params,
148 config,
149 nested_ctx,
150 total_usage: Usage::default(),
151 iterations: 0,
152 tool_calls_executed: 0,
153 last_tool_results: Vec::new(),
154 loop_state: LoopDetectionState::default(),
155 start_time: Instant::now(),
156 finished: false,
157 pending_command: None,
158 final_result: None,
159 depth_error,
160 events: Vec::new(),
161 }
162 }
163
164 pub(crate) async fn start_iteration(&mut self, provider: &dyn DynProvider) -> StartOutcome {
183 if let Some(outcome) = self.check_preconditions() {
185 return StartOutcome::Terminal(Box::new(outcome));
186 }
187
188 self.iterations += 1;
189
190 self.events.push(LoopEvent::IterationStart {
192 iteration: self.iterations,
193 message_count: self.params.messages.len(),
194 });
195
196 if self.iterations > self.config.max_iterations {
198 return StartOutcome::Terminal(Box::new(self.finish(
199 ChatResponse::empty(),
200 TerminationReason::MaxIterations {
201 limit: self.config.max_iterations,
202 },
203 )));
204 }
205
206 match provider.stream_boxed(&self.params).await {
208 Ok(stream) => StartOutcome::Stream(stream),
209 Err(e) => StartOutcome::Terminal(Box::new(self.finish_error(e))),
210 }
211 }
212
213 pub(crate) async fn finish_iteration(
225 &mut self,
226 response: ChatResponse,
227 registry: &ToolRegistry<Ctx>,
228 ) -> IterationOutcome {
229 self.total_usage += &response.usage;
230
231 let call_refs: Vec<&ToolCall> = response.tool_calls();
233 if let Some(outcome) = self.check_termination(&response, &call_refs) {
234 return outcome;
235 }
236
237 self.execute_tools(registry, response).await
239 }
240
241 pub(crate) async fn do_iteration(
248 &mut self,
249 provider: &dyn DynProvider,
250 registry: &ToolRegistry<Ctx>,
251 ) -> IterationOutcome {
252 let stream = match self.start_iteration(provider).await {
253 StartOutcome::Stream(s) => s,
254 StartOutcome::Terminal(outcome) => return *outcome,
255 };
256
257 let response = collect_stream(stream).await;
258 match response {
259 Ok(resp) => self.finish_iteration(resp, registry).await,
260 Err(e) => self.finish_error(e),
261 }
262 }
263
264 pub(crate) fn drain_events(&mut self) -> Vec<LoopEvent> {
268 std::mem::take(&mut self.events)
269 }
270
271 fn check_preconditions(&mut self) -> Option<IterationOutcome> {
274 if let Some(error) = self.depth_error.take() {
276 return Some(self.finish_error(error));
277 }
278
279 if self.finished {
281 return Some(self.make_terminal_outcome());
282 }
283
284 if let Some(command) = self.pending_command.take() {
286 match command {
287 LoopCommand::Continue => {}
288 LoopCommand::InjectMessages(messages) => {
289 self.params.messages.extend(messages);
290 }
291 LoopCommand::Stop(reason) => {
292 return Some(self.finish(
293 ChatResponse::empty(),
294 TerminationReason::StopCondition { reason },
295 ));
296 }
297 }
298 }
299
300 if let Some(limit) = self.config.timeout {
302 if self.start_time.elapsed() >= limit {
303 return Some(
304 self.finish(ChatResponse::empty(), TerminationReason::Timeout { limit }),
305 );
306 }
307 }
308
309 None
310 }
311
312 fn check_termination(
315 &mut self,
316 response: &ChatResponse,
317 call_refs: &[&ToolCall],
318 ) -> Option<IterationOutcome> {
319 if let Some(ref stop_fn) = self.config.stop_when {
321 let ctx = StopContext {
322 iteration: self.iterations,
323 response,
324 total_usage: &self.total_usage,
325 tool_calls_executed: self.tool_calls_executed,
326 last_tool_results: &self.last_tool_results,
327 };
328 match stop_fn(&ctx) {
329 StopDecision::Continue => {}
330 StopDecision::Stop => {
331 return Some(self.finish(
332 response.clone(),
333 TerminationReason::StopCondition { reason: None },
334 ));
335 }
336 StopDecision::StopWithReason(reason) => {
337 return Some(self.finish(
338 response.clone(),
339 TerminationReason::StopCondition {
340 reason: Some(reason),
341 },
342 ));
343 }
344 }
345 }
346
347 if call_refs.is_empty() || response.stop_reason != StopReason::ToolUse {
349 return Some(self.finish(response.clone(), TerminationReason::Complete));
350 }
351
352 if self.iterations > self.config.max_iterations {
355 return Some(self.finish(
356 response.clone(),
357 TerminationReason::MaxIterations {
358 limit: self.config.max_iterations,
359 },
360 ));
361 }
362
363 let snap = IterationSnapshot {
365 response,
366 call_refs,
367 iterations: self.iterations,
368 total_usage: &self.total_usage,
369 config: &self.config,
370 };
371 if let Some(result) = handle_loop_detection(
372 &mut self.loop_state,
373 &snap,
374 &mut self.params.messages,
375 &mut self.events,
376 ) {
377 return Some(self.finish(result.response, result.termination_reason));
378 }
379
380 None
381 }
382
383 async fn execute_tools(
386 &mut self,
387 registry: &ToolRegistry<Ctx>,
388 response: ChatResponse,
389 ) -> IterationOutcome {
390 let (calls, other_content) = response.partition_content();
391
392 let outcome_calls = calls.clone();
395
396 let mut msg_content = other_content.clone();
400 msg_content.extend(calls.iter().map(|c| ContentBlock::ToolCall(c.clone())));
401 self.params.messages.push(ChatMessage {
402 role: crate::chat::ChatRole::Assistant,
403 content: msg_content,
404 });
405
406 let (approved_calls, denied_results) = approve_calls(calls, &self.config);
408 let exec_result = execute_with_events(
409 registry,
410 approved_calls,
411 denied_results,
412 self.config.parallel_tool_execution,
413 &self.nested_ctx,
414 )
415 .await;
416
417 self.events.extend(exec_result.events);
418
419 let results = exec_result.results;
420 self.tool_calls_executed += results.len();
421 self.last_tool_results.clone_from(&results);
422
423 for result in &results {
425 self.params
426 .messages
427 .push(ChatMessage::tool_result_full(result.clone()));
428 }
429
430 IterationOutcome::ToolsExecuted {
431 tool_calls: outcome_calls,
432 results,
433 assistant_content: other_content,
434 iteration: self.iterations,
435 total_usage: self.total_usage.clone(),
436 }
437 }
438
439 fn finish(
445 &mut self,
446 response: ChatResponse,
447 termination_reason: TerminationReason,
448 ) -> IterationOutcome {
449 self.finished = true;
450 let usage = self.total_usage.clone();
451 let result = ToolLoopResult {
452 response: response.clone(),
453 iterations: self.iterations,
454 total_usage: usage.clone(),
455 termination_reason: termination_reason.clone(),
456 };
457 self.final_result = Some(result.clone());
458 self.events.push(LoopEvent::Done(result));
459
460 IterationOutcome::Completed(CompletedData {
461 response,
462 termination_reason,
463 iterations: self.iterations,
464 total_usage: usage,
465 })
466 }
467
468 pub(crate) fn finish_error(&mut self, error: LlmError) -> IterationOutcome {
470 self.finished = true;
471 let usage = self.total_usage.clone();
472 self.final_result = Some(ToolLoopResult {
473 response: ChatResponse::empty(),
474 iterations: self.iterations,
475 total_usage: usage.clone(),
476 termination_reason: TerminationReason::Complete,
477 });
478 IterationOutcome::Error(ErrorData {
479 error,
480 iterations: self.iterations,
481 total_usage: usage,
482 })
483 }
484
485 fn make_terminal_outcome(&self) -> IterationOutcome {
487 if let Some(ref result) = self.final_result {
488 IterationOutcome::Completed(CompletedData {
489 response: result.response.clone(),
490 termination_reason: result.termination_reason.clone(),
491 iterations: result.iterations,
492 total_usage: result.total_usage.clone(),
493 })
494 } else {
495 IterationOutcome::Completed(CompletedData {
496 response: ChatResponse::empty(),
497 termination_reason: TerminationReason::Complete,
498 iterations: self.iterations,
499 total_usage: self.total_usage.clone(),
500 })
501 }
502 }
503
504 pub(crate) fn resume(&mut self, command: LoopCommand) {
508 if !self.finished {
509 self.pending_command = Some(command);
510 }
511 }
512
513 pub(crate) fn messages(&self) -> &[ChatMessage] {
515 &self.params.messages
516 }
517
518 pub(crate) fn messages_mut(&mut self) -> &mut Vec<ChatMessage> {
520 &mut self.params.messages
521 }
522
523 pub(crate) fn total_usage(&self) -> &Usage {
525 &self.total_usage
526 }
527
528 pub(crate) fn iterations(&self) -> u32 {
530 self.iterations
531 }
532
533 pub(crate) fn is_finished(&self) -> bool {
535 self.finished
536 }
537
538 pub(crate) fn into_result(self) -> ToolLoopResult {
540 self.final_result.unwrap_or_else(|| ToolLoopResult {
541 response: ChatResponse::empty(),
542 iterations: self.iterations,
543 total_usage: self.total_usage,
544 termination_reason: TerminationReason::Complete,
545 })
546 }
547}
548
549pub(crate) async fn collect_stream(mut stream: ChatStream) -> Result<ChatResponse, LlmError> {
557 let mut text = String::new();
558 let mut tool_calls: Vec<ToolCall> = Vec::new();
559 let mut usage = Usage::default();
560 let mut stop_reason = StopReason::EndTurn;
561
562 while let Some(event) = stream.next().await {
563 match event? {
564 StreamEvent::TextDelta(t) => text.push_str(&t),
565 StreamEvent::ToolCallComplete { call, .. } => tool_calls.push(call),
566 StreamEvent::Usage(u) => usage += &u,
567 StreamEvent::Done { stop_reason: sr } => stop_reason = sr,
568 _ => {}
570 }
571 }
572
573 let mut content = Vec::new();
574 if !text.is_empty() {
575 content.push(ContentBlock::Text(text));
576 }
577 for call in tool_calls {
578 content.push(ContentBlock::ToolCall(call));
579 }
580
581 Ok(ChatResponse {
582 content,
583 usage,
584 stop_reason,
585 model: String::new(),
586 metadata: HashMap::new(),
587 })
588}
589
590impl ChatMessage {
593 pub fn tool_result_full(result: ToolResult) -> Self {
595 Self {
596 role: crate::chat::ChatRole::Tool,
597 content: vec![ContentBlock::ToolResult(result)],
598 }
599 }
600}
601
602impl<Ctx: LoopDepth + Send + Sync + 'static> std::fmt::Debug for LoopCore<Ctx> {
603 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
604 f.debug_struct("LoopCore")
605 .field("iterations", &self.iterations)
606 .field("tool_calls_executed", &self.tool_calls_executed)
607 .field("finished", &self.finished)
608 .field("has_pending_command", &self.pending_command.is_some())
609 .field("buffered_events", &self.events.len())
610 .finish_non_exhaustive()
611 }
612}