aagt_core/agent/
history.rs1use std::collections::{HashMap, HashSet};
4use tracing::debug;
5
6#[derive(Debug, Clone)]
8pub struct CallRecord {
9 pub tool_name: String,
10 pub input: String,
11}
12
13#[derive(Debug, Default, Clone)]
15pub struct QueryHistory {
16 records: Vec<CallRecord>,
17 counts: HashMap<String, usize>,
18}
19
20impl QueryHistory {
21 pub fn new() -> Self {
22 Self::default()
23 }
24
25 pub fn record(&mut self, tool_name: String, input: String) {
27 let count = self.counts.entry(tool_name.clone()).or_insert(0);
28 *count += 1;
29 self.records.push(CallRecord { tool_name, input });
30 }
31
32 pub fn calculate_similarity(s1: &str, s2: &str) -> f64 {
35 let tokens1: HashSet<_> = s1
36 .split_whitespace()
37 .map(|s| s.to_lowercase().replace(|c: char| !c.is_alphanumeric(), ""))
38 .filter(|s| s.len() > 2)
39 .collect();
40
41 let tokens2: HashSet<_> = s2
42 .split_whitespace()
43 .map(|s| s.to_lowercase().replace(|c: char| !c.is_alphanumeric(), ""))
44 .filter(|s| s.len() > 2)
45 .collect();
46
47 if tokens1.is_empty() && tokens2.is_empty() {
48 return 1.0;
49 }
50
51 let intersection: HashSet<_> = tokens1.intersection(&tokens2).collect();
52 let union: HashSet<_> = tokens1.union(&tokens2).collect();
53
54 intersection.len() as f64 / union.len() as f64
55 }
56
57 pub fn check_loop(&self, tool_name: &str, input: &str, threshold: f64) -> Option<String> {
60 for record in self.records.iter().rev() {
61 if record.tool_name == tool_name {
62 let similarity = Self::calculate_similarity(&record.input, input);
63 if similarity >= threshold {
64 debug!(tool = %tool_name, similarity = %similarity, "Detected potential loop call");
65 return Some(format!(
66 "WARNING: This call to '{}' is {:.0}% similar to a previous call. \
67 If the previous call did not yield the results you wanted, try a different approach, \
68 different keywords, or a different tool instead of repeating the same action.",
69 tool_name, similarity * 100.0
70 ));
71 }
72 }
73 }
74 None
75 }
76
77 pub fn get_count(&self, tool_name: &str) -> usize {
78 *self.counts.get(tool_name).unwrap_or(&0)
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
87 fn test_similarity() {
88 let s1 = "Search for Apple stock price in 2024";
89 let s2 = "Search for Apple stock price 2024";
90 let sim = QueryHistory::calculate_similarity(s1, s2);
91 assert!(sim > 0.8);
92
93 let s3 = "Get latest news about Tesla";
94 let sim2 = QueryHistory::calculate_similarity(s1, s3);
95 assert!(sim2 < 0.2);
96 }
97
98 #[test]
99 fn test_loop_detection() {
100 let mut history = QueryHistory::new();
101 history.record("search".to_string(), "Apple stock".to_string());
102
103 let result = history.check_loop("search", "Apple stock price", 0.6);
104 assert!(result.is_some());
105
106 let result2 = history.check_loop("search", "Tesla news", 0.6);
107 assert!(result2.is_none());
108 }
109}