use std::collections::HashMap;
pub const TOOL_LOOP_THRESHOLD: usize = 5;
#[derive(Debug, Clone)]
pub struct ToolCallRecord {
pub tool_name: String,
pub args_json: Option<String>,
pub turn_local_index: usize,
}
#[derive(Debug, Clone)]
pub struct ToolResultRecord {
pub tool_name: String,
pub success: bool,
pub duration_ms: u64,
pub error_summary: Option<String>,
pub turn_local_index: usize,
}
#[derive(Debug, Clone, Default)]
pub struct ToolStatsAccumulator {
calls: Vec<ToolCallRecord>,
results: Vec<ToolResultRecord>,
}
impl ToolStatsAccumulator {
pub fn new() -> Self {
Self::default()
}
pub fn reset_turn(&mut self) {
self.calls.clear();
self.results.clear();
}
pub fn record_call(&mut self, tool_name: String, args_json: Option<String>) {
let turn_local_index = self.calls.len();
self.calls.push(ToolCallRecord {
tool_name,
args_json,
turn_local_index,
});
}
pub fn record_result(
&mut self,
tool_name: String,
success: bool,
duration_ms: u64,
error_summary: Option<String>,
) {
let turn_local_index = self.results.len();
self.results.push(ToolResultRecord {
tool_name,
success,
duration_ms,
error_summary,
turn_local_index,
});
}
pub fn total_calls(&self) -> usize {
self.calls.len()
}
pub fn total_results(&self) -> usize {
self.results.len()
}
pub fn counts_by_tool(&self) -> HashMap<String, usize> {
let mut counts: HashMap<String, usize> = HashMap::new();
for call in &self.calls {
*counts.entry(call.tool_name.clone()).or_insert(0) += 1;
}
counts
}
pub fn detected_loop(&self) -> Option<(String, usize)> {
if self.calls.len() < TOOL_LOOP_THRESHOLD {
return None;
}
let last_tool = &self.calls.last()?.tool_name;
let mut consecutive = 0usize;
for call in self.calls.iter().rev() {
if &call.tool_name == last_tool {
consecutive += 1;
} else {
break;
}
}
if consecutive >= TOOL_LOOP_THRESHOLD {
Some((last_tool.clone(), consecutive))
} else {
None
}
}
pub fn total_duration_ms(&self) -> u64 {
self.results.iter().map(|r| r.duration_ms).sum()
}
pub fn failure_count(&self) -> usize {
self.results.iter().filter(|r| !r.success).count()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_has_no_loop() {
let acc = ToolStatsAccumulator::new();
assert!(acc.detected_loop().is_none());
assert_eq!(acc.total_calls(), 0);
}
#[test]
fn below_threshold_has_no_loop() {
let mut acc = ToolStatsAccumulator::new();
for _ in 0..(TOOL_LOOP_THRESHOLD - 1) {
acc.record_call("search".into(), None);
}
assert!(acc.detected_loop().is_none());
}
#[test]
fn threshold_same_tool_detects_loop() {
let mut acc = ToolStatsAccumulator::new();
for _ in 0..TOOL_LOOP_THRESHOLD {
acc.record_call("search".into(), None);
}
let (tool, count) = acc.detected_loop().expect("loop should fire");
assert_eq!(tool, "search");
assert_eq!(count, TOOL_LOOP_THRESHOLD);
}
#[test]
fn interleaved_tools_break_loop() {
let mut acc = ToolStatsAccumulator::new();
for _ in 0..4 {
acc.record_call("search".into(), None);
}
acc.record_call("db.query".into(), None);
for _ in 0..4 {
acc.record_call("search".into(), None);
}
assert!(acc.detected_loop().is_none());
}
#[test]
fn reset_turn_clears_state() {
let mut acc = ToolStatsAccumulator::new();
for _ in 0..TOOL_LOOP_THRESHOLD {
acc.record_call("search".into(), None);
}
assert!(acc.detected_loop().is_some());
acc.reset_turn();
assert!(acc.detected_loop().is_none());
assert_eq!(acc.total_calls(), 0);
}
#[test]
fn counts_by_tool_aggregates_correctly() {
let mut acc = ToolStatsAccumulator::new();
acc.record_call("search".into(), None);
acc.record_call("search".into(), None);
acc.record_call("db.query".into(), None);
let counts = acc.counts_by_tool();
assert_eq!(counts.get("search"), Some(&2));
assert_eq!(counts.get("db.query"), Some(&1));
}
#[test]
fn duration_and_failure_counters() {
let mut acc = ToolStatsAccumulator::new();
acc.record_result("search".into(), true, 100, None);
acc.record_result("db.query".into(), false, 250, Some("timeout".into()));
acc.record_result("search".into(), true, 50, None);
assert_eq!(acc.total_duration_ms(), 400);
assert_eq!(acc.failure_count(), 1);
assert_eq!(acc.total_results(), 3);
}
#[test]
fn loop_fires_on_tail_run_after_mixed_prefix() {
let mut acc = ToolStatsAccumulator::new();
acc.record_call("explore".into(), None);
acc.record_call("db.query".into(), None);
for _ in 0..TOOL_LOOP_THRESHOLD {
acc.record_call("search".into(), None);
}
let (tool, count) = acc
.detected_loop()
.expect("tail same-tool run should fire regardless of prefix");
assert_eq!(tool, "search");
assert_eq!(count, TOOL_LOOP_THRESHOLD);
}
#[test]
fn loop_fires_on_same_tool_with_diverging_args() {
let mut acc = ToolStatsAccumulator::new();
for i in 0..TOOL_LOOP_THRESHOLD {
acc.record_call(
"search_orders".into(),
Some(format!("{{\"user_id\":42,\"attempt\":{i}}}")),
);
}
let (tool, count) = acc
.detected_loop()
.expect("same name + varying args is still a loop");
assert_eq!(tool, "search_orders");
assert_eq!(count, TOOL_LOOP_THRESHOLD);
}
#[test]
fn over_threshold_returns_actual_count_not_threshold() {
let mut acc = ToolStatsAccumulator::new();
for _ in 0..8 {
acc.record_call("search".into(), None);
}
let (_, count) = acc.detected_loop().expect("well over threshold");
assert_eq!(count, 8, "count reflects actual run depth, not threshold");
}
#[test]
fn resolved_loop_followed_by_short_tail_does_not_fire() {
let mut acc = ToolStatsAccumulator::new();
for _ in 0..TOOL_LOOP_THRESHOLD {
acc.record_call("search".into(), None);
}
acc.record_call("db.query".into(), None);
acc.record_call("db.query".into(), None);
assert!(
acc.detected_loop().is_none(),
"historical loops that resolved are explicitly out-of-scope"
);
}
#[test]
fn alternating_high_volume_tools_do_not_fire() {
let mut acc = ToolStatsAccumulator::new();
for i in 0..20 {
let tool = if i % 2 == 0 { "search" } else { "db.query" };
acc.record_call(tool.into(), None);
}
assert!(
acc.detected_loop().is_none(),
"alternating A/B is not a same-tool consecutive loop by design"
);
}
#[test]
fn exact_threshold_boundary_both_sides() {
let mut under = ToolStatsAccumulator::new();
for _ in 0..(TOOL_LOOP_THRESHOLD - 1) {
under.record_call("t".into(), None);
}
assert!(
under.detected_loop().is_none(),
"THRESHOLD - 1 must not fire"
);
let mut at = ToolStatsAccumulator::new();
for _ in 0..TOOL_LOOP_THRESHOLD {
at.record_call("t".into(), None);
}
assert!(at.detected_loop().is_some(), "THRESHOLD must fire");
}
}