Skip to main content

defect_tools/
bash.rs

1//! Bash built-in tool: runs a non-interactive shell command, merges stdout/stderr,
2//! returns a single frame.
3
4use std::io;
5use std::path::{Path, PathBuf};
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9
10use agent_client_protocol_schema::{
11    Content, ContentBlock, TextContent, ToolCallContent, ToolCallLocation, ToolCallUpdateFields,
12    ToolKind,
13};
14use defect_agent::error::BoxError;
15use defect_agent::shell::{ShellBackend, ShellError, TerminalExitStatus, TerminalId};
16use defect_agent::tool::{
17    SafetyClass, Tool, ToolCallDescription, ToolContext, ToolError, ToolEvent, ToolSchema,
18    ToolStream,
19};
20use defect_config::BashToolConfig;
21use futures::future::BoxFuture;
22use futures::stream;
23use serde::{Deserialize, Serialize};
24use serde_json::json;
25
26const TITLE_TRUNC: usize = 80;
27
28/// Built-in bash tool. No internal state — a singleton `Arc::new(BashTool::new())`
29/// suffices.
30pub struct BashTool {
31    schema: ToolSchema,
32    default_timeout_ms: u64,
33    max_timeout_ms: u64,
34}
35
36impl BashTool {
37    pub fn new() -> Self {
38        Self::from_config(&BashToolConfig::default())
39    }
40
41    pub fn from_config(config: &BashToolConfig) -> Self {
42        let default_timeout_ms = config.default_timeout_ms.max(1);
43        let max_timeout_ms = config.max_timeout_ms.max(default_timeout_ms);
44        Self {
45            schema: ToolSchema {
46                name: "bash".to_string(),
47                description: format!(
48                    "Run a non-interactive shell command. \
49                     Captures stdout and stderr (merged); returns combined output and \
50                     exit code. Times out after `timeout_ms` (default {default_timeout_ms}; max {max_timeout_ms})."
51                ),
52                input_schema: json!({
53                    "type": "object",
54                    "properties": {
55                        "command": {
56                            "type": "string",
57                            "description": "The shell command to execute (passed to `sh -c` on unix, `cmd /C` on windows)."
58                        },
59                        "workdir": {
60                            "type": "string",
61                            "description": "Optional working directory. Must resolve inside the session cwd; relative paths resolve against the session cwd. Defaults to the session cwd."
62                        },
63                        "timeout_ms": {
64                            "type": "integer",
65                            "minimum": 1,
66                            "maximum": max_timeout_ms,
67                            "description": format!(
68                                "Per-call timeout in milliseconds. Defaults to {default_timeout_ms}."
69                            )
70                        }
71                    },
72                    "required": ["command"]
73                }),
74            },
75            default_timeout_ms,
76            max_timeout_ms,
77        }
78    }
79}
80
81impl Default for BashTool {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87#[derive(Debug, Deserialize)]
88struct BashArgs {
89    command: String,
90    #[serde(default)]
91    workdir: Option<String>,
92    #[serde(default)]
93    timeout_ms: Option<u64>,
94}
95
96#[derive(Debug, Serialize)]
97struct BashOutput {
98    /// `None` when the child process was killed by a signal or timed out; check `signal`
99    /// / `timed_out`.
100    exit_code: Option<i32>,
101    /// The signal name (e.g. `SIGKILL`) if the child process was terminated by a signal;
102    /// `None` otherwise.
103    #[serde(skip_serializing_if = "Option::is_none")]
104    signal: Option<String>,
105    timed_out: bool,
106    /// Bytes dropped due to the 1 MiB cap (≥0).
107    truncated_bytes: u64,
108    /// Actual elapsed time in milliseconds. Not written when spawn fails.
109    duration_ms: u64,
110}
111
112impl Tool for BashTool {
113    fn schema(&self) -> &ToolSchema {
114        &self.schema
115    }
116
117    fn safety_hint(&self, _args: &serde_json::Value) -> SafetyClass {
118        // Always destructive — does not parse the command text.
119        SafetyClass::Destructive
120    }
121
122    fn describe<'a>(
123        &'a self,
124        args: &'a serde_json::Value,
125        _ctx: ToolContext<'a>,
126    ) -> BoxFuture<'a, ToolCallDescription> {
127        Box::pin(async move {
128            let command = args
129                .get("command")
130                .and_then(|v| v.as_str())
131                .unwrap_or("")
132                .to_string();
133            let workdir = args
134                .get("workdir")
135                .and_then(|v| v.as_str())
136                .map(|s| s.to_string());
137
138            let title = format!("$ {}", truncate_title(&command));
139            let mut fields = ToolCallUpdateFields::default();
140            fields.title = Some(title);
141            fields.kind = Some(ToolKind::Execute);
142            if let Some(dir) = workdir {
143                fields.locations = Some(vec![ToolCallLocation::new(PathBuf::from(dir))]);
144            }
145            ToolCallDescription { fields }
146        })
147    }
148
149    fn execute(&self, args: serde_json::Value, ctx: ToolContext<'_>) -> ToolStream {
150        let cwd = ctx.cwd.to_path_buf();
151        let cancel = ctx.cancel.clone();
152        let shell = ctx.shell.clone();
153        let default_timeout_ms = self.default_timeout_ms;
154        let max_timeout_ms = self.max_timeout_ms;
155        let fut = async move {
156            run_bash(args, cwd, cancel, shell, default_timeout_ms, max_timeout_ms).await
157        };
158        let s: Pin<Box<dyn futures::Stream<Item = ToolEvent> + Send>> = Box::pin(stream::once(fut));
159        s
160    }
161}
162
163/// A complete bash invocation: parse args, resolve workdir, go through [`ShellBackend`],
164/// assemble the final output. Returns a single [`ToolEvent`] — `Completed` or `Failed`.
165async fn run_bash(
166    args: serde_json::Value,
167    session_cwd: PathBuf,
168    cancel: tokio_util::sync::CancellationToken,
169    shell: Arc<dyn ShellBackend>,
170    default_timeout_ms: u64,
171    max_timeout_ms: u64,
172) -> ToolEvent {
173    let parsed: BashArgs = match serde_json::from_value(args) {
174        Ok(v) => v,
175        Err(err) => return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(err))),
176    };
177
178    let timeout = parsed
179        .timeout_ms
180        .unwrap_or(default_timeout_ms)
181        .min(max_timeout_ms);
182    if timeout == 0 {
183        return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(io::Error::new(
184            io::ErrorKind::InvalidInput,
185            "timeout_ms must be > 0",
186        ))));
187    }
188
189    let workdir = match resolve_workdir(&session_cwd, parsed.workdir.as_deref()) {
190        Ok(p) => p,
191        Err(e) => return ToolEvent::Failed(e),
192    };
193
194    let started = std::time::Instant::now();
195
196    let terminal_id = match shell.create(parsed.command.clone(), workdir).await {
197        Ok(id) => id,
198        Err(err) => return ToolEvent::Failed(ToolError::Execution(BoxError::new(err))),
199    };
200
201    let result = run_command(shell.clone(), &terminal_id, &cancel, timeout, started).await;
202    // Release is idempotent at all exit points — the backend guarantees that releasing
203    // the same id multiple times does not error.
204    let _ = shell.release(&terminal_id).await;
205    result
206}
207
208async fn run_command(
209    shell: Arc<dyn ShellBackend>,
210    terminal_id: &TerminalId,
211    cancel: &tokio_util::sync::CancellationToken,
212    timeout: u64,
213    started: std::time::Instant,
214) -> ToolEvent {
215    let mut timed_out = false;
216    let mut canceled = false;
217
218    let timeout_at = tokio::time::sleep(Duration::from_millis(timeout));
219    tokio::pin!(timeout_at);
220
221    // wait_fut must survive the cancel branch. Once an ACP reverse request is sent, the
222    // response must be delivered to a live `oneshot::Receiver`; if we drop `wait_fut`,
223    // the server maps "no receiver" to an internal error and tears down the entire
224    // connection (see `router.respond_with_result(result)?` in
225    // `agent_client_protocol::jsonrpc::incoming_actor::dispatch_dispatch`).
226    //
227    // Solution: make `wait_fut` a `'static` self-owning future (the closure holds
228    // `Arc<shell>` and `id`). In the cancel branch, use [`tokio::spawn`] to detach it and
229    // continue draining the response; in the timeout branch, preserve the "kill then
230    // drain" semantics by awaiting the same future directly.
231    let mut wait_fut: Pin<
232        Box<dyn futures::Future<Output = Result<TerminalExitStatus, ShellError>> + Send>,
233    > = {
234        let shell = shell.clone();
235        let id = terminal_id.clone();
236        Box::pin(async move { shell.wait_for_exit(&id).await })
237    };
238
239    let exit_status = tokio::select! {
240        biased;
241
242        _ = cancel.cancelled() => {
243            canceled = true;
244            None
245        }
246
247        _ = &mut timeout_at => {
248            timed_out = true;
249            None
250        }
251
252        result = &mut wait_fut => {
253            match result {
254                Ok(status) => Some(status),
255                Err(err) => {
256                    return ToolEvent::Failed(ToolError::Execution(BoxError::new(err)));
257                }
258            }
259        }
260    };
261
262    if canceled {
263        // First send `kill` so the process finishes promptly; `wait_fut` cannot be
264        // dropped (the oneshot on the reverse-request path must have a receiver), so
265        // detach it to `await` elsewhere, keeping a live receiver in the runtime when the
266        // response arrives. For `LocalShellBackend` the future is an in-process
267        // notification, so detaching has no side effects.
268        let _ = shell.kill(terminal_id).await;
269        tokio::spawn(async move {
270            let _ = wait_fut.await;
271        });
272        return ToolEvent::Failed(ToolError::Canceled);
273    }
274
275    // Timeout path: kill first, then wait_for_exit + output to get the final output.
276    let exit_status = match exit_status {
277        Some(status) => Some(status),
278        None => {
279            let _ = shell.kill(terminal_id).await;
280            wait_fut.await.ok()
281        }
282    };
283
284    let output = match shell.output(terminal_id).await {
285        Ok(o) => o,
286        Err(err) => {
287            return ToolEvent::Failed(ToolError::Execution(BoxError::new(err)));
288        }
289    };
290
291    let duration_ms = started.elapsed().as_millis().min(u64::MAX as u128) as u64;
292
293    let (exit_code, signal_name) = match exit_status.as_ref() {
294        Some(s) => (s.exit_code, s.signal.clone()),
295        None => (None, None),
296    };
297
298    let mut text = output.text;
299    let truncated_bytes: u64 = if output.truncated { 1 } else { 0 };
300    if output.truncated {
301        if !text.is_empty() && !text.ends_with('\n') {
302            text.push('\n');
303        }
304        text.push_str("[output truncated]");
305    }
306    if timed_out {
307        if !text.is_empty() && !text.ends_with('\n') {
308            text.push('\n');
309        }
310        text.push_str(&format!("[timed out after {timeout}ms]"));
311    } else if let Some(sig) = signal_name.as_deref() {
312        if !text.is_empty() && !text.ends_with('\n') {
313            text.push('\n');
314        }
315        text.push_str(&format!("[killed by signal: {sig}]"));
316    } else if let Some(code) = exit_code
317        && code != 0
318    {
319        if !text.is_empty() && !text.ends_with('\n') {
320            text.push('\n');
321        }
322        text.push_str(&format!("[exit code: {code}]"));
323    }
324
325    let raw_output = serde_json::to_value(BashOutput {
326        exit_code,
327        signal: signal_name,
328        timed_out,
329        truncated_bytes,
330        duration_ms,
331    })
332    .unwrap_or(serde_json::Value::Null);
333
334    let mut fields = ToolCallUpdateFields::default();
335    fields.content = Some(vec![ToolCallContent::Content(Content::new(
336        ContentBlock::Text(TextContent::new(text)),
337    ))]);
338    fields.raw_output = Some(raw_output);
339    ToolEvent::Completed(fields)
340}
341
342/// Canonicalize the working directory and verify it is within the session cwd subtree.
343fn resolve_workdir(session_cwd: &Path, requested: Option<&str>) -> Result<PathBuf, ToolError> {
344    let target = match requested {
345        None => session_cwd.to_path_buf(),
346        Some(s) => {
347            let p = Path::new(s);
348            if p.is_absolute() {
349                p.to_path_buf()
350            } else {
351                session_cwd.join(p)
352            }
353        }
354    };
355
356    let canon_target =
357        std::fs::canonicalize(&target).map_err(|e| ToolError::InvalidArgs(BoxError::new(e)))?;
358    let canon_cwd =
359        std::fs::canonicalize(session_cwd).unwrap_or_else(|_| session_cwd.to_path_buf());
360
361    if !canon_target.starts_with(&canon_cwd) {
362        return Err(ToolError::InvalidArgs(BoxError::new(io::Error::new(
363            io::ErrorKind::PermissionDenied,
364            format!(
365                "workdir {} escapes session cwd {}",
366                canon_target.display(),
367                canon_cwd.display()
368            ),
369        ))));
370    }
371
372    Ok(canon_target)
373}
374
375fn truncate_title(s: &str) -> String {
376    if s.chars().count() <= TITLE_TRUNC {
377        return s.to_string();
378    }
379    let truncated: String = s.chars().take(TITLE_TRUNC).collect();
380    format!("{truncated}…")
381}
382
383#[cfg(test)]
384mod tests;