Skip to main content

zeph_tools/
executor.rs

1use std::fmt;
2
3/// Data for rendering file diffs in the TUI.
4#[derive(Debug, Clone)]
5pub struct DiffData {
6    pub file_path: String,
7    pub old_content: String,
8    pub new_content: String,
9}
10
11/// Structured tool invocation from LLM.
12#[derive(Debug, Clone)]
13pub struct ToolCall {
14    pub tool_id: String,
15    pub params: serde_json::Map<String, serde_json::Value>,
16}
17
18/// Cumulative filter statistics for a single tool execution.
19#[derive(Debug, Clone, Default)]
20pub struct FilterStats {
21    pub raw_chars: usize,
22    pub filtered_chars: usize,
23    pub raw_lines: usize,
24    pub filtered_lines: usize,
25    pub confidence: Option<crate::FilterConfidence>,
26    pub command: Option<String>,
27    pub kept_lines: Vec<usize>,
28}
29
30impl FilterStats {
31    #[must_use]
32    #[allow(clippy::cast_precision_loss)]
33    pub fn savings_pct(&self) -> f64 {
34        if self.raw_chars == 0 {
35            return 0.0;
36        }
37        (1.0 - self.filtered_chars as f64 / self.raw_chars as f64) * 100.0
38    }
39
40    #[must_use]
41    pub fn estimated_tokens_saved(&self) -> usize {
42        self.raw_chars.saturating_sub(self.filtered_chars) / 4
43    }
44
45    #[must_use]
46    pub fn format_inline(&self, tool_name: &str) -> String {
47        let cmd_label = self
48            .command
49            .as_deref()
50            .map(|c| {
51                let trimmed = c.trim();
52                if trimmed.len() > 60 {
53                    format!(" `{}…`", &trimmed[..57])
54                } else {
55                    format!(" `{trimmed}`")
56                }
57            })
58            .unwrap_or_default();
59        format!(
60            "[{tool_name}]{cmd_label} {} lines \u{2192} {} lines, {:.1}% filtered",
61            self.raw_lines,
62            self.filtered_lines,
63            self.savings_pct()
64        )
65    }
66}
67
68/// Structured result from tool execution.
69#[derive(Debug, Clone)]
70pub struct ToolOutput {
71    pub tool_name: String,
72    pub summary: String,
73    pub blocks_executed: u32,
74    pub filter_stats: Option<FilterStats>,
75    pub diff: Option<DiffData>,
76    /// Whether this tool already streamed its output via `ToolEvent` channel.
77    pub streamed: bool,
78}
79
80impl fmt::Display for ToolOutput {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        f.write_str(&self.summary)
83    }
84}
85
86pub const MAX_TOOL_OUTPUT_CHARS: usize = 30_000;
87
88/// Truncate tool output that exceeds `MAX_TOOL_OUTPUT_CHARS` using head+tail split.
89#[must_use]
90pub fn truncate_tool_output(output: &str) -> String {
91    if output.len() <= MAX_TOOL_OUTPUT_CHARS {
92        return output.to_string();
93    }
94
95    let half = MAX_TOOL_OUTPUT_CHARS / 2;
96    let head_end = output.floor_char_boundary(half);
97    let tail_start = output.ceil_char_boundary(output.len() - half);
98    let head = &output[..head_end];
99    let tail = &output[tail_start..];
100    let truncated = output.len() - head_end - (output.len() - tail_start);
101
102    format!(
103        "{head}\n\n... [truncated {truncated} chars, showing first and last ~{half} chars] ...\n\n{tail}"
104    )
105}
106
107/// Event emitted during tool execution for real-time UI updates.
108#[derive(Debug, Clone)]
109pub enum ToolEvent {
110    Started {
111        tool_name: String,
112        command: String,
113    },
114    OutputChunk {
115        tool_name: String,
116        command: String,
117        chunk: String,
118    },
119    Completed {
120        tool_name: String,
121        command: String,
122        output: String,
123        success: bool,
124        filter_stats: Option<FilterStats>,
125        diff: Option<DiffData>,
126    },
127}
128
129pub type ToolEventTx = tokio::sync::mpsc::UnboundedSender<ToolEvent>;
130
131/// Errors that can occur during tool execution.
132#[derive(Debug, thiserror::Error)]
133pub enum ToolError {
134    #[error("command blocked by policy: {command}")]
135    Blocked { command: String },
136
137    #[error("path not allowed by sandbox: {path}")]
138    SandboxViolation { path: String },
139
140    #[error("command requires confirmation: {command}")]
141    ConfirmationRequired { command: String },
142
143    #[error("command timed out after {timeout_secs}s")]
144    Timeout { timeout_secs: u64 },
145
146    #[error("operation cancelled")]
147    Cancelled,
148
149    #[error("invalid tool parameters: {message}")]
150    InvalidParams { message: String },
151
152    #[error("execution failed: {0}")]
153    Execution(#[from] std::io::Error),
154}
155
156/// Deserialize tool call params from a `serde_json::Map<String, Value>` into a typed struct.
157///
158/// # Errors
159///
160/// Returns `ToolError::InvalidParams` when deserialization fails.
161pub fn deserialize_params<T: serde::de::DeserializeOwned>(
162    params: &serde_json::Map<String, serde_json::Value>,
163) -> Result<T, ToolError> {
164    let obj = serde_json::Value::Object(params.clone());
165    serde_json::from_value(obj).map_err(|e| ToolError::InvalidParams {
166        message: e.to_string(),
167    })
168}
169
170/// Async trait for tool execution backends (shell, future MCP, A2A).
171///
172/// Accepts the full LLM response and returns an optional output.
173/// Returns `None` when no tool invocation is detected in the response.
174pub trait ToolExecutor: Send + Sync {
175    fn execute(
176        &self,
177        response: &str,
178    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send;
179
180    /// Execute bypassing confirmation checks (called after user approves).
181    /// Default: delegates to `execute`.
182    fn execute_confirmed(
183        &self,
184        response: &str,
185    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
186        self.execute(response)
187    }
188
189    /// Return tool definitions this executor can handle.
190    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
191        vec![]
192    }
193
194    /// Execute a structured tool call. Returns `None` if `tool_id` is not handled.
195    fn execute_tool_call(
196        &self,
197        _call: &ToolCall,
198    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
199        std::future::ready(Ok(None))
200    }
201
202    /// Inject environment variables for the currently active skill. No-op by default.
203    fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
204}
205
206/// Object-safe erased version of [`ToolExecutor`] using boxed futures.
207///
208/// Implemented automatically for all `T: ToolExecutor + 'static`.
209/// Use `Box<dyn ErasedToolExecutor>` when dynamic dispatch is required.
210pub trait ErasedToolExecutor: Send + Sync {
211    fn execute_erased<'a>(
212        &'a self,
213        response: &'a str,
214    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
215
216    fn execute_confirmed_erased<'a>(
217        &'a self,
218        response: &'a str,
219    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
220
221    fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef>;
222
223    fn execute_tool_call_erased<'a>(
224        &'a self,
225        call: &'a ToolCall,
226    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
227
228    /// Inject environment variables for the currently active skill. No-op by default.
229    fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
230}
231
232impl<T: ToolExecutor> ErasedToolExecutor for T {
233    fn execute_erased<'a>(
234        &'a self,
235        response: &'a str,
236    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
237    {
238        Box::pin(self.execute(response))
239    }
240
241    fn execute_confirmed_erased<'a>(
242        &'a self,
243        response: &'a str,
244    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
245    {
246        Box::pin(self.execute_confirmed(response))
247    }
248
249    fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef> {
250        self.tool_definitions()
251    }
252
253    fn execute_tool_call_erased<'a>(
254        &'a self,
255        call: &'a ToolCall,
256    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
257    {
258        Box::pin(self.execute_tool_call(call))
259    }
260
261    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
262        ToolExecutor::set_skill_env(self, env);
263    }
264}
265
266/// Extract fenced code blocks with the given language marker from text.
267///
268/// Searches for `` ```{lang} `` … `` ``` `` pairs, returning trimmed content.
269#[must_use]
270pub fn extract_fenced_blocks<'a>(text: &'a str, lang: &str) -> Vec<&'a str> {
271    let marker = format!("```{lang}");
272    let marker_len = marker.len();
273    let mut blocks = Vec::new();
274    let mut rest = text;
275
276    while let Some(start) = rest.find(&marker) {
277        let after = &rest[start + marker_len..];
278        if let Some(end) = after.find("```") {
279            blocks.push(after[..end].trim());
280            rest = &after[end + 3..];
281        } else {
282            break;
283        }
284    }
285
286    blocks
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn tool_output_display() {
295        let output = ToolOutput {
296            tool_name: "bash".to_owned(),
297            summary: "$ echo hello\nhello".to_owned(),
298            blocks_executed: 1,
299            filter_stats: None,
300            diff: None,
301            streamed: false,
302        };
303        assert_eq!(output.to_string(), "$ echo hello\nhello");
304    }
305
306    #[test]
307    fn tool_error_blocked_display() {
308        let err = ToolError::Blocked {
309            command: "rm -rf /".to_owned(),
310        };
311        assert_eq!(err.to_string(), "command blocked by policy: rm -rf /");
312    }
313
314    #[test]
315    fn tool_error_sandbox_violation_display() {
316        let err = ToolError::SandboxViolation {
317            path: "/etc/shadow".to_owned(),
318        };
319        assert_eq!(err.to_string(), "path not allowed by sandbox: /etc/shadow");
320    }
321
322    #[test]
323    fn tool_error_confirmation_required_display() {
324        let err = ToolError::ConfirmationRequired {
325            command: "rm -rf /tmp".to_owned(),
326        };
327        assert_eq!(
328            err.to_string(),
329            "command requires confirmation: rm -rf /tmp"
330        );
331    }
332
333    #[test]
334    fn tool_error_timeout_display() {
335        let err = ToolError::Timeout { timeout_secs: 30 };
336        assert_eq!(err.to_string(), "command timed out after 30s");
337    }
338
339    #[test]
340    fn tool_error_invalid_params_display() {
341        let err = ToolError::InvalidParams {
342            message: "missing field `command`".to_owned(),
343        };
344        assert_eq!(
345            err.to_string(),
346            "invalid tool parameters: missing field `command`"
347        );
348    }
349
350    #[test]
351    fn deserialize_params_valid() {
352        #[derive(Debug, serde::Deserialize, PartialEq)]
353        struct P {
354            name: String,
355            count: u32,
356        }
357        let mut map = serde_json::Map::new();
358        map.insert("name".to_owned(), serde_json::json!("test"));
359        map.insert("count".to_owned(), serde_json::json!(42));
360        let p: P = deserialize_params(&map).unwrap();
361        assert_eq!(
362            p,
363            P {
364                name: "test".to_owned(),
365                count: 42
366            }
367        );
368    }
369
370    #[test]
371    fn deserialize_params_missing_required_field() {
372        #[derive(Debug, serde::Deserialize)]
373        struct P {
374            #[allow(dead_code)]
375            name: String,
376        }
377        let map = serde_json::Map::new();
378        let err = deserialize_params::<P>(&map).unwrap_err();
379        assert!(matches!(err, ToolError::InvalidParams { .. }));
380    }
381
382    #[test]
383    fn deserialize_params_wrong_type() {
384        #[derive(Debug, serde::Deserialize)]
385        struct P {
386            #[allow(dead_code)]
387            count: u32,
388        }
389        let mut map = serde_json::Map::new();
390        map.insert("count".to_owned(), serde_json::json!("not a number"));
391        let err = deserialize_params::<P>(&map).unwrap_err();
392        assert!(matches!(err, ToolError::InvalidParams { .. }));
393    }
394
395    #[test]
396    fn deserialize_params_all_optional_empty() {
397        #[derive(Debug, serde::Deserialize, PartialEq)]
398        struct P {
399            name: Option<String>,
400        }
401        let map = serde_json::Map::new();
402        let p: P = deserialize_params(&map).unwrap();
403        assert_eq!(p, P { name: None });
404    }
405
406    #[test]
407    fn deserialize_params_ignores_extra_fields() {
408        #[derive(Debug, serde::Deserialize, PartialEq)]
409        struct P {
410            name: String,
411        }
412        let mut map = serde_json::Map::new();
413        map.insert("name".to_owned(), serde_json::json!("test"));
414        map.insert("extra".to_owned(), serde_json::json!(true));
415        let p: P = deserialize_params(&map).unwrap();
416        assert_eq!(
417            p,
418            P {
419                name: "test".to_owned()
420            }
421        );
422    }
423
424    #[test]
425    fn tool_error_execution_display() {
426        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "bash not found");
427        let err = ToolError::Execution(io_err);
428        assert!(err.to_string().starts_with("execution failed:"));
429        assert!(err.to_string().contains("bash not found"));
430    }
431
432    #[test]
433    fn truncate_tool_output_short_passthrough() {
434        let short = "hello world";
435        assert_eq!(truncate_tool_output(short), short);
436    }
437
438    #[test]
439    fn truncate_tool_output_exact_limit() {
440        let exact = "a".repeat(MAX_TOOL_OUTPUT_CHARS);
441        assert_eq!(truncate_tool_output(&exact), exact);
442    }
443
444    #[test]
445    fn truncate_tool_output_long_split() {
446        let long = "x".repeat(MAX_TOOL_OUTPUT_CHARS + 1000);
447        let result = truncate_tool_output(&long);
448        assert!(result.contains("truncated"));
449        assert!(result.len() < long.len());
450    }
451
452    #[test]
453    fn truncate_tool_output_notice_contains_count() {
454        let long = "y".repeat(MAX_TOOL_OUTPUT_CHARS + 2000);
455        let result = truncate_tool_output(&long);
456        assert!(result.contains("truncated"));
457        assert!(result.contains("chars"));
458    }
459
460    #[derive(Debug)]
461    struct DefaultExecutor;
462    impl ToolExecutor for DefaultExecutor {
463        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
464            Ok(None)
465        }
466    }
467
468    #[tokio::test]
469    async fn execute_tool_call_default_returns_none() {
470        let exec = DefaultExecutor;
471        let call = ToolCall {
472            tool_id: "anything".to_owned(),
473            params: serde_json::Map::new(),
474        };
475        let result = exec.execute_tool_call(&call).await.unwrap();
476        assert!(result.is_none());
477    }
478
479    #[test]
480    fn filter_stats_savings_pct() {
481        let fs = FilterStats {
482            raw_chars: 1000,
483            filtered_chars: 200,
484            ..Default::default()
485        };
486        assert!((fs.savings_pct() - 80.0).abs() < 0.01);
487    }
488
489    #[test]
490    fn filter_stats_savings_pct_zero() {
491        let fs = FilterStats::default();
492        assert!((fs.savings_pct()).abs() < 0.01);
493    }
494
495    #[test]
496    fn filter_stats_estimated_tokens_saved() {
497        let fs = FilterStats {
498            raw_chars: 1000,
499            filtered_chars: 200,
500            ..Default::default()
501        };
502        assert_eq!(fs.estimated_tokens_saved(), 200); // (1000 - 200) / 4
503    }
504
505    #[test]
506    fn filter_stats_format_inline() {
507        let fs = FilterStats {
508            raw_chars: 1000,
509            filtered_chars: 200,
510            raw_lines: 342,
511            filtered_lines: 28,
512            ..Default::default()
513        };
514        let line = fs.format_inline("shell");
515        assert_eq!(line, "[shell] 342 lines \u{2192} 28 lines, 80.0% filtered");
516    }
517
518    #[test]
519    fn filter_stats_format_inline_zero() {
520        let fs = FilterStats::default();
521        let line = fs.format_inline("bash");
522        assert_eq!(line, "[bash] 0 lines \u{2192} 0 lines, 0.0% filtered");
523    }
524}