Skip to main content

aagt_core/agent/
history.rs

1//! Tool call history and similarity detection for loop prevention.
2
3use std::collections::{HashMap, HashSet};
4use tracing::debug;
5
6/// Represents a record of a tool call
7#[derive(Debug, Clone)]
8pub struct CallRecord {
9    pub tool_name: String,
10    pub input: String,
11}
12
13/// Tracks history of tool calls to detect repeating patterns
14#[derive(Debug, Default, Clone)]
15pub struct QueryHistory {
16    records: Vec<CallRecord>,
17    counts: HashMap<String, usize>,
18}
19
20impl QueryHistory {
21    pub fn new() -> Self {
22        Self::default()
23    }
24
25    /// Add a call to the history
26    pub fn record(&mut self, tool_name: String, input: String) {
27        let count = self.counts.entry(tool_name.clone()).or_insert(0);
28        *count += 1;
29        self.records.push(CallRecord { tool_name, input });
30    }
31
32    /// Calculate Jaccard similarity between two strings
33    /// Based on word overlap
34    pub fn calculate_similarity(s1: &str, s2: &str) -> f64 {
35        let tokens1: HashSet<_> = s1
36            .split_whitespace()
37            .map(|s| s.to_lowercase().replace(|c: char| !c.is_alphanumeric(), ""))
38            .filter(|s| s.len() > 2)
39            .collect();
40
41        let tokens2: HashSet<_> = s2
42            .split_whitespace()
43            .map(|s| s.to_lowercase().replace(|c: char| !c.is_alphanumeric(), ""))
44            .filter(|s| s.len() > 2)
45            .collect();
46
47        if tokens1.is_empty() && tokens2.is_empty() {
48            return 1.0;
49        }
50
51        let intersection: HashSet<_> = tokens1.intersection(&tokens2).collect();
52        let union: HashSet<_> = tokens1.union(&tokens2).collect();
53
54        intersection.len() as f64 / union.len() as f64
55    }
56
57    /// Check if a tool call is too similar to previous calls of the same tool
58    /// Returns a message suggesting an alternative if similarity is high.
59    pub fn check_loop(&self, tool_name: &str, input: &str, threshold: f64) -> Option<String> {
60        for record in self.records.iter().rev() {
61            if record.tool_name == tool_name {
62                let similarity = Self::calculate_similarity(&record.input, input);
63                if similarity >= threshold {
64                    debug!(tool = %tool_name, similarity = %similarity, "Detected potential loop call");
65                    return Some(format!(
66                        "WARNING: This call to '{}' is {:.0}% similar to a previous call. \
67                        If the previous call did not yield the results you wanted, try a different approach, \
68                        different keywords, or a different tool instead of repeating the same action.",
69                        tool_name, similarity * 100.0
70                    ));
71                }
72            }
73        }
74        None
75    }
76
77    pub fn get_count(&self, tool_name: &str) -> usize {
78        *self.counts.get(tool_name).unwrap_or(&0)
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn test_similarity() {
88        let s1 = "Search for Apple stock price in 2024";
89        let s2 = "Search for Apple stock price 2024";
90        let sim = QueryHistory::calculate_similarity(s1, s2);
91        assert!(sim > 0.8);
92
93        let s3 = "Get latest news about Tesla";
94        let sim2 = QueryHistory::calculate_similarity(s1, s3);
95        assert!(sim2 < 0.2);
96    }
97
98    #[test]
99    fn test_loop_detection() {
100        let mut history = QueryHistory::new();
101        history.record("search".to_string(), "Apple stock".to_string());
102
103        let result = history.check_loop("search", "Apple stock price", 0.6);
104        assert!(result.is_some());
105
106        let result2 = history.check_loop("search", "Tesla news", 0.6);
107        assert!(result2.is_none());
108    }
109}