1use std::io;
5use std::path::{Path, PathBuf};
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9
10use agent_client_protocol_schema::{
11 Content, ContentBlock, TextContent, ToolCallContent, ToolCallLocation, ToolCallUpdateFields,
12 ToolKind,
13};
14use defect_agent::error::BoxError;
15use defect_agent::shell::{ShellBackend, ShellError, TerminalExitStatus, TerminalId};
16use defect_agent::tool::{
17 SafetyClass, Tool, ToolCallDescription, ToolContext, ToolError, ToolEvent, ToolSchema,
18 ToolStream,
19};
20use defect_config::BashToolConfig;
21use futures::future::BoxFuture;
22use futures::stream;
23use serde::{Deserialize, Serialize};
24use serde_json::json;
25
26const TITLE_TRUNC: usize = 80;
27
28pub struct BashTool {
31 schema: ToolSchema,
32 default_timeout_ms: u64,
33 max_timeout_ms: u64,
34}
35
36impl BashTool {
37 pub fn new() -> Self {
38 Self::from_config(&BashToolConfig::default())
39 }
40
41 pub fn from_config(config: &BashToolConfig) -> Self {
42 let default_timeout_ms = config.default_timeout_ms.max(1);
43 let max_timeout_ms = config.max_timeout_ms.max(default_timeout_ms);
44 Self {
45 schema: ToolSchema {
46 name: "bash".to_string(),
47 description: format!(
48 "Run a non-interactive shell command. \
49 Captures stdout and stderr (merged); returns combined output and \
50 exit code. Times out after `timeout_ms` (default {default_timeout_ms}; max {max_timeout_ms})."
51 ),
52 input_schema: json!({
53 "type": "object",
54 "properties": {
55 "command": {
56 "type": "string",
57 "description": "The shell command to execute (passed to `sh -c` on unix, `cmd /C` on windows)."
58 },
59 "workdir": {
60 "type": "string",
61 "description": "Optional working directory. Must resolve inside the session cwd; relative paths resolve against the session cwd. Defaults to the session cwd."
62 },
63 "timeout_ms": {
64 "type": "integer",
65 "minimum": 1,
66 "maximum": max_timeout_ms,
67 "description": format!(
68 "Per-call timeout in milliseconds. Defaults to {default_timeout_ms}."
69 )
70 }
71 },
72 "required": ["command"]
73 }),
74 },
75 default_timeout_ms,
76 max_timeout_ms,
77 }
78 }
79}
80
81impl Default for BashTool {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87#[derive(Debug, Deserialize)]
88struct BashArgs {
89 command: String,
90 #[serde(default)]
91 workdir: Option<String>,
92 #[serde(default)]
93 timeout_ms: Option<u64>,
94}
95
96#[derive(Debug, Serialize)]
97struct BashOutput {
98 exit_code: Option<i32>,
101 #[serde(skip_serializing_if = "Option::is_none")]
104 signal: Option<String>,
105 timed_out: bool,
106 truncated_bytes: u64,
108 duration_ms: u64,
110}
111
112impl Tool for BashTool {
113 fn schema(&self) -> &ToolSchema {
114 &self.schema
115 }
116
117 fn safety_hint(&self, _args: &serde_json::Value) -> SafetyClass {
118 SafetyClass::Destructive
120 }
121
122 fn describe<'a>(
123 &'a self,
124 args: &'a serde_json::Value,
125 _ctx: ToolContext<'a>,
126 ) -> BoxFuture<'a, ToolCallDescription> {
127 Box::pin(async move {
128 let command = args
129 .get("command")
130 .and_then(|v| v.as_str())
131 .unwrap_or("")
132 .to_string();
133 let workdir = args
134 .get("workdir")
135 .and_then(|v| v.as_str())
136 .map(|s| s.to_string());
137
138 let title = format!("$ {}", truncate_title(&command));
139 let mut fields = ToolCallUpdateFields::default();
140 fields.title = Some(title);
141 fields.kind = Some(ToolKind::Execute);
142 if let Some(dir) = workdir {
143 fields.locations = Some(vec![ToolCallLocation::new(PathBuf::from(dir))]);
144 }
145 ToolCallDescription { fields }
146 })
147 }
148
149 fn execute(&self, args: serde_json::Value, ctx: ToolContext<'_>) -> ToolStream {
150 let cwd = ctx.cwd.to_path_buf();
151 let cancel = ctx.cancel.clone();
152 let shell = ctx.shell.clone();
153 let default_timeout_ms = self.default_timeout_ms;
154 let max_timeout_ms = self.max_timeout_ms;
155 let fut = async move {
156 run_bash(args, cwd, cancel, shell, default_timeout_ms, max_timeout_ms).await
157 };
158 let s: Pin<Box<dyn futures::Stream<Item = ToolEvent> + Send>> = Box::pin(stream::once(fut));
159 s
160 }
161}
162
163async fn run_bash(
166 args: serde_json::Value,
167 session_cwd: PathBuf,
168 cancel: tokio_util::sync::CancellationToken,
169 shell: Arc<dyn ShellBackend>,
170 default_timeout_ms: u64,
171 max_timeout_ms: u64,
172) -> ToolEvent {
173 let parsed: BashArgs = match serde_json::from_value(args) {
174 Ok(v) => v,
175 Err(err) => return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(err))),
176 };
177
178 let timeout = parsed
179 .timeout_ms
180 .unwrap_or(default_timeout_ms)
181 .min(max_timeout_ms);
182 if timeout == 0 {
183 return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(io::Error::new(
184 io::ErrorKind::InvalidInput,
185 "timeout_ms must be > 0",
186 ))));
187 }
188
189 let workdir = match resolve_workdir(&session_cwd, parsed.workdir.as_deref()) {
190 Ok(p) => p,
191 Err(e) => return ToolEvent::Failed(e),
192 };
193
194 let started = std::time::Instant::now();
195
196 let terminal_id = match shell.create(parsed.command.clone(), workdir).await {
197 Ok(id) => id,
198 Err(err) => return ToolEvent::Failed(ToolError::Execution(BoxError::new(err))),
199 };
200
201 let result = run_command(shell.clone(), &terminal_id, &cancel, timeout, started).await;
202 let _ = shell.release(&terminal_id).await;
205 result
206}
207
208async fn run_command(
209 shell: Arc<dyn ShellBackend>,
210 terminal_id: &TerminalId,
211 cancel: &tokio_util::sync::CancellationToken,
212 timeout: u64,
213 started: std::time::Instant,
214) -> ToolEvent {
215 let mut timed_out = false;
216 let mut canceled = false;
217
218 let timeout_at = tokio::time::sleep(Duration::from_millis(timeout));
219 tokio::pin!(timeout_at);
220
221 let mut wait_fut: Pin<
232 Box<dyn futures::Future<Output = Result<TerminalExitStatus, ShellError>> + Send>,
233 > = {
234 let shell = shell.clone();
235 let id = terminal_id.clone();
236 Box::pin(async move { shell.wait_for_exit(&id).await })
237 };
238
239 let exit_status = tokio::select! {
240 biased;
241
242 _ = cancel.cancelled() => {
243 canceled = true;
244 None
245 }
246
247 _ = &mut timeout_at => {
248 timed_out = true;
249 None
250 }
251
252 result = &mut wait_fut => {
253 match result {
254 Ok(status) => Some(status),
255 Err(err) => {
256 return ToolEvent::Failed(ToolError::Execution(BoxError::new(err)));
257 }
258 }
259 }
260 };
261
262 if canceled {
263 let _ = shell.kill(terminal_id).await;
269 tokio::spawn(async move {
270 let _ = wait_fut.await;
271 });
272 return ToolEvent::Failed(ToolError::Canceled);
273 }
274
275 let exit_status = match exit_status {
277 Some(status) => Some(status),
278 None => {
279 let _ = shell.kill(terminal_id).await;
280 wait_fut.await.ok()
281 }
282 };
283
284 let output = match shell.output(terminal_id).await {
285 Ok(o) => o,
286 Err(err) => {
287 return ToolEvent::Failed(ToolError::Execution(BoxError::new(err)));
288 }
289 };
290
291 let duration_ms = started.elapsed().as_millis().min(u64::MAX as u128) as u64;
292
293 let (exit_code, signal_name) = match exit_status.as_ref() {
294 Some(s) => (s.exit_code, s.signal.clone()),
295 None => (None, None),
296 };
297
298 let mut text = output.text;
299 let truncated_bytes: u64 = if output.truncated { 1 } else { 0 };
300 if output.truncated {
301 if !text.is_empty() && !text.ends_with('\n') {
302 text.push('\n');
303 }
304 text.push_str("[output truncated]");
305 }
306 if timed_out {
307 if !text.is_empty() && !text.ends_with('\n') {
308 text.push('\n');
309 }
310 text.push_str(&format!("[timed out after {timeout}ms]"));
311 } else if let Some(sig) = signal_name.as_deref() {
312 if !text.is_empty() && !text.ends_with('\n') {
313 text.push('\n');
314 }
315 text.push_str(&format!("[killed by signal: {sig}]"));
316 } else if let Some(code) = exit_code
317 && code != 0
318 {
319 if !text.is_empty() && !text.ends_with('\n') {
320 text.push('\n');
321 }
322 text.push_str(&format!("[exit code: {code}]"));
323 }
324
325 let raw_output = serde_json::to_value(BashOutput {
326 exit_code,
327 signal: signal_name,
328 timed_out,
329 truncated_bytes,
330 duration_ms,
331 })
332 .unwrap_or(serde_json::Value::Null);
333
334 let mut fields = ToolCallUpdateFields::default();
335 fields.content = Some(vec![ToolCallContent::Content(Content::new(
336 ContentBlock::Text(TextContent::new(text)),
337 ))]);
338 fields.raw_output = Some(raw_output);
339 ToolEvent::Completed(fields)
340}
341
342fn resolve_workdir(session_cwd: &Path, requested: Option<&str>) -> Result<PathBuf, ToolError> {
344 let target = match requested {
345 None => session_cwd.to_path_buf(),
346 Some(s) => {
347 let p = Path::new(s);
348 if p.is_absolute() {
349 p.to_path_buf()
350 } else {
351 session_cwd.join(p)
352 }
353 }
354 };
355
356 let canon_target =
357 std::fs::canonicalize(&target).map_err(|e| ToolError::InvalidArgs(BoxError::new(e)))?;
358 let canon_cwd =
359 std::fs::canonicalize(session_cwd).unwrap_or_else(|_| session_cwd.to_path_buf());
360
361 if !canon_target.starts_with(&canon_cwd) {
362 return Err(ToolError::InvalidArgs(BoxError::new(io::Error::new(
363 io::ErrorKind::PermissionDenied,
364 format!(
365 "workdir {} escapes session cwd {}",
366 canon_target.display(),
367 canon_cwd.display()
368 ),
369 ))));
370 }
371
372 Ok(canon_target)
373}
374
375fn truncate_title(s: &str) -> String {
376 if s.chars().count() <= TITLE_TRUNC {
377 return s.to_string();
378 }
379 let truncated: String = s.chars().take(TITLE_TRUNC).collect();
380 format!("{truncated}…")
381}
382
383#[cfg(test)]
384mod tests;