use super::AgentResult;
use crate::llm::{Message, TokenUsage};
use crate::verification::VerificationReport;
use serde_json::Value;
use std::time::Instant;
const RECENT_TOOL_SIGNATURE_LIMIT: usize = 8;
pub(super) struct ExecutionLoopState {
pub(super) messages: Vec<Message>,
pub(super) total_usage: TokenUsage,
pub(super) tool_calls_count: usize,
pub(super) verification_reports: Vec<VerificationReport>,
turn: usize,
parse_error_count: u32,
continuation_count: u32,
recent_tool_signatures: Vec<String>,
execution_start: Instant,
}
pub(super) struct ParseErrorOutcome {
pub(super) output: String,
pub(super) count: u32,
pub(super) fatal_message: Option<String>,
}
impl ExecutionLoopState {
pub(super) fn new(history: &[Message]) -> Self {
Self {
messages: history.to_vec(),
total_usage: TokenUsage::default(),
tool_calls_count: 0,
verification_reports: Vec::new(),
turn: 0,
parse_error_count: 0,
continuation_count: 0,
recent_tool_signatures: Vec::new(),
execution_start: Instant::now(),
}
}
pub(super) fn next_turn(&mut self) -> usize {
self.turn += 1;
self.turn
}
pub(super) fn continuation_count(&self) -> u32 {
self.continuation_count
}
pub(super) fn check_execution_timeout(&self, max_time_ms: Option<u64>) -> Option<String> {
let max_time_ms = max_time_ms?;
let elapsed_ms = self.execution_start.elapsed().as_millis() as u64;
if elapsed_ms <= max_time_ms {
return None;
}
Some(format!(
"Execution timeout after {} seconds (limit: {} seconds). Completed {} turns.",
elapsed_ms / 1000,
max_time_ms / 1000,
self.turn.saturating_sub(1)
))
}
pub(super) fn elapsed_ms(&self) -> u64 {
self.execution_start.elapsed().as_millis() as u64
}
pub(super) fn turn_limit_error(&self, max_tool_rounds: usize) -> Option<String> {
(self.turn > max_tool_rounds)
.then(|| format!("Max tool rounds ({}) exceeded", max_tool_rounds))
}
pub(super) fn record_usage(&mut self, usage: &TokenUsage) {
self.total_usage.prompt_tokens += usage.prompt_tokens;
self.total_usage.completion_tokens += usage.completion_tokens;
self.total_usage.total_tokens += usage.total_tokens;
}
pub(super) fn record_tool_call(&mut self) {
self.tool_calls_count += 1;
}
pub(super) fn duplicate_tool_call(
&self,
tool_name: &str,
args: &Value,
threshold: u32,
) -> Option<(usize, String)> {
let signature = Self::tool_signature(tool_name, args);
let duplicate_count = self
.recent_tool_signatures
.iter()
.filter(|sig| sig.starts_with(&signature))
.count();
if duplicate_count < threshold as usize {
return None;
}
Some((
duplicate_count,
format!(
"Tool '{}' has been called {} times with identical arguments. \
Aborting to prevent infinite loop. Consider modifying your approach.",
tool_name, duplicate_count
),
))
}
pub(super) fn record_parse_error(
&mut self,
parse_error: &str,
max_parse_retries: u32,
) -> ParseErrorOutcome {
self.parse_error_count += 1;
let output = format!("Error: {}", parse_error);
let fatal_message = (self.parse_error_count > max_parse_retries).then(|| {
format!(
"LLM produced malformed tool arguments {} time(s) in a row \
(max_parse_retries={}); giving up",
self.parse_error_count, max_parse_retries
)
});
ParseErrorOutcome {
output,
count: self.parse_error_count,
fatal_message,
}
}
pub(super) fn reset_parse_errors(&mut self) {
self.parse_error_count = 0;
}
pub(super) fn recent_tool_signatures(&self) -> Vec<String> {
self.recent_tool_signatures.clone()
}
pub(super) fn remember_tool_signature(
&mut self,
tool_name: &str,
args: &Value,
is_error: bool,
) {
self.recent_tool_signatures.push(format!(
"{}:{} => {}",
tool_name,
serde_json::to_string(args).unwrap_or_default(),
if is_error { "error" } else { "ok" }
));
if self.recent_tool_signatures.len() > RECENT_TOOL_SIGNATURE_LIMIT {
let overflow = self.recent_tool_signatures.len() - RECENT_TOOL_SIGNATURE_LIMIT;
self.recent_tool_signatures.drain(0..overflow);
}
}
pub(super) fn should_inject_continuation(
&mut self,
looks_incomplete: bool,
enabled: bool,
max_continuation_turns: u32,
max_tool_rounds: usize,
) -> bool {
if enabled
&& self.continuation_count < max_continuation_turns
&& self.turn < max_tool_rounds
&& looks_incomplete
{
self.continuation_count += 1;
return true;
}
false
}
pub(super) fn finish(self, text: String) -> AgentResult {
AgentResult {
text,
messages: self.messages,
usage: self.total_usage,
tool_calls_count: self.tool_calls_count,
verification_reports: self.verification_reports,
}
}
fn tool_signature(tool_name: &str, args: &Value) -> String {
format!(
"{}:{}",
tool_name,
serde_json::to_string(args).unwrap_or_default()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn duplicate_tool_call_uses_recent_success_and_error_signatures() {
let mut state = ExecutionLoopState::new(&[]);
let args = json!({"path":"src/lib.rs"});
state.remember_tool_signature("read_file", &args, false);
state.remember_tool_signature("read_file", &args, true);
let duplicate = state.duplicate_tool_call("read_file", &args, 2).unwrap();
assert_eq!(duplicate.0, 2);
assert!(duplicate.1.contains("read_file"));
}
#[test]
fn parse_error_becomes_fatal_after_retry_budget() {
let mut state = ExecutionLoopState::new(&[]);
let first = state.record_parse_error("bad json", 1);
assert_eq!(first.count, 1);
assert!(first.fatal_message.is_none());
let second = state.record_parse_error("bad json", 1);
assert_eq!(second.count, 2);
assert!(second
.fatal_message
.unwrap()
.contains("max_parse_retries=1"));
}
#[test]
fn continuation_budget_is_consumed_only_when_incomplete() {
let mut state = ExecutionLoopState::new(&[]);
state.next_turn();
assert!(state.should_inject_continuation(true, true, 1, 5));
assert!(!state.should_inject_continuation(true, true, 1, 5));
assert!(!state.should_inject_continuation(false, true, 2, 5));
}
}