Skip to main content

wisp/components/
tool_call_statuses.rs

1use acp_utils::notifications::SubAgentProgressParams;
2use agent_client_protocol as acp;
3use std::collections::HashMap;
4use std::time::Instant;
5
6use crate::components::sub_agent_tracker::SubAgentTracker;
7use crate::components::tool_call_status_view::{ToolCallStatus, compute_diff_preview, render_tool_tree};
8use crate::components::tracked_tool_call::{TrackedToolCall, raw_input_fragment, upsert_tracked_tool_call};
9use tui::{Line, ViewContext};
10
11/// Tracks active tool calls and produces status lines for the frame.
12#[derive(Clone)]
13pub struct ToolCallStatuses {
14    /// Ordered list of tool call IDs (insertion order)
15    tool_order: Vec<String>,
16    /// Tool call info by ID
17    tool_calls: HashMap<String, TrackedToolCall>,
18    /// Sub-agent states keyed by parent tool call ID
19    sub_agents: SubAgentTracker,
20    /// Animation tick for the spinner on running tool calls
21    tick: u16,
22}
23
24pub struct ToolProgress {
25    pub running_any: bool,
26    pub completed_top_level: usize,
27    pub total_top_level: usize,
28}
29
30impl ToolCallStatuses {
31    pub fn new() -> Self {
32        Self { tool_order: Vec::new(), tool_calls: HashMap::new(), sub_agents: SubAgentTracker::default(), tick: 0 }
33    }
34
35    pub fn progress(&self) -> ToolProgress {
36        let running_any = self.any_running_including_subagents();
37        let (completed_top_level, total_top_level) = self.top_level_counts();
38        ToolProgress { running_any, completed_top_level, total_top_level }
39    }
40
41    /// Advance the animation state. Call this on tick events.
42    pub fn on_tick(&mut self, _now: Instant) {
43        if self.progress().running_any {
44            self.tick = self.tick.wrapping_add(1);
45        }
46    }
47
48    /// Handle a new tool call from ACP `SessionUpdate::ToolCall`.
49    pub fn on_tool_call(&mut self, tool_call: &acp::ToolCall) {
50        let id = tool_call.tool_call_id.0.to_string();
51        let arguments = tool_call.raw_input.as_ref().map(raw_input_fragment).unwrap_or_default();
52
53        let tracked = upsert_tracked_tool_call(
54            &mut self.tool_order,
55            &mut self.tool_calls,
56            &id,
57            &tool_call.title,
58            arguments.clone(),
59        );
60        tracked.update_name(&tool_call.title);
61        tracked.arguments = arguments;
62        tracked.status = ToolCallStatus::Running;
63    }
64
65    /// Handle a tool call update from ACP `SessionUpdate::ToolCallUpdate`.
66    pub fn on_tool_call_update(&mut self, update: &acp::ToolCallUpdate) {
67        let id = update.tool_call_id.0.to_string();
68
69        if let Some(tc) = self.tool_calls.get_mut(&id) {
70            if let Some(title) = &update.fields.title {
71                tc.update_name(title);
72            }
73            if let Some(raw_input) = &update.fields.raw_input {
74                tc.append_arguments(&raw_input_fragment(raw_input));
75            }
76            if let Some(meta) = &update.meta
77                && let Some(dv) = meta.get("display_value").and_then(|v| v.as_str())
78            {
79                tc.display_value = Some(dv.to_string());
80            }
81            if let Some(content) = &update.fields.content {
82                for item in content {
83                    if let acp::ToolCallContent::Diff(diff) = item {
84                        tc.diff_preview = Some(compute_diff_preview(diff));
85                    }
86                }
87            }
88            if let Some(status) = update.fields.status {
89                tc.apply_status(status);
90            }
91        }
92    }
93
94    pub fn finalize_running(&mut self, cancelled: bool) {
95        let terminal_status =
96            if cancelled { ToolCallStatus::Error("cancelled".to_string()) } else { ToolCallStatus::Success };
97
98        for tool_call in self.tool_calls.values_mut() {
99            if matches!(tool_call.status, ToolCallStatus::Running) {
100                tool_call.status = terminal_status.clone();
101            }
102        }
103
104        self.sub_agents.finalize_running(cancelled);
105    }
106
107    pub fn has_tool(&self, id: &str) -> bool {
108        self.tool_calls.contains_key(id)
109    }
110
111    #[cfg(test)]
112    pub fn is_tool_running(&self, id: &str) -> bool {
113        self.tool_calls.get(id).is_some_and(|tc| matches!(tc.status, ToolCallStatus::Running))
114    }
115
116    /// Handle a sub-agent progress notification.
117    pub fn on_sub_agent_progress(&mut self, notification: &SubAgentProgressParams) {
118        self.sub_agents.on_progress(notification);
119    }
120
121    #[cfg(test)]
122    pub fn remove_tool(&mut self, id: &str) {
123        self.tool_calls.remove(id);
124        self.tool_order.retain(|tool_id| tool_id != id);
125        self.sub_agents.remove(id);
126    }
127
128    pub fn render_tool(&self, id: &str, context: &ViewContext) -> Vec<Line> {
129        render_tool_tree(id, &self.tool_calls, &self.sub_agents, self.tick, context)
130    }
131
132    /// Clear all tracked tool calls (e.g., after pushing to scrollback).
133    pub fn clear(&mut self) {
134        self.tool_order.clear();
135        self.tool_calls.clear();
136        self.sub_agents.clear();
137    }
138
139    fn top_level_counts(&self) -> (usize, usize) {
140        let total = self.tool_order.iter().filter(|id| !self.sub_agents.has_sub_agents(id)).count();
141        let completed = self
142            .tool_order
143            .iter()
144            .filter(|id| !self.sub_agents.has_sub_agents(id))
145            .filter_map(|id| self.tool_calls.get(id))
146            .filter(|tc| !matches!(tc.status, ToolCallStatus::Running))
147            .count();
148        (completed, total)
149    }
150
151    fn any_running_including_subagents(&self) -> bool {
152        self.tool_calls.values().any(|tc| matches!(tc.status, ToolCallStatus::Running)) || self.sub_agents.any_running()
153    }
154}
155
156impl Default for ToolCallStatuses {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use acp_utils::notifications::{SubAgentEvent, SubAgentProgressParams};
166    use tui::{DiffLine, DiffPreview, DiffTag, SplitDiffCell, SplitDiffRow};
167
168    fn ctx() -> ViewContext {
169        ViewContext::new((80, 24))
170    }
171
172    fn make_tool_call(id: &str, title: &str, raw_input: Option<&str>) -> acp::ToolCall {
173        let mut tc = acp::ToolCall::new(id.to_string(), title);
174        if let Some(input) = raw_input {
175            tc = tc.raw_input(serde_json::from_str::<serde_json::Value>(input).unwrap());
176        }
177        tc
178    }
179
180    fn make_tool_call_update(id: &str, status: acp::ToolCallStatus) -> acp::ToolCallUpdate {
181        acp::ToolCallUpdate::new(id.to_string(), acp::ToolCallUpdateFields::new().status(status))
182    }
183
184    fn make_sub_agent_notification(parent_tool_id: &str, agent_name: &str, event_json: &str) -> SubAgentProgressParams {
185        make_sub_agent_notification_with_task_id(parent_tool_id, agent_name, agent_name, event_json)
186    }
187
188    fn make_sub_agent_notification_with_task_id(
189        parent_tool_id: &str,
190        task_id: &str,
191        agent_name: &str,
192        event_json: &str,
193    ) -> SubAgentProgressParams {
194        let json = format!(
195            r#"{{"parent_tool_id":"{parent_tool_id}","task_id":"{task_id}","agent_name":"{agent_name}","event":{event_json}}}"#,
196        );
197        serde_json::from_str(&json).unwrap()
198    }
199
200    #[test]
201    fn progress_reports_sub_agent_running_tools() {
202        let mut statuses = ToolCallStatuses::new();
203        statuses.on_tool_call(&make_tool_call("parent-1", "spawn_subagent", None));
204        statuses.on_tool_call_update(&make_tool_call_update("parent-1", acp::ToolCallStatus::Completed));
205        statuses.on_sub_agent_progress(&make_sub_agent_notification(
206            "parent-1",
207            "explorer",
208            r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{}"},"model_name":"m"}}"#,
209        ));
210
211        assert!(statuses.progress().running_any);
212    }
213
214    #[test]
215    fn remove_tool_cleans_up_sub_agent_state() {
216        let mut statuses = ToolCallStatuses::new();
217        statuses.on_tool_call(&make_tool_call("parent-1", "spawn_subagent", None));
218        statuses.on_sub_agent_progress(&make_sub_agent_notification(
219            "parent-1",
220            "explorer",
221            r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{}"},"model_name":"m"}}"#,
222        ));
223
224        statuses.remove_tool("parent-1");
225        assert!(!statuses.progress().running_any);
226        assert!(statuses.render_tool("parent-1", &ctx()).is_empty());
227    }
228
229    #[test]
230    fn clear_removes_sub_agent_state() {
231        let mut statuses = ToolCallStatuses::new();
232        statuses.on_tool_call(&make_tool_call("parent-1", "spawn_subagent", None));
233        statuses.on_sub_agent_progress(&make_sub_agent_notification(
234            "parent-1",
235            "explorer",
236            r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{}"},"model_name":"m"}}"#,
237        ));
238
239        statuses.clear();
240        assert!(!statuses.progress().running_any);
241    }
242
243    #[test]
244    fn deserialize_tool_call_event() {
245        let n = make_sub_agent_notification(
246            "p1",
247            "explorer",
248            r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#,
249        );
250        assert!(matches!(n.event, SubAgentEvent::ToolCall { .. }));
251    }
252
253    #[test]
254    fn deserialize_tool_call_update_event() {
255        let n = make_sub_agent_notification(
256            "p1",
257            "explorer",
258            r#"{"ToolCallUpdate":{"update":{"id":"c1","chunk":"{\"pattern\":\"updated\"}"},"model_name":"m"}}"#,
259        );
260        assert!(matches!(n.event, SubAgentEvent::ToolCallUpdate { .. }));
261    }
262
263    #[test]
264    fn deserialize_tool_result_event() {
265        let n = make_sub_agent_notification(
266            "p1",
267            "explorer",
268            r#"{"ToolResult":{"result":{"id":"c1","name":"grep","arguments":"{}","result":"ok"},"model_name":"m"}}"#,
269        );
270        assert!(matches!(n.event, SubAgentEvent::ToolResult { .. }));
271    }
272
273    #[test]
274    fn deserialize_done_event() {
275        let n = make_sub_agent_notification("p1", "explorer", r#""Done""#);
276        assert!(matches!(n.event, SubAgentEvent::Done));
277    }
278
279    #[test]
280    fn deserialize_other_variant() {
281        let n = make_sub_agent_notification("p1", "explorer", r#""Other""#);
282        assert!(matches!(n.event, SubAgentEvent::Other));
283    }
284
285    #[test]
286    fn test_diff_preview_rendered_on_success() {
287        let mut statuses = ToolCallStatuses::new();
288        statuses.on_tool_call(&make_tool_call("tool-1", "Edit", None));
289
290        let tc = statuses.tool_calls.get_mut("tool-1").unwrap();
291        tc.status = ToolCallStatus::Success;
292        tc.diff_preview = Some(DiffPreview {
293            lines: vec![
294                DiffLine { tag: DiffTag::Removed, content: "old line".to_string() },
295                DiffLine { tag: DiffTag::Added, content: "new line".to_string() },
296            ],
297            rows: vec![SplitDiffRow {
298                left: Some(SplitDiffCell {
299                    tag: DiffTag::Removed,
300                    content: "old line".to_string(),
301                    line_number: Some(1),
302                }),
303                right: Some(SplitDiffCell {
304                    tag: DiffTag::Added,
305                    content: "new line".to_string(),
306                    line_number: Some(1),
307                }),
308            }],
309            lang_hint: "rs".to_string(),
310            start_line: Some(1),
311        });
312
313        let lines = statuses.render_tool("tool-1", &ctx());
314        assert!(lines.len() > 1);
315        let all_text: String = lines.iter().map(tui::Line::plain_text).collect();
316        assert!(all_text.contains("old line"), "Expected removed line: {all_text}");
317        assert!(all_text.contains("new line"), "Expected added line: {all_text}");
318    }
319
320    #[test]
321    fn test_diff_preview_not_rendered_while_running() {
322        let mut statuses = ToolCallStatuses::new();
323        statuses.on_tool_call(&make_tool_call("tool-1", "Edit", None));
324
325        let tc = statuses.tool_calls.get_mut("tool-1").unwrap();
326        tc.diff_preview = Some(DiffPreview {
327            lines: vec![DiffLine { tag: DiffTag::Added, content: "new line".to_string() }],
328            rows: vec![SplitDiffRow {
329                left: None,
330                right: Some(SplitDiffCell {
331                    tag: DiffTag::Added,
332                    content: "new line".to_string(),
333                    line_number: Some(1),
334                }),
335            }],
336            lang_hint: "rs".to_string(),
337            start_line: Some(1),
338        });
339
340        let lines = statuses.render_tool("tool-1", &ctx());
341        assert_eq!(lines.len(), 1, "Should only have status line while running");
342    }
343
344    #[test]
345    fn finalize_running_marks_top_level_tools_terminal() {
346        let mut statuses = ToolCallStatuses::new();
347        statuses.on_tool_call(&make_tool_call("tool-1", "Read", None));
348
349        statuses.finalize_running(false);
350
351        assert!(!statuses.is_tool_running("tool-1"));
352        assert!(!statuses.progress().running_any);
353        let lines = statuses.render_tool("tool-1", &ctx());
354        assert!(lines[0].plain_text().contains('✓'));
355    }
356
357    #[test]
358    fn finalize_running_marks_sub_agent_tools_terminal() {
359        let mut statuses = ToolCallStatuses::new();
360        statuses.on_sub_agent_progress(&make_sub_agent_notification(
361            "parent-1",
362            "explorer",
363            r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{}"},"model_name":"m"}}"#,
364        ));
365
366        assert!(statuses.progress().running_any);
367
368        statuses.finalize_running(true);
369
370        assert!(!statuses.progress().running_any);
371    }
372}