Skip to main content

imp_tui/
turn_tracker.rs

1//! Per-turn activity tracker that accumulates statistics while an agent turn
2//! is in progress. Feeds into progress indicators and post-turn summaries.
3
4use std::collections::{BTreeSet, HashMap};
5use std::time::{Duration, Instant};
6
7fn abbreviate_home_path(path: &str) -> String {
8    for prefix in ["/Users/", "/home/"] {
9        if let Some(rest) = path.strip_prefix(prefix) {
10            if let Some((_, suffix)) = rest.split_once('/') {
11                return format!("~/{suffix}");
12            }
13            return "~".to_string();
14        }
15    }
16    path.to_string()
17}
18
19/// Tracks tool calls, file accesses, and command runs during a single agent turn.
20///
21/// Designed to be reset at `AgentStart` and queried at any point during or
22/// after the turn to drive progress indicators and summaries.
23pub struct TurnTracker {
24    pub started_at: Instant,
25    pub tool_calls_started: u32,
26    pub tool_calls_completed: u32,
27    pub tool_errors: u32,
28    /// Unique paths that were read (via the `read` tool).
29    pub files_read: BTreeSet<String>,
30    /// Unique paths that were written or edited (edit / multi_edit / write).
31    pub files_written: BTreeSet<String>,
32    /// Unique paths that were created by the `write` tool.
33    pub files_created: BTreeSet<String>,
34    /// Bash commands that were executed (first 80 chars each).
35    pub commands_run: Vec<String>,
36    /// Number of search-like tool calls (grep / find / probe_search / probe_extract).
37    pub searches: u32,
38
39    /// Maps tool_call_id → tool_name so `record_tool_end` can act on the name
40    /// even though `ToolExecutionEnd` only carries the id.
41    pending: HashMap<String, String>,
42}
43
44impl TurnTracker {
45    /// Create a fresh tracker with the clock started now.
46    pub fn new() -> Self {
47        Self {
48            started_at: Instant::now(),
49            tool_calls_started: 0,
50            tool_calls_completed: 0,
51            tool_errors: 0,
52            files_read: BTreeSet::new(),
53            files_written: BTreeSet::new(),
54            files_created: BTreeSet::new(),
55            commands_run: Vec::new(),
56            searches: 0,
57            pending: HashMap::new(),
58        }
59    }
60
61    /// Reset all counters and restart the clock. Called at `AgentStart`.
62    pub fn reset(&mut self) {
63        self.started_at = Instant::now();
64        self.tool_calls_started = 0;
65        self.tool_calls_completed = 0;
66        self.tool_errors = 0;
67        self.files_read.clear();
68        self.files_written.clear();
69        self.files_created.clear();
70        self.commands_run.clear();
71        self.searches = 0;
72        self.pending.clear();
73    }
74
75    /// Wall-clock time since the turn started (or last reset).
76    pub fn elapsed(&self) -> Duration {
77        self.started_at.elapsed()
78    }
79
80    /// Called at `ToolExecutionStart`. Records the tool call and classifies it.
81    ///
82    /// The `tool_call_id` is stored so that `record_tool_end` can look up the
83    /// name when the result arrives.
84    pub fn record_tool_start(&mut self, tool_call_id: &str, name: &str, args: &serde_json::Value) {
85        self.tool_calls_started += 1;
86        self.pending
87            .insert(tool_call_id.to_string(), name.to_string());
88        self.classify(name, args);
89    }
90
91    /// Called at `ToolExecutionEnd`. Increments completed / error counters.
92    pub fn record_tool_end(&mut self, tool_call_id: &str, is_error: bool) {
93        self.pending.remove(tool_call_id);
94        self.tool_calls_completed += 1;
95        if is_error {
96            self.tool_errors += 1;
97        }
98    }
99
100    // ── Private helpers ──────────────────────────────────────────────────────
101
102    fn classify(&mut self, name: &str, args: &serde_json::Value) {
103        match name {
104            "read" => {
105                if let Some(path) = args["path"].as_str() {
106                    self.files_read.insert(abbreviate_home_path(path));
107                }
108            }
109            "edit" | "multi_edit" => {
110                if let Some(path) = args["path"].as_str() {
111                    self.files_written.insert(abbreviate_home_path(path));
112                }
113            }
114            "write" => {
115                if let Some(path) = args["path"].as_str() {
116                    let path = abbreviate_home_path(path);
117                    self.files_written.insert(path.clone());
118                    self.files_created.insert(path);
119                }
120            }
121            "bash" => {
122                if let Some(cmd) = args["command"].as_str() {
123                    let truncated = cmd.chars().take(80).collect::<String>();
124                    self.commands_run.push(truncated);
125                    if cmd.trim_start().starts_with("grep ")
126                        || cmd.trim_start().starts_with("find ")
127                        || cmd.trim_start() == "find"
128                        || cmd.trim_start().starts_with("ls ")
129                        || cmd.trim_start() == "ls"
130                    {
131                        self.searches += 1;
132                    }
133                }
134            }
135            "grep" | "find" | "probe_search" | "probe_extract" => {
136                self.searches += 1;
137            }
138            _ => {
139                // Counted via tool_calls_started — no further classification needed.
140            }
141        }
142    }
143}
144
145impl Default for TurnTracker {
146    fn default() -> Self {
147        Self::new()
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use serde_json::json;
155
156    #[test]
157    fn classifies_read_and_write_tools() {
158        let mut tracker = TurnTracker::new();
159
160        tracker.record_tool_start("id-1", "read", &json!({"path": "/Users/test/foo.txt"}));
161        tracker.record_tool_start("id-2", "write", &json!({"path": "/Users/test/bar.txt"}));
162        tracker.record_tool_start("id-3", "edit", &json!({"path": "/Users/test/baz.txt"}));
163
164        assert_eq!(tracker.tool_calls_started, 3);
165        assert!(tracker.files_read.contains("~/foo.txt"));
166        assert!(tracker.files_written.contains("~/bar.txt"));
167        assert!(tracker.files_created.contains("~/bar.txt"));
168        assert!(tracker.files_written.contains("~/baz.txt"));
169        // edit should NOT go into files_created
170        assert!(!tracker.files_created.contains("~/baz.txt"));
171
172        tracker.record_tool_end("id-1", false);
173        tracker.record_tool_end("id-2", false);
174        tracker.record_tool_end("id-3", true);
175
176        assert_eq!(tracker.tool_calls_completed, 3);
177        assert_eq!(tracker.tool_errors, 1);
178    }
179
180    #[test]
181    fn classifies_bash_and_search_tools() {
182        let mut tracker = TurnTracker::new();
183
184        let long_cmd = "a".repeat(120);
185        tracker.record_tool_start("id-bash", "bash", &json!({"command": long_cmd}));
186        tracker.record_tool_start("id-bash-grep", "bash", &json!({"command": "grep foo ."}));
187        tracker.record_tool_start("id-grep", "grep", &json!({"pattern": "foo"}));
188        tracker.record_tool_start("id-find", "find", &json!({"pattern": "*.rs"}));
189        tracker.record_tool_start(
190            "id-probe",
191            "probe_search",
192            &json!({"query": "error handling"}),
193        );
194        tracker.record_tool_start("id-probe2", "probe_extract", &json!({"targets": []}));
195
196        assert_eq!(tracker.commands_run.len(), 2);
197        // Command should be truncated to 80 chars
198        assert_eq!(tracker.commands_run[0].len(), 80);
199        // bash grep + grep, find, probe_search, probe_extract = 5 search calls
200        assert_eq!(tracker.searches, 5);
201    }
202
203    #[test]
204    fn reset_clears_all_state() {
205        let mut tracker = TurnTracker::new();
206
207        tracker.record_tool_start("id-1", "read", &json!({"path": "/tmp/a.txt"}));
208        tracker.record_tool_start("id-2", "bash", &json!({"command": "ls"}));
209        tracker.record_tool_end("id-1", false);
210        tracker.record_tool_end("id-2", true);
211
212        tracker.reset();
213
214        assert_eq!(tracker.tool_calls_started, 0);
215        assert_eq!(tracker.tool_calls_completed, 0);
216        assert_eq!(tracker.tool_errors, 0);
217        assert!(tracker.files_read.is_empty());
218        assert!(tracker.commands_run.is_empty());
219        assert_eq!(tracker.searches, 0);
220    }
221
222    #[test]
223    fn deduplicates_file_paths() {
224        let mut tracker = TurnTracker::new();
225
226        for i in 0..5 {
227            tracker.record_tool_start(
228                &format!("id-{i}"),
229                "read",
230                &json!({"path": "/Users/test/same.txt"}),
231            );
232        }
233
234        // BTreeSet deduplicates
235        assert_eq!(tracker.files_read.len(), 1);
236    }
237
238    #[test]
239    fn elapsed_increases_over_time() {
240        let tracker = TurnTracker::new();
241        let d1 = tracker.elapsed();
242        std::thread::sleep(std::time::Duration::from_millis(10));
243        let d2 = tracker.elapsed();
244        assert!(d2 > d1);
245    }
246}