Skip to main content

koda_core/
loop_guard.rs

1//! Loop detection and hard-cap user prompt for the inference loop.
2//!
3//! Tracks recent tool call fingerprints in a sliding window and flags
4//! when the same tool+args combination repeats too many times.
5//! When the hard iteration cap is hit, prompts the user interactively
6//! to continue or stop — falling back to stop in headless environments.
7
8use crate::providers::ToolCall;
9use std::collections::{HashMap, VecDeque};
10
11/// Default hard cap for the main inference loop.
12pub const MAX_ITERATIONS_DEFAULT: u32 = 200;
13
14/// Hard cap for sub-agent loops.
15pub const MAX_SUB_AGENT_ITERATIONS: usize = 20;
16
17/// How many times the same fingerprint must appear to flag a loop.
18const REPEAT_THRESHOLD: usize = 3;
19
20/// Sliding window size (individual tool calls, not batches).
21const WINDOW_SIZE: usize = 20;
22
23/// How many recent tool names to show in the hard-cap prompt.
24const DISPLAY_RECENT: usize = 5;
25
26// ── Loop detection ────────────────────────────────────────────────
27
28/// Tracks repeated tool call patterns.
29#[derive(Default)]
30pub struct LoopDetector {
31    /// Sliding window of recent tool fingerprints.
32    window: VecDeque<String>,
33    /// Ring buffer of the last N tool names (for display only).
34    recent: VecDeque<String>,
35}
36
37impl LoopDetector {
38    pub fn new() -> Self {
39        Self {
40            window: VecDeque::new(),
41            recent: VecDeque::new(),
42        }
43    }
44
45    /// Record a batch of tool calls.
46    /// Returns `Some(repeated_fingerprint)` when a loop is detected.
47    pub fn record(&mut self, tool_calls: &[ToolCall]) -> Option<String> {
48        for tc in tool_calls {
49            let fp = fingerprint(&tc.function_name, &tc.arguments);
50
51            // Sliding window for loop detection ONLY tracks mutating tools.
52            // Repeating read-only operations is handled by stale-read optimization.
53            if is_mutating_tool(&tc.function_name) {
54                self.window.push_back(fp);
55                if self.window.len() > WINDOW_SIZE {
56                    self.window.pop_front();
57                }
58            }
59
60            // Ring buffer for display always tracks all tools
61            self.recent.push_back(tc.function_name.clone());
62            if self.recent.len() > DISPLAY_RECENT {
63                self.recent.pop_front();
64            }
65        }
66
67        self.check()
68    }
69
70    /// Recent tool names (most recent last), for display in the hard-cap prompt.
71    pub fn recent_names(&self) -> Vec<String> {
72        self.recent.iter().cloned().collect()
73    }
74
75    fn check(&self) -> Option<String> {
76        let mut counts: HashMap<&str, usize> = HashMap::new();
77        for fp in &self.window {
78            *counts.entry(fp.as_str()).or_insert(0) += 1;
79        }
80        counts
81            .into_iter()
82            .find(|(_, n)| *n >= REPEAT_THRESHOLD)
83            .map(|(fp, _)| fp.to_string())
84    }
85}
86
87/// Stable fingerprint: tool name + first 200 chars of args.
88fn fingerprint(name: &str, args: &str) -> String {
89    let prefix = &args[..args.len().min(200)];
90    format!("{name}:{prefix}")
91}
92
93/// Tools that can cause destructive/mutating loops if repeated blindly.
94/// Read-only tools (Read, List, Grep) are excluded to allow safe exploration.
95fn is_mutating_tool(name: &str) -> bool {
96    matches!(
97        name,
98        "Bash" | "Edit" | "Write" | "Delete" | "MemoryWrite" | "CreateAgent" | "InvokeAgent"
99    )
100}
101
102// ── Hard-cap prompt ───────────────────────────────────────────────
103
104/// Prompt the user when the hard iteration cap is hit.
105///
106/// Options for continuing after hitting the hard cap.
107#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
108#[serde(rename_all = "snake_case")]
109pub enum LoopContinuation {
110    Stop,
111    Continue50,
112    Continue200,
113}
114
115impl LoopContinuation {
116    /// Number of additional iterations granted.
117    pub fn extra_iterations(self) -> u32 {
118        match self {
119            Self::Stop => 0,
120            Self::Continue50 => 50,
121            Self::Continue200 => 200,
122        }
123    }
124}
125
126/// Returns the number of additional iterations granted (0 = stop).
127///
128/// The `prompt_fn` callback is responsible for asking the user (terminal,
129/// server, or headless). It receives `(cap, recent_tool_names)` and returns
130/// the user's choice.
131pub fn ask_continue_or_stop(
132    cap: u32,
133    recent_names: &[String],
134    prompt_fn: &dyn Fn(u32, &[String]) -> LoopContinuation,
135) -> u32 {
136    prompt_fn(cap, recent_names).extra_iterations()
137}
138
139// ── Tests ─────────────────────────────────────────────────────────
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    fn call(name: &str, args: &str) -> ToolCall {
146        ToolCall {
147            id: "x".into(),
148            function_name: name.into(),
149            arguments: args.into(),
150            thought_signature: None,
151        }
152    }
153
154    #[test]
155    fn no_loop_on_unique_calls() {
156        let mut d = LoopDetector::new();
157        assert!(d.record(&[call("Edit", "{\"path\":\"a.rs\"}")]).is_none());
158        assert!(d.record(&[call("Edit", "{\"path\":\"b.rs\"}")]).is_none());
159        assert!(d.record(&[call("Bash", "{\"cmd\":\"ls\"}")]).is_none());
160    }
161
162    #[test]
163    fn detects_repeated_identical_call() {
164        let mut d = LoopDetector::new();
165        let tc = call("Edit", "{\"path\":\"src/main.rs\"}");
166        assert!(d.record(std::slice::from_ref(&tc)).is_none());
167        assert!(d.record(std::slice::from_ref(&tc)).is_none());
168        // Third repetition should trigger
169        assert!(d.record(std::slice::from_ref(&tc)).is_some());
170    }
171
172    #[test]
173    fn different_args_not_a_loop() {
174        let mut d = LoopDetector::new();
175        for i in 0..10 {
176            let args = format!("{{\"path\":\"file{i}.rs\"}}");
177            assert!(d.record(&[call("Edit", &args)]).is_none());
178        }
179    }
180
181    #[test]
182    fn ignores_readonly_tools() {
183        let mut d = LoopDetector::new();
184        let tc = call("Read", "{\"path\":\"src/main.rs\"}");
185        assert!(d.record(std::slice::from_ref(&tc)).is_none());
186        assert!(d.record(std::slice::from_ref(&tc)).is_none());
187        assert!(d.record(std::slice::from_ref(&tc)).is_none());
188        assert!(d.record(std::slice::from_ref(&tc)).is_none());
189        // Even 4 repetitions shouldn't trigger because Read is ignored
190        assert!(d.check().is_none());
191    }
192
193    #[test]
194    fn recent_names_tracks_last_five() {
195        let mut d = LoopDetector::new();
196        for i in 0..8 {
197            let name = format!("Tool{i}");
198            d.record(&[call(&name, "{}")]);
199        }
200        let names = d.recent_names();
201        assert_eq!(names.len(), 5);
202        assert_eq!(names[0], "Tool3");
203        assert_eq!(names[4], "Tool7");
204    }
205
206    #[test]
207    fn fingerprint_truncates_long_args() {
208        let long_args = "x".repeat(500);
209        let fp = fingerprint("Bash", &long_args);
210        // name + ":" + 200 chars
211        assert_eq!(fp.len(), "Bash:".len() + 200);
212    }
213}