Skip to main content

imp_core/tools/
shell.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::process::Stdio;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use serde_json::{json, Map, Value};
8use tokio::io::AsyncReadExt;
9use tokio::process::Command;
10
11use crate::error::{Error, Result};
12use crate::tools::{truncate_head, truncate_tail, Tool, ToolContext, ToolOutput, ToolRegistry};
13
14const MAX_OUTPUT_LINES: usize = 2000;
15const MAX_OUTPUT_BYTES: usize = 50 * 1024;
16
17/// TOML-defined shell tool definition.
18#[derive(Debug, Clone, serde::Deserialize)]
19pub struct ShellToolDef {
20    pub name: String,
21    pub label: String,
22    pub description: String,
23    #[serde(default)]
24    pub readonly: bool,
25    #[serde(default)]
26    pub params: std::collections::HashMap<String, ShellParamDef>,
27    pub exec: ShellExecDef,
28}
29
30#[derive(Debug, Clone, serde::Deserialize)]
31pub struct ShellParamDef {
32    #[serde(rename = "type")]
33    pub param_type: String,
34    pub description: String,
35    #[serde(default)]
36    pub optional: bool,
37}
38
39#[derive(Debug, Clone, serde::Deserialize)]
40pub struct ShellExecDef {
41    pub command: String,
42    #[serde(default)]
43    pub args: Vec<String>,
44    #[serde(default = "default_timeout")]
45    pub timeout: u32,
46    #[serde(default = "default_truncate")]
47    pub truncate: String,
48    pub install_hint: Option<String>,
49}
50
51#[derive(Debug, Clone)]
52pub struct ShellTool {
53    def: ShellToolDef,
54}
55
56impl ShellTool {
57    fn new(def: ShellToolDef) -> Self {
58        Self { def }
59    }
60}
61
62fn default_timeout() -> u32 {
63    30
64}
65fn default_truncate() -> String {
66    "head".into()
67}
68
69#[async_trait]
70impl Tool for ShellTool {
71    fn name(&self) -> &str {
72        &self.def.name
73    }
74
75    fn label(&self) -> &str {
76        &self.def.label
77    }
78
79    fn description(&self) -> &str {
80        &self.def.description
81    }
82
83    fn parameters(&self) -> Value {
84        let mut properties = Map::new();
85        let mut required = Vec::new();
86
87        let mut param_names: Vec<_> = self.def.params.keys().cloned().collect();
88        param_names.sort();
89
90        for name in param_names {
91            if let Some(def) = self.def.params.get(&name) {
92                properties.insert(
93                    name.clone(),
94                    json!({
95                        "type": def.param_type,
96                        "description": def.description,
97                    }),
98                );
99
100                if !def.optional {
101                    required.push(Value::String(name));
102                }
103            }
104        }
105
106        json!({
107            "type": "object",
108            "properties": Value::Object(properties),
109            "required": Value::Array(required),
110        })
111    }
112
113    fn is_readonly(&self) -> bool {
114        self.def.readonly
115    }
116
117    async fn execute(&self, _call_id: &str, params: Value, ctx: ToolContext) -> Result<ToolOutput> {
118        if ctx.is_cancelled() {
119            return Ok(ToolOutput::error("Tool execution cancelled."));
120        }
121
122        let provided = params.as_object().cloned().unwrap_or_default();
123        validate_required_params(&self.def.params, &provided)?;
124
125        let mut args = Vec::with_capacity(self.def.exec.args.len());
126        for arg in &self.def.exec.args {
127            args.push(interpolate_arg(arg, &self.def.params, &provided)?);
128        }
129
130        let mut command = Command::new(&self.def.exec.command);
131        command
132            .args(&args)
133            .current_dir(&ctx.cwd)
134            // Shell tools are also non-interactive; avoid inheriting the TUI's
135            // stdin so terminal escape sequences cannot leak into child reads.
136            .stdin(Stdio::null())
137            .stdout(Stdio::piped())
138            .stderr(Stdio::piped());
139
140        let mut child = match command.spawn() {
141            Ok(child) => child,
142            Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
143                let mut message = format!(
144                    "Command not found for shell tool '{}': {}",
145                    self.def.name, self.def.exec.command
146                );
147                if let Some(hint) = &self.def.exec.install_hint {
148                    message.push_str(&format!("\nInstall hint: {hint}"));
149                }
150                return Ok(ToolOutput::error(message));
151            }
152            Err(err) => {
153                return Err(Error::Tool(format!(
154                    "failed to spawn shell tool '{}': {err}",
155                    self.def.name
156                )));
157            }
158        };
159
160        let stdout = child
161            .stdout
162            .take()
163            .ok_or_else(|| Error::Tool("failed to capture stdout".into()))?;
164        let stderr = child
165            .stderr
166            .take()
167            .ok_or_else(|| Error::Tool("failed to capture stderr".into()))?;
168
169        let stdout_task = tokio::spawn(async move {
170            let mut reader = tokio::io::BufReader::new(stdout);
171            let mut buffer = Vec::new();
172            reader.read_to_end(&mut buffer).await.map(|_| buffer)
173        });
174        let stderr_task = tokio::spawn(async move {
175            let mut reader = tokio::io::BufReader::new(stderr);
176            let mut buffer = Vec::new();
177            reader.read_to_end(&mut buffer).await.map(|_| buffer)
178        });
179
180        let timeout = std::time::Duration::from_secs(self.def.exec.timeout as u64);
181        let (status, timed_out) = tokio::select! {
182            status = child.wait() => (status?, false),
183            _ = tokio::time::sleep(timeout) => {
184                let _ = child.kill().await;
185                let status = child.wait().await?;
186                (status, true)
187            }
188        };
189
190        let stdout_bytes = stdout_task
191            .await
192            .map_err(|err| Error::Tool(format!("stdout reader task failed: {err}")))??;
193        let stderr_bytes = stderr_task
194            .await
195            .map_err(|err| Error::Tool(format!("stderr reader task failed: {err}")))??;
196
197        let mut combined_output = String::new();
198        let stdout_text = String::from_utf8_lossy(&stdout_bytes);
199        let stderr_text = String::from_utf8_lossy(&stderr_bytes);
200
201        if !stdout_text.is_empty() {
202            combined_output.push_str(&stdout_text);
203        }
204        if !stderr_text.is_empty() {
205            if !combined_output.is_empty() && !combined_output.ends_with('\n') {
206                combined_output.push('\n');
207            }
208            combined_output.push_str(&stderr_text);
209        }
210
211        let truncation = match self.def.exec.truncate.as_str() {
212            "tail" => truncate_tail(&combined_output, MAX_OUTPUT_LINES, MAX_OUTPUT_BYTES),
213            _ => truncate_head(&combined_output, MAX_OUTPUT_LINES, MAX_OUTPUT_BYTES),
214        };
215
216        let mut result_text = truncation.content;
217        if truncation.truncated {
218            let note = format!(
219                "\n[Output truncated: showing {} of {} lines{}]",
220                truncation.output_lines,
221                truncation.total_lines,
222                truncation
223                    .temp_file
224                    .as_ref()
225                    .map(|path| format!(". Full output saved to {}", path.display()))
226                    .unwrap_or_default()
227            );
228            result_text.push_str(&note);
229        }
230        if timed_out {
231            result_text.push_str(&format!(
232                "\n[Command timed out after {}s]",
233                self.def.exec.timeout
234            ));
235        }
236
237        Ok(ToolOutput {
238            content: vec![imp_llm::ContentBlock::Text { text: result_text }],
239            details: json!({
240                "exit_code": status.code().unwrap_or(-1),
241                "timed_out": timed_out,
242                "truncated": truncation.truncated,
243            }),
244            is_error: timed_out || !status.success(),
245        })
246    }
247}
248
249fn validate_required_params(
250    defs: &HashMap<String, ShellParamDef>,
251    provided: &Map<String, Value>,
252) -> Result<()> {
253    let mut missing = Vec::new();
254
255    for (name, def) in defs {
256        if !def.optional && provided.get(name).is_none_or(Value::is_null) {
257            missing.push(name.clone());
258        }
259    }
260
261    missing.sort();
262
263    if missing.is_empty() {
264        Ok(())
265    } else {
266        Err(Error::Tool(format!(
267            "missing required parameter(s): {}",
268            missing.join(", ")
269        )))
270    }
271}
272
273fn interpolate_arg(
274    template: &str,
275    defs: &HashMap<String, ShellParamDef>,
276    provided: &Map<String, Value>,
277) -> Result<String> {
278    let mut result = String::new();
279    let mut remaining = template;
280
281    while let Some(start) = remaining.find('{') {
282        result.push_str(&remaining[..start]);
283
284        let after_start = &remaining[start + 1..];
285        let end = after_start.find('}').ok_or_else(|| {
286            Error::Tool(format!(
287                "unclosed placeholder in shell tool argument: {template}"
288            ))
289        })?;
290
291        let placeholder = &after_start[..end];
292        result.push_str(&resolve_placeholder(placeholder, defs, provided)?);
293        remaining = &after_start[end + 1..];
294    }
295
296    result.push_str(remaining);
297    Ok(result)
298}
299
300fn resolve_placeholder(
301    placeholder: &str,
302    defs: &HashMap<String, ShellParamDef>,
303    provided: &Map<String, Value>,
304) -> Result<String> {
305    let (name, default) = placeholder
306        .split_once('|')
307        .map_or((placeholder, None), |(name, default)| (name, Some(default)));
308
309    if name.is_empty() {
310        return Err(Error::Tool(
311            "empty placeholder in shell tool argument".into(),
312        ));
313    }
314
315    if let Some(value) = provided.get(name).filter(|value| !value.is_null()) {
316        return stringify_param_value(name, value);
317    }
318
319    if let Some(default) = default {
320        return Ok(default.to_string());
321    }
322
323    if defs.get(name).is_some_and(|def| def.optional) {
324        return Ok(String::new());
325    }
326
327    Err(Error::Tool(format!(
328        "missing required parameter for placeholder: {name}"
329    )))
330}
331
332fn stringify_param_value(name: &str, value: &Value) -> Result<String> {
333    match value {
334        Value::String(value) => Ok(value.clone()),
335        Value::Number(value) => Ok(value.to_string()),
336        Value::Bool(value) => Ok(value.to_string()),
337        Value::Null => Ok(String::new()),
338        _ => Err(Error::Tool(format!(
339            "parameter '{name}' must be a string, number, or boolean"
340        ))),
341    }
342}
343
344/// Load shell tools from a directory of TOML definitions.
345pub fn load_shell_tools(dir: &Path, registry: &mut ToolRegistry) -> Result<()> {
346    if !dir.exists() {
347        return Ok(());
348    }
349
350    for entry in walkdir::WalkDir::new(dir) {
351        let entry = entry.map_err(|err| {
352            Error::Tool(format!(
353                "failed to walk shell tool directory {}: {err}",
354                dir.display()
355            ))
356        })?;
357        if !entry.file_type().is_file() {
358            continue;
359        }
360        if entry.path().extension().and_then(|ext| ext.to_str()) != Some("toml") {
361            continue;
362        }
363
364        let content = std::fs::read_to_string(entry.path())?;
365        match toml::from_str::<ShellToolDef>(&content) {
366            Ok(def) => registry.register(Arc::new(ShellTool::new(def))),
367            Err(_err) => {
368                // Keep shell tool discovery side-effect free for embedded callers.
369            }
370        }
371    }
372
373    Ok(())
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use crate::ui::NullInterface;
380    use serde_json::json;
381    use std::sync::atomic::AtomicBool;
382    use std::sync::Arc;
383
384    fn test_ctx(dir: &Path) -> ToolContext {
385        let (tx, _rx) = tokio::sync::mpsc::channel(16);
386        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
387        ToolContext {
388            cwd: dir.to_path_buf(),
389            cancelled: Arc::new(AtomicBool::new(false)),
390            update_tx: tx,
391            command_tx: cmd_tx,
392            ui: Arc::new(NullInterface),
393            file_cache: Arc::new(crate::tools::FileCache::new()),
394            checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
395            file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
396            anchor_store: Arc::new(crate::tools::AnchorStore::new()),
397            lua_tool_loader: None,
398            mode: crate::config::AgentMode::Full,
399            read_max_lines: 500,
400            turn_mana_review: Arc::new(std::sync::Mutex::new(
401                crate::mana_review::TurnManaReviewAccumulator::default(),
402            )),
403            config: Arc::new(crate::config::Config::default()),
404        }
405    }
406
407    #[test]
408    fn load_shell_tools_registers_valid_defs_and_skips_invalid_ones() {
409        let temp_dir = tempfile::tempdir().unwrap();
410        let tools_dir = temp_dir.path().join("tools");
411        std::fs::create_dir_all(tools_dir.join("nested")).unwrap();
412
413        std::fs::write(
414            tools_dir.join("nested").join("greet.toml"),
415            r#"
416name = "greet"
417label = "Greet"
418description = "Print a greeting"
419readonly = true
420
421[params.name]
422type = "string"
423description = "Name to greet"
424
425[params.greeting]
426type = "string"
427description = "Greeting text"
428optional = true
429
430[exec]
431command = "printf"
432args = ["%s %s", "{greeting|hello}", "{name}"]
433timeout = 5
434truncate = "head"
435"#,
436        )
437        .unwrap();
438
439        std::fs::write(tools_dir.join("broken.toml"), "not = [valid").unwrap();
440
441        let mut registry = ToolRegistry::new();
442        load_shell_tools(&tools_dir, &mut registry).unwrap();
443
444        let tool = registry.get("greet").expect("tool should be registered");
445        assert_eq!(tool.name(), "greet");
446        assert!(registry.get("broken").is_none());
447    }
448
449    #[tokio::test]
450    async fn shell_tool_executes_with_param_interpolation() {
451        let tool = ShellTool::new(ShellToolDef {
452            name: "greet".into(),
453            label: "Greet".into(),
454            description: "Print a greeting".into(),
455            readonly: true,
456            params: HashMap::from([
457                (
458                    "name".into(),
459                    ShellParamDef {
460                        param_type: "string".into(),
461                        description: "Name to greet".into(),
462                        optional: false,
463                    },
464                ),
465                (
466                    "greeting".into(),
467                    ShellParamDef {
468                        param_type: "string".into(),
469                        description: "Greeting text".into(),
470                        optional: true,
471                    },
472                ),
473            ]),
474            exec: ShellExecDef {
475                command: "printf".into(),
476                args: vec!["%s %s".into(), "{greeting|hello}".into(), "{name}".into()],
477                timeout: 5,
478                truncate: "head".into(),
479                install_hint: None,
480            },
481        });
482
483        let temp_dir = tempfile::tempdir().unwrap();
484        let result = tool
485            .execute(
486                "call-1",
487                json!({ "name": "Asher" }),
488                test_ctx(temp_dir.path()),
489            )
490            .await
491            .unwrap();
492
493        assert!(!result.is_error);
494        let text = match &result.content[0] {
495            imp_llm::ContentBlock::Text { text } => text.clone(),
496            _ => panic!("expected text output"),
497        };
498        assert_eq!(text, "hello Asher");
499        assert_eq!(result.details["exit_code"], 0);
500        assert_eq!(result.details["timed_out"], false);
501    }
502
503    #[tokio::test]
504    async fn shell_tool_default_param_used_when_not_provided() {
505        let tool = ShellTool::new(ShellToolDef {
506            name: "echo_default".into(),
507            label: "Echo Default".into(),
508            description: "Echo with default".into(),
509            readonly: true,
510            params: HashMap::from([(
511                "msg".into(),
512                ShellParamDef {
513                    param_type: "string".into(),
514                    description: "Message".into(),
515                    optional: true,
516                },
517            )]),
518            exec: ShellExecDef {
519                command: "echo".into(),
520                args: vec!["{msg|default_value}".into()],
521                timeout: 5,
522                truncate: "head".into(),
523                install_hint: None,
524            },
525        });
526
527        let temp_dir = tempfile::tempdir().unwrap();
528        let result = tool
529            .execute("call-3", json!({}), test_ctx(temp_dir.path()))
530            .await
531            .unwrap();
532
533        assert!(!result.is_error);
534        let text = match &result.content[0] {
535            imp_llm::ContentBlock::Text { text } => text.clone(),
536            _ => panic!("expected text output"),
537        };
538        assert!(text.contains("default_value"));
539    }
540
541    #[test]
542    fn shell_tool_required_param_missing_errors() {
543        let defs = HashMap::from([(
544            "name".into(),
545            ShellParamDef {
546                param_type: "string".into(),
547                description: "Name".into(),
548                optional: false,
549            },
550        )]);
551        let provided = serde_json::Map::new();
552        let result = validate_required_params(&defs, &provided);
553        assert!(result.is_err());
554        let err_msg = result.unwrap_err().to_string();
555        assert!(err_msg.contains("name"));
556    }
557
558    #[tokio::test]
559    async fn shell_tool_stderr_included_in_output() {
560        let tool = ShellTool::new(ShellToolDef {
561            name: "stderr_test".into(),
562            label: "Stderr Test".into(),
563            description: "Writes to stderr".into(),
564            readonly: true,
565            params: HashMap::new(),
566            exec: ShellExecDef {
567                command: "sh".into(),
568                args: vec!["-c".into(), "echo stdout_msg; echo stderr_msg >&2".into()],
569                timeout: 5,
570                truncate: "head".into(),
571                install_hint: None,
572            },
573        });
574
575        let temp_dir = tempfile::tempdir().unwrap();
576        let result = tool
577            .execute("call-4", json!({}), test_ctx(temp_dir.path()))
578            .await
579            .unwrap();
580
581        assert!(!result.is_error);
582        let text = match &result.content[0] {
583            imp_llm::ContentBlock::Text { text } => text.clone(),
584            _ => panic!("expected text output"),
585        };
586        assert!(text.contains("stdout_msg"));
587        assert!(text.contains("stderr_msg"));
588    }
589
590    #[tokio::test]
591    async fn shell_tool_timeout() {
592        let tool = ShellTool::new(ShellToolDef {
593            name: "slow".into(),
594            label: "Slow".into(),
595            description: "Times out".into(),
596            readonly: true,
597            params: HashMap::new(),
598            exec: ShellExecDef {
599                command: "sleep".into(),
600                args: vec!["60".into()],
601                timeout: 1,
602                truncate: "head".into(),
603                install_hint: None,
604            },
605        });
606
607        let temp_dir = tempfile::tempdir().unwrap();
608        let result = tool
609            .execute("call-5", json!({}), test_ctx(temp_dir.path()))
610            .await
611            .unwrap();
612
613        assert!(result.is_error);
614        assert_eq!(result.details["timed_out"], true);
615    }
616
617    #[tokio::test]
618    async fn shell_tool_reports_missing_commands_with_install_hint() {
619        let tool = ShellTool::new(ShellToolDef {
620            name: "missing".into(),
621            label: "Missing".into(),
622            description: "Missing command".into(),
623            readonly: true,
624            params: HashMap::new(),
625            exec: ShellExecDef {
626                command: "definitely-not-a-real-command".into(),
627                args: Vec::new(),
628                timeout: 5,
629                truncate: "head".into(),
630                install_hint: Some("brew install definitely-not-a-real-command".into()),
631            },
632        });
633
634        let temp_dir = tempfile::tempdir().unwrap();
635        let result = tool
636            .execute("call-2", json!({}), test_ctx(temp_dir.path()))
637            .await
638            .unwrap();
639
640        assert!(result.is_error);
641        let text = match &result.content[0] {
642            imp_llm::ContentBlock::Text { text } => text.clone(),
643            _ => panic!("expected text output"),
644        };
645        assert!(text.contains("Command not found"));
646        assert!(text.contains("Install hint"));
647    }
648}