1use std::collections::{HashMap, HashSet, VecDeque};
15
16use sha2::{Digest, Sha256};
17use tracing::{debug, info, warn};
18
19use punch_types::ToolCall;
20
21const DEFAULT_WARN_THRESHOLD: usize = 3;
23const DEFAULT_BLOCK_THRESHOLD: usize = 5;
24const DEFAULT_CIRCUIT_BREAKER_THRESHOLD: usize = 30;
25
26const RECENT_CALLS_BUFFER_SIZE: usize = 30;
28
29const POLL_TOOL_THRESHOLD_MULTIPLIER: usize = 3;
31
32const POLL_TOOLS: &[&str] = &["shell_exec"];
34
35#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum GuardVerdict {
38 Allow,
40 Warn(String),
42 Block(String),
45 CircuitBreak(String),
47}
48
49impl GuardVerdict {
50 pub fn is_allowed(&self) -> bool {
52 matches!(self, GuardVerdict::Allow | GuardVerdict::Warn(_))
53 }
54
55 pub fn is_circuit_break(&self) -> bool {
57 matches!(self, GuardVerdict::CircuitBreak(_))
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct GuardConfig {
64 pub max_iterations: usize,
66 pub warn_threshold: usize,
68 pub block_threshold: usize,
70 pub circuit_breaker_threshold: usize,
72}
73
74impl Default for GuardConfig {
75 fn default() -> Self {
76 Self {
77 max_iterations: 50,
78 warn_threshold: DEFAULT_WARN_THRESHOLD,
79 block_threshold: DEFAULT_BLOCK_THRESHOLD,
80 circuit_breaker_threshold: DEFAULT_CIRCUIT_BREAKER_THRESHOLD,
81 }
82 }
83}
84
85#[derive(Debug)]
91pub struct LoopGuard {
92 config: GuardConfig,
93 current_iteration: usize,
95 call_counts: HashMap<String, usize>,
97 outcome_counts: HashMap<String, usize>,
99 blocked_outcomes: HashSet<String>,
101 recent_calls: VecDeque<String>,
103}
104
105impl LoopGuard {
106 pub fn new(max_iterations: usize, _repetition_threshold: usize) -> Self {
108 Self::with_config(GuardConfig {
109 max_iterations,
110 ..Default::default()
111 })
112 }
113
114 pub fn with_config(config: GuardConfig) -> Self {
116 Self {
117 config,
118 current_iteration: 0,
119 call_counts: HashMap::new(),
120 outcome_counts: HashMap::new(),
121 blocked_outcomes: HashSet::new(),
122 recent_calls: VecDeque::with_capacity(RECENT_CALLS_BUFFER_SIZE),
123 }
124 }
125
126 pub fn record_tool_calls(&mut self, tool_calls: &[ToolCall]) -> LoopGuardVerdict {
131 self.current_iteration += 1;
132
133 if self.current_iteration >= self.config.max_iterations {
135 return LoopGuardVerdict::Break(format!(
136 "maximum iterations reached ({}/{})",
137 self.current_iteration, self.config.max_iterations
138 ));
139 }
140
141 if self.current_iteration >= self.config.circuit_breaker_threshold {
143 return LoopGuardVerdict::Break(format!(
144 "circuit breaker triggered after {} iterations",
145 self.current_iteration
146 ));
147 }
148
149 for tc in tool_calls {
151 let call_hash = hash_call(tc);
152
153 if self.recent_calls.len() >= RECENT_CALLS_BUFFER_SIZE {
155 self.recent_calls.pop_front();
156 }
157 self.recent_calls.push_back(call_hash.clone());
158
159 let (warn_t, block_t) = self.effective_thresholds(&tc.name);
162
163 let count = self.call_counts.entry(call_hash.clone()).or_insert(0);
165 *count += 1;
166 let current_count = *count;
167
168 if current_count >= block_t {
169 return LoopGuardVerdict::Break(format!(
170 "tool '{}' blocked: {} identical calls (threshold: {})",
171 tc.name, current_count, block_t
172 ));
173 }
174
175 if current_count >= warn_t {
176 warn!(
177 tool = %tc.name,
178 count = current_count,
179 threshold = warn_t,
180 "repetitive tool call detected"
181 );
182 }
183 }
184
185 if let Some(reason) = self.detect_ping_pong() {
187 return LoopGuardVerdict::Break(reason);
188 }
189
190 LoopGuardVerdict::Continue
191 }
192
193 pub fn evaluate_call(&mut self, tool_call: &ToolCall) -> GuardVerdict {
198 let call_hash = hash_call(tool_call);
199
200 if self.recent_calls.len() >= RECENT_CALLS_BUFFER_SIZE {
202 self.recent_calls.pop_front();
203 }
204 self.recent_calls.push_back(call_hash.clone());
205
206 let (warn_t, block_t) = self.effective_thresholds(&tool_call.name);
208
209 let count = self.call_counts.entry(call_hash).or_insert(0);
211 *count += 1;
212 let current_count = *count;
213
214 if current_count >= block_t {
215 GuardVerdict::Block(format!(
216 "tool '{}' blocked after {} identical calls",
217 tool_call.name, current_count
218 ))
219 } else if current_count >= warn_t {
220 GuardVerdict::Warn(format!(
221 "tool '{}' called {} times with identical params (warn threshold: {})",
222 tool_call.name, current_count, warn_t
223 ))
224 } else {
225 GuardVerdict::Allow
226 }
227 }
228
229 pub fn record_outcome(&mut self, tool_call: &ToolCall, result: &str) {
234 let outcome_hash = hash_outcome(tool_call, result);
235 let count = self.outcome_counts.entry(outcome_hash.clone()).or_insert(0);
236 *count += 1;
237
238 if *count >= 2 {
239 debug!(
240 tool = %tool_call.name,
241 outcome_count = *count,
242 "auto-blocking repeated identical outcome"
243 );
244 self.blocked_outcomes.insert(outcome_hash);
245 }
246 }
247
248 pub fn is_outcome_blocked(&self, tool_call: &ToolCall, result: &str) -> bool {
250 let outcome_hash = hash_outcome(tool_call, result);
251 self.blocked_outcomes.contains(&outcome_hash)
252 }
253
254 pub fn record_iteration(&mut self) -> LoopGuardVerdict {
256 self.current_iteration += 1;
257
258 if self.current_iteration >= self.config.max_iterations {
259 return LoopGuardVerdict::Break(format!(
260 "maximum iterations reached ({}/{})",
261 self.current_iteration, self.config.max_iterations
262 ));
263 }
264
265 LoopGuardVerdict::Continue
266 }
267
268 pub fn iterations(&self) -> usize {
270 self.current_iteration
271 }
272
273 fn effective_thresholds(&self, tool_name: &str) -> (usize, usize) {
275 if POLL_TOOLS.contains(&tool_name) {
276 (
277 self.config.warn_threshold * POLL_TOOL_THRESHOLD_MULTIPLIER,
278 self.config.block_threshold * POLL_TOOL_THRESHOLD_MULTIPLIER,
279 )
280 } else {
281 (self.config.warn_threshold, self.config.block_threshold)
282 }
283 }
284
285 fn detect_ping_pong(&self) -> Option<String> {
289 let len = self.recent_calls.len();
290 if len < 4 {
291 return None;
292 }
293
294 let calls: Vec<&String> = self.recent_calls.iter().collect();
297
298 let check_len = len.min(6);
300 if check_len >= 4 {
301 let tail = &calls[len - check_len..];
302 let mut is_ping_pong = true;
303
304 for i in 2..tail.len() {
305 if tail[i] != tail[i - 2] {
306 is_ping_pong = false;
307 break;
308 }
309 }
310
311 if is_ping_pong && tail.len() >= 4 && tail[0] != tail[1] {
312 info!(
313 pattern_length = tail.len(),
314 "ping-pong pattern detected in tool calls"
315 );
316 return Some(format!(
317 "ping-pong pattern detected: alternating tool calls over {} iterations",
318 tail.len()
319 ));
320 }
321 }
322
323 None
324 }
325}
326
327#[derive(Debug, Clone, PartialEq, Eq)]
329pub enum LoopGuardVerdict {
330 Continue,
332 Break(String),
334}
335
336fn hash_call(tool_call: &ToolCall) -> String {
338 let mut hasher = Sha256::new();
339 hasher.update(tool_call.name.as_bytes());
340 hasher.update(b"|");
341 hasher.update(tool_call.input.to_string().as_bytes());
342 let result = hasher.finalize();
343 base64::Engine::encode(&base64::engine::general_purpose::STANDARD, result)
344}
345
346fn hash_outcome(tool_call: &ToolCall, result: &str) -> String {
348 let mut hasher = Sha256::new();
349 hasher.update(tool_call.name.as_bytes());
350 hasher.update(b"|");
351 hasher.update(tool_call.input.to_string().as_bytes());
352 hasher.update(b"|");
353 hasher.update(result.as_bytes());
354 let result_hash = hasher.finalize();
355 base64::Engine::encode(&base64::engine::general_purpose::STANDARD, result_hash)
356}
357
358#[cfg(test)]
363mod tests {
364 use super::*;
365 use punch_types::ToolCall;
366
367 fn make_tool_call(name: &str, input: serde_json::Value) -> ToolCall {
368 ToolCall {
369 id: format!("call_{name}"),
370 name: name.to_string(),
371 input,
372 }
373 }
374
375 #[test]
376 fn test_max_iterations_enforcement() {
377 let mut guard = LoopGuard::new(3, 5);
378
379 assert_eq!(guard.record_iteration(), LoopGuardVerdict::Continue);
380 assert_eq!(guard.record_iteration(), LoopGuardVerdict::Continue);
381
382 match guard.record_iteration() {
383 LoopGuardVerdict::Break(reason) => {
384 assert!(reason.contains("maximum iterations"));
385 }
386 LoopGuardVerdict::Continue => panic!("should have broken"),
387 }
388 }
389
390 #[test]
391 fn test_repetitive_pattern_detection() {
392 let mut guard = LoopGuard::with_config(GuardConfig {
393 max_iterations: 50,
394 warn_threshold: 3,
395 block_threshold: 5,
396 circuit_breaker_threshold: 30,
397 });
398
399 let calls = vec![make_tool_call(
400 "file_read",
401 serde_json::json!({"path": "/tmp/foo.txt"}),
402 )];
403
404 assert_eq!(guard.record_tool_calls(&calls), LoopGuardVerdict::Continue);
406 assert_eq!(guard.record_tool_calls(&calls), LoopGuardVerdict::Continue);
407 assert_eq!(guard.record_tool_calls(&calls), LoopGuardVerdict::Continue);
408 assert_eq!(guard.record_tool_calls(&calls), LoopGuardVerdict::Continue);
409
410 match guard.record_tool_calls(&calls) {
412 LoopGuardVerdict::Break(reason) => {
413 assert!(reason.contains("blocked"));
414 }
415 LoopGuardVerdict::Continue => panic!("should have blocked"),
416 }
417 }
418
419 #[test]
420 fn test_different_calls_no_repetition() {
421 let mut guard = LoopGuard::new(50, 3);
422
423 let calls_a = vec![make_tool_call(
424 "file_read",
425 serde_json::json!({"path": "/a.txt"}),
426 )];
427 let calls_b = vec![make_tool_call(
428 "file_read",
429 serde_json::json!({"path": "/b.txt"}),
430 )];
431 let calls_c = vec![make_tool_call(
432 "file_write",
433 serde_json::json!({"path": "/c.txt", "content": "hi"}),
434 )];
435
436 assert_eq!(
437 guard.record_tool_calls(&calls_a),
438 LoopGuardVerdict::Continue
439 );
440 assert_eq!(
441 guard.record_tool_calls(&calls_b),
442 LoopGuardVerdict::Continue
443 );
444 assert_eq!(
445 guard.record_tool_calls(&calls_c),
446 LoopGuardVerdict::Continue
447 );
448 }
449
450 #[test]
451 fn test_iteration_counter() {
452 let mut guard = LoopGuard::new(100, 3);
453
454 let calls = vec![make_tool_call("test", serde_json::json!({}))];
455 guard.record_tool_calls(&calls);
456 guard.record_iteration();
457
458 assert_eq!(guard.iterations(), 2);
459 }
460
461 #[test]
462 fn test_graduated_response_warn() {
463 let mut guard = LoopGuard::with_config(GuardConfig {
464 max_iterations: 50,
465 warn_threshold: 2,
466 block_threshold: 5,
467 circuit_breaker_threshold: 30,
468 });
469
470 let tc = make_tool_call("file_read", serde_json::json!({"path": "/test"}));
471
472 assert_eq!(guard.evaluate_call(&tc), GuardVerdict::Allow);
473 match guard.evaluate_call(&tc) {
475 GuardVerdict::Warn(msg) => assert!(msg.contains("file_read")),
476 other => panic!("expected Warn, got {:?}", other),
477 }
478 }
479
480 #[test]
481 fn test_graduated_response_block() {
482 let mut guard = LoopGuard::with_config(GuardConfig {
483 max_iterations: 50,
484 warn_threshold: 2,
485 block_threshold: 3,
486 circuit_breaker_threshold: 30,
487 });
488
489 let tc = make_tool_call("file_read", serde_json::json!({"path": "/test"}));
490
491 guard.evaluate_call(&tc); guard.evaluate_call(&tc); match guard.evaluate_call(&tc) {
494 GuardVerdict::Block(msg) => assert!(msg.contains("blocked")),
495 other => panic!("expected Block, got {:?}", other),
496 }
497 }
498
499 #[test]
500 fn test_poll_tool_relaxed_thresholds() {
501 let mut guard = LoopGuard::with_config(GuardConfig {
502 max_iterations: 50,
503 warn_threshold: 3,
504 block_threshold: 5,
505 circuit_breaker_threshold: 50,
506 });
507
508 let tc = make_tool_call("shell_exec", serde_json::json!({"command": "ls"}));
509
510 for _ in 0..8 {
512 assert_eq!(guard.evaluate_call(&tc), GuardVerdict::Allow);
513 }
514 match guard.evaluate_call(&tc) {
516 GuardVerdict::Warn(_) => {}
517 other => panic!("expected Warn at 9, got {:?}", other),
518 }
519 }
520
521 #[test]
522 fn test_outcome_tracking() {
523 let mut guard = LoopGuard::new(50, 3);
524
525 let tc = make_tool_call("file_read", serde_json::json!({"path": "/test"}));
526
527 assert!(!guard.is_outcome_blocked(&tc, "file contents"));
528
529 guard.record_outcome(&tc, "file contents");
530 assert!(!guard.is_outcome_blocked(&tc, "file contents"));
531
532 guard.record_outcome(&tc, "file contents");
534 assert!(guard.is_outcome_blocked(&tc, "file contents"));
535
536 assert!(!guard.is_outcome_blocked(&tc, "different contents"));
538 }
539
540 #[test]
541 fn test_ping_pong_detection() {
542 let mut guard = LoopGuard::with_config(GuardConfig {
543 max_iterations: 50,
544 warn_threshold: 10,
545 block_threshold: 20,
546 circuit_breaker_threshold: 50,
547 });
548
549 let call_a = vec![make_tool_call(
550 "file_read",
551 serde_json::json!({"path": "/a"}),
552 )];
553 let call_b = vec![make_tool_call(
554 "file_read",
555 serde_json::json!({"path": "/b"}),
556 )];
557
558 assert_eq!(guard.record_tool_calls(&call_a), LoopGuardVerdict::Continue);
560 assert_eq!(guard.record_tool_calls(&call_b), LoopGuardVerdict::Continue);
561 assert_eq!(guard.record_tool_calls(&call_a), LoopGuardVerdict::Continue);
562
563 match guard.record_tool_calls(&call_b) {
565 LoopGuardVerdict::Break(reason) => {
566 assert!(reason.contains("ping-pong"));
567 }
568 LoopGuardVerdict::Continue => panic!("should have detected ping-pong"),
569 }
570 }
571
572 #[test]
573 fn test_circuit_breaker_threshold() {
574 let mut guard = LoopGuard::with_config(GuardConfig {
575 max_iterations: 100,
576 warn_threshold: 50,
577 block_threshold: 50,
578 circuit_breaker_threshold: 5,
579 });
580
581 for i in 0..4 {
583 let calls = vec![make_tool_call(
584 "file_read",
585 serde_json::json!({"path": format!("/file_{}", i)}),
586 )];
587 assert_eq!(guard.record_tool_calls(&calls), LoopGuardVerdict::Continue);
588 }
589
590 let calls = vec![make_tool_call(
592 "file_read",
593 serde_json::json!({"path": "/file_4"}),
594 )];
595 match guard.record_tool_calls(&calls) {
596 LoopGuardVerdict::Break(reason) => {
597 assert!(reason.contains("circuit breaker"));
598 }
599 LoopGuardVerdict::Continue => panic!("should have circuit broken"),
600 }
601 }
602
603 #[test]
604 fn test_guard_verdict_is_allowed() {
605 assert!(GuardVerdict::Allow.is_allowed());
606 assert!(GuardVerdict::Warn("test".into()).is_allowed());
607 assert!(!GuardVerdict::Block("test".into()).is_allowed());
608 assert!(!GuardVerdict::CircuitBreak("test".into()).is_allowed());
609 }
610
611 #[test]
612 fn test_guard_verdict_is_circuit_break() {
613 assert!(!GuardVerdict::Allow.is_circuit_break());
614 assert!(!GuardVerdict::Warn("test".into()).is_circuit_break());
615 assert!(!GuardVerdict::Block("test".into()).is_circuit_break());
616 assert!(GuardVerdict::CircuitBreak("test".into()).is_circuit_break());
617 }
618
619 #[test]
624 fn test_guard_config_default() {
625 let config = GuardConfig::default();
626 assert_eq!(config.max_iterations, 50);
627 assert_eq!(config.warn_threshold, 3);
628 assert_eq!(config.block_threshold, 5);
629 assert_eq!(config.circuit_breaker_threshold, 30);
630 }
631
632 #[test]
633 fn test_loop_guard_new() {
634 let guard = LoopGuard::new(10, 5);
635 assert_eq!(guard.iterations(), 0);
636 }
637
638 #[test]
639 fn test_loop_guard_iterations_incremented_by_tool_calls() {
640 let mut guard = LoopGuard::new(50, 5);
641 let calls = vec![make_tool_call("test", serde_json::json!({}))];
642 guard.record_tool_calls(&calls);
643 assert_eq!(guard.iterations(), 1);
644 }
645
646 #[test]
647 fn test_loop_guard_iterations_incremented_by_record_iteration() {
648 let mut guard = LoopGuard::new(50, 5);
649 guard.record_iteration();
650 guard.record_iteration();
651 guard.record_iteration();
652 assert_eq!(guard.iterations(), 3);
653 }
654
655 #[test]
656 fn test_evaluate_call_first_call_is_allow() {
657 let mut guard = LoopGuard::new(50, 5);
658 let tc = make_tool_call("file_read", serde_json::json!({"path": "/tmp/test"}));
659 assert_eq!(guard.evaluate_call(&tc), GuardVerdict::Allow);
660 }
661
662 #[test]
663 fn test_evaluate_call_different_params_no_warn() {
664 let mut guard = LoopGuard::with_config(GuardConfig {
665 max_iterations: 50,
666 warn_threshold: 2,
667 block_threshold: 5,
668 circuit_breaker_threshold: 50,
669 });
670
671 let tc1 = make_tool_call("file_read", serde_json::json!({"path": "/a"}));
672 let tc2 = make_tool_call("file_read", serde_json::json!({"path": "/b"}));
673
674 assert_eq!(guard.evaluate_call(&tc1), GuardVerdict::Allow);
675 assert_eq!(guard.evaluate_call(&tc2), GuardVerdict::Allow);
676 }
677
678 #[test]
679 fn test_hash_call_deterministic() {
680 let tc = make_tool_call("test", serde_json::json!({"key": "value"}));
681 let h1 = hash_call(&tc);
682 let h2 = hash_call(&tc);
683 assert_eq!(h1, h2);
684 }
685
686 #[test]
687 fn test_hash_call_different_for_different_inputs() {
688 let tc1 = make_tool_call("test", serde_json::json!({"key": "value1"}));
689 let tc2 = make_tool_call("test", serde_json::json!({"key": "value2"}));
690 assert_ne!(hash_call(&tc1), hash_call(&tc2));
691 }
692
693 #[test]
694 fn test_hash_outcome_deterministic() {
695 let tc = make_tool_call("test", serde_json::json!({}));
696 let h1 = hash_outcome(&tc, "result");
697 let h2 = hash_outcome(&tc, "result");
698 assert_eq!(h1, h2);
699 }
700
701 #[test]
702 fn test_hash_outcome_different_for_different_results() {
703 let tc = make_tool_call("test", serde_json::json!({}));
704 assert_ne!(hash_outcome(&tc, "result1"), hash_outcome(&tc, "result2"));
705 }
706
707 #[test]
708 fn test_outcome_blocked_only_after_two() {
709 let mut guard = LoopGuard::new(50, 5);
710 let tc = make_tool_call("test", serde_json::json!({}));
711
712 assert!(!guard.is_outcome_blocked(&tc, "result"));
713 guard.record_outcome(&tc, "result");
714 assert!(!guard.is_outcome_blocked(&tc, "result"));
715 guard.record_outcome(&tc, "result");
716 assert!(guard.is_outcome_blocked(&tc, "result"));
717 }
718
719 #[test]
720 fn test_no_ping_pong_with_three_calls() {
721 let mut guard = LoopGuard::with_config(GuardConfig {
722 max_iterations: 50,
723 warn_threshold: 10,
724 block_threshold: 20,
725 circuit_breaker_threshold: 50,
726 });
727
728 let call_a = vec![make_tool_call("a", serde_json::json!({}))];
729 let call_b = vec![make_tool_call("b", serde_json::json!({}))];
730
731 assert_eq!(guard.record_tool_calls(&call_a), LoopGuardVerdict::Continue);
733 assert_eq!(guard.record_tool_calls(&call_b), LoopGuardVerdict::Continue);
734 assert_eq!(guard.record_tool_calls(&call_a), LoopGuardVerdict::Continue);
735 }
736
737 #[test]
738 fn test_no_ping_pong_same_call_repeated() {
739 let mut guard = LoopGuard::with_config(GuardConfig {
740 max_iterations: 50,
741 warn_threshold: 10,
742 block_threshold: 20,
743 circuit_breaker_threshold: 50,
744 });
745
746 let call_a = vec![make_tool_call("a", serde_json::json!({}))];
747
748 for _ in 0..4 {
750 guard.record_tool_calls(&call_a);
751 }
752 }
754
755 #[test]
756 fn test_loop_guard_verdict_continue_equality() {
757 assert_eq!(LoopGuardVerdict::Continue, LoopGuardVerdict::Continue);
758 }
759
760 #[test]
761 fn test_loop_guard_verdict_break_equality() {
762 assert_eq!(
763 LoopGuardVerdict::Break("reason".into()),
764 LoopGuardVerdict::Break("reason".into())
765 );
766 assert_ne!(
767 LoopGuardVerdict::Break("reason1".into()),
768 LoopGuardVerdict::Break("reason2".into())
769 );
770 }
771
772 #[test]
773 fn test_effective_thresholds_normal_tool() {
774 let guard = LoopGuard::with_config(GuardConfig {
775 max_iterations: 50,
776 warn_threshold: 3,
777 block_threshold: 5,
778 circuit_breaker_threshold: 30,
779 });
780 let (warn, block) = guard.effective_thresholds("file_read");
781 assert_eq!(warn, 3);
782 assert_eq!(block, 5);
783 }
784
785 #[test]
786 fn test_effective_thresholds_poll_tool() {
787 let guard = LoopGuard::with_config(GuardConfig {
788 max_iterations: 50,
789 warn_threshold: 3,
790 block_threshold: 5,
791 circuit_breaker_threshold: 30,
792 });
793 let (warn, block) = guard.effective_thresholds("shell_exec");
794 assert_eq!(warn, 9); assert_eq!(block, 15); }
797}