Skip to main content

opi_coding_agent/tool/
bash.rs

1use std::future::Future;
2use std::path::PathBuf;
3use std::pin::Pin;
4use std::time::Duration;
5
6use opi_agent::tool::{ExecutionMode, Tool, ToolError, ToolResult};
7use opi_ai::message::{OutputContent, ToolDef};
8use schemars::JsonSchema;
9use serde::Deserialize;
10use tokio_util::sync::CancellationToken;
11
12#[derive(Debug, Deserialize, JsonSchema)]
13pub struct BashArgs {
14    /// Command to execute.
15    pub command: String,
16    /// Timeout in seconds (optional, defaults to 30).
17    pub timeout_secs: Option<u64>,
18}
19
20pub struct BashTool {
21    workspace_root: PathBuf,
22    schema: serde_json::Value,
23}
24
25impl BashTool {
26    pub fn new(workspace_root: PathBuf) -> Self {
27        let schema = schemars::schema_for!(BashArgs);
28        Self {
29            workspace_root,
30            schema: serde_json::to_value(&schema).unwrap_or_default(),
31        }
32    }
33}
34
35impl Tool for BashTool {
36    fn definition(&self) -> ToolDef {
37        ToolDef {
38            name: "bash".into(),
39            description: "Execute a shell command with timeout and streamed output.".into(),
40            input_schema: self.schema.clone(),
41        }
42    }
43
44    fn execute(
45        &self,
46        _call_id: &str,
47        arguments: serde_json::Value,
48        signal: CancellationToken,
49        _on_update: Option<opi_agent::tool::UpdateCallback>,
50    ) -> Pin<Box<dyn Future<Output = Result<ToolResult, ToolError>> + Send>> {
51        let args: BashArgs = match serde_json::from_value(arguments) {
52            Ok(a) => a,
53            Err(e) => {
54                return Box::pin(async move {
55                    Ok(ToolResult {
56                        content: vec![OutputContent::Text {
57                            text: format!("invalid arguments: {e}"),
58                        }],
59                        details: None,
60                        is_error: true,
61                        terminate: false,
62                    })
63                });
64            }
65        };
66        let timeout = Duration::from_secs(args.timeout_secs.unwrap_or(30));
67        let command = args.command;
68        let cwd = self.workspace_root.clone();
69        let workspace_root = self.workspace_root.clone();
70        Box::pin(async move {
71            let (program, args) = if cfg!(windows) {
72                ("cmd".to_string(), vec!["/C", &command])
73            } else {
74                ("sh".to_string(), vec!["-c", &command])
75            };
76
77            let mut cmd = tokio::process::Command::new(&program);
78            cmd.args(&args).current_dir(&cwd);
79            let child = cmd
80                .stdout(std::process::Stdio::piped())
81                .stderr(std::process::Stdio::piped())
82                .spawn();
83
84            let mut child = match child {
85                Ok(c) => c,
86                Err(e) => {
87                    return Ok(ToolResult {
88                        content: vec![OutputContent::Text {
89                            text: format!("failed to spawn command: {e}"),
90                        }],
91                        details: None,
92                        is_error: true,
93                        terminate: false,
94                    });
95                }
96            };
97
98            let timeout_future = tokio::time::sleep(timeout);
99            let cancel_future = signal.cancelled();
100
101            tokio::pin!(timeout_future);
102            tokio::pin!(cancel_future);
103
104            let result = tokio::select! {
105                status = child.wait() => {
106                    match status {
107                        Ok(s) => {
108                            let stdout = child.stdout.take();
109                            let stderr = child.stderr.take();
110                            let output_fut = async {
111                                let out = match stdout {
112                                    Some(mut s) => {
113                                        let mut buf = Vec::new();
114                                        use tokio::io::AsyncReadExt;
115                                        let _ = s.read_to_end(&mut buf).await;
116                                        String::from_utf8_lossy(&buf).into_owned()
117                                    }
118                                    None => String::new(),
119                                };
120                                let err = match stderr {
121                                    Some(mut s) => {
122                                        let mut buf = Vec::new();
123                                        use tokio::io::AsyncReadExt;
124                                        let _ = s.read_to_end(&mut buf).await;
125                                        String::from_utf8_lossy(&buf).into_owned()
126                                    }
127                                    None => String::new(),
128                                };
129                                (out, err)
130                            };
131                            let (stdout, stderr) = output_fut.await;
132                            Ok((stdout, stderr, s.code()))
133                        }
134                        Err(e) => Err(format!("failed to wait for process: {e}")),
135                    }
136                }
137                _ = &mut timeout_future => {
138                    let _ = child.kill().await;
139                    Err("command timed out".into())
140                }
141                _ = &mut cancel_future => {
142                    let _ = child.kill().await;
143                    Err("command cancelled".into())
144                }
145            };
146
147            match result {
148                Ok((stdout, stderr, exit_code)) => {
149                    let mut output = stdout;
150                    if !stderr.is_empty() {
151                        if !output.is_empty() {
152                            output.push('\n');
153                        }
154                        output.push_str(&stderr);
155                    }
156
157                    let is_error = exit_code != Some(0);
158                    let details = serde_json::json!({
159                        "command": command,
160                        "cwd": cwd.to_string_lossy(),
161                        "exit_code": exit_code,
162                        "workspace_root": workspace_root.to_string_lossy(),
163                    });
164
165                    Ok(ToolResult {
166                        content: vec![OutputContent::Text { text: output }],
167                        details: Some(details),
168                        is_error,
169                        terminate: false,
170                    })
171                }
172                Err(msg) => Ok(ToolResult {
173                    content: vec![OutputContent::Text { text: msg }],
174                    details: Some(serde_json::json!({
175                        "command": command,
176                        "cwd": cwd.to_string_lossy(),
177                        "timed_out": true,
178                    })),
179                    is_error: true,
180                    terminate: false,
181                }),
182            }
183        })
184    }
185
186    fn execution_mode(&self) -> ExecutionMode {
187        ExecutionMode::Sequential
188    }
189}