Skip to main content

zeph_tools/
executor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt;
5
6/// Data for rendering file diffs in the TUI.
7#[derive(Debug, Clone)]
8pub struct DiffData {
9    pub file_path: String,
10    pub old_content: String,
11    pub new_content: String,
12}
13
14/// Structured tool invocation from LLM.
15#[derive(Debug, Clone)]
16pub struct ToolCall {
17    pub tool_id: String,
18    pub params: serde_json::Map<String, serde_json::Value>,
19}
20
21/// Cumulative filter statistics for a single tool execution.
22#[derive(Debug, Clone, Default)]
23pub struct FilterStats {
24    pub raw_chars: usize,
25    pub filtered_chars: usize,
26    pub raw_lines: usize,
27    pub filtered_lines: usize,
28    pub confidence: Option<crate::FilterConfidence>,
29    pub command: Option<String>,
30    pub kept_lines: Vec<usize>,
31}
32
33impl FilterStats {
34    #[must_use]
35    #[allow(clippy::cast_precision_loss)]
36    pub fn savings_pct(&self) -> f64 {
37        if self.raw_chars == 0 {
38            return 0.0;
39        }
40        (1.0 - self.filtered_chars as f64 / self.raw_chars as f64) * 100.0
41    }
42
43    #[must_use]
44    pub fn estimated_tokens_saved(&self) -> usize {
45        self.raw_chars.saturating_sub(self.filtered_chars) / 4
46    }
47
48    #[must_use]
49    pub fn format_inline(&self, tool_name: &str) -> String {
50        let cmd_label = self
51            .command
52            .as_deref()
53            .map(|c| {
54                let trimmed = c.trim();
55                if trimmed.len() > 60 {
56                    format!(" `{}…`", &trimmed[..57])
57                } else {
58                    format!(" `{trimmed}`")
59                }
60            })
61            .unwrap_or_default();
62        format!(
63            "[{tool_name}]{cmd_label} {} lines \u{2192} {} lines, {:.1}% filtered",
64            self.raw_lines,
65            self.filtered_lines,
66            self.savings_pct()
67        )
68    }
69}
70
71/// Structured result from tool execution.
72#[derive(Debug, Clone)]
73pub struct ToolOutput {
74    pub tool_name: String,
75    pub summary: String,
76    pub blocks_executed: u32,
77    pub filter_stats: Option<FilterStats>,
78    pub diff: Option<DiffData>,
79    /// Whether this tool already streamed its output via `ToolEvent` channel.
80    pub streamed: bool,
81    /// Terminal ID when the tool was executed via IDE terminal (ACP terminal/* protocol).
82    pub terminal_id: Option<String>,
83    /// File paths touched by this tool call, for IDE follow-along (e.g. `ToolCallLocation`).
84    pub locations: Option<Vec<String>>,
85    /// Structured tool response payload for ACP intermediate `tool_call_update` notifications.
86    pub raw_response: Option<serde_json::Value>,
87}
88
89impl fmt::Display for ToolOutput {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        f.write_str(&self.summary)
92    }
93}
94
95pub const MAX_TOOL_OUTPUT_CHARS: usize = 30_000;
96
97/// Truncate tool output that exceeds `MAX_TOOL_OUTPUT_CHARS` using head+tail split.
98#[must_use]
99pub fn truncate_tool_output(output: &str) -> String {
100    if output.len() <= MAX_TOOL_OUTPUT_CHARS {
101        return output.to_string();
102    }
103
104    let half = MAX_TOOL_OUTPUT_CHARS / 2;
105    let head_end = output.floor_char_boundary(half);
106    let tail_start = output.ceil_char_boundary(output.len() - half);
107    let head = &output[..head_end];
108    let tail = &output[tail_start..];
109    let truncated = output.len() - head_end - (output.len() - tail_start);
110
111    format!(
112        "{head}\n\n... [truncated {truncated} chars, showing first and last ~{half} chars] ...\n\n{tail}"
113    )
114}
115
116/// Event emitted during tool execution for real-time UI updates.
117#[derive(Debug, Clone)]
118pub enum ToolEvent {
119    Started {
120        tool_name: String,
121        command: String,
122    },
123    OutputChunk {
124        tool_name: String,
125        command: String,
126        chunk: String,
127    },
128    Completed {
129        tool_name: String,
130        command: String,
131        output: String,
132        success: bool,
133        filter_stats: Option<FilterStats>,
134        diff: Option<DiffData>,
135    },
136}
137
138pub type ToolEventTx = tokio::sync::mpsc::UnboundedSender<ToolEvent>;
139
140/// Errors that can occur during tool execution.
141#[derive(Debug, thiserror::Error)]
142pub enum ToolError {
143    #[error("command blocked by policy: {command}")]
144    Blocked { command: String },
145
146    #[error("path not allowed by sandbox: {path}")]
147    SandboxViolation { path: String },
148
149    #[error("command requires confirmation: {command}")]
150    ConfirmationRequired { command: String },
151
152    #[error("command timed out after {timeout_secs}s")]
153    Timeout { timeout_secs: u64 },
154
155    #[error("operation cancelled")]
156    Cancelled,
157
158    #[error("invalid tool parameters: {message}")]
159    InvalidParams { message: String },
160
161    #[error("execution failed: {0}")]
162    Execution(#[from] std::io::Error),
163}
164
165/// Deserialize tool call params from a `serde_json::Map<String, Value>` into a typed struct.
166///
167/// # Errors
168///
169/// Returns `ToolError::InvalidParams` when deserialization fails.
170pub fn deserialize_params<T: serde::de::DeserializeOwned>(
171    params: &serde_json::Map<String, serde_json::Value>,
172) -> Result<T, ToolError> {
173    let obj = serde_json::Value::Object(params.clone());
174    serde_json::from_value(obj).map_err(|e| ToolError::InvalidParams {
175        message: e.to_string(),
176    })
177}
178
179/// Async trait for tool execution backends (shell, future MCP, A2A).
180///
181/// Accepts the full LLM response and returns an optional output.
182/// Returns `None` when no tool invocation is detected in the response.
183pub trait ToolExecutor: Send + Sync {
184    fn execute(
185        &self,
186        response: &str,
187    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send;
188
189    /// Execute bypassing confirmation checks (called after user approves).
190    /// Default: delegates to `execute`.
191    fn execute_confirmed(
192        &self,
193        response: &str,
194    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
195        self.execute(response)
196    }
197
198    /// Return tool definitions this executor can handle.
199    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
200        vec![]
201    }
202
203    /// Execute a structured tool call. Returns `None` if `tool_id` is not handled.
204    fn execute_tool_call(
205        &self,
206        _call: &ToolCall,
207    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
208        std::future::ready(Ok(None))
209    }
210
211    /// Inject environment variables for the currently active skill. No-op by default.
212    fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
213}
214
215/// Object-safe erased version of [`ToolExecutor`] using boxed futures.
216///
217/// Implemented automatically for all `T: ToolExecutor + 'static`.
218/// Use `Box<dyn ErasedToolExecutor>` when dynamic dispatch is required.
219pub trait ErasedToolExecutor: Send + Sync {
220    fn execute_erased<'a>(
221        &'a self,
222        response: &'a str,
223    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
224
225    fn execute_confirmed_erased<'a>(
226        &'a self,
227        response: &'a str,
228    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
229
230    fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef>;
231
232    fn execute_tool_call_erased<'a>(
233        &'a self,
234        call: &'a ToolCall,
235    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
236
237    /// Inject environment variables for the currently active skill. No-op by default.
238    fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
239}
240
241impl<T: ToolExecutor> ErasedToolExecutor for T {
242    fn execute_erased<'a>(
243        &'a self,
244        response: &'a str,
245    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
246    {
247        Box::pin(self.execute(response))
248    }
249
250    fn execute_confirmed_erased<'a>(
251        &'a self,
252        response: &'a str,
253    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
254    {
255        Box::pin(self.execute_confirmed(response))
256    }
257
258    fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef> {
259        self.tool_definitions()
260    }
261
262    fn execute_tool_call_erased<'a>(
263        &'a self,
264        call: &'a ToolCall,
265    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
266    {
267        Box::pin(self.execute_tool_call(call))
268    }
269
270    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
271        ToolExecutor::set_skill_env(self, env);
272    }
273}
274
275/// Wraps `Arc<dyn ErasedToolExecutor>` so it can be used as a concrete `ToolExecutor`.
276///
277/// Enables dynamic composition of tool executors at runtime without static type chains.
278pub struct DynExecutor(pub std::sync::Arc<dyn ErasedToolExecutor>);
279
280impl ToolExecutor for DynExecutor {
281    fn execute(
282        &self,
283        response: &str,
284    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
285        // Clone data to satisfy the 'static-ish bound: erased futures must not borrow self.
286        let inner = std::sync::Arc::clone(&self.0);
287        let response = response.to_owned();
288        async move { inner.execute_erased(&response).await }
289    }
290
291    fn execute_confirmed(
292        &self,
293        response: &str,
294    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
295        let inner = std::sync::Arc::clone(&self.0);
296        let response = response.to_owned();
297        async move { inner.execute_confirmed_erased(&response).await }
298    }
299
300    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
301        self.0.tool_definitions_erased()
302    }
303
304    fn execute_tool_call(
305        &self,
306        call: &ToolCall,
307    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
308        let inner = std::sync::Arc::clone(&self.0);
309        let call = call.clone();
310        async move { inner.execute_tool_call_erased(&call).await }
311    }
312
313    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
314        ErasedToolExecutor::set_skill_env(self.0.as_ref(), env);
315    }
316}
317
318/// Extract fenced code blocks with the given language marker from text.
319///
320/// Searches for `` ```{lang} `` … `` ``` `` pairs, returning trimmed content.
321#[must_use]
322pub fn extract_fenced_blocks<'a>(text: &'a str, lang: &str) -> Vec<&'a str> {
323    let marker = format!("```{lang}");
324    let marker_len = marker.len();
325    let mut blocks = Vec::new();
326    let mut rest = text;
327
328    while let Some(start) = rest.find(&marker) {
329        let after = &rest[start + marker_len..];
330        if let Some(end) = after.find("```") {
331            blocks.push(after[..end].trim());
332            rest = &after[end + 3..];
333        } else {
334            break;
335        }
336    }
337
338    blocks
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn tool_output_display() {
347        let output = ToolOutput {
348            tool_name: "bash".to_owned(),
349            summary: "$ echo hello\nhello".to_owned(),
350            blocks_executed: 1,
351            filter_stats: None,
352            diff: None,
353            streamed: false,
354            terminal_id: None,
355            locations: None,
356            raw_response: None,
357        };
358        assert_eq!(output.to_string(), "$ echo hello\nhello");
359    }
360
361    #[test]
362    fn tool_error_blocked_display() {
363        let err = ToolError::Blocked {
364            command: "rm -rf /".to_owned(),
365        };
366        assert_eq!(err.to_string(), "command blocked by policy: rm -rf /");
367    }
368
369    #[test]
370    fn tool_error_sandbox_violation_display() {
371        let err = ToolError::SandboxViolation {
372            path: "/etc/shadow".to_owned(),
373        };
374        assert_eq!(err.to_string(), "path not allowed by sandbox: /etc/shadow");
375    }
376
377    #[test]
378    fn tool_error_confirmation_required_display() {
379        let err = ToolError::ConfirmationRequired {
380            command: "rm -rf /tmp".to_owned(),
381        };
382        assert_eq!(
383            err.to_string(),
384            "command requires confirmation: rm -rf /tmp"
385        );
386    }
387
388    #[test]
389    fn tool_error_timeout_display() {
390        let err = ToolError::Timeout { timeout_secs: 30 };
391        assert_eq!(err.to_string(), "command timed out after 30s");
392    }
393
394    #[test]
395    fn tool_error_invalid_params_display() {
396        let err = ToolError::InvalidParams {
397            message: "missing field `command`".to_owned(),
398        };
399        assert_eq!(
400            err.to_string(),
401            "invalid tool parameters: missing field `command`"
402        );
403    }
404
405    #[test]
406    fn deserialize_params_valid() {
407        #[derive(Debug, serde::Deserialize, PartialEq)]
408        struct P {
409            name: String,
410            count: u32,
411        }
412        let mut map = serde_json::Map::new();
413        map.insert("name".to_owned(), serde_json::json!("test"));
414        map.insert("count".to_owned(), serde_json::json!(42));
415        let p: P = deserialize_params(&map).unwrap();
416        assert_eq!(
417            p,
418            P {
419                name: "test".to_owned(),
420                count: 42
421            }
422        );
423    }
424
425    #[test]
426    fn deserialize_params_missing_required_field() {
427        #[derive(Debug, serde::Deserialize)]
428        #[allow(dead_code)]
429        struct P {
430            name: String,
431        }
432        let map = serde_json::Map::new();
433        let err = deserialize_params::<P>(&map).unwrap_err();
434        assert!(matches!(err, ToolError::InvalidParams { .. }));
435    }
436
437    #[test]
438    fn deserialize_params_wrong_type() {
439        #[derive(Debug, serde::Deserialize)]
440        #[allow(dead_code)]
441        struct P {
442            count: u32,
443        }
444        let mut map = serde_json::Map::new();
445        map.insert("count".to_owned(), serde_json::json!("not a number"));
446        let err = deserialize_params::<P>(&map).unwrap_err();
447        assert!(matches!(err, ToolError::InvalidParams { .. }));
448    }
449
450    #[test]
451    fn deserialize_params_all_optional_empty() {
452        #[derive(Debug, serde::Deserialize, PartialEq)]
453        struct P {
454            name: Option<String>,
455        }
456        let map = serde_json::Map::new();
457        let p: P = deserialize_params(&map).unwrap();
458        assert_eq!(p, P { name: None });
459    }
460
461    #[test]
462    fn deserialize_params_ignores_extra_fields() {
463        #[derive(Debug, serde::Deserialize, PartialEq)]
464        struct P {
465            name: String,
466        }
467        let mut map = serde_json::Map::new();
468        map.insert("name".to_owned(), serde_json::json!("test"));
469        map.insert("extra".to_owned(), serde_json::json!(true));
470        let p: P = deserialize_params(&map).unwrap();
471        assert_eq!(
472            p,
473            P {
474                name: "test".to_owned()
475            }
476        );
477    }
478
479    #[test]
480    fn tool_error_execution_display() {
481        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "bash not found");
482        let err = ToolError::Execution(io_err);
483        assert!(err.to_string().starts_with("execution failed:"));
484        assert!(err.to_string().contains("bash not found"));
485    }
486
487    #[test]
488    fn truncate_tool_output_short_passthrough() {
489        let short = "hello world";
490        assert_eq!(truncate_tool_output(short), short);
491    }
492
493    #[test]
494    fn truncate_tool_output_exact_limit() {
495        let exact = "a".repeat(MAX_TOOL_OUTPUT_CHARS);
496        assert_eq!(truncate_tool_output(&exact), exact);
497    }
498
499    #[test]
500    fn truncate_tool_output_long_split() {
501        let long = "x".repeat(MAX_TOOL_OUTPUT_CHARS + 1000);
502        let result = truncate_tool_output(&long);
503        assert!(result.contains("truncated"));
504        assert!(result.len() < long.len());
505    }
506
507    #[test]
508    fn truncate_tool_output_notice_contains_count() {
509        let long = "y".repeat(MAX_TOOL_OUTPUT_CHARS + 2000);
510        let result = truncate_tool_output(&long);
511        assert!(result.contains("truncated"));
512        assert!(result.contains("chars"));
513    }
514
515    #[derive(Debug)]
516    struct DefaultExecutor;
517    impl ToolExecutor for DefaultExecutor {
518        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
519            Ok(None)
520        }
521    }
522
523    #[tokio::test]
524    async fn execute_tool_call_default_returns_none() {
525        let exec = DefaultExecutor;
526        let call = ToolCall {
527            tool_id: "anything".to_owned(),
528            params: serde_json::Map::new(),
529        };
530        let result = exec.execute_tool_call(&call).await.unwrap();
531        assert!(result.is_none());
532    }
533
534    #[test]
535    fn filter_stats_savings_pct() {
536        let fs = FilterStats {
537            raw_chars: 1000,
538            filtered_chars: 200,
539            ..Default::default()
540        };
541        assert!((fs.savings_pct() - 80.0).abs() < 0.01);
542    }
543
544    #[test]
545    fn filter_stats_savings_pct_zero() {
546        let fs = FilterStats::default();
547        assert!((fs.savings_pct()).abs() < 0.01);
548    }
549
550    #[test]
551    fn filter_stats_estimated_tokens_saved() {
552        let fs = FilterStats {
553            raw_chars: 1000,
554            filtered_chars: 200,
555            ..Default::default()
556        };
557        assert_eq!(fs.estimated_tokens_saved(), 200); // (1000 - 200) / 4
558    }
559
560    #[test]
561    fn filter_stats_format_inline() {
562        let fs = FilterStats {
563            raw_chars: 1000,
564            filtered_chars: 200,
565            raw_lines: 342,
566            filtered_lines: 28,
567            ..Default::default()
568        };
569        let line = fs.format_inline("shell");
570        assert_eq!(line, "[shell] 342 lines \u{2192} 28 lines, 80.0% filtered");
571    }
572
573    #[test]
574    fn filter_stats_format_inline_zero() {
575        let fs = FilterStats::default();
576        let line = fs.format_inline("bash");
577        assert_eq!(line, "[bash] 0 lines \u{2192} 0 lines, 0.0% filtered");
578    }
579
580    // DynExecutor tests
581
582    struct FixedExecutor {
583        tool_id: &'static str,
584        output: &'static str,
585    }
586
587    impl ToolExecutor for FixedExecutor {
588        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
589            Ok(Some(ToolOutput {
590                tool_name: self.tool_id.to_owned(),
591                summary: self.output.to_owned(),
592                blocks_executed: 1,
593                filter_stats: None,
594                diff: None,
595                streamed: false,
596                terminal_id: None,
597                locations: None,
598                raw_response: None,
599            }))
600        }
601
602        fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
603            vec![]
604        }
605
606        async fn execute_tool_call(
607            &self,
608            _call: &ToolCall,
609        ) -> Result<Option<ToolOutput>, ToolError> {
610            Ok(Some(ToolOutput {
611                tool_name: self.tool_id.to_owned(),
612                summary: self.output.to_owned(),
613                blocks_executed: 1,
614                filter_stats: None,
615                diff: None,
616                streamed: false,
617                terminal_id: None,
618                locations: None,
619                raw_response: None,
620            }))
621        }
622    }
623
624    #[tokio::test]
625    async fn dyn_executor_execute_delegates() {
626        let inner = std::sync::Arc::new(FixedExecutor {
627            tool_id: "bash",
628            output: "hello",
629        });
630        let exec = DynExecutor(inner);
631        let result = exec.execute("```bash\necho hello\n```").await.unwrap();
632        assert!(result.is_some());
633        assert_eq!(result.unwrap().summary, "hello");
634    }
635
636    #[tokio::test]
637    async fn dyn_executor_execute_confirmed_delegates() {
638        let inner = std::sync::Arc::new(FixedExecutor {
639            tool_id: "bash",
640            output: "confirmed",
641        });
642        let exec = DynExecutor(inner);
643        let result = exec.execute_confirmed("...").await.unwrap();
644        assert!(result.is_some());
645        assert_eq!(result.unwrap().summary, "confirmed");
646    }
647
648    #[test]
649    fn dyn_executor_tool_definitions_delegates() {
650        let inner = std::sync::Arc::new(FixedExecutor {
651            tool_id: "my_tool",
652            output: "",
653        });
654        let exec = DynExecutor(inner);
655        // FixedExecutor returns empty definitions; verify delegation occurs without panic.
656        let defs = exec.tool_definitions();
657        assert!(defs.is_empty());
658    }
659
660    #[tokio::test]
661    async fn dyn_executor_execute_tool_call_delegates() {
662        let inner = std::sync::Arc::new(FixedExecutor {
663            tool_id: "bash",
664            output: "tool_call_result",
665        });
666        let exec = DynExecutor(inner);
667        let call = ToolCall {
668            tool_id: "bash".to_owned(),
669            params: serde_json::Map::new(),
670        };
671        let result = exec.execute_tool_call(&call).await.unwrap();
672        assert!(result.is_some());
673        assert_eq!(result.unwrap().summary, "tool_call_result");
674    }
675}