use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EscalationTrigger {
RepeatedFailure { tool: String, consecutive: u32 },
HallucinatedTool { tool: String },
LoopDetected { tool: String, occurrences: u32 },
}
impl std::fmt::Display for EscalationTrigger {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::RepeatedFailure { tool, consecutive } => {
write!(f, "repeated_failure(tool={}, n={})", tool, consecutive)
}
Self::HallucinatedTool { tool } => write!(f, "hallucinated_tool({})", tool),
Self::LoopDetected { tool, occurrences } => {
write!(f, "loop_detected(tool={}, n={})", tool, occurrences)
}
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EscalationAction {
Stop,
InjectWarning,
HumanReviewEvent,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EscalationConfig {
pub repeated_failure_threshold: u32,
pub loop_detection_window: usize,
pub loop_occurrence_threshold: u32,
pub action: EscalationAction,
}
impl Default for EscalationConfig {
fn default() -> Self {
Self {
repeated_failure_threshold: 3,
loop_detection_window: 10,
loop_occurrence_threshold: 3,
action: EscalationAction::InjectWarning,
}
}
}
pub struct EscalationDetector {
config: EscalationConfig,
consecutive_failures: HashMap<String, u32>,
recent_calls: Vec<(String, String)>,
}
impl EscalationDetector {
pub fn new(config: EscalationConfig) -> Self {
Self {
config,
consecutive_failures: HashMap::new(),
recent_calls: Vec::new(),
}
}
pub fn config(&self) -> &EscalationConfig {
&self.config
}
pub fn record_tool_call(
&mut self,
tool_name: &str,
success: bool,
known_tools: &[&str],
arg_signature: &str,
) -> Option<EscalationTrigger> {
if !known_tools.contains(&tool_name) {
return Some(EscalationTrigger::HallucinatedTool {
tool: tool_name.to_string(),
});
}
let failures = self
.consecutive_failures
.entry(tool_name.to_string())
.or_insert(0);
if success {
*failures = 0;
} else {
*failures += 1;
if *failures >= self.config.repeated_failure_threshold {
let n = *failures;
return Some(EscalationTrigger::RepeatedFailure {
tool: tool_name.to_string(),
consecutive: n,
});
}
}
let key = (tool_name.to_string(), arg_signature.to_string());
self.recent_calls.push(key.clone());
if self.recent_calls.len() > self.config.loop_detection_window {
self.recent_calls.remove(0);
}
let occurrences = self.recent_calls.iter().filter(|e| *e == &key).count() as u32;
if occurrences >= self.config.loop_occurrence_threshold {
return Some(EscalationTrigger::LoopDetected {
tool: tool_name.to_string(),
occurrences,
});
}
None
}
pub fn reset(&mut self) {
self.consecutive_failures.clear();
self.recent_calls.clear();
}
pub fn warning_message(trigger: &EscalationTrigger) -> String {
match trigger {
EscalationTrigger::RepeatedFailure { tool, consecutive } => format!(
"[SYSTEM WARNING] Tool '{}' has failed {} consecutive times. \
Consider an alternative approach or stopping.",
tool, consecutive
),
EscalationTrigger::HallucinatedTool { tool } => format!(
"[SYSTEM WARNING] Tool '{}' does not exist. \
Please use only the tools that are available to you.",
tool
),
EscalationTrigger::LoopDetected { tool, occurrences } => format!(
"[SYSTEM WARNING] Possible circular reasoning detected: \
tool '{}' has been called with identical arguments {} times recently. \
Try a different approach.",
tool, occurrences
),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const TOOLS: &[&str] = &["search", "calculator", "file_read"];
#[test]
fn test_no_trigger_on_normal_calls() {
let mut d = EscalationDetector::new(EscalationConfig::default());
assert!(d
.record_tool_call("search", true, TOOLS, r#"{"q":"hello"}"#)
.is_none());
assert!(d
.record_tool_call("calculator", true, TOOLS, r#"{"expr":"1+1"}"#)
.is_none());
}
#[test]
fn test_hallucinated_tool() {
let mut d = EscalationDetector::new(EscalationConfig::default());
let trigger = d.record_tool_call("nonexistent_tool", true, TOOLS, "{}");
assert!(matches!(
trigger,
Some(EscalationTrigger::HallucinatedTool { .. })
));
}
#[test]
fn test_repeated_failure_trigger() {
let config = EscalationConfig {
repeated_failure_threshold: 2,
..EscalationConfig::default()
};
let mut d = EscalationDetector::new(config);
assert!(d.record_tool_call("search", false, TOOLS, "{}").is_none()); let trigger = d.record_tool_call("search", false, TOOLS, "{}"); assert!(matches!(
trigger,
Some(EscalationTrigger::RepeatedFailure { consecutive: 2, .. })
));
}
#[test]
fn test_success_resets_failure_count() {
let config = EscalationConfig {
repeated_failure_threshold: 2,
..EscalationConfig::default()
};
let mut d = EscalationDetector::new(config);
d.record_tool_call("search", false, TOOLS, "{}");
d.record_tool_call("search", true, TOOLS, "{}"); d.record_tool_call("search", false, TOOLS, "{}"); let trigger = d.record_tool_call("search", false, TOOLS, "{}"); assert!(matches!(
trigger,
Some(EscalationTrigger::RepeatedFailure { consecutive: 2, .. })
));
}
#[test]
fn test_loop_detection() {
let config = EscalationConfig {
loop_occurrence_threshold: 2,
loop_detection_window: 5,
..EscalationConfig::default()
};
let mut d = EscalationDetector::new(config);
let args = r#"{"q":"same"}"#;
assert!(d.record_tool_call("search", true, TOOLS, args).is_none()); let trigger = d.record_tool_call("search", true, TOOLS, args); assert!(matches!(
trigger,
Some(EscalationTrigger::LoopDetected { occurrences: 2, .. })
));
}
#[test]
fn test_different_args_no_loop() {
let config = EscalationConfig {
loop_occurrence_threshold: 2,
..EscalationConfig::default()
};
let mut d = EscalationDetector::new(config);
d.record_tool_call("search", true, TOOLS, r#"{"q":"a"}"#);
let trigger = d.record_tool_call("search", true, TOOLS, r#"{"q":"b"}"#);
assert!(trigger.is_none()); }
#[test]
fn test_warning_messages_non_empty() {
let triggers = vec![
EscalationTrigger::HallucinatedTool { tool: "x".into() },
EscalationTrigger::RepeatedFailure {
tool: "x".into(),
consecutive: 3,
},
EscalationTrigger::LoopDetected {
tool: "x".into(),
occurrences: 3,
},
];
for t in &triggers {
assert!(!EscalationDetector::warning_message(t).is_empty());
}
}
}