Skip to main content

koda_core/
loop_guard.rs

1//! Loop detection and hard-cap user prompt for the inference loop.
2//!
3//! Tracks recent tool call fingerprints in a sliding window and flags
4//! when the same tool+args combination repeats too many times.
5//! When the hard iteration cap is hit, prompts the user interactively
6//! to continue or stop — falling back to stop in headless environments.
7
8use crate::providers::ToolCall;
9use std::collections::{HashMap, VecDeque};
10
11/// Default hard cap for the main inference loop.
12pub const MAX_ITERATIONS_DEFAULT: u32 = 200;
13
14/// Hard cap for sub-agent loops.
15pub const MAX_SUB_AGENT_ITERATIONS: usize = 20;
16
17/// How many times the same fingerprint must appear to flag a loop.
18const REPEAT_THRESHOLD: usize = 3;
19
20/// Sliding window size (individual tool calls, not batches).
21const WINDOW_SIZE: usize = 20;
22
23/// How many recent tool names to show in the hard-cap prompt.
24const DISPLAY_RECENT: usize = 5;
25
26// ── Loop detection ────────────────────────────────────────────────
27
28/// Tracks repeated tool call patterns.
29#[derive(Default)]
30pub struct LoopDetector {
31    /// Sliding window of recent tool fingerprints.
32    window: VecDeque<String>,
33    /// Ring buffer of the last N tool names (for display only).
34    recent: VecDeque<String>,
35}
36
37impl LoopDetector {
38    pub fn new() -> Self {
39        Self {
40            window: VecDeque::new(),
41            recent: VecDeque::new(),
42        }
43    }
44
45    /// Record a batch of tool calls.
46    /// Returns `Some(repeated_fingerprint)` when a loop is detected.
47    pub fn record(&mut self, tool_calls: &[ToolCall]) -> Option<String> {
48        for tc in tool_calls {
49            let fp = fingerprint(&tc.function_name, &tc.arguments);
50
51            // Sliding window for loop detection ONLY tracks mutating tools.
52            // Repeating read-only operations is handled by stale-read optimization.
53            if crate::tools::is_mutating_tool(&tc.function_name) {
54                self.window.push_back(fp);
55                if self.window.len() > WINDOW_SIZE {
56                    self.window.pop_front();
57                }
58            }
59
60            // Ring buffer for display always tracks all tools
61            self.recent.push_back(tc.function_name.clone());
62            if self.recent.len() > DISPLAY_RECENT {
63                self.recent.pop_front();
64            }
65        }
66
67        self.check()
68    }
69
70    /// Recent tool names (most recent last), for display in the hard-cap prompt.
71    pub fn recent_names(&self) -> Vec<String> {
72        self.recent.iter().cloned().collect()
73    }
74
75    fn check(&self) -> Option<String> {
76        let mut counts: HashMap<&str, usize> = HashMap::new();
77        for fp in &self.window {
78            *counts.entry(fp.as_str()).or_insert(0) += 1;
79        }
80        counts
81            .into_iter()
82            .find(|(_, n)| *n >= REPEAT_THRESHOLD)
83            .map(|(fp, _)| fp.to_string())
84    }
85}
86
87/// Stable fingerprint: tool name + first 200 chars of args.
88fn fingerprint(name: &str, args: &str) -> String {
89    let prefix = &args[..args.len().min(200)];
90    format!("{name}:{prefix}")
91}
92
93// ── Hard-cap prompt ───────────────────────────────────────────────
94
95/// Prompt the user when the hard iteration cap is hit.
96///
97/// Options for continuing after hitting the hard cap.
98#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
99#[serde(rename_all = "snake_case")]
100pub enum LoopContinuation {
101    Stop,
102    Continue50,
103    Continue200,
104}
105
106impl LoopContinuation {
107    /// Number of additional iterations granted.
108    pub fn extra_iterations(self) -> u32 {
109        match self {
110            Self::Stop => 0,
111            Self::Continue50 => 50,
112            Self::Continue200 => 200,
113        }
114    }
115}
116
117/// Returns the number of additional iterations granted (0 = stop).
118///
119/// The `prompt_fn` callback is responsible for asking the user (terminal,
120/// server, or headless). It receives `(cap, recent_tool_names)` and returns
121/// the user's choice.
122pub fn ask_continue_or_stop(
123    cap: u32,
124    recent_names: &[String],
125    prompt_fn: &dyn Fn(u32, &[String]) -> LoopContinuation,
126) -> u32 {
127    prompt_fn(cap, recent_names).extra_iterations()
128}
129
130// ── Tests ─────────────────────────────────────────────────────────
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    fn call(name: &str, args: &str) -> ToolCall {
137        ToolCall {
138            id: "x".into(),
139            function_name: name.into(),
140            arguments: args.into(),
141            thought_signature: None,
142        }
143    }
144
145    #[test]
146    fn no_loop_on_unique_calls() {
147        let mut d = LoopDetector::new();
148        assert!(d.record(&[call("Edit", "{\"path\":\"a.rs\"}")]).is_none());
149        assert!(d.record(&[call("Edit", "{\"path\":\"b.rs\"}")]).is_none());
150        assert!(d.record(&[call("Bash", "{\"cmd\":\"ls\"}")]).is_none());
151    }
152
153    #[test]
154    fn detects_repeated_identical_call() {
155        let mut d = LoopDetector::new();
156        let tc = call("Edit", "{\"path\":\"src/main.rs\"}");
157        assert!(d.record(std::slice::from_ref(&tc)).is_none());
158        assert!(d.record(std::slice::from_ref(&tc)).is_none());
159        // Third repetition should trigger
160        assert!(d.record(std::slice::from_ref(&tc)).is_some());
161    }
162
163    #[test]
164    fn different_args_not_a_loop() {
165        let mut d = LoopDetector::new();
166        for i in 0..10 {
167            let args = format!("{{\"path\":\"file{i}.rs\"}}");
168            assert!(d.record(&[call("Edit", &args)]).is_none());
169        }
170    }
171
172    #[test]
173    fn ignores_readonly_tools() {
174        let mut d = LoopDetector::new();
175        let tc = call("Read", "{\"path\":\"src/main.rs\"}");
176        assert!(d.record(std::slice::from_ref(&tc)).is_none());
177        assert!(d.record(std::slice::from_ref(&tc)).is_none());
178        assert!(d.record(std::slice::from_ref(&tc)).is_none());
179        assert!(d.record(std::slice::from_ref(&tc)).is_none());
180        // Even 4 repetitions shouldn't trigger because Read is ignored
181        assert!(d.check().is_none());
182    }
183
184    #[test]
185    fn recent_names_tracks_last_five() {
186        let mut d = LoopDetector::new();
187        for i in 0..8 {
188            let name = format!("Tool{i}");
189            d.record(&[call(&name, "{}")]);
190        }
191        let names = d.recent_names();
192        assert_eq!(names.len(), 5);
193        assert_eq!(names[0], "Tool3");
194        assert_eq!(names[4], "Tool7");
195    }
196
197    #[test]
198    fn fingerprint_truncates_long_args() {
199        let long_args = "x".repeat(500);
200        let fp = fingerprint("Bash", &long_args);
201        // name + ":" + 200 chars
202        assert_eq!(fp.len(), "Bash:".len() + 200);
203    }
204}