Skip to main content

defect_tools/
bash.rs

1//! Bash built-in tool: runs a non-interactive shell command, merges stdout/stderr,
2//! returns a single frame.
3
4use std::io;
5use std::path::{Path, PathBuf};
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9
10use agent_client_protocol_schema::{
11    Content, ContentBlock, TextContent, ToolCallContent, ToolCallLocation, ToolCallUpdateFields,
12    ToolKind,
13};
14use defect_agent::error::BoxError;
15use defect_agent::shell::{ShellBackend, ShellError, TerminalExitStatus, TerminalId};
16use defect_agent::tool::{
17    SafetyClass, Tool, ToolCallDescription, ToolContext, ToolError, ToolEvent, ToolSchema,
18    ToolStream,
19};
20use defect_config::BashToolConfig;
21use futures::future::BoxFuture;
22use futures::stream;
23use serde::{Deserialize, Serialize};
24use serde_json::json;
25
26const DEFAULT_TIMEOUT_MS: u64 = 30_000;
27const MAX_TIMEOUT_MS: u64 = 600_000;
28const TITLE_TRUNC: usize = 80;
29
30/// Built-in bash tool. No internal state — a singleton `Arc::new(BashTool::new())`
31/// suffices.
32pub struct BashTool {
33    schema: ToolSchema,
34    default_timeout_ms: u64,
35    max_timeout_ms: u64,
36}
37
38impl BashTool {
39    pub fn new() -> Self {
40        Self::from_config(&BashToolConfig {
41            default_timeout_ms: DEFAULT_TIMEOUT_MS,
42            max_timeout_ms: MAX_TIMEOUT_MS,
43        })
44    }
45
46    pub fn from_config(config: &BashToolConfig) -> Self {
47        let default_timeout_ms = config.default_timeout_ms.max(1);
48        let max_timeout_ms = config.max_timeout_ms.max(default_timeout_ms);
49        Self {
50            schema: ToolSchema {
51                name: "bash".to_string(),
52                description: format!(
53                    "Run a non-interactive shell command. \
54                     Captures stdout and stderr (merged); returns combined output and \
55                     exit code. Times out after `timeout_ms` (default {default_timeout_ms}; max {max_timeout_ms})."
56                ),
57                input_schema: json!({
58                    "type": "object",
59                    "properties": {
60                        "command": {
61                            "type": "string",
62                            "description": "The shell command to execute (passed to `sh -c` on unix, `cmd /C` on windows)."
63                        },
64                        "workdir": {
65                            "type": "string",
66                            "description": "Optional working directory. Must resolve inside the session cwd; relative paths resolve against the session cwd. Defaults to the session cwd."
67                        },
68                        "timeout_ms": {
69                            "type": "integer",
70                            "minimum": 1,
71                            "maximum": max_timeout_ms,
72                            "description": format!(
73                                "Per-call timeout in milliseconds. Defaults to {default_timeout_ms}."
74                            )
75                        }
76                    },
77                    "required": ["command"]
78                }),
79            },
80            default_timeout_ms,
81            max_timeout_ms,
82        }
83    }
84}
85
86impl Default for BashTool {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92#[derive(Debug, Deserialize)]
93struct BashArgs {
94    command: String,
95    #[serde(default)]
96    workdir: Option<String>,
97    #[serde(default)]
98    timeout_ms: Option<u64>,
99}
100
101#[derive(Debug, Serialize)]
102struct BashOutput {
103    /// `None` when the child process was killed by a signal or timed out; check `signal`
104    /// / `timed_out`.
105    exit_code: Option<i32>,
106    /// The signal name (e.g. `SIGKILL`) if the child process was terminated by a signal;
107    /// `None` otherwise.
108    #[serde(skip_serializing_if = "Option::is_none")]
109    signal: Option<String>,
110    timed_out: bool,
111    /// Bytes dropped due to the 1 MiB cap (≥0).
112    truncated_bytes: u64,
113    /// Actual elapsed time in milliseconds. Not written when spawn fails.
114    duration_ms: u64,
115}
116
117impl Tool for BashTool {
118    fn schema(&self) -> &ToolSchema {
119        &self.schema
120    }
121
122    fn safety_hint(&self, _args: &serde_json::Value) -> SafetyClass {
123        // Always destructive — does not parse the command text.
124        SafetyClass::Destructive
125    }
126
127    fn describe<'a>(
128        &'a self,
129        args: &'a serde_json::Value,
130        _ctx: ToolContext<'a>,
131    ) -> BoxFuture<'a, ToolCallDescription> {
132        Box::pin(async move {
133            let command = args
134                .get("command")
135                .and_then(|v| v.as_str())
136                .unwrap_or("")
137                .to_string();
138            let workdir = args
139                .get("workdir")
140                .and_then(|v| v.as_str())
141                .map(|s| s.to_string());
142
143            let title = format!("$ {}", truncate_title(&command));
144            let mut fields = ToolCallUpdateFields::default();
145            fields.title = Some(title);
146            fields.kind = Some(ToolKind::Execute);
147            if let Some(dir) = workdir {
148                fields.locations = Some(vec![ToolCallLocation::new(PathBuf::from(dir))]);
149            }
150            ToolCallDescription { fields }
151        })
152    }
153
154    fn execute(&self, args: serde_json::Value, ctx: ToolContext<'_>) -> ToolStream {
155        let cwd = ctx.cwd.to_path_buf();
156        let cancel = ctx.cancel.clone();
157        let shell = ctx.shell.clone();
158        let default_timeout_ms = self.default_timeout_ms;
159        let max_timeout_ms = self.max_timeout_ms;
160        let fut = async move {
161            run_bash(args, cwd, cancel, shell, default_timeout_ms, max_timeout_ms).await
162        };
163        let s: Pin<Box<dyn futures::Stream<Item = ToolEvent> + Send>> = Box::pin(stream::once(fut));
164        s
165    }
166}
167
168/// A complete bash invocation: parse args, resolve workdir, go through [`ShellBackend`],
169/// assemble the final output. Returns a single [`ToolEvent`] — `Completed` or `Failed`.
170async fn run_bash(
171    args: serde_json::Value,
172    session_cwd: PathBuf,
173    cancel: tokio_util::sync::CancellationToken,
174    shell: Arc<dyn ShellBackend>,
175    default_timeout_ms: u64,
176    max_timeout_ms: u64,
177) -> ToolEvent {
178    let parsed: BashArgs = match serde_json::from_value(args) {
179        Ok(v) => v,
180        Err(err) => return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(err))),
181    };
182
183    let timeout = parsed
184        .timeout_ms
185        .unwrap_or(default_timeout_ms)
186        .min(max_timeout_ms);
187    if timeout == 0 {
188        return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(io::Error::new(
189            io::ErrorKind::InvalidInput,
190            "timeout_ms must be > 0",
191        ))));
192    }
193
194    let workdir = match resolve_workdir(&session_cwd, parsed.workdir.as_deref()) {
195        Ok(p) => p,
196        Err(e) => return ToolEvent::Failed(e),
197    };
198
199    let started = std::time::Instant::now();
200
201    let terminal_id = match shell.create(parsed.command.clone(), workdir).await {
202        Ok(id) => id,
203        Err(err) => return ToolEvent::Failed(ToolError::Execution(BoxError::new(err))),
204    };
205
206    let result = run_command(shell.clone(), &terminal_id, &cancel, timeout, started).await;
207    // Release is idempotent at all exit points — the backend guarantees that releasing
208    // the same id multiple times does not error.
209    let _ = shell.release(&terminal_id).await;
210    result
211}
212
213async fn run_command(
214    shell: Arc<dyn ShellBackend>,
215    terminal_id: &TerminalId,
216    cancel: &tokio_util::sync::CancellationToken,
217    timeout: u64,
218    started: std::time::Instant,
219) -> ToolEvent {
220    let mut timed_out = false;
221    let mut canceled = false;
222
223    let timeout_at = tokio::time::sleep(Duration::from_millis(timeout));
224    tokio::pin!(timeout_at);
225
226    // wait_fut must survive the cancel branch. Once an ACP reverse request is sent, the
227    // response must be delivered to a live `oneshot::Receiver`; if we drop `wait_fut`,
228    // the server maps "no receiver" to an internal error and tears down the entire
229    // connection (see `router.respond_with_result(result)?` in
230    // `agent_client_protocol::jsonrpc::incoming_actor::dispatch_dispatch`).
231    //
232    // Solution: make `wait_fut` a `'static` self-owning future (the closure holds
233    // `Arc<shell>` and `id`). In the cancel branch, use [`tokio::spawn`] to detach it and
234    // continue draining the response; in the timeout branch, preserve the "kill then
235    // drain" semantics by awaiting the same future directly.
236    let mut wait_fut: Pin<
237        Box<dyn futures::Future<Output = Result<TerminalExitStatus, ShellError>> + Send>,
238    > = {
239        let shell = shell.clone();
240        let id = terminal_id.clone();
241        Box::pin(async move { shell.wait_for_exit(&id).await })
242    };
243
244    let exit_status = tokio::select! {
245        biased;
246
247        _ = cancel.cancelled() => {
248            canceled = true;
249            None
250        }
251
252        _ = &mut timeout_at => {
253            timed_out = true;
254            None
255        }
256
257        result = &mut wait_fut => {
258            match result {
259                Ok(status) => Some(status),
260                Err(err) => {
261                    return ToolEvent::Failed(ToolError::Execution(BoxError::new(err)));
262                }
263            }
264        }
265    };
266
267    if canceled {
268        // First send `kill` so the process finishes promptly; `wait_fut` cannot be
269        // dropped (the oneshot on the reverse-request path must have a receiver), so
270        // detach it to `await` elsewhere, keeping a live receiver in the runtime when the
271        // response arrives. For `LocalShellBackend` the future is an in-process
272        // notification, so detaching has no side effects.
273        let _ = shell.kill(terminal_id).await;
274        tokio::spawn(async move {
275            let _ = wait_fut.await;
276        });
277        return ToolEvent::Failed(ToolError::Canceled);
278    }
279
280    // Timeout path: kill first, then wait_for_exit + output to get the final output.
281    let exit_status = match exit_status {
282        Some(status) => Some(status),
283        None => {
284            let _ = shell.kill(terminal_id).await;
285            wait_fut.await.ok()
286        }
287    };
288
289    let output = match shell.output(terminal_id).await {
290        Ok(o) => o,
291        Err(err) => {
292            return ToolEvent::Failed(ToolError::Execution(BoxError::new(err)));
293        }
294    };
295
296    let duration_ms = started.elapsed().as_millis().min(u64::MAX as u128) as u64;
297
298    let (exit_code, signal_name) = match exit_status.as_ref() {
299        Some(s) => (s.exit_code, s.signal.clone()),
300        None => (None, None),
301    };
302
303    let mut text = output.text;
304    let truncated_bytes: u64 = if output.truncated { 1 } else { 0 };
305    if output.truncated {
306        if !text.is_empty() && !text.ends_with('\n') {
307            text.push('\n');
308        }
309        text.push_str("[output truncated]");
310    }
311    if timed_out {
312        if !text.is_empty() && !text.ends_with('\n') {
313            text.push('\n');
314        }
315        text.push_str(&format!("[timed out after {timeout}ms]"));
316    } else if let Some(sig) = signal_name.as_deref() {
317        if !text.is_empty() && !text.ends_with('\n') {
318            text.push('\n');
319        }
320        text.push_str(&format!("[killed by signal: {sig}]"));
321    } else if let Some(code) = exit_code
322        && code != 0
323    {
324        if !text.is_empty() && !text.ends_with('\n') {
325            text.push('\n');
326        }
327        text.push_str(&format!("[exit code: {code}]"));
328    }
329
330    let raw_output = serde_json::to_value(BashOutput {
331        exit_code,
332        signal: signal_name,
333        timed_out,
334        truncated_bytes,
335        duration_ms,
336    })
337    .unwrap_or(serde_json::Value::Null);
338
339    let mut fields = ToolCallUpdateFields::default();
340    fields.content = Some(vec![ToolCallContent::Content(Content::new(
341        ContentBlock::Text(TextContent::new(text)),
342    ))]);
343    fields.raw_output = Some(raw_output);
344    ToolEvent::Completed(fields)
345}
346
347/// Canonicalize the working directory and verify it is within the session cwd subtree.
348fn resolve_workdir(session_cwd: &Path, requested: Option<&str>) -> Result<PathBuf, ToolError> {
349    let target = match requested {
350        None => session_cwd.to_path_buf(),
351        Some(s) => {
352            let p = Path::new(s);
353            if p.is_absolute() {
354                p.to_path_buf()
355            } else {
356                session_cwd.join(p)
357            }
358        }
359    };
360
361    let canon_target =
362        std::fs::canonicalize(&target).map_err(|e| ToolError::InvalidArgs(BoxError::new(e)))?;
363    let canon_cwd =
364        std::fs::canonicalize(session_cwd).unwrap_or_else(|_| session_cwd.to_path_buf());
365
366    if !canon_target.starts_with(&canon_cwd) {
367        return Err(ToolError::InvalidArgs(BoxError::new(io::Error::new(
368            io::ErrorKind::PermissionDenied,
369            format!(
370                "workdir {} escapes session cwd {}",
371                canon_target.display(),
372                canon_cwd.display()
373            ),
374        ))));
375    }
376
377    Ok(canon_target)
378}
379
380fn truncate_title(s: &str) -> String {
381    if s.chars().count() <= TITLE_TRUNC {
382        return s.to_string();
383    }
384    let truncated: String = s.chars().take(TITLE_TRUNC).collect();
385    format!("{truncated}…")
386}
387
388#[cfg(test)]
389mod tests;