Skip to main content

defect_tools/
shell.rs

1//! A [`ShellBackend`] implementation for local processes.
2//!
3//! Originates from the same inline `tokio::process::Command` flow historically used in
4//! the `bash` tool, but moves process management, buffered reads, and exit
5//! synchronization into the backend layer so that `BashTool` interacts only through the
6//! [`ShellBackend`] trait — a local shell execution backend.
7//!
8//! Internal data structures:
9//!
10//! - `LocalShellBackend.terminals: Mutex<HashMap<TerminalId, Arc<TerminalState>>>`
11//!   Global terminal table.
12//! - `TerminalState` holds the output buffer, `exit` status, `exit_notify`, and
13//!   `kill_notify`.
14//! - Each terminal spawns a **reader task**: blocks reading stdout/stderr → writes into
15//!   buffer → waits on `kill_notify` or both EOFs → calls `child.wait()` → writes `exit`
16//!   → calls `notify_waiters()`. The child is exclusively owned by the reader task to
17//!   avoid lock contention.
18
19use std::collections::HashMap;
20use std::path::PathBuf;
21use std::sync::Mutex;
22use std::sync::atomic::{AtomicU64, Ordering};
23use std::sync::{Arc, OnceLock};
24
25use defect_agent::error::BoxError;
26use defect_agent::shell::{ShellBackend, ShellError, ShellOutput, TerminalExitStatus, TerminalId};
27use futures::future::BoxFuture;
28use tokio::io::{AsyncBufReadExt, BufReader};
29use tokio::process::{Child, Command};
30use tokio::sync::Notify;
31
32/// Default per-terminal captured-output cap (1 MiB). Output beyond this is dropped and
33/// counted as `truncated`.
34pub const DEFAULT_MAX_OUTPUT_BYTES: usize = 1024 * 1024;
35
36/// Local shell backend: each command spawns a `sh -c` child process, with state managed
37/// in the `terminals` table until `release`.
38pub struct LocalShellBackend {
39    terminals: Mutex<HashMap<TerminalId, Arc<TerminalState>>>,
40    /// Maximum bytes of merged stdout/stderr captured per terminal; excess is truncated.
41    max_output_bytes: usize,
42}
43
44impl LocalShellBackend {
45    pub fn new() -> Self {
46        Self::with_max_output_bytes(DEFAULT_MAX_OUTPUT_BYTES)
47    }
48
49    /// Constructs a backend with an explicit captured-output cap. `0` is clamped to `1` so
50    /// at least one byte can always be captured.
51    pub fn with_max_output_bytes(max_output_bytes: usize) -> Self {
52        Self {
53            terminals: Mutex::new(HashMap::new()),
54            max_output_bytes: max_output_bytes.max(1),
55        }
56    }
57
58    fn lookup(&self, id: &TerminalId) -> Result<Arc<TerminalState>, ShellError> {
59        let guard = self
60            .terminals
61            .lock()
62            .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
63        guard
64            .get(id)
65            .cloned()
66            .ok_or_else(|| ShellError::NotFound(id.clone()))
67    }
68}
69
70impl Default for LocalShellBackend {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76/// Runtime state for a single terminal. The reader task and `output` / `wait_for_exit` /
77/// `kill` all share access via `Arc<TerminalState>`.
78struct TerminalState {
79    output: Mutex<OutputBuffer>,
80    exit: Mutex<Option<TerminalExitStatus>>,
81    exit_notify: Notify,
82    /// Set by `kill`; the reader task observes it in a `select` and calls
83    /// `Child::start_kill()`. Uses `notify_one()` to buffer a permit, preventing signal
84    /// loss when the reader task has not yet registered a waiter (`notify_waiters` only
85    /// wakes already-registered waiters). The reader task deduplicates via a `killed`
86    /// flag, so multiple kills are equivalent to one.
87    kill_notify: Notify,
88}
89
90#[derive(Debug, thiserror::Error)]
91#[error("local shell backend mutex poisoned")]
92struct PoisonedTable;
93
94impl ShellBackend for LocalShellBackend {
95    fn create(
96        &self,
97        command: String,
98        cwd: PathBuf,
99    ) -> BoxFuture<'_, Result<TerminalId, ShellError>> {
100        Box::pin(async move {
101            let mut cmd = build_command(&command);
102            cmd.current_dir(&cwd)
103                .stdin(std::process::Stdio::null())
104                .stdout(std::process::Stdio::piped())
105                .stderr(std::process::Stdio::piped())
106                .kill_on_drop(true);
107
108            let mut child = cmd
109                .spawn()
110                .map_err(|err| ShellError::Backend(BoxError::new(err)))?;
111
112            let stdout = child.stdout.take().expect("piped stdout");
113            let stderr = child.stderr.take().expect("piped stderr");
114
115            let id = next_terminal_id();
116            let state = Arc::new(TerminalState {
117                output: Mutex::new(OutputBuffer::new(self.max_output_bytes)),
118                exit: Mutex::new(None),
119                exit_notify: Notify::new(),
120                kill_notify: Notify::new(),
121            });
122
123            {
124                let mut guard = self
125                    .terminals
126                    .lock()
127                    .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
128                guard.insert(id.clone(), state.clone());
129            }
130
131            tokio::spawn(reader_task(state, child, stdout, stderr));
132
133            Ok(id)
134        })
135    }
136
137    fn output(&self, id: &TerminalId) -> BoxFuture<'_, Result<ShellOutput, ShellError>> {
138        let id = id.clone();
139        Box::pin(async move {
140            let state = self.lookup(&id)?;
141            let (text, truncated) = {
142                let buf = state
143                    .output
144                    .lock()
145                    .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
146                (
147                    String::from_utf8_lossy(buf.as_bytes()).into_owned(),
148                    buf.truncated() > 0,
149                )
150            };
151            let exit_status = {
152                let exit = state
153                    .exit
154                    .lock()
155                    .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
156                exit.clone()
157            };
158            Ok(ShellOutput {
159                text,
160                truncated,
161                exit_status,
162            })
163        })
164    }
165
166    fn wait_for_exit(
167        &self,
168        id: &TerminalId,
169    ) -> BoxFuture<'_, Result<TerminalExitStatus, ShellError>> {
170        let id = id.clone();
171        Box::pin(async move {
172            let state = self.lookup(&id)?;
173            loop {
174                {
175                    let exit = state
176                        .exit
177                        .lock()
178                        .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
179                    if let Some(status) = exit.as_ref() {
180                        return Ok(status.clone());
181                    }
182                }
183                // `notified()` only observes `notify_waiters` calls made **after** it is
184                // registered – so register first, then double-check for an already-set
185                // value to avoid a race.
186                let notified = state.exit_notify.notified();
187                tokio::pin!(notified);
188                {
189                    let exit = state
190                        .exit
191                        .lock()
192                        .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
193                    if let Some(status) = exit.as_ref() {
194                        return Ok(status.clone());
195                    }
196                }
197                notified.await;
198            }
199        })
200    }
201
202    fn release(&self, id: &TerminalId) -> BoxFuture<'_, Result<(), ShellError>> {
203        let id = id.clone();
204        Box::pin(async move {
205            let removed = {
206                let mut guard = self
207                    .terminals
208                    .lock()
209                    .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
210                guard.remove(&id)
211            };
212            // Notify the reader task to wind down if it is still running. The `Child`
213            // held by the reader task will be dropped when the task exits, triggering the
214            // `kill_on_drop` fallback.
215            if let Some(state) = removed {
216                state.kill_notify.notify_one();
217            }
218            Ok(())
219        })
220    }
221
222    fn kill(&self, id: &TerminalId) -> BoxFuture<'_, Result<(), ShellError>> {
223        let id = id.clone();
224        Box::pin(async move {
225            let state = self.lookup(&id)?;
226            state.kill_notify.notify_one();
227            Ok(())
228        })
229    }
230}
231
232async fn reader_task(
233    state: Arc<TerminalState>,
234    mut child: Child,
235    stdout: tokio::process::ChildStdout,
236    stderr: tokio::process::ChildStderr,
237) {
238    let mut stdout_lines = BufReader::new(stdout).lines();
239    let mut stderr_lines = BufReader::new(stderr).lines();
240    let mut stdout_open = true;
241    let mut stderr_open = true;
242    let mut killed = false;
243
244    while stdout_open || stderr_open {
245        tokio::select! {
246            _ = state.kill_notify.notified(), if !killed => {
247                killed = true;
248                let _ = child.start_kill();
249                // Continue draining: after `start_kill`, the child process receives
250                // SIGKILL, the pipe fds close, and both `next_line` calls will naturally
251                // return EOF. Note that commands like `sh -c "sleep N"` leave `sleep`
252                // alive because `sh` does not `exec` it; the caller is responsible for
253                // `exec`-ing the long-running part in the shell command (or accepting
254                // that `kill_on_drop` will handle it on release).
255            }
256            line = stdout_lines.next_line(), if stdout_open => {
257                match line {
258                    Ok(Some(mut l)) => {
259                        l.push('\n');
260                        if let Ok(mut buf) = state.output.lock() {
261                            buf.push(l.as_bytes());
262                        }
263                    }
264                    _ => stdout_open = false,
265                }
266            }
267            line = stderr_lines.next_line(), if stderr_open => {
268                match line {
269                    Ok(Some(mut l)) => {
270                        l.push('\n');
271                        if let Ok(mut buf) = state.output.lock() {
272                            buf.push(l.as_bytes());
273                        }
274                    }
275                    _ => stderr_open = false,
276                }
277            }
278        }
279    }
280    // When already killed, `killed` also means "terminated by external request" — the
281    // exit status from the `wait` below reflects the actual signal (SIGKILL/SIGTERM,
282    // etc.).
283    let _ = killed;
284
285    let wait_result = child.wait().await;
286    let status = decode_status(wait_result.ok().as_ref());
287    if let Ok(mut exit) = state.exit.lock() {
288        *exit = Some(status);
289    }
290    state.exit_notify.notify_waiters();
291}
292
293#[cfg(unix)]
294fn decode_status(status: Option<&std::process::ExitStatus>) -> TerminalExitStatus {
295    use std::os::unix::process::ExitStatusExt;
296    match status {
297        None => TerminalExitStatus {
298            exit_code: None,
299            signal: None,
300        },
301        Some(s) => {
302            if let Some(code) = s.code() {
303                TerminalExitStatus {
304                    exit_code: Some(code),
305                    signal: None,
306                }
307            } else if let Some(sig) = s.signal() {
308                TerminalExitStatus {
309                    exit_code: None,
310                    signal: Some(signal_name(sig)),
311                }
312            } else {
313                TerminalExitStatus {
314                    exit_code: None,
315                    signal: None,
316                }
317            }
318        }
319    }
320}
321
322#[cfg(windows)]
323fn decode_status(status: Option<&std::process::ExitStatus>) -> TerminalExitStatus {
324    match status {
325        None => TerminalExitStatus {
326            exit_code: None,
327            signal: None,
328        },
329        Some(s) => TerminalExitStatus {
330            exit_code: s.code(),
331            signal: None,
332        },
333    }
334}
335
336#[cfg(unix)]
337fn signal_name(sig: i32) -> String {
338    match sig {
339        1 => "SIGHUP".into(),
340        2 => "SIGINT".into(),
341        3 => "SIGQUIT".into(),
342        6 => "SIGABRT".into(),
343        9 => "SIGKILL".into(),
344        13 => "SIGPIPE".into(),
345        14 => "SIGALRM".into(),
346        15 => "SIGTERM".into(),
347        other => format!("SIG#{other}"),
348    }
349}
350
351#[cfg(unix)]
352fn build_command(command: &str) -> Command {
353    let mut cmd = Command::new("/bin/sh");
354    cmd.arg("-c").arg(command);
355    cmd
356}
357
358#[cfg(windows)]
359fn build_command(command: &str) -> Command {
360    let mut cmd = Command::new("cmd");
361    cmd.arg("/C").arg(command);
362    cmd
363}
364
365/// An append-only buffer with a 1 MiB cap. Excess bytes are dropped but counted in
366/// `truncated`.
367struct OutputBuffer {
368    bytes: Vec<u8>,
369    truncated: u64,
370    max_bytes: usize,
371}
372
373impl OutputBuffer {
374    fn new(max_bytes: usize) -> Self {
375        Self {
376            bytes: Vec::new(),
377            truncated: 0,
378            max_bytes,
379        }
380    }
381
382    fn push(&mut self, chunk: &[u8]) {
383        let remaining = self.max_bytes.saturating_sub(self.bytes.len());
384        if remaining == 0 {
385            self.truncated += chunk.len() as u64;
386            return;
387        }
388        if chunk.len() <= remaining {
389            self.bytes.extend_from_slice(chunk);
390        } else {
391            self.bytes
392                .extend_from_slice(chunk.get(..remaining).unwrap_or(chunk));
393            self.truncated += (chunk.len() - remaining) as u64;
394        }
395    }
396
397    fn as_bytes(&self) -> &[u8] {
398        &self.bytes
399    }
400
401    fn truncated(&self) -> u64 {
402        self.truncated
403    }
404}
405
406/// A monotonically increasing terminal ID generator. The prefix includes the nanos at
407/// process start to avoid conflicts with old IDs from future persistence scenarios.
408fn next_terminal_id() -> TerminalId {
409    static COUNTER: AtomicU64 = AtomicU64::new(0);
410    static PREFIX: OnceLock<String> = OnceLock::new();
411    let prefix = PREFIX.get_or_init(|| {
412        let ts = std::time::SystemTime::now()
413            .duration_since(std::time::UNIX_EPOCH)
414            .map(|d| d.as_nanos())
415            .unwrap_or(0);
416        format!("local-{ts:x}")
417    });
418    let n = COUNTER.fetch_add(1, Ordering::Relaxed);
419    TerminalId::new(format!("{prefix}-{n:x}"))
420}
421
422#[cfg(test)]
423mod tests;