1use std::collections::{BTreeSet, HashMap};
5use std::time::{Duration, Instant};
6
7fn abbreviate_home_path(path: &str) -> String {
8 for prefix in ["/Users/", "/home/"] {
9 if let Some(rest) = path.strip_prefix(prefix) {
10 if let Some((_, suffix)) = rest.split_once('/') {
11 return format!("~/{suffix}");
12 }
13 return "~".to_string();
14 }
15 }
16 path.to_string()
17}
18
19pub struct TurnTracker {
24 pub started_at: Instant,
25 pub tool_calls_started: u32,
26 pub tool_calls_completed: u32,
27 pub tool_errors: u32,
28 pub files_read: BTreeSet<String>,
30 pub files_written: BTreeSet<String>,
32 pub files_created: BTreeSet<String>,
34 pub commands_run: Vec<String>,
36 pub searches: u32,
38
39 pending: HashMap<String, String>,
42}
43
44impl TurnTracker {
45 pub fn new() -> Self {
47 Self {
48 started_at: Instant::now(),
49 tool_calls_started: 0,
50 tool_calls_completed: 0,
51 tool_errors: 0,
52 files_read: BTreeSet::new(),
53 files_written: BTreeSet::new(),
54 files_created: BTreeSet::new(),
55 commands_run: Vec::new(),
56 searches: 0,
57 pending: HashMap::new(),
58 }
59 }
60
61 pub fn reset(&mut self) {
63 self.started_at = Instant::now();
64 self.tool_calls_started = 0;
65 self.tool_calls_completed = 0;
66 self.tool_errors = 0;
67 self.files_read.clear();
68 self.files_written.clear();
69 self.files_created.clear();
70 self.commands_run.clear();
71 self.searches = 0;
72 self.pending.clear();
73 }
74
75 pub fn elapsed(&self) -> Duration {
77 self.started_at.elapsed()
78 }
79
80 pub fn record_tool_start(&mut self, tool_call_id: &str, name: &str, args: &serde_json::Value) {
85 self.tool_calls_started += 1;
86 self.pending
87 .insert(tool_call_id.to_string(), name.to_string());
88 self.classify(name, args);
89 }
90
91 pub fn record_tool_end(&mut self, tool_call_id: &str, is_error: bool) {
93 self.pending.remove(tool_call_id);
94 self.tool_calls_completed += 1;
95 if is_error {
96 self.tool_errors += 1;
97 }
98 }
99
100 fn classify(&mut self, name: &str, args: &serde_json::Value) {
103 match name {
104 "read" => {
105 if let Some(path) = args["path"].as_str() {
106 self.files_read.insert(abbreviate_home_path(path));
107 }
108 }
109 "edit" | "multi_edit" => {
110 if let Some(path) = args["path"].as_str() {
111 self.files_written.insert(abbreviate_home_path(path));
112 }
113 }
114 "write" => {
115 if let Some(path) = args["path"].as_str() {
116 let path = abbreviate_home_path(path);
117 self.files_written.insert(path.clone());
118 self.files_created.insert(path);
119 }
120 }
121 "bash" => {
122 if let Some(cmd) = args["command"].as_str() {
123 let truncated = cmd.chars().take(80).collect::<String>();
124 self.commands_run.push(truncated);
125 if cmd.trim_start().starts_with("grep ")
126 || cmd.trim_start().starts_with("find ")
127 || cmd.trim_start() == "find"
128 || cmd.trim_start().starts_with("ls ")
129 || cmd.trim_start() == "ls"
130 {
131 self.searches += 1;
132 }
133 }
134 }
135 "grep" | "find" | "probe_search" | "probe_extract" => {
136 self.searches += 1;
137 }
138 _ => {
139 }
141 }
142 }
143}
144
145impl Default for TurnTracker {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use serde_json::json;
155
156 #[test]
157 fn classifies_read_and_write_tools() {
158 let mut tracker = TurnTracker::new();
159
160 tracker.record_tool_start("id-1", "read", &json!({"path": "/Users/test/foo.txt"}));
161 tracker.record_tool_start("id-2", "write", &json!({"path": "/Users/test/bar.txt"}));
162 tracker.record_tool_start("id-3", "edit", &json!({"path": "/Users/test/baz.txt"}));
163
164 assert_eq!(tracker.tool_calls_started, 3);
165 assert!(tracker.files_read.contains("~/foo.txt"));
166 assert!(tracker.files_written.contains("~/bar.txt"));
167 assert!(tracker.files_created.contains("~/bar.txt"));
168 assert!(tracker.files_written.contains("~/baz.txt"));
169 assert!(!tracker.files_created.contains("~/baz.txt"));
171
172 tracker.record_tool_end("id-1", false);
173 tracker.record_tool_end("id-2", false);
174 tracker.record_tool_end("id-3", true);
175
176 assert_eq!(tracker.tool_calls_completed, 3);
177 assert_eq!(tracker.tool_errors, 1);
178 }
179
180 #[test]
181 fn classifies_bash_and_search_tools() {
182 let mut tracker = TurnTracker::new();
183
184 let long_cmd = "a".repeat(120);
185 tracker.record_tool_start("id-bash", "bash", &json!({"command": long_cmd}));
186 tracker.record_tool_start("id-bash-grep", "bash", &json!({"command": "grep foo ."}));
187 tracker.record_tool_start("id-grep", "grep", &json!({"pattern": "foo"}));
188 tracker.record_tool_start("id-find", "find", &json!({"pattern": "*.rs"}));
189 tracker.record_tool_start(
190 "id-probe",
191 "probe_search",
192 &json!({"query": "error handling"}),
193 );
194 tracker.record_tool_start("id-probe2", "probe_extract", &json!({"targets": []}));
195
196 assert_eq!(tracker.commands_run.len(), 2);
197 assert_eq!(tracker.commands_run[0].len(), 80);
199 assert_eq!(tracker.searches, 5);
201 }
202
203 #[test]
204 fn reset_clears_all_state() {
205 let mut tracker = TurnTracker::new();
206
207 tracker.record_tool_start("id-1", "read", &json!({"path": "/tmp/a.txt"}));
208 tracker.record_tool_start("id-2", "bash", &json!({"command": "ls"}));
209 tracker.record_tool_end("id-1", false);
210 tracker.record_tool_end("id-2", true);
211
212 tracker.reset();
213
214 assert_eq!(tracker.tool_calls_started, 0);
215 assert_eq!(tracker.tool_calls_completed, 0);
216 assert_eq!(tracker.tool_errors, 0);
217 assert!(tracker.files_read.is_empty());
218 assert!(tracker.commands_run.is_empty());
219 assert_eq!(tracker.searches, 0);
220 }
221
222 #[test]
223 fn deduplicates_file_paths() {
224 let mut tracker = TurnTracker::new();
225
226 for i in 0..5 {
227 tracker.record_tool_start(
228 &format!("id-{i}"),
229 "read",
230 &json!({"path": "/Users/test/same.txt"}),
231 );
232 }
233
234 assert_eq!(tracker.files_read.len(), 1);
236 }
237
238 #[test]
239 fn elapsed_increases_over_time() {
240 let tracker = TurnTracker::new();
241 let d1 = tracker.elapsed();
242 std::thread::sleep(std::time::Duration::from_millis(10));
243 let d2 = tracker.elapsed();
244 assert!(d2 > d1);
245 }
246}