Skip to main content

imp_core/tools/
shell.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::process::Stdio;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use serde_json::{json, Map, Value};
8use tokio::io::AsyncReadExt;
9use tokio::process::Command;
10
11use crate::error::{Error, Result};
12use crate::tools::{truncate_head, truncate_tail, Tool, ToolContext, ToolOutput, ToolRegistry};
13
14const MAX_OUTPUT_LINES: usize = 2000;
15const MAX_OUTPUT_BYTES: usize = 50 * 1024;
16
17/// TOML-defined shell tool definition.
18#[derive(Debug, Clone, serde::Deserialize)]
19pub struct ShellToolDef {
20    pub name: String,
21    pub label: String,
22    pub description: String,
23    #[serde(default)]
24    pub readonly: bool,
25    #[serde(default)]
26    pub params: std::collections::HashMap<String, ShellParamDef>,
27    pub exec: ShellExecDef,
28}
29
30#[derive(Debug, Clone, serde::Deserialize)]
31pub struct ShellParamDef {
32    #[serde(rename = "type")]
33    pub param_type: String,
34    pub description: String,
35    #[serde(default)]
36    pub optional: bool,
37}
38
39#[derive(Debug, Clone, serde::Deserialize)]
40pub struct ShellExecDef {
41    pub command: String,
42    #[serde(default)]
43    pub args: Vec<String>,
44    #[serde(default = "default_timeout")]
45    pub timeout: u32,
46    #[serde(default = "default_truncate")]
47    pub truncate: String,
48    pub install_hint: Option<String>,
49}
50
51#[derive(Debug, Clone)]
52pub struct ShellTool {
53    def: ShellToolDef,
54}
55
56impl ShellTool {
57    fn new(def: ShellToolDef) -> Self {
58        Self { def }
59    }
60}
61
62fn default_timeout() -> u32 {
63    30
64}
65fn default_truncate() -> String {
66    "head".into()
67}
68
69#[async_trait]
70impl Tool for ShellTool {
71    fn name(&self) -> &str {
72        &self.def.name
73    }
74
75    fn label(&self) -> &str {
76        &self.def.label
77    }
78
79    fn description(&self) -> &str {
80        &self.def.description
81    }
82
83    fn parameters(&self) -> Value {
84        let mut properties = Map::new();
85        let mut required = Vec::new();
86
87        let mut param_names: Vec<_> = self.def.params.keys().cloned().collect();
88        param_names.sort();
89
90        for name in param_names {
91            if let Some(def) = self.def.params.get(&name) {
92                properties.insert(
93                    name.clone(),
94                    json!({
95                        "type": def.param_type,
96                        "description": def.description,
97                    }),
98                );
99
100                if !def.optional {
101                    required.push(Value::String(name));
102                }
103            }
104        }
105
106        json!({
107            "type": "object",
108            "properties": Value::Object(properties),
109            "required": Value::Array(required),
110        })
111    }
112
113    fn is_readonly(&self) -> bool {
114        self.def.readonly
115    }
116
117    async fn execute(&self, _call_id: &str, params: Value, ctx: ToolContext) -> Result<ToolOutput> {
118        if ctx.is_cancelled() {
119            return Ok(ToolOutput::error("Tool execution cancelled."));
120        }
121
122        let provided = params.as_object().cloned().unwrap_or_default();
123        validate_required_params(&self.def.params, &provided)?;
124
125        let mut args = Vec::with_capacity(self.def.exec.args.len());
126        for arg in &self.def.exec.args {
127            args.push(interpolate_arg(arg, &self.def.params, &provided)?);
128        }
129
130        let mut command = Command::new(&self.def.exec.command);
131        command
132            .args(&args)
133            .current_dir(&ctx.cwd)
134            // Shell tools are also non-interactive; avoid inheriting the TUI's
135            // stdin so terminal escape sequences cannot leak into child reads.
136            .stdin(Stdio::null())
137            .stdout(Stdio::piped())
138            .stderr(Stdio::piped());
139
140        let mut child = match command.spawn() {
141            Ok(child) => child,
142            Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
143                let mut message = format!(
144                    "Command not found for shell tool '{}': {}",
145                    self.def.name, self.def.exec.command
146                );
147                if let Some(hint) = &self.def.exec.install_hint {
148                    message.push_str(&format!("\nInstall hint: {hint}"));
149                }
150                return Ok(ToolOutput::error(message));
151            }
152            Err(err) => {
153                return Err(Error::Tool(format!(
154                    "failed to spawn shell tool '{}': {err}",
155                    self.def.name
156                )));
157            }
158        };
159
160        let stdout = child
161            .stdout
162            .take()
163            .ok_or_else(|| Error::Tool("failed to capture stdout".into()))?;
164        let stderr = child
165            .stderr
166            .take()
167            .ok_or_else(|| Error::Tool("failed to capture stderr".into()))?;
168
169        let stdout_task = tokio::spawn(async move {
170            let mut reader = tokio::io::BufReader::new(stdout);
171            let mut buffer = Vec::new();
172            reader.read_to_end(&mut buffer).await.map(|_| buffer)
173        });
174        let stderr_task = tokio::spawn(async move {
175            let mut reader = tokio::io::BufReader::new(stderr);
176            let mut buffer = Vec::new();
177            reader.read_to_end(&mut buffer).await.map(|_| buffer)
178        });
179
180        let timeout = std::time::Duration::from_secs(self.def.exec.timeout as u64);
181        let (status, timed_out) = tokio::select! {
182            status = child.wait() => (status?, false),
183            _ = tokio::time::sleep(timeout) => {
184                let _ = child.kill().await;
185                let status = child.wait().await?;
186                (status, true)
187            }
188        };
189
190        let stdout_bytes = stdout_task
191            .await
192            .map_err(|err| Error::Tool(format!("stdout reader task failed: {err}")))??;
193        let stderr_bytes = stderr_task
194            .await
195            .map_err(|err| Error::Tool(format!("stderr reader task failed: {err}")))??;
196
197        let mut combined_output = String::new();
198        let stdout_text = String::from_utf8_lossy(&stdout_bytes);
199        let stderr_text = String::from_utf8_lossy(&stderr_bytes);
200
201        if !stdout_text.is_empty() {
202            combined_output.push_str(&stdout_text);
203        }
204        if !stderr_text.is_empty() {
205            if !combined_output.is_empty() && !combined_output.ends_with('\n') {
206                combined_output.push('\n');
207            }
208            combined_output.push_str(&stderr_text);
209        }
210
211        let truncation = match self.def.exec.truncate.as_str() {
212            "tail" => truncate_tail(&combined_output, MAX_OUTPUT_LINES, MAX_OUTPUT_BYTES),
213            _ => truncate_head(&combined_output, MAX_OUTPUT_LINES, MAX_OUTPUT_BYTES),
214        };
215
216        let mut result_text = truncation.content;
217        if truncation.truncated {
218            let note = format!(
219                "\n[Output truncated: showing {} of {} lines{}]",
220                truncation.output_lines,
221                truncation.total_lines,
222                truncation
223                    .temp_file
224                    .as_ref()
225                    .map(|path| format!(". Full output saved to {}", path.display()))
226                    .unwrap_or_default()
227            );
228            result_text.push_str(&note);
229        }
230        if timed_out {
231            result_text.push_str(&format!(
232                "\n[Command timed out after {}s]",
233                self.def.exec.timeout
234            ));
235        }
236
237        Ok(ToolOutput {
238            content: vec![imp_llm::ContentBlock::Text { text: result_text }],
239            details: json!({
240                "exit_code": status.code().unwrap_or(-1),
241                "timed_out": timed_out,
242                "truncated": truncation.truncated,
243            }),
244            is_error: timed_out || !status.success(),
245        })
246    }
247}
248
249fn validate_required_params(
250    defs: &HashMap<String, ShellParamDef>,
251    provided: &Map<String, Value>,
252) -> Result<()> {
253    let mut missing = Vec::new();
254
255    for (name, def) in defs {
256        if !def.optional && provided.get(name).is_none_or(Value::is_null) {
257            missing.push(name.clone());
258        }
259    }
260
261    missing.sort();
262
263    if missing.is_empty() {
264        Ok(())
265    } else {
266        Err(Error::Tool(format!(
267            "missing required parameter(s): {}",
268            missing.join(", ")
269        )))
270    }
271}
272
273fn interpolate_arg(
274    template: &str,
275    defs: &HashMap<String, ShellParamDef>,
276    provided: &Map<String, Value>,
277) -> Result<String> {
278    let mut result = String::new();
279    let mut remaining = template;
280
281    while let Some(start) = remaining.find('{') {
282        result.push_str(&remaining[..start]);
283
284        let after_start = &remaining[start + 1..];
285        let end = after_start.find('}').ok_or_else(|| {
286            Error::Tool(format!(
287                "unclosed placeholder in shell tool argument: {template}"
288            ))
289        })?;
290
291        let placeholder = &after_start[..end];
292        result.push_str(&resolve_placeholder(placeholder, defs, provided)?);
293        remaining = &after_start[end + 1..];
294    }
295
296    result.push_str(remaining);
297    Ok(result)
298}
299
300fn resolve_placeholder(
301    placeholder: &str,
302    defs: &HashMap<String, ShellParamDef>,
303    provided: &Map<String, Value>,
304) -> Result<String> {
305    let (name, default) = placeholder
306        .split_once('|')
307        .map_or((placeholder, None), |(name, default)| (name, Some(default)));
308
309    if name.is_empty() {
310        return Err(Error::Tool(
311            "empty placeholder in shell tool argument".into(),
312        ));
313    }
314
315    if let Some(value) = provided.get(name).filter(|value| !value.is_null()) {
316        return stringify_param_value(name, value);
317    }
318
319    if let Some(default) = default {
320        return Ok(default.to_string());
321    }
322
323    if defs.get(name).is_some_and(|def| def.optional) {
324        return Ok(String::new());
325    }
326
327    Err(Error::Tool(format!(
328        "missing required parameter for placeholder: {name}"
329    )))
330}
331
332fn stringify_param_value(name: &str, value: &Value) -> Result<String> {
333    match value {
334        Value::String(value) => Ok(value.clone()),
335        Value::Number(value) => Ok(value.to_string()),
336        Value::Bool(value) => Ok(value.to_string()),
337        Value::Null => Ok(String::new()),
338        _ => Err(Error::Tool(format!(
339            "parameter '{name}' must be a string, number, or boolean"
340        ))),
341    }
342}
343
344/// Load shell tools from a directory of TOML definitions.
345pub fn load_shell_tools(dir: &Path, registry: &mut ToolRegistry) -> Result<()> {
346    if !dir.exists() {
347        return Ok(());
348    }
349
350    for entry in walkdir::WalkDir::new(dir) {
351        let entry = entry.map_err(|err| {
352            Error::Tool(format!(
353                "failed to walk shell tool directory {}: {err}",
354                dir.display()
355            ))
356        })?;
357        if !entry.file_type().is_file() {
358            continue;
359        }
360        if entry.path().extension().and_then(|ext| ext.to_str()) != Some("toml") {
361            continue;
362        }
363
364        let content = std::fs::read_to_string(entry.path())?;
365        match toml::from_str::<ShellToolDef>(&content) {
366            Ok(def) => registry.register(Arc::new(ShellTool::new(def))),
367            Err(_err) => {
368                // Keep shell tool discovery side-effect free for embedded callers.
369            }
370        }
371    }
372
373    Ok(())
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use crate::ui::NullInterface;
380    use serde_json::json;
381    use std::sync::atomic::AtomicBool;
382    use std::sync::Arc;
383
384    fn test_ctx(dir: &Path) -> ToolContext {
385        let (tx, _rx) = tokio::sync::mpsc::channel(16);
386        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
387        ToolContext {
388            cwd: dir.to_path_buf(),
389            cancelled: Arc::new(AtomicBool::new(false)),
390            update_tx: tx,
391            command_tx: cmd_tx,
392            ui: Arc::new(NullInterface),
393            file_cache: Arc::new(crate::tools::FileCache::new()),
394            checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
395            file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
396            anchor_store: Arc::new(crate::tools::AnchorStore::new()),
397            lua_tool_loader: None,
398            mode: crate::config::AgentMode::Full,
399            read_max_lines: 500,
400            turn_mana_review: Arc::new(std::sync::Mutex::new(
401                crate::mana_review::TurnManaReviewAccumulator::default(),
402            )),
403            config: Arc::new(crate::config::Config::default()),
404            run_policy: Default::default(),
405            supporting_provenance: Vec::new(),
406        }
407    }
408
409    #[test]
410    fn load_shell_tools_registers_valid_defs_and_skips_invalid_ones() {
411        let temp_dir = tempfile::tempdir().unwrap();
412        let tools_dir = temp_dir.path().join("tools");
413        std::fs::create_dir_all(tools_dir.join("nested")).unwrap();
414
415        std::fs::write(
416            tools_dir.join("nested").join("greet.toml"),
417            r#"
418name = "greet"
419label = "Greet"
420description = "Print a greeting"
421readonly = true
422
423[params.name]
424type = "string"
425description = "Name to greet"
426
427[params.greeting]
428type = "string"
429description = "Greeting text"
430optional = true
431
432[exec]
433command = "printf"
434args = ["%s %s", "{greeting|hello}", "{name}"]
435timeout = 5
436truncate = "head"
437"#,
438        )
439        .unwrap();
440
441        std::fs::write(tools_dir.join("broken.toml"), "not = [valid").unwrap();
442
443        let mut registry = ToolRegistry::new();
444        load_shell_tools(&tools_dir, &mut registry).unwrap();
445
446        let tool = registry.get("greet").expect("tool should be registered");
447        assert_eq!(tool.name(), "greet");
448        assert!(registry.get("broken").is_none());
449    }
450
451    #[tokio::test]
452    async fn shell_tool_executes_with_param_interpolation() {
453        let tool = ShellTool::new(ShellToolDef {
454            name: "greet".into(),
455            label: "Greet".into(),
456            description: "Print a greeting".into(),
457            readonly: true,
458            params: HashMap::from([
459                (
460                    "name".into(),
461                    ShellParamDef {
462                        param_type: "string".into(),
463                        description: "Name to greet".into(),
464                        optional: false,
465                    },
466                ),
467                (
468                    "greeting".into(),
469                    ShellParamDef {
470                        param_type: "string".into(),
471                        description: "Greeting text".into(),
472                        optional: true,
473                    },
474                ),
475            ]),
476            exec: ShellExecDef {
477                command: "printf".into(),
478                args: vec!["%s %s".into(), "{greeting|hello}".into(), "{name}".into()],
479                timeout: 5,
480                truncate: "head".into(),
481                install_hint: None,
482            },
483        });
484
485        let temp_dir = tempfile::tempdir().unwrap();
486        let result = tool
487            .execute(
488                "call-1",
489                json!({ "name": "Asher" }),
490                test_ctx(temp_dir.path()),
491            )
492            .await
493            .unwrap();
494
495        assert!(!result.is_error);
496        let text = match &result.content[0] {
497            imp_llm::ContentBlock::Text { text } => text.clone(),
498            _ => panic!("expected text output"),
499        };
500        assert_eq!(text, "hello Asher");
501        assert_eq!(result.details["exit_code"], 0);
502        assert_eq!(result.details["timed_out"], false);
503    }
504
505    #[tokio::test]
506    async fn shell_tool_default_param_used_when_not_provided() {
507        let tool = ShellTool::new(ShellToolDef {
508            name: "echo_default".into(),
509            label: "Echo Default".into(),
510            description: "Echo with default".into(),
511            readonly: true,
512            params: HashMap::from([(
513                "msg".into(),
514                ShellParamDef {
515                    param_type: "string".into(),
516                    description: "Message".into(),
517                    optional: true,
518                },
519            )]),
520            exec: ShellExecDef {
521                command: "echo".into(),
522                args: vec!["{msg|default_value}".into()],
523                timeout: 5,
524                truncate: "head".into(),
525                install_hint: None,
526            },
527        });
528
529        let temp_dir = tempfile::tempdir().unwrap();
530        let result = tool
531            .execute("call-3", json!({}), test_ctx(temp_dir.path()))
532            .await
533            .unwrap();
534
535        assert!(!result.is_error);
536        let text = match &result.content[0] {
537            imp_llm::ContentBlock::Text { text } => text.clone(),
538            _ => panic!("expected text output"),
539        };
540        assert!(text.contains("default_value"));
541    }
542
543    #[test]
544    fn shell_tool_required_param_missing_errors() {
545        let defs = HashMap::from([(
546            "name".into(),
547            ShellParamDef {
548                param_type: "string".into(),
549                description: "Name".into(),
550                optional: false,
551            },
552        )]);
553        let provided = serde_json::Map::new();
554        let result = validate_required_params(&defs, &provided);
555        assert!(result.is_err());
556        let err_msg = result.unwrap_err().to_string();
557        assert!(err_msg.contains("name"));
558    }
559
560    #[tokio::test]
561    async fn shell_tool_stderr_included_in_output() {
562        let tool = ShellTool::new(ShellToolDef {
563            name: "stderr_test".into(),
564            label: "Stderr Test".into(),
565            description: "Writes to stderr".into(),
566            readonly: true,
567            params: HashMap::new(),
568            exec: ShellExecDef {
569                command: "sh".into(),
570                args: vec!["-c".into(), "echo stdout_msg; echo stderr_msg >&2".into()],
571                timeout: 5,
572                truncate: "head".into(),
573                install_hint: None,
574            },
575        });
576
577        let temp_dir = tempfile::tempdir().unwrap();
578        let result = tool
579            .execute("call-4", json!({}), test_ctx(temp_dir.path()))
580            .await
581            .unwrap();
582
583        assert!(!result.is_error);
584        let text = match &result.content[0] {
585            imp_llm::ContentBlock::Text { text } => text.clone(),
586            _ => panic!("expected text output"),
587        };
588        assert!(text.contains("stdout_msg"));
589        assert!(text.contains("stderr_msg"));
590    }
591
592    #[tokio::test]
593    async fn shell_tool_timeout() {
594        let tool = ShellTool::new(ShellToolDef {
595            name: "slow".into(),
596            label: "Slow".into(),
597            description: "Times out".into(),
598            readonly: true,
599            params: HashMap::new(),
600            exec: ShellExecDef {
601                command: "sleep".into(),
602                args: vec!["60".into()],
603                timeout: 1,
604                truncate: "head".into(),
605                install_hint: None,
606            },
607        });
608
609        let temp_dir = tempfile::tempdir().unwrap();
610        let result = tool
611            .execute("call-5", json!({}), test_ctx(temp_dir.path()))
612            .await
613            .unwrap();
614
615        assert!(result.is_error);
616        assert_eq!(result.details["timed_out"], true);
617    }
618
619    #[tokio::test]
620    async fn shell_tool_reports_missing_commands_with_install_hint() {
621        let tool = ShellTool::new(ShellToolDef {
622            name: "missing".into(),
623            label: "Missing".into(),
624            description: "Missing command".into(),
625            readonly: true,
626            params: HashMap::new(),
627            exec: ShellExecDef {
628                command: "definitely-not-a-real-command".into(),
629                args: Vec::new(),
630                timeout: 5,
631                truncate: "head".into(),
632                install_hint: Some("brew install definitely-not-a-real-command".into()),
633            },
634        });
635
636        let temp_dir = tempfile::tempdir().unwrap();
637        let result = tool
638            .execute("call-2", json!({}), test_ctx(temp_dir.path()))
639            .await
640            .unwrap();
641
642        assert!(result.is_error);
643        let text = match &result.content[0] {
644            imp_llm::ContentBlock::Text { text } => text.clone(),
645            _ => panic!("expected text output"),
646        };
647        assert!(text.contains("Command not found"));
648        assert!(text.contains("Install hint"));
649    }
650}