Skip to main content

koda_core/
loop_guard.rs

1//! Loop detection and hard-cap user prompt for the inference loop.
2//!
3//! Tracks recent tool call fingerprints in a sliding window and flags
4//! when the same tool+args combination repeats too many times.
5//! When the hard iteration cap is hit, prompts the user interactively
6//! to continue or stop — falling back to stop in headless environments.
7
8use crate::providers::ToolCall;
9use std::collections::{HashMap, VecDeque};
10
11/// Default hard cap for the main inference loop.
12pub const MAX_ITERATIONS_DEFAULT: u32 = 200;
13
14/// Hard cap for sub-agent loops.
15pub const MAX_SUB_AGENT_ITERATIONS: usize = 20;
16
17/// How many times the same fingerprint must appear to flag a loop.
18const REPEAT_THRESHOLD: usize = 3;
19
20/// Sliding window size (individual tool calls, not batches).
21const WINDOW_SIZE: usize = 20;
22
23/// How many recent tool names to show in the hard-cap prompt.
24const DISPLAY_RECENT: usize = 5;
25
26// ── Loop detection ────────────────────────────────────────────────
27
28/// Tracks repeated tool call patterns.
29#[derive(Default)]
30pub struct LoopDetector {
31    /// Sliding window of recent tool fingerprints.
32    window: VecDeque<String>,
33    /// Ring buffer of the last N tool names (for display only).
34    recent: VecDeque<String>,
35}
36
37impl LoopDetector {
38    /// Create a new loop detector with empty history.
39    pub fn new() -> Self {
40        Self {
41            window: VecDeque::new(),
42            recent: VecDeque::new(),
43        }
44    }
45
46    /// Record a batch of tool calls.
47    /// Returns `Some(repeated_fingerprint)` when a loop is detected.
48    pub fn record(&mut self, tool_calls: &[ToolCall]) -> Option<String> {
49        for tc in tool_calls {
50            let fp = fingerprint(&tc.function_name, &tc.arguments);
51
52            // Sliding window for loop detection ONLY tracks mutating tools.
53            // Repeating read-only operations is handled by stale-read optimization.
54            if crate::tools::is_mutating_tool(&tc.function_name) {
55                self.window.push_back(fp);
56                if self.window.len() > WINDOW_SIZE {
57                    self.window.pop_front();
58                }
59            }
60
61            // Ring buffer for display always tracks all tools
62            self.recent.push_back(tc.function_name.clone());
63            if self.recent.len() > DISPLAY_RECENT {
64                self.recent.pop_front();
65            }
66        }
67
68        self.check()
69    }
70
71    /// Recent tool names (most recent last), for display in the hard-cap prompt.
72    pub fn recent_names(&self) -> Vec<String> {
73        self.recent.iter().cloned().collect()
74    }
75
76    fn check(&self) -> Option<String> {
77        let mut counts: HashMap<&str, usize> = HashMap::new();
78        for fp in &self.window {
79            *counts.entry(fp.as_str()).or_insert(0) += 1;
80        }
81        counts
82            .into_iter()
83            .find(|(_, n)| *n >= REPEAT_THRESHOLD)
84            .map(|(fp, _)| fp.to_string())
85    }
86}
87
88/// Stable fingerprint: tool name + first 200 chars of args.
89fn fingerprint(name: &str, args: &str) -> String {
90    let prefix = &args[..args.len().min(200)];
91    format!("{name}:{prefix}")
92}
93
94// ── Hard-cap prompt ───────────────────────────────────────────────
95
96/// Prompt the user when the hard iteration cap is hit.
97///
98/// Options for continuing after hitting the hard cap.
99#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
100#[serde(rename_all = "snake_case")]
101pub enum LoopContinuation {
102    /// Stop the inference loop.
103    Stop,
104    /// Continue for 50 more iterations.
105    Continue50,
106    /// Continue for 200 more iterations.
107    Continue200,
108}
109
110impl LoopContinuation {
111    /// Number of additional iterations granted.
112    pub fn extra_iterations(self) -> u32 {
113        match self {
114            Self::Stop => 0,
115            Self::Continue50 => 50,
116            Self::Continue200 => 200,
117        }
118    }
119}
120
121// ── Tests ─────────────────────────────────────────────────────────
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    fn call(name: &str, args: &str) -> ToolCall {
128        ToolCall {
129            id: "x".into(),
130            function_name: name.into(),
131            arguments: args.into(),
132            thought_signature: None,
133        }
134    }
135
136    #[test]
137    fn no_loop_on_unique_calls() {
138        let mut d = LoopDetector::new();
139        assert!(d.record(&[call("Edit", "{\"path\":\"a.rs\"}")]).is_none());
140        assert!(d.record(&[call("Edit", "{\"path\":\"b.rs\"}")]).is_none());
141        assert!(d.record(&[call("Bash", "{\"cmd\":\"ls\"}")]).is_none());
142    }
143
144    #[test]
145    fn detects_repeated_identical_call() {
146        let mut d = LoopDetector::new();
147        let tc = call("Edit", "{\"path\":\"src/main.rs\"}");
148        assert!(d.record(std::slice::from_ref(&tc)).is_none());
149        assert!(d.record(std::slice::from_ref(&tc)).is_none());
150        // Third repetition should trigger
151        assert!(d.record(std::slice::from_ref(&tc)).is_some());
152    }
153
154    #[test]
155    fn different_args_not_a_loop() {
156        let mut d = LoopDetector::new();
157        for i in 0..10 {
158            let args = format!("{{\"path\":\"file{i}.rs\"}}");
159            assert!(d.record(&[call("Edit", &args)]).is_none());
160        }
161    }
162
163    #[test]
164    fn ignores_readonly_tools() {
165        let mut d = LoopDetector::new();
166        let tc = call("Read", "{\"path\":\"src/main.rs\"}");
167        assert!(d.record(std::slice::from_ref(&tc)).is_none());
168        assert!(d.record(std::slice::from_ref(&tc)).is_none());
169        assert!(d.record(std::slice::from_ref(&tc)).is_none());
170        assert!(d.record(std::slice::from_ref(&tc)).is_none());
171        // Even 4 repetitions shouldn't trigger because Read is ignored
172        assert!(d.check().is_none());
173    }
174
175    #[test]
176    fn recent_names_tracks_last_five() {
177        let mut d = LoopDetector::new();
178        for i in 0..8 {
179            let name = format!("Tool{i}");
180            d.record(&[call(&name, "{}")]);
181        }
182        let names = d.recent_names();
183        assert_eq!(names.len(), 5);
184        assert_eq!(names[0], "Tool3");
185        assert_eq!(names[4], "Tool7");
186    }
187
188    #[test]
189    fn fingerprint_truncates_long_args() {
190        let long_args = "x".repeat(500);
191        let fp = fingerprint("Bash", &long_args);
192        // name + ":" + 200 chars
193        assert_eq!(fp.len(), "Bash:".len() + 200);
194    }
195}