Skip to main content

zeph_tools/
executor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt;
5
6/// Data for rendering file diffs in the TUI.
7#[derive(Debug, Clone)]
8pub struct DiffData {
9    pub file_path: String,
10    pub old_content: String,
11    pub new_content: String,
12}
13
14/// Structured tool invocation from LLM.
15#[derive(Debug, Clone)]
16pub struct ToolCall {
17    pub tool_id: String,
18    pub params: serde_json::Map<String, serde_json::Value>,
19}
20
21/// Cumulative filter statistics for a single tool execution.
22#[derive(Debug, Clone, Default)]
23pub struct FilterStats {
24    pub raw_chars: usize,
25    pub filtered_chars: usize,
26    pub raw_lines: usize,
27    pub filtered_lines: usize,
28    pub confidence: Option<crate::FilterConfidence>,
29    pub command: Option<String>,
30    pub kept_lines: Vec<usize>,
31}
32
33impl FilterStats {
34    #[must_use]
35    #[allow(clippy::cast_precision_loss)]
36    pub fn savings_pct(&self) -> f64 {
37        if self.raw_chars == 0 {
38            return 0.0;
39        }
40        (1.0 - self.filtered_chars as f64 / self.raw_chars as f64) * 100.0
41    }
42
43    #[must_use]
44    pub fn estimated_tokens_saved(&self) -> usize {
45        self.raw_chars.saturating_sub(self.filtered_chars) / 4
46    }
47
48    #[must_use]
49    pub fn format_inline(&self, tool_name: &str) -> String {
50        let cmd_label = self
51            .command
52            .as_deref()
53            .map(|c| {
54                let trimmed = c.trim();
55                if trimmed.len() > 60 {
56                    format!(" `{}…`", &trimmed[..57])
57                } else {
58                    format!(" `{trimmed}`")
59                }
60            })
61            .unwrap_or_default();
62        format!(
63            "[{tool_name}]{cmd_label} {} lines \u{2192} {} lines, {:.1}% filtered",
64            self.raw_lines,
65            self.filtered_lines,
66            self.savings_pct()
67        )
68    }
69}
70
71/// Structured result from tool execution.
72#[derive(Debug, Clone)]
73pub struct ToolOutput {
74    pub tool_name: String,
75    pub summary: String,
76    pub blocks_executed: u32,
77    pub filter_stats: Option<FilterStats>,
78    pub diff: Option<DiffData>,
79    /// Whether this tool already streamed its output via `ToolEvent` channel.
80    pub streamed: bool,
81    /// Terminal ID when the tool was executed via IDE terminal (ACP terminal/* protocol).
82    pub terminal_id: Option<String>,
83    /// File paths touched by this tool call, for IDE follow-along (e.g. `ToolCallLocation`).
84    pub locations: Option<Vec<String>>,
85}
86
87impl fmt::Display for ToolOutput {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        f.write_str(&self.summary)
90    }
91}
92
93pub const MAX_TOOL_OUTPUT_CHARS: usize = 30_000;
94
95/// Truncate tool output that exceeds `MAX_TOOL_OUTPUT_CHARS` using head+tail split.
96#[must_use]
97pub fn truncate_tool_output(output: &str) -> String {
98    if output.len() <= MAX_TOOL_OUTPUT_CHARS {
99        return output.to_string();
100    }
101
102    let half = MAX_TOOL_OUTPUT_CHARS / 2;
103    let head_end = output.floor_char_boundary(half);
104    let tail_start = output.ceil_char_boundary(output.len() - half);
105    let head = &output[..head_end];
106    let tail = &output[tail_start..];
107    let truncated = output.len() - head_end - (output.len() - tail_start);
108
109    format!(
110        "{head}\n\n... [truncated {truncated} chars, showing first and last ~{half} chars] ...\n\n{tail}"
111    )
112}
113
114/// Event emitted during tool execution for real-time UI updates.
115#[derive(Debug, Clone)]
116pub enum ToolEvent {
117    Started {
118        tool_name: String,
119        command: String,
120    },
121    OutputChunk {
122        tool_name: String,
123        command: String,
124        chunk: String,
125    },
126    Completed {
127        tool_name: String,
128        command: String,
129        output: String,
130        success: bool,
131        filter_stats: Option<FilterStats>,
132        diff: Option<DiffData>,
133    },
134}
135
136pub type ToolEventTx = tokio::sync::mpsc::UnboundedSender<ToolEvent>;
137
138/// Errors that can occur during tool execution.
139#[derive(Debug, thiserror::Error)]
140pub enum ToolError {
141    #[error("command blocked by policy: {command}")]
142    Blocked { command: String },
143
144    #[error("path not allowed by sandbox: {path}")]
145    SandboxViolation { path: String },
146
147    #[error("command requires confirmation: {command}")]
148    ConfirmationRequired { command: String },
149
150    #[error("command timed out after {timeout_secs}s")]
151    Timeout { timeout_secs: u64 },
152
153    #[error("operation cancelled")]
154    Cancelled,
155
156    #[error("invalid tool parameters: {message}")]
157    InvalidParams { message: String },
158
159    #[error("execution failed: {0}")]
160    Execution(#[from] std::io::Error),
161}
162
163/// Deserialize tool call params from a `serde_json::Map<String, Value>` into a typed struct.
164///
165/// # Errors
166///
167/// Returns `ToolError::InvalidParams` when deserialization fails.
168pub fn deserialize_params<T: serde::de::DeserializeOwned>(
169    params: &serde_json::Map<String, serde_json::Value>,
170) -> Result<T, ToolError> {
171    let obj = serde_json::Value::Object(params.clone());
172    serde_json::from_value(obj).map_err(|e| ToolError::InvalidParams {
173        message: e.to_string(),
174    })
175}
176
177/// Async trait for tool execution backends (shell, future MCP, A2A).
178///
179/// Accepts the full LLM response and returns an optional output.
180/// Returns `None` when no tool invocation is detected in the response.
181pub trait ToolExecutor: Send + Sync {
182    fn execute(
183        &self,
184        response: &str,
185    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send;
186
187    /// Execute bypassing confirmation checks (called after user approves).
188    /// Default: delegates to `execute`.
189    fn execute_confirmed(
190        &self,
191        response: &str,
192    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
193        self.execute(response)
194    }
195
196    /// Return tool definitions this executor can handle.
197    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
198        vec![]
199    }
200
201    /// Execute a structured tool call. Returns `None` if `tool_id` is not handled.
202    fn execute_tool_call(
203        &self,
204        _call: &ToolCall,
205    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
206        std::future::ready(Ok(None))
207    }
208
209    /// Inject environment variables for the currently active skill. No-op by default.
210    fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
211}
212
213/// Object-safe erased version of [`ToolExecutor`] using boxed futures.
214///
215/// Implemented automatically for all `T: ToolExecutor + 'static`.
216/// Use `Box<dyn ErasedToolExecutor>` when dynamic dispatch is required.
217pub trait ErasedToolExecutor: Send + Sync {
218    fn execute_erased<'a>(
219        &'a self,
220        response: &'a str,
221    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
222
223    fn execute_confirmed_erased<'a>(
224        &'a self,
225        response: &'a str,
226    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
227
228    fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef>;
229
230    fn execute_tool_call_erased<'a>(
231        &'a self,
232        call: &'a ToolCall,
233    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
234
235    /// Inject environment variables for the currently active skill. No-op by default.
236    fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
237}
238
239impl<T: ToolExecutor> ErasedToolExecutor for T {
240    fn execute_erased<'a>(
241        &'a self,
242        response: &'a str,
243    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
244    {
245        Box::pin(self.execute(response))
246    }
247
248    fn execute_confirmed_erased<'a>(
249        &'a self,
250        response: &'a str,
251    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
252    {
253        Box::pin(self.execute_confirmed(response))
254    }
255
256    fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef> {
257        self.tool_definitions()
258    }
259
260    fn execute_tool_call_erased<'a>(
261        &'a self,
262        call: &'a ToolCall,
263    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
264    {
265        Box::pin(self.execute_tool_call(call))
266    }
267
268    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
269        ToolExecutor::set_skill_env(self, env);
270    }
271}
272
273/// Wraps `Arc<dyn ErasedToolExecutor>` so it can be used as a concrete `ToolExecutor`.
274///
275/// Enables dynamic composition of tool executors at runtime without static type chains.
276pub struct DynExecutor(pub std::sync::Arc<dyn ErasedToolExecutor>);
277
278impl ToolExecutor for DynExecutor {
279    fn execute(
280        &self,
281        response: &str,
282    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
283        // Clone data to satisfy the 'static-ish bound: erased futures must not borrow self.
284        let inner = std::sync::Arc::clone(&self.0);
285        let response = response.to_owned();
286        async move { inner.execute_erased(&response).await }
287    }
288
289    fn execute_confirmed(
290        &self,
291        response: &str,
292    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
293        let inner = std::sync::Arc::clone(&self.0);
294        let response = response.to_owned();
295        async move { inner.execute_confirmed_erased(&response).await }
296    }
297
298    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
299        self.0.tool_definitions_erased()
300    }
301
302    fn execute_tool_call(
303        &self,
304        call: &ToolCall,
305    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
306        let inner = std::sync::Arc::clone(&self.0);
307        let call = call.clone();
308        async move { inner.execute_tool_call_erased(&call).await }
309    }
310
311    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
312        ErasedToolExecutor::set_skill_env(self.0.as_ref(), env);
313    }
314}
315
316/// Extract fenced code blocks with the given language marker from text.
317///
318/// Searches for `` ```{lang} `` … `` ``` `` pairs, returning trimmed content.
319#[must_use]
320pub fn extract_fenced_blocks<'a>(text: &'a str, lang: &str) -> Vec<&'a str> {
321    let marker = format!("```{lang}");
322    let marker_len = marker.len();
323    let mut blocks = Vec::new();
324    let mut rest = text;
325
326    while let Some(start) = rest.find(&marker) {
327        let after = &rest[start + marker_len..];
328        if let Some(end) = after.find("```") {
329            blocks.push(after[..end].trim());
330            rest = &after[end + 3..];
331        } else {
332            break;
333        }
334    }
335
336    blocks
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn tool_output_display() {
345        let output = ToolOutput {
346            tool_name: "bash".to_owned(),
347            summary: "$ echo hello\nhello".to_owned(),
348            blocks_executed: 1,
349            filter_stats: None,
350            diff: None,
351            streamed: false,
352            terminal_id: None,
353            locations: None,
354        };
355        assert_eq!(output.to_string(), "$ echo hello\nhello");
356    }
357
358    #[test]
359    fn tool_error_blocked_display() {
360        let err = ToolError::Blocked {
361            command: "rm -rf /".to_owned(),
362        };
363        assert_eq!(err.to_string(), "command blocked by policy: rm -rf /");
364    }
365
366    #[test]
367    fn tool_error_sandbox_violation_display() {
368        let err = ToolError::SandboxViolation {
369            path: "/etc/shadow".to_owned(),
370        };
371        assert_eq!(err.to_string(), "path not allowed by sandbox: /etc/shadow");
372    }
373
374    #[test]
375    fn tool_error_confirmation_required_display() {
376        let err = ToolError::ConfirmationRequired {
377            command: "rm -rf /tmp".to_owned(),
378        };
379        assert_eq!(
380            err.to_string(),
381            "command requires confirmation: rm -rf /tmp"
382        );
383    }
384
385    #[test]
386    fn tool_error_timeout_display() {
387        let err = ToolError::Timeout { timeout_secs: 30 };
388        assert_eq!(err.to_string(), "command timed out after 30s");
389    }
390
391    #[test]
392    fn tool_error_invalid_params_display() {
393        let err = ToolError::InvalidParams {
394            message: "missing field `command`".to_owned(),
395        };
396        assert_eq!(
397            err.to_string(),
398            "invalid tool parameters: missing field `command`"
399        );
400    }
401
402    #[test]
403    fn deserialize_params_valid() {
404        #[derive(Debug, serde::Deserialize, PartialEq)]
405        struct P {
406            name: String,
407            count: u32,
408        }
409        let mut map = serde_json::Map::new();
410        map.insert("name".to_owned(), serde_json::json!("test"));
411        map.insert("count".to_owned(), serde_json::json!(42));
412        let p: P = deserialize_params(&map).unwrap();
413        assert_eq!(
414            p,
415            P {
416                name: "test".to_owned(),
417                count: 42
418            }
419        );
420    }
421
422    #[test]
423    fn deserialize_params_missing_required_field() {
424        #[derive(Debug, serde::Deserialize)]
425        #[allow(dead_code)]
426        struct P {
427            name: String,
428        }
429        let map = serde_json::Map::new();
430        let err = deserialize_params::<P>(&map).unwrap_err();
431        assert!(matches!(err, ToolError::InvalidParams { .. }));
432    }
433
434    #[test]
435    fn deserialize_params_wrong_type() {
436        #[derive(Debug, serde::Deserialize)]
437        #[allow(dead_code)]
438        struct P {
439            count: u32,
440        }
441        let mut map = serde_json::Map::new();
442        map.insert("count".to_owned(), serde_json::json!("not a number"));
443        let err = deserialize_params::<P>(&map).unwrap_err();
444        assert!(matches!(err, ToolError::InvalidParams { .. }));
445    }
446
447    #[test]
448    fn deserialize_params_all_optional_empty() {
449        #[derive(Debug, serde::Deserialize, PartialEq)]
450        struct P {
451            name: Option<String>,
452        }
453        let map = serde_json::Map::new();
454        let p: P = deserialize_params(&map).unwrap();
455        assert_eq!(p, P { name: None });
456    }
457
458    #[test]
459    fn deserialize_params_ignores_extra_fields() {
460        #[derive(Debug, serde::Deserialize, PartialEq)]
461        struct P {
462            name: String,
463        }
464        let mut map = serde_json::Map::new();
465        map.insert("name".to_owned(), serde_json::json!("test"));
466        map.insert("extra".to_owned(), serde_json::json!(true));
467        let p: P = deserialize_params(&map).unwrap();
468        assert_eq!(
469            p,
470            P {
471                name: "test".to_owned()
472            }
473        );
474    }
475
476    #[test]
477    fn tool_error_execution_display() {
478        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "bash not found");
479        let err = ToolError::Execution(io_err);
480        assert!(err.to_string().starts_with("execution failed:"));
481        assert!(err.to_string().contains("bash not found"));
482    }
483
484    #[test]
485    fn truncate_tool_output_short_passthrough() {
486        let short = "hello world";
487        assert_eq!(truncate_tool_output(short), short);
488    }
489
490    #[test]
491    fn truncate_tool_output_exact_limit() {
492        let exact = "a".repeat(MAX_TOOL_OUTPUT_CHARS);
493        assert_eq!(truncate_tool_output(&exact), exact);
494    }
495
496    #[test]
497    fn truncate_tool_output_long_split() {
498        let long = "x".repeat(MAX_TOOL_OUTPUT_CHARS + 1000);
499        let result = truncate_tool_output(&long);
500        assert!(result.contains("truncated"));
501        assert!(result.len() < long.len());
502    }
503
504    #[test]
505    fn truncate_tool_output_notice_contains_count() {
506        let long = "y".repeat(MAX_TOOL_OUTPUT_CHARS + 2000);
507        let result = truncate_tool_output(&long);
508        assert!(result.contains("truncated"));
509        assert!(result.contains("chars"));
510    }
511
512    #[derive(Debug)]
513    struct DefaultExecutor;
514    impl ToolExecutor for DefaultExecutor {
515        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
516            Ok(None)
517        }
518    }
519
520    #[tokio::test]
521    async fn execute_tool_call_default_returns_none() {
522        let exec = DefaultExecutor;
523        let call = ToolCall {
524            tool_id: "anything".to_owned(),
525            params: serde_json::Map::new(),
526        };
527        let result = exec.execute_tool_call(&call).await.unwrap();
528        assert!(result.is_none());
529    }
530
531    #[test]
532    fn filter_stats_savings_pct() {
533        let fs = FilterStats {
534            raw_chars: 1000,
535            filtered_chars: 200,
536            ..Default::default()
537        };
538        assert!((fs.savings_pct() - 80.0).abs() < 0.01);
539    }
540
541    #[test]
542    fn filter_stats_savings_pct_zero() {
543        let fs = FilterStats::default();
544        assert!((fs.savings_pct()).abs() < 0.01);
545    }
546
547    #[test]
548    fn filter_stats_estimated_tokens_saved() {
549        let fs = FilterStats {
550            raw_chars: 1000,
551            filtered_chars: 200,
552            ..Default::default()
553        };
554        assert_eq!(fs.estimated_tokens_saved(), 200); // (1000 - 200) / 4
555    }
556
557    #[test]
558    fn filter_stats_format_inline() {
559        let fs = FilterStats {
560            raw_chars: 1000,
561            filtered_chars: 200,
562            raw_lines: 342,
563            filtered_lines: 28,
564            ..Default::default()
565        };
566        let line = fs.format_inline("shell");
567        assert_eq!(line, "[shell] 342 lines \u{2192} 28 lines, 80.0% filtered");
568    }
569
570    #[test]
571    fn filter_stats_format_inline_zero() {
572        let fs = FilterStats::default();
573        let line = fs.format_inline("bash");
574        assert_eq!(line, "[bash] 0 lines \u{2192} 0 lines, 0.0% filtered");
575    }
576
577    // DynExecutor tests
578
579    struct FixedExecutor {
580        tool_id: &'static str,
581        output: &'static str,
582    }
583
584    impl ToolExecutor for FixedExecutor {
585        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
586            Ok(Some(ToolOutput {
587                tool_name: self.tool_id.to_owned(),
588                summary: self.output.to_owned(),
589                blocks_executed: 1,
590                filter_stats: None,
591                diff: None,
592                streamed: false,
593                terminal_id: None,
594                locations: None,
595            }))
596        }
597
598        fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
599            vec![]
600        }
601
602        async fn execute_tool_call(
603            &self,
604            _call: &ToolCall,
605        ) -> Result<Option<ToolOutput>, ToolError> {
606            Ok(Some(ToolOutput {
607                tool_name: self.tool_id.to_owned(),
608                summary: self.output.to_owned(),
609                blocks_executed: 1,
610                filter_stats: None,
611                diff: None,
612                streamed: false,
613                terminal_id: None,
614                locations: None,
615            }))
616        }
617    }
618
619    #[tokio::test]
620    async fn dyn_executor_execute_delegates() {
621        let inner = std::sync::Arc::new(FixedExecutor {
622            tool_id: "bash",
623            output: "hello",
624        });
625        let exec = DynExecutor(inner);
626        let result = exec.execute("```bash\necho hello\n```").await.unwrap();
627        assert!(result.is_some());
628        assert_eq!(result.unwrap().summary, "hello");
629    }
630
631    #[tokio::test]
632    async fn dyn_executor_execute_confirmed_delegates() {
633        let inner = std::sync::Arc::new(FixedExecutor {
634            tool_id: "bash",
635            output: "confirmed",
636        });
637        let exec = DynExecutor(inner);
638        let result = exec.execute_confirmed("...").await.unwrap();
639        assert!(result.is_some());
640        assert_eq!(result.unwrap().summary, "confirmed");
641    }
642
643    #[test]
644    fn dyn_executor_tool_definitions_delegates() {
645        let inner = std::sync::Arc::new(FixedExecutor {
646            tool_id: "my_tool",
647            output: "",
648        });
649        let exec = DynExecutor(inner);
650        // FixedExecutor returns empty definitions; verify delegation occurs without panic.
651        let defs = exec.tool_definitions();
652        assert!(defs.is_empty());
653    }
654
655    #[tokio::test]
656    async fn dyn_executor_execute_tool_call_delegates() {
657        let inner = std::sync::Arc::new(FixedExecutor {
658            tool_id: "bash",
659            output: "tool_call_result",
660        });
661        let exec = DynExecutor(inner);
662        let call = ToolCall {
663            tool_id: "bash".to_owned(),
664            params: serde_json::Map::new(),
665        };
666        let result = exec.execute_tool_call(&call).await.unwrap();
667        assert!(result.is_some());
668        assert_eq!(result.unwrap().summary, "tool_call_result");
669    }
670}