use sha2::{Digest, Sha256};
use std::collections::HashMap;
use crate::config::LoopGuardConfig;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LoopGuardAction {
Allow,
Warn {
reason: String,
suggested_delay_ms: Option<u64>,
},
Block { reason: String },
CircuitBreak { total_repetitions: u32 },
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct LoopGuardStats {
pub total_checks: u64,
pub warnings: u64,
pub blocks: u64,
pub circuit_breaks: u64,
pub ping_pong_detections: u64,
pub outcome_blocks: u64,
}
#[derive(Debug, Clone, Copy)]
pub struct ToolCallSig<'a> {
pub name: &'a str,
pub arguments: &'a str,
}
#[derive(Debug, Clone)]
pub struct LoopGuard {
config: LoopGuardConfig,
call_counts: HashMap<String, u32>,
outcome_counts: HashMap<String, u32>,
call_sequence: Vec<String>,
outcome_sequence: Vec<String>,
total_repetitions: u32,
stats: LoopGuardStats,
}
const POLL_PATTERNS: &[&str] = &[
"status",
"poll",
"wait",
"docker ps",
"kubectl get",
"git status",
];
impl LoopGuard {
pub fn new(config: LoopGuardConfig) -> Self {
Self {
config,
call_counts: HashMap::new(),
outcome_counts: HashMap::new(),
call_sequence: Vec::new(),
outcome_sequence: Vec::new(),
total_repetitions: 0,
stats: LoopGuardStats::default(),
}
}
pub fn check(&mut self, calls: &[ToolCallSig<'_>]) -> LoopGuardAction {
if !self.config.enabled || calls.is_empty() {
return LoopGuardAction::Allow;
}
self.stats.total_checks += 1;
let call_hash = hash_call_batch(calls);
let is_poll = is_poll_command(calls);
self.call_sequence.push(call_hash.clone());
let window = self.config.window_size as usize;
if window > 0 && self.call_sequence.len() > window {
self.prune_call_window(window);
}
let count = self.call_counts.entry(call_hash.clone()).or_insert(0);
*count += 1;
let count = *count;
let multiplier = if is_poll {
self.config.poll_multiplier
} else {
1
};
let warn_at = self.config.warn_threshold * multiplier;
let block_at = self.config.block_threshold * multiplier;
if let Some(action) = self.check_ping_pong() {
return action;
}
if count >= warn_at {
self.total_repetitions += 1;
if self.total_repetitions >= self.config.global_circuit_breaker {
self.stats.circuit_breaks += 1;
return LoopGuardAction::CircuitBreak {
total_repetitions: self.total_repetitions,
};
}
if count >= block_at {
self.stats.blocks += 1;
return LoopGuardAction::Block {
reason: format!(
"tool call repeated {} times (threshold {})",
count, block_at
),
};
}
let suggested_delay_ms = if is_poll {
Some(backoff_delay(count, warn_at))
} else {
None
};
self.stats.warnings += 1;
return LoopGuardAction::Warn {
reason: format!(
"tool call repeated {} times (warn threshold {})",
count, warn_at
),
suggested_delay_ms,
};
}
LoopGuardAction::Allow
}
pub fn record_outcome(
&mut self,
name: &str,
params: &str,
result_prefix: &str,
) -> Option<LoopGuardAction> {
if !self.config.enabled {
return None;
}
let outcome_hash = hash_outcome(name, params, result_prefix);
self.outcome_sequence.push(outcome_hash.clone());
let window = self.config.window_size as usize;
if window > 0 && self.outcome_sequence.len() > window {
self.prune_outcome_window(window);
}
let count = self.outcome_counts.entry(outcome_hash).or_insert(0);
*count += 1;
let count = *count;
if count >= self.config.outcome_block_threshold {
self.stats.outcome_blocks += 1;
self.total_repetitions += 1;
if self.total_repetitions >= self.config.global_circuit_breaker {
self.stats.circuit_breaks += 1;
return Some(LoopGuardAction::CircuitBreak {
total_repetitions: self.total_repetitions,
});
}
return Some(LoopGuardAction::Block {
reason: format!(
"identical outcome repeated {} times (threshold {})",
count, self.config.outcome_block_threshold
),
});
}
if count >= self.config.outcome_warn_threshold {
self.stats.warnings += 1;
return Some(LoopGuardAction::Warn {
reason: format!(
"identical outcome repeated {} times (warn threshold {})",
count, self.config.outcome_warn_threshold
),
suggested_delay_ms: None,
});
}
None
}
pub fn stats(&self) -> &LoopGuardStats {
&self.stats
}
fn prune_call_window(&mut self, keep: usize) {
let half = keep / 2;
let drain_count = self.call_sequence.len() - half;
self.call_sequence.drain(..drain_count);
self.call_counts.clear();
for hash in &self.call_sequence {
*self.call_counts.entry(hash.clone()).or_insert(0) += 1;
}
}
fn prune_outcome_window(&mut self, keep: usize) {
let half = keep / 2;
let drain_count = self.outcome_sequence.len() - half;
self.outcome_sequence.drain(..drain_count);
self.outcome_counts.clear();
for hash in &self.outcome_sequence {
*self.outcome_counts.entry(hash.clone()).or_insert(0) += 1;
}
}
fn check_ping_pong(&mut self) -> Option<LoopGuardAction> {
if self.config.ping_pong_min_repeats == 0 {
return None;
}
let seq = &self.call_sequence;
let min_repeats = self.config.ping_pong_min_repeats as usize;
if seq.len() >= 2 * min_repeats && Self::has_periodic_pattern(seq, 2, min_repeats) {
self.stats.ping_pong_detections += 1;
self.stats.warnings += 1;
self.total_repetitions += 1;
if self.total_repetitions >= self.config.global_circuit_breaker {
self.stats.circuit_breaks += 1;
return Some(LoopGuardAction::CircuitBreak {
total_repetitions: self.total_repetitions,
});
}
return Some(LoopGuardAction::Warn {
reason: format!(
"ping-pong pattern (period 2) detected over {} cycles",
min_repeats
),
suggested_delay_ms: None,
});
}
if seq.len() >= 3 * min_repeats && Self::has_periodic_pattern(seq, 3, min_repeats) {
self.stats.ping_pong_detections += 1;
self.stats.warnings += 1;
self.total_repetitions += 1;
if self.total_repetitions >= self.config.global_circuit_breaker {
self.stats.circuit_breaks += 1;
return Some(LoopGuardAction::CircuitBreak {
total_repetitions: self.total_repetitions,
});
}
return Some(LoopGuardAction::Warn {
reason: format!(
"ping-pong pattern (period 3) detected over {} cycles",
min_repeats
),
suggested_delay_ms: None,
});
}
None
}
fn has_periodic_pattern(seq: &[String], period: usize, min_repeats: usize) -> bool {
if period == 0 || min_repeats == 0 {
return false;
}
let needed = period * min_repeats;
if seq.len() < needed {
return false;
}
let tail = &seq[seq.len() - needed..];
let pattern = &tail[..period];
if pattern.iter().all(|h| h == &pattern[0]) {
return false;
}
for cycle in 1..min_repeats {
let offset = cycle * period;
for i in 0..period {
if tail[offset + i] != pattern[i] {
return false;
}
}
}
true
}
}
pub(crate) fn truncate_utf8(s: &str, max_bytes: usize) -> &str {
if s.len() <= max_bytes {
return s;
}
let mut end = max_bytes;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
}
fn hash_call_batch(batch: &[ToolCallSig<'_>]) -> String {
let mut hasher = Sha256::new();
for call in batch {
hasher.update(call.name.as_bytes());
hasher.update(b"\n");
hasher.update(normalize_args(call.arguments).as_bytes());
hasher.update(b"\n--\n");
}
hex::encode(hasher.finalize())
}
fn hash_outcome(name: &str, params: &str, result_prefix: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(name.as_bytes());
hasher.update(b"\n");
let params_truncated = truncate_utf8(params, 2048);
hasher.update(normalize_args(params_truncated).as_bytes());
hasher.update(b"\n--\n");
let prefix = truncate_utf8(result_prefix, 1000);
hasher.update(prefix.as_bytes());
hex::encode(hasher.finalize())
}
fn normalize_args(raw: &str) -> String {
match serde_json::from_str::<serde_json::Value>(raw) {
Ok(v) => v.to_string(),
Err(_) => raw.trim().to_string(),
}
}
fn is_poll_command(calls: &[ToolCallSig<'_>]) -> bool {
for call in calls {
let lower_name = call.name.to_lowercase();
let lower_args = call.arguments.to_lowercase();
for pattern in POLL_PATTERNS {
if lower_name.contains(pattern) || lower_args.contains(pattern) {
return true;
}
}
}
false
}
fn backoff_delay(count: u32, warn_at: u32) -> u64 {
let exponent = count.saturating_sub(warn_at);
let delay_secs = 1u64.wrapping_shl(exponent).min(30);
delay_secs * 1000
}
#[cfg(test)]
mod tests {
use super::*;
fn sig<'a>(name: &'a str, args: &'a str) -> ToolCallSig<'a> {
ToolCallSig {
name,
arguments: args,
}
}
fn default_config() -> LoopGuardConfig {
LoopGuardConfig::default()
}
#[test]
fn test_allow_first_call() {
let mut guard = LoopGuard::new(default_config());
let action = guard.check(&[sig("web_search", r#"{"q":"rust"}"#)]);
assert_eq!(action, LoopGuardAction::Allow);
assert_eq!(guard.stats().total_checks, 1);
}
#[test]
fn test_warn_on_repeated_calls() {
let mut guard = LoopGuard::new(default_config());
let call = [sig("web_search", r#"{"q":"rust"}"#)];
assert_eq!(guard.check(&call), LoopGuardAction::Allow);
assert_eq!(guard.check(&call), LoopGuardAction::Allow);
match guard.check(&call) {
LoopGuardAction::Warn { reason, .. } => {
assert!(reason.contains("repeated 3 times"));
}
other => panic!("expected Warn, got {other:?}"),
}
assert_eq!(guard.stats().warnings, 1);
}
#[test]
fn test_block_after_threshold() {
let mut guard = LoopGuard::new(default_config());
let call = [sig("shell", r#"{"command":"ls"}"#)];
for _ in 0..4 {
guard.check(&call);
}
match guard.check(&call) {
LoopGuardAction::Block { reason } => {
assert!(reason.contains("repeated 5 times"));
}
other => panic!("expected Block, got {other:?}"),
}
assert_eq!(guard.stats().blocks, 1);
}
#[test]
fn test_circuit_breaker() {
let config = LoopGuardConfig {
global_circuit_breaker: 3,
warn_threshold: 2,
block_threshold: 100, ..default_config()
};
let mut guard = LoopGuard::new(config);
for i in 0..2 {
let name = format!("tool_{i}");
let call = [sig(&name, r#"{}"#)];
guard.check(&call); guard.check(&call); }
assert_eq!(guard.stats().warnings, 2);
assert_eq!(guard.stats().circuit_breaks, 0);
let call = [sig("tool_2", r#"{}"#)];
guard.check(&call); let action = guard.check(&call); assert!(
matches!(action, LoopGuardAction::CircuitBreak { .. }),
"expected CircuitBreak, got {action:?}"
);
assert_eq!(guard.stats().circuit_breaks, 1);
}
#[test]
fn test_ping_pong_detection_period2() {
let config = LoopGuardConfig {
ping_pong_min_repeats: 2,
warn_threshold: 100, ..default_config()
};
let mut guard = LoopGuard::new(config);
let call_a = [sig("read_file", r#"{"path":"a.txt"}"#)];
let call_b = [sig("write_file", r#"{"path":"b.txt"}"#)];
guard.check(&call_a);
guard.check(&call_b);
guard.check(&call_a);
match guard.check(&call_b) {
LoopGuardAction::Warn { reason, .. } => {
assert!(reason.contains("ping-pong"));
assert!(reason.contains("period 2"));
}
other => panic!("expected Warn with ping-pong, got {other:?}"),
}
assert_eq!(guard.stats().ping_pong_detections, 1);
}
#[test]
fn test_ping_pong_detection_period3() {
let config = LoopGuardConfig {
ping_pong_min_repeats: 2,
warn_threshold: 100,
..default_config()
};
let mut guard = LoopGuard::new(config);
let call_a = [sig("tool_a", r#"{"x":1}"#)];
let call_b = [sig("tool_b", r#"{"x":2}"#)];
let call_c = [sig("tool_c", r#"{"x":3}"#)];
guard.check(&call_a);
guard.check(&call_b);
guard.check(&call_c);
guard.check(&call_a);
guard.check(&call_b);
match guard.check(&call_c) {
LoopGuardAction::Warn { reason, .. } => {
assert!(reason.contains("ping-pong"));
assert!(reason.contains("period 3"));
}
other => panic!("expected Warn with ping-pong period 3, got {other:?}"),
}
assert_eq!(guard.stats().ping_pong_detections, 1);
}
#[test]
fn test_outcome_aware_blocking() {
let mut guard = LoopGuard::new(default_config());
let r = guard.record_outcome("shell", r#"{"cmd":"ls"}"#, "file1.txt\nfile2.txt");
assert!(r.is_none());
let r = guard.record_outcome("shell", r#"{"cmd":"ls"}"#, "file1.txt\nfile2.txt");
match r {
Some(LoopGuardAction::Warn { reason, .. }) => {
assert!(reason.contains("identical outcome"));
}
other => panic!("expected Warn, got {other:?}"),
}
let r = guard.record_outcome("shell", r#"{"cmd":"ls"}"#, "file1.txt\nfile2.txt");
match r {
Some(LoopGuardAction::Block { reason }) => {
assert!(reason.contains("identical outcome"));
}
other => panic!("expected Block, got {other:?}"),
}
assert_eq!(guard.stats().outcome_blocks, 1);
}
#[test]
fn test_poll_relaxation() {
let config = LoopGuardConfig {
warn_threshold: 2,
block_threshold: 4,
poll_multiplier: 3,
..default_config()
};
let mut guard = LoopGuard::new(config);
let call = [sig("shell", r#"{"command":"git status"}"#)];
for _ in 0..5 {
assert_eq!(guard.check(&call), LoopGuardAction::Allow);
}
match guard.check(&call) {
LoopGuardAction::Warn { reason, .. } => {
assert!(reason.contains("repeated 6 times"));
}
other => panic!("expected Warn at call 6, got {other:?}"),
}
}
#[test]
fn test_graduated_response() {
let config = LoopGuardConfig {
warn_threshold: 2,
block_threshold: 4,
global_circuit_breaker: 100,
..default_config()
};
let mut guard = LoopGuard::new(config);
let call = [sig("tool", r#"{}"#)];
assert_eq!(guard.check(&call), LoopGuardAction::Allow); assert!(matches!(guard.check(&call), LoopGuardAction::Warn { .. })); assert!(matches!(guard.check(&call), LoopGuardAction::Warn { .. })); assert!(matches!(guard.check(&call), LoopGuardAction::Block { .. })); }
#[test]
fn test_backoff_schedule() {
let config = LoopGuardConfig {
warn_threshold: 2,
block_threshold: 100,
poll_multiplier: 1, ..default_config()
};
let mut guard = LoopGuard::new(config);
let call = [sig("shell", r#"{"command":"docker ps"}"#)];
guard.check(&call);
match guard.check(&call) {
LoopGuardAction::Warn {
suggested_delay_ms, ..
} => {
assert_eq!(suggested_delay_ms, Some(1000));
}
other => panic!("expected Warn, got {other:?}"),
}
match guard.check(&call) {
LoopGuardAction::Warn {
suggested_delay_ms, ..
} => {
assert_eq!(suggested_delay_ms, Some(2000));
}
other => panic!("expected Warn, got {other:?}"),
}
match guard.check(&call) {
LoopGuardAction::Warn {
suggested_delay_ms, ..
} => {
assert_eq!(suggested_delay_ms, Some(4000));
}
other => panic!("expected Warn, got {other:?}"),
}
}
#[test]
fn test_stats_tracking() {
let config = LoopGuardConfig {
warn_threshold: 1,
block_threshold: 3,
global_circuit_breaker: 100,
..default_config()
};
let mut guard = LoopGuard::new(config);
let call = [sig("tool", r#"{}"#)];
guard.check(&call); guard.check(&call); guard.check(&call); guard.check(&call);
let s = guard.stats();
assert_eq!(s.total_checks, 4);
assert_eq!(s.warnings, 2);
assert_eq!(s.blocks, 2);
assert_eq!(s.circuit_breaks, 0);
guard.record_outcome("t", "{}", "same");
guard.record_outcome("t", "{}", "same");
guard.record_outcome("t", "{}", "same");
let s = guard.stats();
assert_eq!(s.outcome_blocks, 1);
assert!(s.warnings >= 3); }
#[test]
fn test_disabled_guard_allows_all() {
let config = LoopGuardConfig {
enabled: false,
..default_config()
};
let mut guard = LoopGuard::new(config);
let call = [sig("tool", r#"{}"#)];
for _ in 0..100 {
assert_eq!(guard.check(&call), LoopGuardAction::Allow);
}
assert!(guard.record_outcome("t", "{}", "x").is_none());
assert_eq!(guard.stats().total_checks, 0);
}
#[test]
fn test_no_false_ping_pong_on_varied_calls() {
let config = LoopGuardConfig {
ping_pong_min_repeats: 2,
warn_threshold: 100,
..default_config()
};
let mut guard = LoopGuard::new(config);
guard.check(&[sig("tool_a", r#"{"x":1}"#)]);
guard.check(&[sig("tool_b", r#"{"x":2}"#)]);
guard.check(&[sig("tool_c", r#"{"x":3}"#)]);
let action = guard.check(&[sig("tool_d", r#"{"x":4}"#)]);
assert_eq!(action, LoopGuardAction::Allow);
assert_eq!(guard.stats().ping_pong_detections, 0);
}
#[test]
fn test_backoff_delay_caps_at_30s() {
assert_eq!(backoff_delay(100, 2), 30_000);
assert_eq!(backoff_delay(50, 2), 30_000);
assert_eq!(backoff_delay(3, 2), 2_000); assert_eq!(backoff_delay(2, 2), 1_000); }
#[test]
fn test_config_defaults() {
let config = LoopGuardConfig::default();
assert!(config.enabled);
assert_eq!(config.warn_threshold, 3);
assert_eq!(config.block_threshold, 5);
assert_eq!(config.global_circuit_breaker, 30);
assert_eq!(config.ping_pong_min_repeats, 3);
assert_eq!(config.poll_multiplier, 3);
assert_eq!(config.outcome_warn_threshold, 2);
assert_eq!(config.outcome_block_threshold, 3);
assert_eq!(config.window_size, 200);
}
#[test]
fn test_truncate_utf8_ascii() {
assert_eq!(truncate_utf8("hello", 3), "hel");
assert_eq!(truncate_utf8("hello", 10), "hello");
assert_eq!(truncate_utf8("hello", 5), "hello");
}
#[test]
fn test_truncate_utf8_multibyte() {
let s = "\u{4f60}\u{597d}\u{4e16}\u{754c}"; assert_eq!(truncate_utf8(s, 6), "\u{4f60}\u{597d}"); assert_eq!(truncate_utf8(s, 7), "\u{4f60}\u{597d}");
assert_eq!(truncate_utf8(s, 8), "\u{4f60}\u{597d}");
assert_eq!(truncate_utf8(s, 9), "\u{4f60}\u{597d}\u{4e16}");
}
#[test]
fn test_truncate_utf8_empty() {
assert_eq!(truncate_utf8("", 0), "");
assert_eq!(truncate_utf8("", 10), "");
}
#[test]
fn test_truncate_utf8_emoji() {
let s = "\u{1f600}abc"; assert_eq!(truncate_utf8(s, 3), ""); assert_eq!(truncate_utf8(s, 4), "\u{1f600}");
assert_eq!(truncate_utf8(s, 5), "\u{1f600}a");
}
#[test]
fn test_ping_pong_min_repeats_zero_no_panic() {
let config = LoopGuardConfig {
ping_pong_min_repeats: 0,
warn_threshold: 100,
block_threshold: 200,
global_circuit_breaker: 1000,
..default_config()
};
let mut guard = LoopGuard::new(config);
let call_a = [sig("read_file", r#"{"path":"a.txt"}"#)];
let call_b = [sig("write_file", r#"{"path":"b.txt"}"#)];
for _ in 0..10 {
assert_eq!(guard.check(&call_a), LoopGuardAction::Allow);
assert_eq!(guard.check(&call_b), LoopGuardAction::Allow);
}
assert_eq!(guard.stats().ping_pong_detections, 0);
}
#[test]
fn test_has_periodic_pattern_zero_period() {
let seq: Vec<String> = vec!["a".into(), "b".into(), "a".into(), "b".into()];
assert!(!LoopGuard::has_periodic_pattern(&seq, 0, 2));
}
#[test]
fn test_has_periodic_pattern_zero_min_repeats() {
let seq: Vec<String> = vec!["a".into(), "b".into(), "a".into(), "b".into()];
assert!(!LoopGuard::has_periodic_pattern(&seq, 2, 0));
}
#[test]
fn test_hash_outcome_multibyte_no_panic() {
let long_params = "\u{4f60}\u{597d}".repeat(1500); let long_result = "\u{1f600}".repeat(500); let h = hash_outcome("tool", &long_params, &long_result);
assert!(!h.is_empty());
}
#[test]
fn test_window_pruning_prevents_unbounded_growth() {
let config = LoopGuardConfig {
window_size: 100,
warn_threshold: 200, block_threshold: 300,
global_circuit_breaker: 1000,
..default_config()
};
let mut guard = LoopGuard::new(config);
for i in 0..300 {
let name = format!("tool_{i}");
guard.check(&[sig(&name, r#"{}"#)]);
guard.record_outcome(&name, r#"{}"#, &format!("result_{i}"));
}
assert!(
guard.call_sequence.len() <= 100,
"call_sequence len {} exceeds window_size 100",
guard.call_sequence.len()
);
assert!(
guard.outcome_sequence.len() <= 100,
"outcome_sequence len {} exceeds window_size 100",
guard.outcome_sequence.len()
);
assert!(
guard.call_counts.len() <= 100,
"call_counts len {} exceeds window_size 100",
guard.call_counts.len()
);
assert!(
guard.outcome_counts.len() <= 100,
"outcome_counts len {} exceeds window_size 100",
guard.outcome_counts.len()
);
}
#[test]
fn test_window_pruning_resets_false_positives() {
let config = LoopGuardConfig {
window_size: 100,
warn_threshold: 3,
block_threshold: 5,
global_circuit_breaker: 1000,
..default_config()
};
let mut guard = LoopGuard::new(config);
let target_call = [sig("target_tool", r#"{"q":"test"}"#)];
for batch in 0..10 {
let action = guard.check(&target_call);
if batch >= 3 {
assert!(
!matches!(action, LoopGuardAction::Block { .. }),
"target_tool should not be blocked at batch {batch}"
);
}
for j in 0..49 {
let filler = format!("filler_{}_{}", batch, j);
guard.check(&[sig(&filler, r#"{}"#)]);
}
}
}
#[test]
fn test_window_size_zero_disables_pruning() {
let config = LoopGuardConfig {
window_size: 0,
warn_threshold: 200,
block_threshold: 300,
global_circuit_breaker: 1000,
..default_config()
};
let mut guard = LoopGuard::new(config);
for i in 0..150 {
let name = format!("tool_{i}");
guard.check(&[sig(&name, r#"{}"#)]);
}
assert_eq!(guard.call_sequence.len(), 150);
}
#[test]
fn test_outcome_window_pruning() {
let config = LoopGuardConfig {
window_size: 20,
outcome_warn_threshold: 5,
outcome_block_threshold: 10,
global_circuit_breaker: 1000,
..default_config()
};
let mut guard = LoopGuard::new(config);
for i in 0..25 {
guard.record_outcome("tool", &format!(r#"{{"i":{i}}}"#), &format!("res_{i}"));
}
assert!(
guard.outcome_sequence.len() <= 20,
"outcome_sequence len {} exceeds window 20",
guard.outcome_sequence.len()
);
assert_eq!(guard.stats().outcome_blocks, 0);
}
}