use std::collections::HashMap;
use std::time::Instant;
use futures::StreamExt;
use crate::chat::{ChatMessage, ChatResponse, ContentBlock, StopReason, ToolCall, ToolResult};
use crate::error::LlmError;
use crate::provider::{ChatParams, DynProvider};
use crate::stream::{ChatStream, StreamEvent};
use crate::usage::Usage;
use super::LoopDepth;
use super::ToolRegistry;
use super::approval::approve_calls;
use super::config::{
LoopEvent, StopContext, StopDecision, TerminationReason, ToolLoopConfig, ToolLoopResult,
};
use super::execution::execute_with_events;
use super::loop_detection::{IterationSnapshot, LoopDetectionState, handle_loop_detection};
use super::loop_resumable::LoopCommand;
pub(crate) enum IterationOutcome {
ToolsExecuted {
tool_calls: Vec<ToolCall>,
results: Vec<ToolResult>,
assistant_content: Vec<ContentBlock>,
iteration: u32,
total_usage: Usage,
},
Completed(CompletedData),
Error(ErrorData),
}
pub(crate) struct CompletedData {
pub response: ChatResponse,
pub termination_reason: TerminationReason,
pub iterations: u32,
pub total_usage: Usage,
}
pub(crate) struct ErrorData {
pub error: LlmError,
pub iterations: u32,
pub total_usage: Usage,
}
pub(crate) enum StartOutcome {
Stream(ChatStream),
Terminal(Box<IterationOutcome>),
}
pub(crate) struct LoopCore<Ctx: LoopDepth + Send + Sync + 'static> {
pub(crate) params: ChatParams,
config: ToolLoopConfig,
nested_ctx: Ctx,
total_usage: Usage,
iterations: u32,
tool_calls_executed: usize,
last_tool_results: Vec<ToolResult>,
loop_state: LoopDetectionState,
start_time: Instant,
finished: bool,
pending_command: Option<LoopCommand>,
final_result: Option<ToolLoopResult>,
depth_error: Option<LlmError>,
events: Vec<LoopEvent>,
tool_result_meta: Vec<ToolResultMeta>,
}
struct ToolResultMeta {
message_index: usize,
iteration: u32,
masked: bool,
}
impl<Ctx: LoopDepth + Send + Sync + 'static> LoopCore<Ctx> {
pub(crate) fn new(params: ChatParams, config: ToolLoopConfig, ctx: &Ctx) -> Self {
let current_depth = ctx.loop_depth();
let depth_error = config.max_depth.and_then(|max_depth| {
if current_depth >= max_depth {
Some(LlmError::MaxDepthExceeded {
current: current_depth,
limit: max_depth,
})
} else {
None
}
});
let nested_ctx = ctx.with_depth(current_depth + 1);
Self {
params,
config,
nested_ctx,
total_usage: Usage::default(),
iterations: 0,
tool_calls_executed: 0,
last_tool_results: Vec::new(),
loop_state: LoopDetectionState::default(),
start_time: Instant::now(),
finished: false,
pending_command: None,
final_result: None,
depth_error,
events: Vec::new(),
tool_result_meta: Vec::new(),
}
}
pub(crate) async fn start_iteration(&mut self, provider: &dyn DynProvider) -> StartOutcome {
if let Some(outcome) = self.check_preconditions() {
return StartOutcome::Terminal(Box::new(outcome));
}
self.iterations += 1;
self.events.push(LoopEvent::IterationStart {
iteration: self.iterations,
message_count: self.params.messages.len(),
});
if self.iterations > self.config.max_iterations {
return StartOutcome::Terminal(Box::new(self.finish(
ChatResponse::empty(),
TerminationReason::MaxIterations {
limit: self.config.max_iterations,
},
)));
}
self.mask_old_observations();
match provider.stream_boxed(&self.params).await {
Ok(stream) => StartOutcome::Stream(stream),
Err(e) => StartOutcome::Terminal(Box::new(self.finish_error(e))),
}
}
pub(crate) async fn finish_iteration(
&mut self,
response: ChatResponse,
registry: &ToolRegistry<Ctx>,
) -> IterationOutcome {
self.total_usage += &response.usage;
let call_refs: Vec<&ToolCall> = response.tool_calls();
if let Some(outcome) = self.check_termination(&response, &call_refs) {
return outcome;
}
self.execute_tools(registry, response).await
}
pub(crate) async fn do_iteration(
&mut self,
provider: &dyn DynProvider,
registry: &ToolRegistry<Ctx>,
) -> IterationOutcome {
let stream = match self.start_iteration(provider).await {
StartOutcome::Stream(s) => s,
StartOutcome::Terminal(outcome) => return *outcome,
};
let response = collect_stream(stream).await;
match response {
Ok(resp) => self.finish_iteration(resp, registry).await,
Err(e) => self.finish_error(e),
}
}
pub(crate) fn drain_events(&mut self) -> Vec<LoopEvent> {
std::mem::take(&mut self.events)
}
fn check_preconditions(&mut self) -> Option<IterationOutcome> {
if let Some(error) = self.depth_error.take() {
return Some(self.finish_error(error));
}
if self.finished {
return Some(self.make_terminal_outcome());
}
if let Some(command) = self.pending_command.take() {
match command {
LoopCommand::Continue => {}
LoopCommand::InjectMessages(messages) => {
self.params.messages.extend(messages);
}
LoopCommand::Stop(reason) => {
return Some(self.finish(
ChatResponse::empty(),
TerminationReason::StopCondition { reason },
));
}
}
}
if let Some(limit) = self.config.timeout {
if self.start_time.elapsed() >= limit {
return Some(
self.finish(ChatResponse::empty(), TerminationReason::Timeout { limit }),
);
}
}
None
}
fn check_termination(
&mut self,
response: &ChatResponse,
call_refs: &[&ToolCall],
) -> Option<IterationOutcome> {
if let Some(ref stop_fn) = self.config.stop_when {
let ctx = StopContext {
iteration: self.iterations,
response,
total_usage: &self.total_usage,
tool_calls_executed: self.tool_calls_executed,
last_tool_results: &self.last_tool_results,
};
match stop_fn(&ctx) {
StopDecision::Continue => {}
StopDecision::Stop => {
return Some(self.finish(
response.clone(),
TerminationReason::StopCondition { reason: None },
));
}
StopDecision::StopWithReason(reason) => {
return Some(self.finish(
response.clone(),
TerminationReason::StopCondition {
reason: Some(reason),
},
));
}
}
}
if call_refs.is_empty() || response.stop_reason != StopReason::ToolUse {
return Some(self.finish(response.clone(), TerminationReason::Complete));
}
if self.iterations > self.config.max_iterations {
return Some(self.finish(
response.clone(),
TerminationReason::MaxIterations {
limit: self.config.max_iterations,
},
));
}
let snap = IterationSnapshot {
response,
call_refs,
iterations: self.iterations,
total_usage: &self.total_usage,
config: &self.config,
};
if let Some(result) = handle_loop_detection(
&mut self.loop_state,
&snap,
&mut self.params.messages,
&mut self.events,
) {
return Some(self.finish(result.response, result.termination_reason));
}
None
}
async fn execute_tools(
&mut self,
registry: &ToolRegistry<Ctx>,
response: ChatResponse,
) -> IterationOutcome {
let (calls, other_content) = response.partition_content();
let outcome_calls = calls.clone();
let mut msg_content = other_content.clone();
msg_content.extend(calls.iter().map(|c| ContentBlock::ToolCall(c.clone())));
self.params.messages.push(ChatMessage {
role: crate::chat::ChatRole::Assistant,
content: msg_content,
});
let (approved_calls, denied_results) = approve_calls(calls, &self.config);
let exec_result = execute_with_events(
registry,
approved_calls,
denied_results,
self.config.parallel_tool_execution,
&self.nested_ctx,
)
.await;
self.events.extend(exec_result.events);
let mut results = exec_result.results;
self.tool_calls_executed += results.len();
self.postprocess_results(&mut results, &outcome_calls).await;
self.last_tool_results.clone_from(&results);
for result in &results {
let idx = self.params.messages.len();
self.params
.messages
.push(ChatMessage::tool_result_full(result.clone()));
self.tool_result_meta.push(ToolResultMeta {
message_index: idx,
iteration: self.iterations,
masked: false,
});
}
IterationOutcome::ToolsExecuted {
tool_calls: outcome_calls,
results,
assistant_content: other_content,
iteration: self.iterations,
total_usage: self.total_usage.clone(),
}
}
async fn postprocess_results(&mut self, results: &mut [ToolResult], calls: &[ToolCall]) {
let has_processor = self.config.result_processor.is_some();
let has_extractor = self.config.result_extractor.is_some();
let has_cacher = self.config.result_cacher.is_some();
if !has_processor && !has_extractor && !has_cacher {
return;
}
let call_id_to_name: HashMap<&str, &str> = calls
.iter()
.map(|c| (c.id.as_str(), c.name.as_str()))
.collect();
let user_query: String = self
.params
.messages
.iter()
.rev()
.find_map(|m| {
if m.role == crate::chat::ChatRole::User {
m.content.iter().find_map(|b| match b {
ContentBlock::Text(t) => Some(t.clone()),
_ => None,
})
} else {
None
}
})
.unwrap_or_default();
for result in results.iter_mut() {
let tool_name = call_id_to_name
.get(result.tool_call_id.as_str())
.copied()
.unwrap_or("unknown");
if result.is_error {
continue;
}
if let Some(ref processor) = self.config.result_processor {
let processed = processor.process(tool_name, &result.content);
if processed.was_processed {
self.events.push(LoopEvent::ToolResultProcessed {
tool_name: tool_name.to_string(),
original_tokens: processed.original_tokens_est,
processed_tokens: processed.processed_tokens_est,
});
result.content = processed.content;
}
}
if let Some(ref extractor) = self.config.result_extractor {
let tokens = crate::context::estimate_tokens(&result.content);
if tokens > extractor.extraction_threshold() {
if let Some(extracted) = extractor
.extract(tool_name, &result.content, &user_query)
.await
{
self.events.push(LoopEvent::ToolResultExtracted {
tool_name: tool_name.to_string(),
original_tokens: extracted.original_tokens_est,
extracted_tokens: extracted.extracted_tokens_est,
});
result.content = extracted.content;
}
}
}
if let Some(ref cacher) = self.config.result_cacher {
let tokens = crate::context::estimate_tokens(&result.content);
if tokens > cacher.inline_threshold() {
if let Some(cached) = cacher.cache(tool_name, &result.content) {
self.events.push(LoopEvent::ToolResultCached {
tool_name: tool_name.to_string(),
original_tokens: cached.original_tokens_est,
summary_tokens: cached.summary_tokens_est,
});
result.content = cached.summary;
}
}
}
}
}
fn finish(
&mut self,
response: ChatResponse,
termination_reason: TerminationReason,
) -> IterationOutcome {
self.finished = true;
let usage = self.total_usage.clone();
let result = ToolLoopResult {
response: response.clone(),
iterations: self.iterations,
total_usage: usage.clone(),
termination_reason: termination_reason.clone(),
};
self.final_result = Some(result.clone());
self.events.push(LoopEvent::Done(result));
IterationOutcome::Completed(CompletedData {
response,
termination_reason,
iterations: self.iterations,
total_usage: usage,
})
}
pub(crate) fn finish_error(&mut self, error: LlmError) -> IterationOutcome {
self.finished = true;
let usage = self.total_usage.clone();
self.final_result = Some(ToolLoopResult {
response: ChatResponse::empty(),
iterations: self.iterations,
total_usage: usage.clone(),
termination_reason: TerminationReason::Complete,
});
IterationOutcome::Error(ErrorData {
error,
iterations: self.iterations,
total_usage: usage,
})
}
fn make_terminal_outcome(&self) -> IterationOutcome {
if let Some(ref result) = self.final_result {
IterationOutcome::Completed(CompletedData {
response: result.response.clone(),
termination_reason: result.termination_reason.clone(),
iterations: result.iterations,
total_usage: result.total_usage.clone(),
})
} else {
IterationOutcome::Completed(CompletedData {
response: ChatResponse::empty(),
termination_reason: TerminationReason::Complete,
iterations: self.iterations,
total_usage: self.total_usage.clone(),
})
}
}
fn mask_old_observations(&mut self) {
let Some(masking_config) = self.config.masking else {
return;
};
let force_mask = self
.config
.force_mask_iterations
.as_ref()
.and_then(|fm| fm.lock().ok())
.map(|set| set.clone());
let has_force_masks = force_mask.as_ref().is_some_and(|s| !s.is_empty());
if !has_force_masks && self.iterations <= masking_config.max_iterations_to_keep {
return;
}
let cutoff = self
.iterations
.saturating_sub(masking_config.max_iterations_to_keep);
let mut masked_count: usize = 0;
let mut tokens_saved: u32 = 0;
for meta in &mut self.tool_result_meta {
if meta.masked {
continue;
}
let is_old = meta.iteration <= cutoff;
let is_forced = force_mask
.as_ref()
.is_some_and(|s| s.contains(&meta.iteration));
if !is_old && !is_forced {
continue;
}
let msg = &self.params.messages[meta.message_index];
let (tool_call_id, content, is_error) = match msg.content.first() {
Some(ContentBlock::ToolResult(tr)) => {
(tr.tool_call_id.clone(), &tr.content, tr.is_error)
}
_ => continue,
};
if is_error {
continue;
}
let content_tokens = crate::context::estimate_tokens(content);
if content_tokens < masking_config.min_tokens_to_mask {
continue;
}
let placeholder = format!(
"[Masked — tool result from iteration {iter}, ~{content_tokens} tokens. \
Use result_cache tool if available, or re-invoke tool.]",
iter = meta.iteration,
);
let placeholder_tokens = crate::context::estimate_tokens(&placeholder);
self.params.messages[meta.message_index] = ChatMessage::tool_result_full(ToolResult {
tool_call_id,
content: placeholder,
is_error: false,
});
meta.masked = true;
masked_count += 1;
tokens_saved += content_tokens.saturating_sub(placeholder_tokens);
}
if masked_count > 0 {
self.events.push(LoopEvent::ObservationsMasked {
masked_count,
tokens_saved,
});
}
}
pub(crate) fn resume(&mut self, command: LoopCommand) {
if !self.finished {
self.pending_command = Some(command);
}
}
pub(crate) fn messages(&self) -> &[ChatMessage] {
&self.params.messages
}
pub(crate) fn messages_mut(&mut self) -> &mut Vec<ChatMessage> {
&mut self.params.messages
}
pub(crate) fn total_usage(&self) -> &Usage {
&self.total_usage
}
pub(crate) fn iterations(&self) -> u32 {
self.iterations
}
pub(crate) fn is_finished(&self) -> bool {
self.finished
}
pub(crate) fn into_result(self) -> ToolLoopResult {
self.final_result.unwrap_or_else(|| ToolLoopResult {
response: ChatResponse::empty(),
iterations: self.iterations,
total_usage: self.total_usage,
termination_reason: TerminationReason::Complete,
})
}
}
pub(crate) async fn collect_stream(mut stream: ChatStream) -> Result<ChatResponse, LlmError> {
let mut text = String::new();
let mut tool_calls: Vec<ToolCall> = Vec::new();
let mut usage = Usage::default();
let mut stop_reason = StopReason::EndTurn;
while let Some(event) = stream.next().await {
match event? {
StreamEvent::TextDelta(t) => text.push_str(&t),
StreamEvent::ToolCallComplete { call, .. } => tool_calls.push(call),
StreamEvent::Usage(u) => usage += &u,
StreamEvent::Done { stop_reason: sr } => stop_reason = sr,
_ => {}
}
}
let mut content = Vec::new();
if !text.is_empty() {
content.push(ContentBlock::Text(text));
}
for call in tool_calls {
content.push(ContentBlock::ToolCall(call));
}
Ok(ChatResponse {
content,
usage,
stop_reason,
model: String::new(),
metadata: HashMap::new(),
})
}
impl ChatMessage {
pub fn tool_result_full(result: ToolResult) -> Self {
Self {
role: crate::chat::ChatRole::Tool,
content: vec![ContentBlock::ToolResult(result)],
}
}
}
impl<Ctx: LoopDepth + Send + Sync + 'static> std::fmt::Debug for LoopCore<Ctx> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LoopCore")
.field("iterations", &self.iterations)
.field("tool_calls_executed", &self.tool_calls_executed)
.field("finished", &self.finished)
.field("has_pending_command", &self.pending_command.is_some())
.field("buffered_events", &self.events.len())
.finish_non_exhaustive()
}
}