Skip to main content

imp_tui/
turn_tracker.rs

1//! Per-turn activity tracker that accumulates statistics while an agent turn
2//! is in progress. Feeds into progress indicators and post-turn summaries.
3
4use std::collections::{BTreeSet, HashMap};
5use std::time::{Duration, Instant};
6
7fn abbreviate_home_path(path: &str) -> String {
8    for prefix in ["/Users/", "/home/"] {
9        if let Some(rest) = path.strip_prefix(prefix) {
10            if let Some((_, suffix)) = rest.split_once('/') {
11                return format!("~/{suffix}");
12            }
13            return "~".to_string();
14        }
15    }
16    path.to_string()
17}
18
19/// Tracks tool calls, file accesses, and command runs during a single agent turn.
20///
21/// Designed to be reset at `AgentStart` and queried at any point during or
22/// after the turn to drive progress indicators and summaries.
23pub struct TurnTracker {
24    pub started_at: Instant,
25    pub tool_calls_started: u32,
26    pub tool_calls_completed: u32,
27    pub tool_errors: u32,
28    /// Unique paths that were read (via the `read` tool).
29    pub files_read: BTreeSet<String>,
30    /// Unique paths that were written or edited (edit / multi_edit / write).
31    pub files_written: BTreeSet<String>,
32    /// Unique paths that were created by the `write` tool.
33    pub files_created: BTreeSet<String>,
34    /// Bash commands that were executed (first 80 chars each).
35    pub commands_run: Vec<String>,
36    /// Number of search-like tool calls (grep / find / probe_search / probe_extract).
37    pub searches: u32,
38
39    /// Maps tool_call_id → tool_name so `record_tool_end` can act on the name
40    /// even though `ToolExecutionEnd` only carries the id.
41    pending: HashMap<String, String>,
42}
43
44impl TurnTracker {
45    /// Create a fresh tracker with the clock started now.
46    pub fn new() -> Self {
47        Self {
48            started_at: Instant::now(),
49            tool_calls_started: 0,
50            tool_calls_completed: 0,
51            tool_errors: 0,
52            files_read: BTreeSet::new(),
53            files_written: BTreeSet::new(),
54            files_created: BTreeSet::new(),
55            commands_run: Vec::new(),
56            searches: 0,
57            pending: HashMap::new(),
58        }
59    }
60
61    pub fn start_now(&mut self) {
62        self.started_at = Instant::now();
63    }
64
65    pub fn clear_counts(&mut self) {
66        self.tool_calls_started = 0;
67        self.tool_calls_completed = 0;
68        self.tool_errors = 0;
69        self.files_read.clear();
70        self.files_written.clear();
71        self.files_created.clear();
72        self.commands_run.clear();
73        self.searches = 0;
74        self.pending.clear();
75    }
76
77    /// Reset all counters and restart the clock. Called at `AgentStart`.
78    pub fn reset(&mut self) {
79        self.started_at = Instant::now();
80        self.tool_calls_started = 0;
81        self.tool_calls_completed = 0;
82        self.tool_errors = 0;
83        self.files_read.clear();
84        self.files_written.clear();
85        self.files_created.clear();
86        self.commands_run.clear();
87        self.searches = 0;
88        self.pending.clear();
89    }
90
91    /// Wall-clock time since the turn started (or last reset).
92    pub fn elapsed(&self) -> Duration {
93        self.started_at.elapsed()
94    }
95
96    /// Called at `ToolExecutionStart`. Records the tool call and classifies it.
97    ///
98    /// The `tool_call_id` is stored so that `record_tool_end` can look up the
99    /// name when the result arrives.
100    pub fn record_tool_start(&mut self, tool_call_id: &str, name: &str, args: &serde_json::Value) {
101        self.tool_calls_started += 1;
102        self.pending
103            .insert(tool_call_id.to_string(), name.to_string());
104        self.classify(name, args);
105    }
106
107    /// Called at `ToolExecutionEnd`. Increments completed / error counters.
108    pub fn record_tool_end(&mut self, tool_call_id: &str, is_error: bool) {
109        self.pending.remove(tool_call_id);
110        self.tool_calls_completed += 1;
111        if is_error {
112            self.tool_errors += 1;
113        }
114    }
115
116    // ── Private helpers ──────────────────────────────────────────────────────
117
118    fn classify(&mut self, name: &str, args: &serde_json::Value) {
119        match name {
120            "read" => {
121                if let Some(path) = args["path"].as_str() {
122                    self.files_read.insert(abbreviate_home_path(path));
123                }
124            }
125            "edit" | "multi_edit" => {
126                if let Some(path) = args["path"].as_str() {
127                    self.files_written.insert(abbreviate_home_path(path));
128                }
129            }
130            "write" => {
131                if let Some(path) = args["path"].as_str() {
132                    let path = abbreviate_home_path(path);
133                    self.files_written.insert(path.clone());
134                    self.files_created.insert(path);
135                }
136            }
137            "bash" => {
138                if let Some(cmd) = args["command"].as_str() {
139                    let truncated = cmd.chars().take(80).collect::<String>();
140                    self.commands_run.push(truncated);
141                    if cmd.trim_start().starts_with("grep ")
142                        || cmd.trim_start().starts_with("find ")
143                        || cmd.trim_start() == "find"
144                        || cmd.trim_start().starts_with("ls ")
145                        || cmd.trim_start() == "ls"
146                    {
147                        self.searches += 1;
148                    }
149                }
150            }
151            "grep" | "find" | "probe_search" | "probe_extract" => {
152                self.searches += 1;
153            }
154            _ => {
155                // Counted via tool_calls_started — no further classification needed.
156            }
157        }
158    }
159}
160
161impl Default for TurnTracker {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use serde_json::json;
171
172    #[test]
173    fn classifies_read_and_write_tools() {
174        let mut tracker = TurnTracker::new();
175
176        tracker.record_tool_start("id-1", "read", &json!({"path": "/Users/test/foo.txt"}));
177        tracker.record_tool_start("id-2", "write", &json!({"path": "/Users/test/bar.txt"}));
178        tracker.record_tool_start("id-3", "edit", &json!({"path": "/Users/test/baz.txt"}));
179
180        assert_eq!(tracker.tool_calls_started, 3);
181        assert!(tracker.files_read.contains("~/foo.txt"));
182        assert!(tracker.files_written.contains("~/bar.txt"));
183        assert!(tracker.files_created.contains("~/bar.txt"));
184        assert!(tracker.files_written.contains("~/baz.txt"));
185        // edit should NOT go into files_created
186        assert!(!tracker.files_created.contains("~/baz.txt"));
187
188        tracker.record_tool_end("id-1", false);
189        tracker.record_tool_end("id-2", false);
190        tracker.record_tool_end("id-3", true);
191
192        assert_eq!(tracker.tool_calls_completed, 3);
193        assert_eq!(tracker.tool_errors, 1);
194    }
195
196    #[test]
197    fn classifies_bash_and_search_tools() {
198        let mut tracker = TurnTracker::new();
199
200        let long_cmd = "a".repeat(120);
201        tracker.record_tool_start("id-bash", "bash", &json!({"command": long_cmd}));
202        tracker.record_tool_start("id-bash-grep", "bash", &json!({"command": "grep foo ."}));
203        tracker.record_tool_start("id-grep", "grep", &json!({"pattern": "foo"}));
204        tracker.record_tool_start("id-find", "find", &json!({"pattern": "*.rs"}));
205        tracker.record_tool_start(
206            "id-probe",
207            "probe_search",
208            &json!({"query": "error handling"}),
209        );
210        tracker.record_tool_start("id-probe2", "probe_extract", &json!({"targets": []}));
211
212        assert_eq!(tracker.commands_run.len(), 2);
213        // Command should be truncated to 80 chars
214        assert_eq!(tracker.commands_run[0].len(), 80);
215        // bash grep + grep, find, probe_search, probe_extract = 5 search calls
216        assert_eq!(tracker.searches, 5);
217    }
218
219    #[test]
220    fn reset_clears_all_state() {
221        let mut tracker = TurnTracker::new();
222
223        tracker.record_tool_start("id-1", "read", &json!({"path": "/tmp/a.txt"}));
224        tracker.record_tool_start("id-2", "bash", &json!({"command": "ls"}));
225        tracker.record_tool_end("id-1", false);
226        tracker.record_tool_end("id-2", true);
227
228        tracker.reset();
229
230        assert_eq!(tracker.tool_calls_started, 0);
231        assert_eq!(tracker.tool_calls_completed, 0);
232        assert_eq!(tracker.tool_errors, 0);
233        assert!(tracker.files_read.is_empty());
234        assert!(tracker.commands_run.is_empty());
235        assert_eq!(tracker.searches, 0);
236    }
237
238    #[test]
239    fn deduplicates_file_paths() {
240        let mut tracker = TurnTracker::new();
241
242        for i in 0..5 {
243            tracker.record_tool_start(
244                &format!("id-{i}"),
245                "read",
246                &json!({"path": "/Users/test/same.txt"}),
247            );
248        }
249
250        // BTreeSet deduplicates
251        assert_eq!(tracker.files_read.len(), 1);
252    }
253
254    #[test]
255    fn elapsed_increases_over_time() {
256        let tracker = TurnTracker::new();
257        let d1 = tracker.elapsed();
258        std::thread::sleep(std::time::Duration::from_millis(10));
259        let d2 = tracker.elapsed();
260        assert!(d2 > d1);
261    }
262}