1use std::collections::HashMap;
2use std::path::Path;
3use std::process::Stdio;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use serde_json::{json, Map, Value};
8use tokio::io::AsyncReadExt;
9use tokio::process::Command;
10
11use crate::error::{Error, Result};
12use crate::tools::{truncate_head, truncate_tail, Tool, ToolContext, ToolOutput, ToolRegistry};
13
14const MAX_OUTPUT_LINES: usize = 2000;
15const MAX_OUTPUT_BYTES: usize = 50 * 1024;
16
17#[derive(Debug, Clone, serde::Deserialize)]
19pub struct ShellToolDef {
20 pub name: String,
21 pub label: String,
22 pub description: String,
23 #[serde(default)]
24 pub readonly: bool,
25 #[serde(default)]
26 pub params: std::collections::HashMap<String, ShellParamDef>,
27 pub exec: ShellExecDef,
28}
29
30#[derive(Debug, Clone, serde::Deserialize)]
31pub struct ShellParamDef {
32 #[serde(rename = "type")]
33 pub param_type: String,
34 pub description: String,
35 #[serde(default)]
36 pub optional: bool,
37}
38
39#[derive(Debug, Clone, serde::Deserialize)]
40pub struct ShellExecDef {
41 pub command: String,
42 #[serde(default)]
43 pub args: Vec<String>,
44 #[serde(default = "default_timeout")]
45 pub timeout: u32,
46 #[serde(default = "default_truncate")]
47 pub truncate: String,
48 pub install_hint: Option<String>,
49}
50
51#[derive(Debug, Clone)]
52pub struct ShellTool {
53 def: ShellToolDef,
54}
55
56impl ShellTool {
57 fn new(def: ShellToolDef) -> Self {
58 Self { def }
59 }
60}
61
62fn default_timeout() -> u32 {
63 30
64}
65fn default_truncate() -> String {
66 "head".into()
67}
68
69#[async_trait]
70impl Tool for ShellTool {
71 fn name(&self) -> &str {
72 &self.def.name
73 }
74
75 fn label(&self) -> &str {
76 &self.def.label
77 }
78
79 fn description(&self) -> &str {
80 &self.def.description
81 }
82
83 fn parameters(&self) -> Value {
84 let mut properties = Map::new();
85 let mut required = Vec::new();
86
87 let mut param_names: Vec<_> = self.def.params.keys().cloned().collect();
88 param_names.sort();
89
90 for name in param_names {
91 if let Some(def) = self.def.params.get(&name) {
92 properties.insert(
93 name.clone(),
94 json!({
95 "type": def.param_type,
96 "description": def.description,
97 }),
98 );
99
100 if !def.optional {
101 required.push(Value::String(name));
102 }
103 }
104 }
105
106 json!({
107 "type": "object",
108 "properties": Value::Object(properties),
109 "required": Value::Array(required),
110 })
111 }
112
113 fn is_readonly(&self) -> bool {
114 self.def.readonly
115 }
116
117 async fn execute(&self, _call_id: &str, params: Value, ctx: ToolContext) -> Result<ToolOutput> {
118 if ctx.is_cancelled() {
119 return Ok(ToolOutput::error("Tool execution cancelled."));
120 }
121
122 let provided = params.as_object().cloned().unwrap_or_default();
123 validate_required_params(&self.def.params, &provided)?;
124
125 let mut args = Vec::with_capacity(self.def.exec.args.len());
126 for arg in &self.def.exec.args {
127 args.push(interpolate_arg(arg, &self.def.params, &provided)?);
128 }
129
130 let mut command = Command::new(&self.def.exec.command);
131 command
132 .args(&args)
133 .current_dir(&ctx.cwd)
134 .stdin(Stdio::null())
137 .stdout(Stdio::piped())
138 .stderr(Stdio::piped());
139
140 let mut child = match command.spawn() {
141 Ok(child) => child,
142 Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
143 let mut message = format!(
144 "Command not found for shell tool '{}': {}",
145 self.def.name, self.def.exec.command
146 );
147 if let Some(hint) = &self.def.exec.install_hint {
148 message.push_str(&format!("\nInstall hint: {hint}"));
149 }
150 return Ok(ToolOutput::error(message));
151 }
152 Err(err) => {
153 return Err(Error::Tool(format!(
154 "failed to spawn shell tool '{}': {err}",
155 self.def.name
156 )));
157 }
158 };
159
160 let stdout = child
161 .stdout
162 .take()
163 .ok_or_else(|| Error::Tool("failed to capture stdout".into()))?;
164 let stderr = child
165 .stderr
166 .take()
167 .ok_or_else(|| Error::Tool("failed to capture stderr".into()))?;
168
169 let stdout_task = tokio::spawn(async move {
170 let mut reader = tokio::io::BufReader::new(stdout);
171 let mut buffer = Vec::new();
172 reader.read_to_end(&mut buffer).await.map(|_| buffer)
173 });
174 let stderr_task = tokio::spawn(async move {
175 let mut reader = tokio::io::BufReader::new(stderr);
176 let mut buffer = Vec::new();
177 reader.read_to_end(&mut buffer).await.map(|_| buffer)
178 });
179
180 let timeout = std::time::Duration::from_secs(self.def.exec.timeout as u64);
181 let (status, timed_out) = tokio::select! {
182 status = child.wait() => (status?, false),
183 _ = tokio::time::sleep(timeout) => {
184 let _ = child.kill().await;
185 let status = child.wait().await?;
186 (status, true)
187 }
188 };
189
190 let stdout_bytes = stdout_task
191 .await
192 .map_err(|err| Error::Tool(format!("stdout reader task failed: {err}")))??;
193 let stderr_bytes = stderr_task
194 .await
195 .map_err(|err| Error::Tool(format!("stderr reader task failed: {err}")))??;
196
197 let mut combined_output = String::new();
198 let stdout_text = String::from_utf8_lossy(&stdout_bytes);
199 let stderr_text = String::from_utf8_lossy(&stderr_bytes);
200
201 if !stdout_text.is_empty() {
202 combined_output.push_str(&stdout_text);
203 }
204 if !stderr_text.is_empty() {
205 if !combined_output.is_empty() && !combined_output.ends_with('\n') {
206 combined_output.push('\n');
207 }
208 combined_output.push_str(&stderr_text);
209 }
210
211 let truncation = match self.def.exec.truncate.as_str() {
212 "tail" => truncate_tail(&combined_output, MAX_OUTPUT_LINES, MAX_OUTPUT_BYTES),
213 _ => truncate_head(&combined_output, MAX_OUTPUT_LINES, MAX_OUTPUT_BYTES),
214 };
215
216 let mut result_text = truncation.content;
217 if truncation.truncated {
218 let note = format!(
219 "\n[Output truncated: showing {} of {} lines{}]",
220 truncation.output_lines,
221 truncation.total_lines,
222 truncation
223 .temp_file
224 .as_ref()
225 .map(|path| format!(". Full output saved to {}", path.display()))
226 .unwrap_or_default()
227 );
228 result_text.push_str(¬e);
229 }
230 if timed_out {
231 result_text.push_str(&format!(
232 "\n[Command timed out after {}s]",
233 self.def.exec.timeout
234 ));
235 }
236
237 Ok(ToolOutput {
238 content: vec![imp_llm::ContentBlock::Text { text: result_text }],
239 details: json!({
240 "exit_code": status.code().unwrap_or(-1),
241 "timed_out": timed_out,
242 "truncated": truncation.truncated,
243 }),
244 is_error: timed_out || !status.success(),
245 })
246 }
247}
248
249fn validate_required_params(
250 defs: &HashMap<String, ShellParamDef>,
251 provided: &Map<String, Value>,
252) -> Result<()> {
253 let mut missing = Vec::new();
254
255 for (name, def) in defs {
256 if !def.optional && provided.get(name).is_none_or(Value::is_null) {
257 missing.push(name.clone());
258 }
259 }
260
261 missing.sort();
262
263 if missing.is_empty() {
264 Ok(())
265 } else {
266 Err(Error::Tool(format!(
267 "missing required parameter(s): {}",
268 missing.join(", ")
269 )))
270 }
271}
272
273fn interpolate_arg(
274 template: &str,
275 defs: &HashMap<String, ShellParamDef>,
276 provided: &Map<String, Value>,
277) -> Result<String> {
278 let mut result = String::new();
279 let mut remaining = template;
280
281 while let Some(start) = remaining.find('{') {
282 result.push_str(&remaining[..start]);
283
284 let after_start = &remaining[start + 1..];
285 let end = after_start.find('}').ok_or_else(|| {
286 Error::Tool(format!(
287 "unclosed placeholder in shell tool argument: {template}"
288 ))
289 })?;
290
291 let placeholder = &after_start[..end];
292 result.push_str(&resolve_placeholder(placeholder, defs, provided)?);
293 remaining = &after_start[end + 1..];
294 }
295
296 result.push_str(remaining);
297 Ok(result)
298}
299
300fn resolve_placeholder(
301 placeholder: &str,
302 defs: &HashMap<String, ShellParamDef>,
303 provided: &Map<String, Value>,
304) -> Result<String> {
305 let (name, default) = placeholder
306 .split_once('|')
307 .map_or((placeholder, None), |(name, default)| (name, Some(default)));
308
309 if name.is_empty() {
310 return Err(Error::Tool(
311 "empty placeholder in shell tool argument".into(),
312 ));
313 }
314
315 if let Some(value) = provided.get(name).filter(|value| !value.is_null()) {
316 return stringify_param_value(name, value);
317 }
318
319 if let Some(default) = default {
320 return Ok(default.to_string());
321 }
322
323 if defs.get(name).is_some_and(|def| def.optional) {
324 return Ok(String::new());
325 }
326
327 Err(Error::Tool(format!(
328 "missing required parameter for placeholder: {name}"
329 )))
330}
331
332fn stringify_param_value(name: &str, value: &Value) -> Result<String> {
333 match value {
334 Value::String(value) => Ok(value.clone()),
335 Value::Number(value) => Ok(value.to_string()),
336 Value::Bool(value) => Ok(value.to_string()),
337 Value::Null => Ok(String::new()),
338 _ => Err(Error::Tool(format!(
339 "parameter '{name}' must be a string, number, or boolean"
340 ))),
341 }
342}
343
344pub fn load_shell_tools(dir: &Path, registry: &mut ToolRegistry) -> Result<()> {
346 if !dir.exists() {
347 return Ok(());
348 }
349
350 for entry in walkdir::WalkDir::new(dir) {
351 let entry = entry.map_err(|err| {
352 Error::Tool(format!(
353 "failed to walk shell tool directory {}: {err}",
354 dir.display()
355 ))
356 })?;
357 if !entry.file_type().is_file() {
358 continue;
359 }
360 if entry.path().extension().and_then(|ext| ext.to_str()) != Some("toml") {
361 continue;
362 }
363
364 let content = std::fs::read_to_string(entry.path())?;
365 match toml::from_str::<ShellToolDef>(&content) {
366 Ok(def) => registry.register(Arc::new(ShellTool::new(def))),
367 Err(_err) => {
368 }
370 }
371 }
372
373 Ok(())
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use crate::ui::NullInterface;
380 use serde_json::json;
381 use std::sync::atomic::AtomicBool;
382 use std::sync::Arc;
383
384 fn test_ctx(dir: &Path) -> ToolContext {
385 let (tx, _rx) = tokio::sync::mpsc::channel(16);
386 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
387 ToolContext {
388 cwd: dir.to_path_buf(),
389 cancelled: Arc::new(AtomicBool::new(false)),
390 update_tx: tx,
391 command_tx: cmd_tx,
392 ui: Arc::new(NullInterface),
393 file_cache: Arc::new(crate::tools::FileCache::new()),
394 checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
395 file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
396 anchor_store: Arc::new(crate::tools::AnchorStore::new()),
397 lua_tool_loader: None,
398 mode: crate::config::AgentMode::Full,
399 read_max_lines: 500,
400 turn_mana_review: Arc::new(std::sync::Mutex::new(
401 crate::mana_review::TurnManaReviewAccumulator::default(),
402 )),
403 config: Arc::new(crate::config::Config::default()),
404 run_policy: Default::default(),
405 supporting_provenance: Vec::new(),
406 }
407 }
408
409 #[test]
410 fn load_shell_tools_registers_valid_defs_and_skips_invalid_ones() {
411 let temp_dir = tempfile::tempdir().unwrap();
412 let tools_dir = temp_dir.path().join("tools");
413 std::fs::create_dir_all(tools_dir.join("nested")).unwrap();
414
415 std::fs::write(
416 tools_dir.join("nested").join("greet.toml"),
417 r#"
418name = "greet"
419label = "Greet"
420description = "Print a greeting"
421readonly = true
422
423[params.name]
424type = "string"
425description = "Name to greet"
426
427[params.greeting]
428type = "string"
429description = "Greeting text"
430optional = true
431
432[exec]
433command = "printf"
434args = ["%s %s", "{greeting|hello}", "{name}"]
435timeout = 5
436truncate = "head"
437"#,
438 )
439 .unwrap();
440
441 std::fs::write(tools_dir.join("broken.toml"), "not = [valid").unwrap();
442
443 let mut registry = ToolRegistry::new();
444 load_shell_tools(&tools_dir, &mut registry).unwrap();
445
446 let tool = registry.get("greet").expect("tool should be registered");
447 assert_eq!(tool.name(), "greet");
448 assert!(registry.get("broken").is_none());
449 }
450
451 #[tokio::test]
452 async fn shell_tool_executes_with_param_interpolation() {
453 let tool = ShellTool::new(ShellToolDef {
454 name: "greet".into(),
455 label: "Greet".into(),
456 description: "Print a greeting".into(),
457 readonly: true,
458 params: HashMap::from([
459 (
460 "name".into(),
461 ShellParamDef {
462 param_type: "string".into(),
463 description: "Name to greet".into(),
464 optional: false,
465 },
466 ),
467 (
468 "greeting".into(),
469 ShellParamDef {
470 param_type: "string".into(),
471 description: "Greeting text".into(),
472 optional: true,
473 },
474 ),
475 ]),
476 exec: ShellExecDef {
477 command: "printf".into(),
478 args: vec!["%s %s".into(), "{greeting|hello}".into(), "{name}".into()],
479 timeout: 5,
480 truncate: "head".into(),
481 install_hint: None,
482 },
483 });
484
485 let temp_dir = tempfile::tempdir().unwrap();
486 let result = tool
487 .execute(
488 "call-1",
489 json!({ "name": "Asher" }),
490 test_ctx(temp_dir.path()),
491 )
492 .await
493 .unwrap();
494
495 assert!(!result.is_error);
496 let text = match &result.content[0] {
497 imp_llm::ContentBlock::Text { text } => text.clone(),
498 _ => panic!("expected text output"),
499 };
500 assert_eq!(text, "hello Asher");
501 assert_eq!(result.details["exit_code"], 0);
502 assert_eq!(result.details["timed_out"], false);
503 }
504
505 #[tokio::test]
506 async fn shell_tool_default_param_used_when_not_provided() {
507 let tool = ShellTool::new(ShellToolDef {
508 name: "echo_default".into(),
509 label: "Echo Default".into(),
510 description: "Echo with default".into(),
511 readonly: true,
512 params: HashMap::from([(
513 "msg".into(),
514 ShellParamDef {
515 param_type: "string".into(),
516 description: "Message".into(),
517 optional: true,
518 },
519 )]),
520 exec: ShellExecDef {
521 command: "echo".into(),
522 args: vec!["{msg|default_value}".into()],
523 timeout: 5,
524 truncate: "head".into(),
525 install_hint: None,
526 },
527 });
528
529 let temp_dir = tempfile::tempdir().unwrap();
530 let result = tool
531 .execute("call-3", json!({}), test_ctx(temp_dir.path()))
532 .await
533 .unwrap();
534
535 assert!(!result.is_error);
536 let text = match &result.content[0] {
537 imp_llm::ContentBlock::Text { text } => text.clone(),
538 _ => panic!("expected text output"),
539 };
540 assert!(text.contains("default_value"));
541 }
542
543 #[test]
544 fn shell_tool_required_param_missing_errors() {
545 let defs = HashMap::from([(
546 "name".into(),
547 ShellParamDef {
548 param_type: "string".into(),
549 description: "Name".into(),
550 optional: false,
551 },
552 )]);
553 let provided = serde_json::Map::new();
554 let result = validate_required_params(&defs, &provided);
555 assert!(result.is_err());
556 let err_msg = result.unwrap_err().to_string();
557 assert!(err_msg.contains("name"));
558 }
559
560 #[tokio::test]
561 async fn shell_tool_stderr_included_in_output() {
562 let tool = ShellTool::new(ShellToolDef {
563 name: "stderr_test".into(),
564 label: "Stderr Test".into(),
565 description: "Writes to stderr".into(),
566 readonly: true,
567 params: HashMap::new(),
568 exec: ShellExecDef {
569 command: "sh".into(),
570 args: vec!["-c".into(), "echo stdout_msg; echo stderr_msg >&2".into()],
571 timeout: 5,
572 truncate: "head".into(),
573 install_hint: None,
574 },
575 });
576
577 let temp_dir = tempfile::tempdir().unwrap();
578 let result = tool
579 .execute("call-4", json!({}), test_ctx(temp_dir.path()))
580 .await
581 .unwrap();
582
583 assert!(!result.is_error);
584 let text = match &result.content[0] {
585 imp_llm::ContentBlock::Text { text } => text.clone(),
586 _ => panic!("expected text output"),
587 };
588 assert!(text.contains("stdout_msg"));
589 assert!(text.contains("stderr_msg"));
590 }
591
592 #[tokio::test]
593 async fn shell_tool_timeout() {
594 let tool = ShellTool::new(ShellToolDef {
595 name: "slow".into(),
596 label: "Slow".into(),
597 description: "Times out".into(),
598 readonly: true,
599 params: HashMap::new(),
600 exec: ShellExecDef {
601 command: "sleep".into(),
602 args: vec!["60".into()],
603 timeout: 1,
604 truncate: "head".into(),
605 install_hint: None,
606 },
607 });
608
609 let temp_dir = tempfile::tempdir().unwrap();
610 let result = tool
611 .execute("call-5", json!({}), test_ctx(temp_dir.path()))
612 .await
613 .unwrap();
614
615 assert!(result.is_error);
616 assert_eq!(result.details["timed_out"], true);
617 }
618
619 #[tokio::test]
620 async fn shell_tool_reports_missing_commands_with_install_hint() {
621 let tool = ShellTool::new(ShellToolDef {
622 name: "missing".into(),
623 label: "Missing".into(),
624 description: "Missing command".into(),
625 readonly: true,
626 params: HashMap::new(),
627 exec: ShellExecDef {
628 command: "definitely-not-a-real-command".into(),
629 args: Vec::new(),
630 timeout: 5,
631 truncate: "head".into(),
632 install_hint: Some("brew install definitely-not-a-real-command".into()),
633 },
634 });
635
636 let temp_dir = tempfile::tempdir().unwrap();
637 let result = tool
638 .execute("call-2", json!({}), test_ctx(temp_dir.path()))
639 .await
640 .unwrap();
641
642 assert!(result.is_error);
643 let text = match &result.content[0] {
644 imp_llm::ContentBlock::Text { text } => text.clone(),
645 _ => panic!("expected text output"),
646 };
647 assert!(text.contains("Command not found"));
648 assert!(text.contains("Install hint"));
649 }
650}