1use std::io;
5use std::path::{Path, PathBuf};
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9
10use agent_client_protocol_schema::{
11 Content, ContentBlock, TextContent, ToolCallContent, ToolCallLocation, ToolCallUpdateFields,
12 ToolKind,
13};
14use defect_agent::error::BoxError;
15use defect_agent::shell::{ShellBackend, ShellError, TerminalExitStatus, TerminalId};
16use defect_agent::tool::{
17 SafetyClass, Tool, ToolCallDescription, ToolContext, ToolError, ToolEvent, ToolSchema,
18 ToolStream,
19};
20use defect_config::BashToolConfig;
21use futures::future::BoxFuture;
22use futures::stream;
23use serde::{Deserialize, Serialize};
24use serde_json::json;
25
26const DEFAULT_TIMEOUT_MS: u64 = 30_000;
27const MAX_TIMEOUT_MS: u64 = 600_000;
28const TITLE_TRUNC: usize = 80;
29
30pub struct BashTool {
33 schema: ToolSchema,
34 default_timeout_ms: u64,
35 max_timeout_ms: u64,
36}
37
38impl BashTool {
39 pub fn new() -> Self {
40 Self::from_config(&BashToolConfig {
41 default_timeout_ms: DEFAULT_TIMEOUT_MS,
42 max_timeout_ms: MAX_TIMEOUT_MS,
43 })
44 }
45
46 pub fn from_config(config: &BashToolConfig) -> Self {
47 let default_timeout_ms = config.default_timeout_ms.max(1);
48 let max_timeout_ms = config.max_timeout_ms.max(default_timeout_ms);
49 Self {
50 schema: ToolSchema {
51 name: "bash".to_string(),
52 description: format!(
53 "Run a non-interactive shell command. \
54 Captures stdout and stderr (merged); returns combined output and \
55 exit code. Times out after `timeout_ms` (default {default_timeout_ms}; max {max_timeout_ms})."
56 ),
57 input_schema: json!({
58 "type": "object",
59 "properties": {
60 "command": {
61 "type": "string",
62 "description": "The shell command to execute (passed to `sh -c` on unix, `cmd /C` on windows)."
63 },
64 "workdir": {
65 "type": "string",
66 "description": "Optional working directory. Must resolve inside the session cwd; relative paths resolve against the session cwd. Defaults to the session cwd."
67 },
68 "timeout_ms": {
69 "type": "integer",
70 "minimum": 1,
71 "maximum": max_timeout_ms,
72 "description": format!(
73 "Per-call timeout in milliseconds. Defaults to {default_timeout_ms}."
74 )
75 }
76 },
77 "required": ["command"]
78 }),
79 },
80 default_timeout_ms,
81 max_timeout_ms,
82 }
83 }
84}
85
86impl Default for BashTool {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92#[derive(Debug, Deserialize)]
93struct BashArgs {
94 command: String,
95 #[serde(default)]
96 workdir: Option<String>,
97 #[serde(default)]
98 timeout_ms: Option<u64>,
99}
100
101#[derive(Debug, Serialize)]
102struct BashOutput {
103 exit_code: Option<i32>,
106 #[serde(skip_serializing_if = "Option::is_none")]
109 signal: Option<String>,
110 timed_out: bool,
111 truncated_bytes: u64,
113 duration_ms: u64,
115}
116
117impl Tool for BashTool {
118 fn schema(&self) -> &ToolSchema {
119 &self.schema
120 }
121
122 fn safety_hint(&self, _args: &serde_json::Value) -> SafetyClass {
123 SafetyClass::Destructive
125 }
126
127 fn describe<'a>(
128 &'a self,
129 args: &'a serde_json::Value,
130 _ctx: ToolContext<'a>,
131 ) -> BoxFuture<'a, ToolCallDescription> {
132 Box::pin(async move {
133 let command = args
134 .get("command")
135 .and_then(|v| v.as_str())
136 .unwrap_or("")
137 .to_string();
138 let workdir = args
139 .get("workdir")
140 .and_then(|v| v.as_str())
141 .map(|s| s.to_string());
142
143 let title = format!("$ {}", truncate_title(&command));
144 let mut fields = ToolCallUpdateFields::default();
145 fields.title = Some(title);
146 fields.kind = Some(ToolKind::Execute);
147 if let Some(dir) = workdir {
148 fields.locations = Some(vec![ToolCallLocation::new(PathBuf::from(dir))]);
149 }
150 ToolCallDescription { fields }
151 })
152 }
153
154 fn execute(&self, args: serde_json::Value, ctx: ToolContext<'_>) -> ToolStream {
155 let cwd = ctx.cwd.to_path_buf();
156 let cancel = ctx.cancel.clone();
157 let shell = ctx.shell.clone();
158 let default_timeout_ms = self.default_timeout_ms;
159 let max_timeout_ms = self.max_timeout_ms;
160 let fut = async move {
161 run_bash(args, cwd, cancel, shell, default_timeout_ms, max_timeout_ms).await
162 };
163 let s: Pin<Box<dyn futures::Stream<Item = ToolEvent> + Send>> = Box::pin(stream::once(fut));
164 s
165 }
166}
167
168async fn run_bash(
171 args: serde_json::Value,
172 session_cwd: PathBuf,
173 cancel: tokio_util::sync::CancellationToken,
174 shell: Arc<dyn ShellBackend>,
175 default_timeout_ms: u64,
176 max_timeout_ms: u64,
177) -> ToolEvent {
178 let parsed: BashArgs = match serde_json::from_value(args) {
179 Ok(v) => v,
180 Err(err) => return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(err))),
181 };
182
183 let timeout = parsed
184 .timeout_ms
185 .unwrap_or(default_timeout_ms)
186 .min(max_timeout_ms);
187 if timeout == 0 {
188 return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(io::Error::new(
189 io::ErrorKind::InvalidInput,
190 "timeout_ms must be > 0",
191 ))));
192 }
193
194 let workdir = match resolve_workdir(&session_cwd, parsed.workdir.as_deref()) {
195 Ok(p) => p,
196 Err(e) => return ToolEvent::Failed(e),
197 };
198
199 let started = std::time::Instant::now();
200
201 let terminal_id = match shell.create(parsed.command.clone(), workdir).await {
202 Ok(id) => id,
203 Err(err) => return ToolEvent::Failed(ToolError::Execution(BoxError::new(err))),
204 };
205
206 let result = run_command(shell.clone(), &terminal_id, &cancel, timeout, started).await;
207 let _ = shell.release(&terminal_id).await;
210 result
211}
212
213async fn run_command(
214 shell: Arc<dyn ShellBackend>,
215 terminal_id: &TerminalId,
216 cancel: &tokio_util::sync::CancellationToken,
217 timeout: u64,
218 started: std::time::Instant,
219) -> ToolEvent {
220 let mut timed_out = false;
221 let mut canceled = false;
222
223 let timeout_at = tokio::time::sleep(Duration::from_millis(timeout));
224 tokio::pin!(timeout_at);
225
226 let mut wait_fut: Pin<
237 Box<dyn futures::Future<Output = Result<TerminalExitStatus, ShellError>> + Send>,
238 > = {
239 let shell = shell.clone();
240 let id = terminal_id.clone();
241 Box::pin(async move { shell.wait_for_exit(&id).await })
242 };
243
244 let exit_status = tokio::select! {
245 biased;
246
247 _ = cancel.cancelled() => {
248 canceled = true;
249 None
250 }
251
252 _ = &mut timeout_at => {
253 timed_out = true;
254 None
255 }
256
257 result = &mut wait_fut => {
258 match result {
259 Ok(status) => Some(status),
260 Err(err) => {
261 return ToolEvent::Failed(ToolError::Execution(BoxError::new(err)));
262 }
263 }
264 }
265 };
266
267 if canceled {
268 let _ = shell.kill(terminal_id).await;
274 tokio::spawn(async move {
275 let _ = wait_fut.await;
276 });
277 return ToolEvent::Failed(ToolError::Canceled);
278 }
279
280 let exit_status = match exit_status {
282 Some(status) => Some(status),
283 None => {
284 let _ = shell.kill(terminal_id).await;
285 wait_fut.await.ok()
286 }
287 };
288
289 let output = match shell.output(terminal_id).await {
290 Ok(o) => o,
291 Err(err) => {
292 return ToolEvent::Failed(ToolError::Execution(BoxError::new(err)));
293 }
294 };
295
296 let duration_ms = started.elapsed().as_millis().min(u64::MAX as u128) as u64;
297
298 let (exit_code, signal_name) = match exit_status.as_ref() {
299 Some(s) => (s.exit_code, s.signal.clone()),
300 None => (None, None),
301 };
302
303 let mut text = output.text;
304 let truncated_bytes: u64 = if output.truncated { 1 } else { 0 };
305 if output.truncated {
306 if !text.is_empty() && !text.ends_with('\n') {
307 text.push('\n');
308 }
309 text.push_str("[output truncated]");
310 }
311 if timed_out {
312 if !text.is_empty() && !text.ends_with('\n') {
313 text.push('\n');
314 }
315 text.push_str(&format!("[timed out after {timeout}ms]"));
316 } else if let Some(sig) = signal_name.as_deref() {
317 if !text.is_empty() && !text.ends_with('\n') {
318 text.push('\n');
319 }
320 text.push_str(&format!("[killed by signal: {sig}]"));
321 } else if let Some(code) = exit_code
322 && code != 0
323 {
324 if !text.is_empty() && !text.ends_with('\n') {
325 text.push('\n');
326 }
327 text.push_str(&format!("[exit code: {code}]"));
328 }
329
330 let raw_output = serde_json::to_value(BashOutput {
331 exit_code,
332 signal: signal_name,
333 timed_out,
334 truncated_bytes,
335 duration_ms,
336 })
337 .unwrap_or(serde_json::Value::Null);
338
339 let mut fields = ToolCallUpdateFields::default();
340 fields.content = Some(vec![ToolCallContent::Content(Content::new(
341 ContentBlock::Text(TextContent::new(text)),
342 ))]);
343 fields.raw_output = Some(raw_output);
344 ToolEvent::Completed(fields)
345}
346
347fn resolve_workdir(session_cwd: &Path, requested: Option<&str>) -> Result<PathBuf, ToolError> {
349 let target = match requested {
350 None => session_cwd.to_path_buf(),
351 Some(s) => {
352 let p = Path::new(s);
353 if p.is_absolute() {
354 p.to_path_buf()
355 } else {
356 session_cwd.join(p)
357 }
358 }
359 };
360
361 let canon_target =
362 std::fs::canonicalize(&target).map_err(|e| ToolError::InvalidArgs(BoxError::new(e)))?;
363 let canon_cwd =
364 std::fs::canonicalize(session_cwd).unwrap_or_else(|_| session_cwd.to_path_buf());
365
366 if !canon_target.starts_with(&canon_cwd) {
367 return Err(ToolError::InvalidArgs(BoxError::new(io::Error::new(
368 io::ErrorKind::PermissionDenied,
369 format!(
370 "workdir {} escapes session cwd {}",
371 canon_target.display(),
372 canon_cwd.display()
373 ),
374 ))));
375 }
376
377 Ok(canon_target)
378}
379
380fn truncate_title(s: &str) -> String {
381 if s.chars().count() <= TITLE_TRUNC {
382 return s.to_string();
383 }
384 let truncated: String = s.chars().take(TITLE_TRUNC).collect();
385 format!("{truncated}…")
386}
387
388#[cfg(test)]
389mod tests;