1use std::collections::HashMap;
2use std::path::Path;
3use std::process::Stdio;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use serde_json::{json, Map, Value};
8use tokio::io::AsyncReadExt;
9use tokio::process::Command;
10
11use crate::error::{Error, Result};
12use crate::tools::{truncate_head, truncate_tail, Tool, ToolContext, ToolOutput, ToolRegistry};
13
14const MAX_OUTPUT_LINES: usize = 2000;
15const MAX_OUTPUT_BYTES: usize = 50 * 1024;
16
17#[derive(Debug, Clone, serde::Deserialize)]
19pub struct ShellToolDef {
20 pub name: String,
21 pub label: String,
22 pub description: String,
23 #[serde(default)]
24 pub readonly: bool,
25 #[serde(default)]
26 pub params: std::collections::HashMap<String, ShellParamDef>,
27 pub exec: ShellExecDef,
28}
29
30#[derive(Debug, Clone, serde::Deserialize)]
31pub struct ShellParamDef {
32 #[serde(rename = "type")]
33 pub param_type: String,
34 pub description: String,
35 #[serde(default)]
36 pub optional: bool,
37}
38
39#[derive(Debug, Clone, serde::Deserialize)]
40pub struct ShellExecDef {
41 pub command: String,
42 #[serde(default)]
43 pub args: Vec<String>,
44 #[serde(default = "default_timeout")]
45 pub timeout: u32,
46 #[serde(default = "default_truncate")]
47 pub truncate: String,
48 pub install_hint: Option<String>,
49}
50
51#[derive(Debug, Clone)]
52pub struct ShellTool {
53 def: ShellToolDef,
54}
55
56impl ShellTool {
57 fn new(def: ShellToolDef) -> Self {
58 Self { def }
59 }
60}
61
62fn default_timeout() -> u32 {
63 30
64}
65fn default_truncate() -> String {
66 "head".into()
67}
68
69#[async_trait]
70impl Tool for ShellTool {
71 fn name(&self) -> &str {
72 &self.def.name
73 }
74
75 fn label(&self) -> &str {
76 &self.def.label
77 }
78
79 fn description(&self) -> &str {
80 &self.def.description
81 }
82
83 fn parameters(&self) -> Value {
84 let mut properties = Map::new();
85 let mut required = Vec::new();
86
87 let mut param_names: Vec<_> = self.def.params.keys().cloned().collect();
88 param_names.sort();
89
90 for name in param_names {
91 if let Some(def) = self.def.params.get(&name) {
92 properties.insert(
93 name.clone(),
94 json!({
95 "type": def.param_type,
96 "description": def.description,
97 }),
98 );
99
100 if !def.optional {
101 required.push(Value::String(name));
102 }
103 }
104 }
105
106 json!({
107 "type": "object",
108 "properties": Value::Object(properties),
109 "required": Value::Array(required),
110 })
111 }
112
113 fn is_readonly(&self) -> bool {
114 self.def.readonly
115 }
116
117 async fn execute(&self, _call_id: &str, params: Value, ctx: ToolContext) -> Result<ToolOutput> {
118 if ctx.is_cancelled() {
119 return Ok(ToolOutput::error("Tool execution cancelled."));
120 }
121
122 let provided = params.as_object().cloned().unwrap_or_default();
123 validate_required_params(&self.def.params, &provided)?;
124
125 let mut args = Vec::with_capacity(self.def.exec.args.len());
126 for arg in &self.def.exec.args {
127 args.push(interpolate_arg(arg, &self.def.params, &provided)?);
128 }
129
130 let mut command = Command::new(&self.def.exec.command);
131 command
132 .args(&args)
133 .current_dir(&ctx.cwd)
134 .stdin(Stdio::null())
137 .stdout(Stdio::piped())
138 .stderr(Stdio::piped());
139
140 let mut child = match command.spawn() {
141 Ok(child) => child,
142 Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
143 let mut message = format!(
144 "Command not found for shell tool '{}': {}",
145 self.def.name, self.def.exec.command
146 );
147 if let Some(hint) = &self.def.exec.install_hint {
148 message.push_str(&format!("\nInstall hint: {hint}"));
149 }
150 return Ok(ToolOutput::error(message));
151 }
152 Err(err) => {
153 return Err(Error::Tool(format!(
154 "failed to spawn shell tool '{}': {err}",
155 self.def.name
156 )));
157 }
158 };
159
160 let stdout = child
161 .stdout
162 .take()
163 .ok_or_else(|| Error::Tool("failed to capture stdout".into()))?;
164 let stderr = child
165 .stderr
166 .take()
167 .ok_or_else(|| Error::Tool("failed to capture stderr".into()))?;
168
169 let stdout_task = tokio::spawn(async move {
170 let mut reader = tokio::io::BufReader::new(stdout);
171 let mut buffer = Vec::new();
172 reader.read_to_end(&mut buffer).await.map(|_| buffer)
173 });
174 let stderr_task = tokio::spawn(async move {
175 let mut reader = tokio::io::BufReader::new(stderr);
176 let mut buffer = Vec::new();
177 reader.read_to_end(&mut buffer).await.map(|_| buffer)
178 });
179
180 let timeout = std::time::Duration::from_secs(self.def.exec.timeout as u64);
181 let (status, timed_out) = tokio::select! {
182 status = child.wait() => (status?, false),
183 _ = tokio::time::sleep(timeout) => {
184 let _ = child.kill().await;
185 let status = child.wait().await?;
186 (status, true)
187 }
188 };
189
190 let stdout_bytes = stdout_task
191 .await
192 .map_err(|err| Error::Tool(format!("stdout reader task failed: {err}")))??;
193 let stderr_bytes = stderr_task
194 .await
195 .map_err(|err| Error::Tool(format!("stderr reader task failed: {err}")))??;
196
197 let mut combined_output = String::new();
198 let stdout_text = String::from_utf8_lossy(&stdout_bytes);
199 let stderr_text = String::from_utf8_lossy(&stderr_bytes);
200
201 if !stdout_text.is_empty() {
202 combined_output.push_str(&stdout_text);
203 }
204 if !stderr_text.is_empty() {
205 if !combined_output.is_empty() && !combined_output.ends_with('\n') {
206 combined_output.push('\n');
207 }
208 combined_output.push_str(&stderr_text);
209 }
210
211 let truncation = match self.def.exec.truncate.as_str() {
212 "tail" => truncate_tail(&combined_output, MAX_OUTPUT_LINES, MAX_OUTPUT_BYTES),
213 _ => truncate_head(&combined_output, MAX_OUTPUT_LINES, MAX_OUTPUT_BYTES),
214 };
215
216 let mut result_text = truncation.content;
217 if truncation.truncated {
218 let note = format!(
219 "\n[Output truncated: showing {} of {} lines{}]",
220 truncation.output_lines,
221 truncation.total_lines,
222 truncation
223 .temp_file
224 .as_ref()
225 .map(|path| format!(". Full output saved to {}", path.display()))
226 .unwrap_or_default()
227 );
228 result_text.push_str(¬e);
229 }
230 if timed_out {
231 result_text.push_str(&format!(
232 "\n[Command timed out after {}s]",
233 self.def.exec.timeout
234 ));
235 }
236
237 Ok(ToolOutput {
238 content: vec![imp_llm::ContentBlock::Text { text: result_text }],
239 details: json!({
240 "exit_code": status.code().unwrap_or(-1),
241 "timed_out": timed_out,
242 "truncated": truncation.truncated,
243 }),
244 is_error: timed_out || !status.success(),
245 })
246 }
247}
248
249fn validate_required_params(
250 defs: &HashMap<String, ShellParamDef>,
251 provided: &Map<String, Value>,
252) -> Result<()> {
253 let mut missing = Vec::new();
254
255 for (name, def) in defs {
256 if !def.optional && provided.get(name).is_none_or(Value::is_null) {
257 missing.push(name.clone());
258 }
259 }
260
261 missing.sort();
262
263 if missing.is_empty() {
264 Ok(())
265 } else {
266 Err(Error::Tool(format!(
267 "missing required parameter(s): {}",
268 missing.join(", ")
269 )))
270 }
271}
272
273fn interpolate_arg(
274 template: &str,
275 defs: &HashMap<String, ShellParamDef>,
276 provided: &Map<String, Value>,
277) -> Result<String> {
278 let mut result = String::new();
279 let mut remaining = template;
280
281 while let Some(start) = remaining.find('{') {
282 result.push_str(&remaining[..start]);
283
284 let after_start = &remaining[start + 1..];
285 let end = after_start.find('}').ok_or_else(|| {
286 Error::Tool(format!(
287 "unclosed placeholder in shell tool argument: {template}"
288 ))
289 })?;
290
291 let placeholder = &after_start[..end];
292 result.push_str(&resolve_placeholder(placeholder, defs, provided)?);
293 remaining = &after_start[end + 1..];
294 }
295
296 result.push_str(remaining);
297 Ok(result)
298}
299
300fn resolve_placeholder(
301 placeholder: &str,
302 defs: &HashMap<String, ShellParamDef>,
303 provided: &Map<String, Value>,
304) -> Result<String> {
305 let (name, default) = placeholder
306 .split_once('|')
307 .map_or((placeholder, None), |(name, default)| (name, Some(default)));
308
309 if name.is_empty() {
310 return Err(Error::Tool(
311 "empty placeholder in shell tool argument".into(),
312 ));
313 }
314
315 if let Some(value) = provided.get(name).filter(|value| !value.is_null()) {
316 return stringify_param_value(name, value);
317 }
318
319 if let Some(default) = default {
320 return Ok(default.to_string());
321 }
322
323 if defs.get(name).is_some_and(|def| def.optional) {
324 return Ok(String::new());
325 }
326
327 Err(Error::Tool(format!(
328 "missing required parameter for placeholder: {name}"
329 )))
330}
331
332fn stringify_param_value(name: &str, value: &Value) -> Result<String> {
333 match value {
334 Value::String(value) => Ok(value.clone()),
335 Value::Number(value) => Ok(value.to_string()),
336 Value::Bool(value) => Ok(value.to_string()),
337 Value::Null => Ok(String::new()),
338 _ => Err(Error::Tool(format!(
339 "parameter '{name}' must be a string, number, or boolean"
340 ))),
341 }
342}
343
344pub fn load_shell_tools(dir: &Path, registry: &mut ToolRegistry) -> Result<()> {
346 if !dir.exists() {
347 return Ok(());
348 }
349
350 for entry in walkdir::WalkDir::new(dir) {
351 let entry = entry.map_err(|err| {
352 Error::Tool(format!(
353 "failed to walk shell tool directory {}: {err}",
354 dir.display()
355 ))
356 })?;
357 if !entry.file_type().is_file() {
358 continue;
359 }
360 if entry.path().extension().and_then(|ext| ext.to_str()) != Some("toml") {
361 continue;
362 }
363
364 let content = std::fs::read_to_string(entry.path())?;
365 match toml::from_str::<ShellToolDef>(&content) {
366 Ok(def) => registry.register(Arc::new(ShellTool::new(def))),
367 Err(_err) => {
368 }
370 }
371 }
372
373 Ok(())
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use crate::ui::NullInterface;
380 use serde_json::json;
381 use std::sync::atomic::AtomicBool;
382 use std::sync::Arc;
383
384 fn test_ctx(dir: &Path) -> ToolContext {
385 let (tx, _rx) = tokio::sync::mpsc::channel(16);
386 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
387 ToolContext {
388 cwd: dir.to_path_buf(),
389 cancelled: Arc::new(AtomicBool::new(false)),
390 update_tx: tx,
391 command_tx: cmd_tx,
392 ui: Arc::new(NullInterface),
393 file_cache: Arc::new(crate::tools::FileCache::new()),
394 checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
395 file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
396 anchor_store: Arc::new(crate::tools::AnchorStore::new()),
397 lua_tool_loader: None,
398 mode: crate::config::AgentMode::Full,
399 read_max_lines: 500,
400 turn_mana_review: Arc::new(std::sync::Mutex::new(
401 crate::mana_review::TurnManaReviewAccumulator::default(),
402 )),
403 config: Arc::new(crate::config::Config::default()),
404 }
405 }
406
407 #[test]
408 fn load_shell_tools_registers_valid_defs_and_skips_invalid_ones() {
409 let temp_dir = tempfile::tempdir().unwrap();
410 let tools_dir = temp_dir.path().join("tools");
411 std::fs::create_dir_all(tools_dir.join("nested")).unwrap();
412
413 std::fs::write(
414 tools_dir.join("nested").join("greet.toml"),
415 r#"
416name = "greet"
417label = "Greet"
418description = "Print a greeting"
419readonly = true
420
421[params.name]
422type = "string"
423description = "Name to greet"
424
425[params.greeting]
426type = "string"
427description = "Greeting text"
428optional = true
429
430[exec]
431command = "printf"
432args = ["%s %s", "{greeting|hello}", "{name}"]
433timeout = 5
434truncate = "head"
435"#,
436 )
437 .unwrap();
438
439 std::fs::write(tools_dir.join("broken.toml"), "not = [valid").unwrap();
440
441 let mut registry = ToolRegistry::new();
442 load_shell_tools(&tools_dir, &mut registry).unwrap();
443
444 let tool = registry.get("greet").expect("tool should be registered");
445 assert_eq!(tool.name(), "greet");
446 assert!(registry.get("broken").is_none());
447 }
448
449 #[tokio::test]
450 async fn shell_tool_executes_with_param_interpolation() {
451 let tool = ShellTool::new(ShellToolDef {
452 name: "greet".into(),
453 label: "Greet".into(),
454 description: "Print a greeting".into(),
455 readonly: true,
456 params: HashMap::from([
457 (
458 "name".into(),
459 ShellParamDef {
460 param_type: "string".into(),
461 description: "Name to greet".into(),
462 optional: false,
463 },
464 ),
465 (
466 "greeting".into(),
467 ShellParamDef {
468 param_type: "string".into(),
469 description: "Greeting text".into(),
470 optional: true,
471 },
472 ),
473 ]),
474 exec: ShellExecDef {
475 command: "printf".into(),
476 args: vec!["%s %s".into(), "{greeting|hello}".into(), "{name}".into()],
477 timeout: 5,
478 truncate: "head".into(),
479 install_hint: None,
480 },
481 });
482
483 let temp_dir = tempfile::tempdir().unwrap();
484 let result = tool
485 .execute(
486 "call-1",
487 json!({ "name": "Asher" }),
488 test_ctx(temp_dir.path()),
489 )
490 .await
491 .unwrap();
492
493 assert!(!result.is_error);
494 let text = match &result.content[0] {
495 imp_llm::ContentBlock::Text { text } => text.clone(),
496 _ => panic!("expected text output"),
497 };
498 assert_eq!(text, "hello Asher");
499 assert_eq!(result.details["exit_code"], 0);
500 assert_eq!(result.details["timed_out"], false);
501 }
502
503 #[tokio::test]
504 async fn shell_tool_default_param_used_when_not_provided() {
505 let tool = ShellTool::new(ShellToolDef {
506 name: "echo_default".into(),
507 label: "Echo Default".into(),
508 description: "Echo with default".into(),
509 readonly: true,
510 params: HashMap::from([(
511 "msg".into(),
512 ShellParamDef {
513 param_type: "string".into(),
514 description: "Message".into(),
515 optional: true,
516 },
517 )]),
518 exec: ShellExecDef {
519 command: "echo".into(),
520 args: vec!["{msg|default_value}".into()],
521 timeout: 5,
522 truncate: "head".into(),
523 install_hint: None,
524 },
525 });
526
527 let temp_dir = tempfile::tempdir().unwrap();
528 let result = tool
529 .execute("call-3", json!({}), test_ctx(temp_dir.path()))
530 .await
531 .unwrap();
532
533 assert!(!result.is_error);
534 let text = match &result.content[0] {
535 imp_llm::ContentBlock::Text { text } => text.clone(),
536 _ => panic!("expected text output"),
537 };
538 assert!(text.contains("default_value"));
539 }
540
541 #[test]
542 fn shell_tool_required_param_missing_errors() {
543 let defs = HashMap::from([(
544 "name".into(),
545 ShellParamDef {
546 param_type: "string".into(),
547 description: "Name".into(),
548 optional: false,
549 },
550 )]);
551 let provided = serde_json::Map::new();
552 let result = validate_required_params(&defs, &provided);
553 assert!(result.is_err());
554 let err_msg = result.unwrap_err().to_string();
555 assert!(err_msg.contains("name"));
556 }
557
558 #[tokio::test]
559 async fn shell_tool_stderr_included_in_output() {
560 let tool = ShellTool::new(ShellToolDef {
561 name: "stderr_test".into(),
562 label: "Stderr Test".into(),
563 description: "Writes to stderr".into(),
564 readonly: true,
565 params: HashMap::new(),
566 exec: ShellExecDef {
567 command: "sh".into(),
568 args: vec!["-c".into(), "echo stdout_msg; echo stderr_msg >&2".into()],
569 timeout: 5,
570 truncate: "head".into(),
571 install_hint: None,
572 },
573 });
574
575 let temp_dir = tempfile::tempdir().unwrap();
576 let result = tool
577 .execute("call-4", json!({}), test_ctx(temp_dir.path()))
578 .await
579 .unwrap();
580
581 assert!(!result.is_error);
582 let text = match &result.content[0] {
583 imp_llm::ContentBlock::Text { text } => text.clone(),
584 _ => panic!("expected text output"),
585 };
586 assert!(text.contains("stdout_msg"));
587 assert!(text.contains("stderr_msg"));
588 }
589
590 #[tokio::test]
591 async fn shell_tool_timeout() {
592 let tool = ShellTool::new(ShellToolDef {
593 name: "slow".into(),
594 label: "Slow".into(),
595 description: "Times out".into(),
596 readonly: true,
597 params: HashMap::new(),
598 exec: ShellExecDef {
599 command: "sleep".into(),
600 args: vec!["60".into()],
601 timeout: 1,
602 truncate: "head".into(),
603 install_hint: None,
604 },
605 });
606
607 let temp_dir = tempfile::tempdir().unwrap();
608 let result = tool
609 .execute("call-5", json!({}), test_ctx(temp_dir.path()))
610 .await
611 .unwrap();
612
613 assert!(result.is_error);
614 assert_eq!(result.details["timed_out"], true);
615 }
616
617 #[tokio::test]
618 async fn shell_tool_reports_missing_commands_with_install_hint() {
619 let tool = ShellTool::new(ShellToolDef {
620 name: "missing".into(),
621 label: "Missing".into(),
622 description: "Missing command".into(),
623 readonly: true,
624 params: HashMap::new(),
625 exec: ShellExecDef {
626 command: "definitely-not-a-real-command".into(),
627 args: Vec::new(),
628 timeout: 5,
629 truncate: "head".into(),
630 install_hint: Some("brew install definitely-not-a-real-command".into()),
631 },
632 });
633
634 let temp_dir = tempfile::tempdir().unwrap();
635 let result = tool
636 .execute("call-2", json!({}), test_ctx(temp_dir.path()))
637 .await
638 .unwrap();
639
640 assert!(result.is_error);
641 let text = match &result.content[0] {
642 imp_llm::ContentBlock::Text { text } => text.clone(),
643 _ => panic!("expected text output"),
644 };
645 assert!(text.contains("Command not found"));
646 assert!(text.contains("Install hint"));
647 }
648}