ares/agents/
loop_detector.rs1use std::collections::VecDeque;
8use std::collections::hash_map::DefaultHasher;
9use std::hash::{Hash, Hasher};
10
11#[derive(Clone, Debug)]
13pub struct LoopDetectorConfig {
14 pub window_size: usize,
16 pub repeat_threshold: usize,
18 pub min_output_len: usize,
20}
21
22impl Default for LoopDetectorConfig {
23 fn default() -> Self {
24 Self {
25 window_size: 10,
26 repeat_threshold: 3,
27 min_output_len: 20,
28 }
29 }
30}
31
32#[derive(Clone, Debug)]
34pub struct LoopDetector {
35 config: LoopDetectorConfig,
36 output_hashes: VecDeque<u64>,
38 total_outputs: usize,
40 loops_detected: usize,
42}
43
44#[derive(Clone, Debug, PartialEq)]
46pub enum LoopStatus {
47 Ok,
49 LoopDetected {
51 repeats: usize,
53 action: LoopAction,
55 },
56}
57
58#[derive(Clone, Debug, PartialEq)]
60pub enum LoopAction {
61 InjectWarning,
63 ForceAlternative,
65 HaltAgent,
67}
68
69impl LoopDetector {
70 pub fn new() -> Self {
72 Self::with_config(LoopDetectorConfig::default())
73 }
74
75 pub fn with_config(config: LoopDetectorConfig) -> Self {
77 Self {
78 output_hashes: VecDeque::with_capacity(config.window_size),
79 config,
80 total_outputs: 0,
81 loops_detected: 0,
82 }
83 }
84
85 fn hash_output(output: &str) -> u64 {
87 let normalized = output
88 .chars()
89 .take(500)
90 .filter(|c| !c.is_whitespace())
91 .collect::<String>()
92 .to_lowercase();
93 let mut hasher = DefaultHasher::new();
94 normalized.hash(&mut hasher);
95 hasher.finish()
96 }
97
98 pub fn check(&mut self, output: &str) -> LoopStatus {
100 self.total_outputs += 1;
101
102 if output.len() < self.config.min_output_len {
104 return LoopStatus::Ok;
105 }
106
107 let hash = Self::hash_output(output);
108
109 let repeats = self.output_hashes.iter().filter(|&&h| h == hash).count();
111
112 if self.output_hashes.len() >= self.config.window_size {
114 self.output_hashes.pop_front();
115 }
116 self.output_hashes.push_back(hash);
117
118 if repeats >= self.config.repeat_threshold {
119 self.loops_detected += 1;
120 let action = if repeats >= self.config.repeat_threshold * 2 {
121 LoopAction::HaltAgent
122 } else if repeats >= self.config.repeat_threshold + 1 {
123 LoopAction::ForceAlternative
124 } else {
125 LoopAction::InjectWarning
126 };
127 LoopStatus::LoopDetected { repeats, action }
128 } else {
129 LoopStatus::Ok
130 }
131 }
132
133 pub fn reset(&mut self) {
135 self.output_hashes.clear();
136 self.total_outputs = 0;
137 }
138
139 pub fn stats(&self) -> (usize, usize) {
141 (self.total_outputs, self.loops_detected)
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_no_loop() {
151 let mut detector = LoopDetector::new();
152 assert_eq!(detector.check("Hello, how can I help?"), LoopStatus::Ok);
153 assert_eq!(detector.check("I can assist with that."), LoopStatus::Ok);
154 assert_eq!(detector.check("Here's what I found."), LoopStatus::Ok);
155 }
156
157 #[test]
158 fn test_loop_detected() {
159 let mut detector = LoopDetector::new();
160 let repeated = "I'm sorry, I cannot help with that request at this time.";
161 assert_eq!(detector.check(repeated), LoopStatus::Ok);
162 assert_eq!(detector.check(repeated), LoopStatus::Ok);
163 assert_eq!(detector.check(repeated), LoopStatus::Ok);
164 match detector.check(repeated) {
166 LoopStatus::LoopDetected { repeats, action } => {
167 assert!(repeats >= 3);
168 assert_eq!(action, LoopAction::InjectWarning);
169 }
170 _ => panic!("should detect loop"),
171 }
172 }
173
174 #[test]
175 fn test_short_output_ignored() {
176 let mut detector = LoopDetector::new();
177 assert_eq!(detector.check("ok"), LoopStatus::Ok);
179 assert_eq!(detector.check("ok"), LoopStatus::Ok);
180 assert_eq!(detector.check("ok"), LoopStatus::Ok);
181 assert_eq!(detector.check("ok"), LoopStatus::Ok);
182 }
183
184 #[test]
185 fn test_escalation() {
186 let mut detector = LoopDetector::with_config(LoopDetectorConfig {
187 window_size: 20,
188 repeat_threshold: 2,
189 min_output_len: 10,
190 });
191 let repeated = "This is a repeated response that keeps coming back.";
192 detector.check(repeated); detector.check(repeated); match detector.check(repeated) {
196 LoopStatus::LoopDetected { action, .. } => assert_eq!(action, LoopAction::InjectWarning),
197 _ => panic!("should warn"),
198 }
199 match detector.check(repeated) {
201 LoopStatus::LoopDetected { action, .. } => assert_eq!(action, LoopAction::ForceAlternative),
202 _ => panic!("should force alternative"),
203 }
204 }
205
206 #[test]
207 fn test_reset() {
208 let mut detector = LoopDetector::new();
209 let repeated = "A repeated output that should trigger detection.";
210 detector.check(repeated);
211 detector.check(repeated);
212 detector.check(repeated);
213 detector.reset();
214 assert_eq!(detector.check(repeated), LoopStatus::Ok);
216 }
217
218 #[test]
219 fn test_stats() {
220 let mut detector = LoopDetector::new();
221 detector.check("First unique output here and now.");
222 detector.check("Second unique output here and now.");
223 let (total, loops) = detector.stats();
224 assert_eq!(total, 2);
225 assert_eq!(loops, 0);
226 }
227
228 #[test]
229 fn test_whitespace_normalization() {
230 let mut detector = LoopDetector::with_config(LoopDetectorConfig {
231 repeat_threshold: 2,
232 ..Default::default()
233 });
234 detector.check("Hello world, how are you doing today?");
236 detector.check("Hello world, how are you doing today?");
237 match detector.check("Hello\n\tworld,\thow are you doing today?") {
238 LoopStatus::LoopDetected { .. } => {} _ => panic!("whitespace-normalized duplicates should match"),
240 }
241 }
242}