1use std::collections::{BTreeSet, HashMap};
5use std::time::{Duration, Instant};
6
7fn abbreviate_home_path(path: &str) -> String {
8 for prefix in ["/Users/", "/home/"] {
9 if let Some(rest) = path.strip_prefix(prefix) {
10 if let Some((_, suffix)) = rest.split_once('/') {
11 return format!("~/{suffix}");
12 }
13 return "~".to_string();
14 }
15 }
16 path.to_string()
17}
18
19pub struct TurnTracker {
24 pub started_at: Instant,
25 pub tool_calls_started: u32,
26 pub tool_calls_completed: u32,
27 pub tool_errors: u32,
28 pub files_read: BTreeSet<String>,
30 pub files_written: BTreeSet<String>,
32 pub files_created: BTreeSet<String>,
34 pub commands_run: Vec<String>,
36 pub searches: u32,
38
39 pending: HashMap<String, String>,
42}
43
44impl TurnTracker {
45 pub fn new() -> Self {
47 Self {
48 started_at: Instant::now(),
49 tool_calls_started: 0,
50 tool_calls_completed: 0,
51 tool_errors: 0,
52 files_read: BTreeSet::new(),
53 files_written: BTreeSet::new(),
54 files_created: BTreeSet::new(),
55 commands_run: Vec::new(),
56 searches: 0,
57 pending: HashMap::new(),
58 }
59 }
60
61 pub fn start_now(&mut self) {
62 self.started_at = Instant::now();
63 }
64
65 pub fn clear_counts(&mut self) {
66 self.tool_calls_started = 0;
67 self.tool_calls_completed = 0;
68 self.tool_errors = 0;
69 self.files_read.clear();
70 self.files_written.clear();
71 self.files_created.clear();
72 self.commands_run.clear();
73 self.searches = 0;
74 self.pending.clear();
75 }
76
77 pub fn reset(&mut self) {
79 self.started_at = Instant::now();
80 self.tool_calls_started = 0;
81 self.tool_calls_completed = 0;
82 self.tool_errors = 0;
83 self.files_read.clear();
84 self.files_written.clear();
85 self.files_created.clear();
86 self.commands_run.clear();
87 self.searches = 0;
88 self.pending.clear();
89 }
90
91 pub fn elapsed(&self) -> Duration {
93 self.started_at.elapsed()
94 }
95
96 pub fn record_tool_start(&mut self, tool_call_id: &str, name: &str, args: &serde_json::Value) {
101 self.tool_calls_started += 1;
102 self.pending
103 .insert(tool_call_id.to_string(), name.to_string());
104 self.classify(name, args);
105 }
106
107 pub fn record_tool_end(&mut self, tool_call_id: &str, is_error: bool) {
109 self.pending.remove(tool_call_id);
110 self.tool_calls_completed += 1;
111 if is_error {
112 self.tool_errors += 1;
113 }
114 }
115
116 fn classify(&mut self, name: &str, args: &serde_json::Value) {
119 match name {
120 "read" => {
121 if let Some(path) = args["path"].as_str() {
122 self.files_read.insert(abbreviate_home_path(path));
123 }
124 }
125 "edit" | "multi_edit" => {
126 if let Some(path) = args["path"].as_str() {
127 self.files_written.insert(abbreviate_home_path(path));
128 }
129 }
130 "write" => {
131 if let Some(path) = args["path"].as_str() {
132 let path = abbreviate_home_path(path);
133 self.files_written.insert(path.clone());
134 self.files_created.insert(path);
135 }
136 }
137 "bash" => {
138 if let Some(cmd) = args["command"].as_str() {
139 let truncated = cmd.chars().take(80).collect::<String>();
140 self.commands_run.push(truncated);
141 if cmd.trim_start().starts_with("grep ")
142 || cmd.trim_start().starts_with("find ")
143 || cmd.trim_start() == "find"
144 || cmd.trim_start().starts_with("ls ")
145 || cmd.trim_start() == "ls"
146 {
147 self.searches += 1;
148 }
149 }
150 }
151 "grep" | "find" | "probe_search" | "probe_extract" => {
152 self.searches += 1;
153 }
154 _ => {
155 }
157 }
158 }
159}
160
161impl Default for TurnTracker {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use serde_json::json;
171
172 #[test]
173 fn classifies_read_and_write_tools() {
174 let mut tracker = TurnTracker::new();
175
176 tracker.record_tool_start("id-1", "read", &json!({"path": "/Users/test/foo.txt"}));
177 tracker.record_tool_start("id-2", "write", &json!({"path": "/Users/test/bar.txt"}));
178 tracker.record_tool_start("id-3", "edit", &json!({"path": "/Users/test/baz.txt"}));
179
180 assert_eq!(tracker.tool_calls_started, 3);
181 assert!(tracker.files_read.contains("~/foo.txt"));
182 assert!(tracker.files_written.contains("~/bar.txt"));
183 assert!(tracker.files_created.contains("~/bar.txt"));
184 assert!(tracker.files_written.contains("~/baz.txt"));
185 assert!(!tracker.files_created.contains("~/baz.txt"));
187
188 tracker.record_tool_end("id-1", false);
189 tracker.record_tool_end("id-2", false);
190 tracker.record_tool_end("id-3", true);
191
192 assert_eq!(tracker.tool_calls_completed, 3);
193 assert_eq!(tracker.tool_errors, 1);
194 }
195
196 #[test]
197 fn classifies_bash_and_search_tools() {
198 let mut tracker = TurnTracker::new();
199
200 let long_cmd = "a".repeat(120);
201 tracker.record_tool_start("id-bash", "bash", &json!({"command": long_cmd}));
202 tracker.record_tool_start("id-bash-grep", "bash", &json!({"command": "grep foo ."}));
203 tracker.record_tool_start("id-grep", "grep", &json!({"pattern": "foo"}));
204 tracker.record_tool_start("id-find", "find", &json!({"pattern": "*.rs"}));
205 tracker.record_tool_start(
206 "id-probe",
207 "probe_search",
208 &json!({"query": "error handling"}),
209 );
210 tracker.record_tool_start("id-probe2", "probe_extract", &json!({"targets": []}));
211
212 assert_eq!(tracker.commands_run.len(), 2);
213 assert_eq!(tracker.commands_run[0].len(), 80);
215 assert_eq!(tracker.searches, 5);
217 }
218
219 #[test]
220 fn reset_clears_all_state() {
221 let mut tracker = TurnTracker::new();
222
223 tracker.record_tool_start("id-1", "read", &json!({"path": "/tmp/a.txt"}));
224 tracker.record_tool_start("id-2", "bash", &json!({"command": "ls"}));
225 tracker.record_tool_end("id-1", false);
226 tracker.record_tool_end("id-2", true);
227
228 tracker.reset();
229
230 assert_eq!(tracker.tool_calls_started, 0);
231 assert_eq!(tracker.tool_calls_completed, 0);
232 assert_eq!(tracker.tool_errors, 0);
233 assert!(tracker.files_read.is_empty());
234 assert!(tracker.commands_run.is_empty());
235 assert_eq!(tracker.searches, 0);
236 }
237
238 #[test]
239 fn deduplicates_file_paths() {
240 let mut tracker = TurnTracker::new();
241
242 for i in 0..5 {
243 tracker.record_tool_start(
244 &format!("id-{i}"),
245 "read",
246 &json!({"path": "/Users/test/same.txt"}),
247 );
248 }
249
250 assert_eq!(tracker.files_read.len(), 1);
252 }
253
254 #[test]
255 fn elapsed_increases_over_time() {
256 let tracker = TurnTracker::new();
257 let d1 = tracker.elapsed();
258 std::thread::sleep(std::time::Duration::from_millis(10));
259 let d2 = tracker.elapsed();
260 assert!(d2 > d1);
261 }
262}