Skip to main content

defect_tools/
shell.rs

1//! A [`ShellBackend`] implementation for local processes.
2//!
3//! Originates from the same inline `tokio::process::Command` flow historically used in
4//! the `bash` tool, but moves process management, buffered reads, and exit
5//! synchronization into the backend layer so that `BashTool` interacts only through the
6//! [`ShellBackend`] trait — a local shell execution backend.
7//!
8//! Internal data structures:
9//!
10//! - `LocalShellBackend.terminals: Mutex<HashMap<TerminalId, Arc<TerminalState>>>`
11//!   Global terminal table.
12//! - `TerminalState` holds the output buffer, `exit` status, `exit_notify`, and
13//!   `kill_notify`.
14//! - Each terminal spawns a **reader task**: blocks reading stdout/stderr → writes into
15//!   buffer → waits on `kill_notify` or both EOFs → calls `child.wait()` → writes `exit`
16//!   → calls `notify_waiters()`. The child is exclusively owned by the reader task to
17//!   avoid lock contention.
18
19use std::collections::HashMap;
20use std::path::PathBuf;
21use std::sync::Mutex;
22use std::sync::atomic::{AtomicU64, Ordering};
23use std::sync::{Arc, OnceLock};
24
25use defect_agent::error::BoxError;
26use defect_agent::shell::{ShellBackend, ShellError, ShellOutput, TerminalExitStatus, TerminalId};
27use futures::future::BoxFuture;
28use tokio::io::{AsyncBufReadExt, BufReader};
29use tokio::process::{Child, Command};
30use tokio::sync::Notify;
31
32const MAX_OUTPUT_BYTES: usize = 1024 * 1024;
33
34/// Local shell backend: each command spawns a `sh -c` child process, with state managed
35/// in the `terminals` table until `release`.
36pub struct LocalShellBackend {
37    terminals: Mutex<HashMap<TerminalId, Arc<TerminalState>>>,
38}
39
40impl LocalShellBackend {
41    pub fn new() -> Self {
42        Self {
43            terminals: Mutex::new(HashMap::new()),
44        }
45    }
46
47    fn lookup(&self, id: &TerminalId) -> Result<Arc<TerminalState>, ShellError> {
48        let guard = self
49            .terminals
50            .lock()
51            .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
52        guard
53            .get(id)
54            .cloned()
55            .ok_or_else(|| ShellError::NotFound(id.clone()))
56    }
57}
58
59impl Default for LocalShellBackend {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65/// Runtime state for a single terminal. The reader task and `output` / `wait_for_exit` /
66/// `kill` all share access via `Arc<TerminalState>`.
67struct TerminalState {
68    output: Mutex<OutputBuffer>,
69    exit: Mutex<Option<TerminalExitStatus>>,
70    exit_notify: Notify,
71    /// Set by `kill`; the reader task observes it in a `select` and calls
72    /// `Child::start_kill()`. Uses `notify_one()` to buffer a permit, preventing signal
73    /// loss when the reader task has not yet registered a waiter (`notify_waiters` only
74    /// wakes already-registered waiters). The reader task deduplicates via a `killed`
75    /// flag, so multiple kills are equivalent to one.
76    kill_notify: Notify,
77}
78
79#[derive(Debug, thiserror::Error)]
80#[error("local shell backend mutex poisoned")]
81struct PoisonedTable;
82
83impl ShellBackend for LocalShellBackend {
84    fn create(
85        &self,
86        command: String,
87        cwd: PathBuf,
88    ) -> BoxFuture<'_, Result<TerminalId, ShellError>> {
89        Box::pin(async move {
90            let mut cmd = build_command(&command);
91            cmd.current_dir(&cwd)
92                .stdin(std::process::Stdio::null())
93                .stdout(std::process::Stdio::piped())
94                .stderr(std::process::Stdio::piped())
95                .kill_on_drop(true);
96
97            let mut child = cmd
98                .spawn()
99                .map_err(|err| ShellError::Backend(BoxError::new(err)))?;
100
101            let stdout = child.stdout.take().expect("piped stdout");
102            let stderr = child.stderr.take().expect("piped stderr");
103
104            let id = next_terminal_id();
105            let state = Arc::new(TerminalState {
106                output: Mutex::new(OutputBuffer::new()),
107                exit: Mutex::new(None),
108                exit_notify: Notify::new(),
109                kill_notify: Notify::new(),
110            });
111
112            {
113                let mut guard = self
114                    .terminals
115                    .lock()
116                    .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
117                guard.insert(id.clone(), state.clone());
118            }
119
120            tokio::spawn(reader_task(state, child, stdout, stderr));
121
122            Ok(id)
123        })
124    }
125
126    fn output(&self, id: &TerminalId) -> BoxFuture<'_, Result<ShellOutput, ShellError>> {
127        let id = id.clone();
128        Box::pin(async move {
129            let state = self.lookup(&id)?;
130            let (text, truncated) = {
131                let buf = state
132                    .output
133                    .lock()
134                    .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
135                (
136                    String::from_utf8_lossy(buf.as_bytes()).into_owned(),
137                    buf.truncated() > 0,
138                )
139            };
140            let exit_status = {
141                let exit = state
142                    .exit
143                    .lock()
144                    .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
145                exit.clone()
146            };
147            Ok(ShellOutput {
148                text,
149                truncated,
150                exit_status,
151            })
152        })
153    }
154
155    fn wait_for_exit(
156        &self,
157        id: &TerminalId,
158    ) -> BoxFuture<'_, Result<TerminalExitStatus, ShellError>> {
159        let id = id.clone();
160        Box::pin(async move {
161            let state = self.lookup(&id)?;
162            loop {
163                {
164                    let exit = state
165                        .exit
166                        .lock()
167                        .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
168                    if let Some(status) = exit.as_ref() {
169                        return Ok(status.clone());
170                    }
171                }
172                // `notified()` only observes `notify_waiters` calls made **after** it is
173                // registered – so register first, then double-check for an already-set
174                // value to avoid a race.
175                let notified = state.exit_notify.notified();
176                tokio::pin!(notified);
177                {
178                    let exit = state
179                        .exit
180                        .lock()
181                        .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
182                    if let Some(status) = exit.as_ref() {
183                        return Ok(status.clone());
184                    }
185                }
186                notified.await;
187            }
188        })
189    }
190
191    fn release(&self, id: &TerminalId) -> BoxFuture<'_, Result<(), ShellError>> {
192        let id = id.clone();
193        Box::pin(async move {
194            let removed = {
195                let mut guard = self
196                    .terminals
197                    .lock()
198                    .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
199                guard.remove(&id)
200            };
201            // Notify the reader task to wind down if it is still running. The `Child`
202            // held by the reader task will be dropped when the task exits, triggering the
203            // `kill_on_drop` fallback.
204            if let Some(state) = removed {
205                state.kill_notify.notify_one();
206            }
207            Ok(())
208        })
209    }
210
211    fn kill(&self, id: &TerminalId) -> BoxFuture<'_, Result<(), ShellError>> {
212        let id = id.clone();
213        Box::pin(async move {
214            let state = self.lookup(&id)?;
215            state.kill_notify.notify_one();
216            Ok(())
217        })
218    }
219}
220
221async fn reader_task(
222    state: Arc<TerminalState>,
223    mut child: Child,
224    stdout: tokio::process::ChildStdout,
225    stderr: tokio::process::ChildStderr,
226) {
227    let mut stdout_lines = BufReader::new(stdout).lines();
228    let mut stderr_lines = BufReader::new(stderr).lines();
229    let mut stdout_open = true;
230    let mut stderr_open = true;
231    let mut killed = false;
232
233    while stdout_open || stderr_open {
234        tokio::select! {
235            _ = state.kill_notify.notified(), if !killed => {
236                killed = true;
237                let _ = child.start_kill();
238                // Continue draining: after `start_kill`, the child process receives
239                // SIGKILL, the pipe fds close, and both `next_line` calls will naturally
240                // return EOF. Note that commands like `sh -c "sleep N"` leave `sleep`
241                // alive because `sh` does not `exec` it; the caller is responsible for
242                // `exec`-ing the long-running part in the shell command (or accepting
243                // that `kill_on_drop` will handle it on release).
244            }
245            line = stdout_lines.next_line(), if stdout_open => {
246                match line {
247                    Ok(Some(mut l)) => {
248                        l.push('\n');
249                        if let Ok(mut buf) = state.output.lock() {
250                            buf.push(l.as_bytes());
251                        }
252                    }
253                    _ => stdout_open = false,
254                }
255            }
256            line = stderr_lines.next_line(), if stderr_open => {
257                match line {
258                    Ok(Some(mut l)) => {
259                        l.push('\n');
260                        if let Ok(mut buf) = state.output.lock() {
261                            buf.push(l.as_bytes());
262                        }
263                    }
264                    _ => stderr_open = false,
265                }
266            }
267        }
268    }
269    // When already killed, `killed` also means "terminated by external request" — the
270    // exit status from the `wait` below reflects the actual signal (SIGKILL/SIGTERM,
271    // etc.).
272    let _ = killed;
273
274    let wait_result = child.wait().await;
275    let status = decode_status(wait_result.ok().as_ref());
276    if let Ok(mut exit) = state.exit.lock() {
277        *exit = Some(status);
278    }
279    state.exit_notify.notify_waiters();
280}
281
282#[cfg(unix)]
283fn decode_status(status: Option<&std::process::ExitStatus>) -> TerminalExitStatus {
284    use std::os::unix::process::ExitStatusExt;
285    match status {
286        None => TerminalExitStatus {
287            exit_code: None,
288            signal: None,
289        },
290        Some(s) => {
291            if let Some(code) = s.code() {
292                TerminalExitStatus {
293                    exit_code: Some(code),
294                    signal: None,
295                }
296            } else if let Some(sig) = s.signal() {
297                TerminalExitStatus {
298                    exit_code: None,
299                    signal: Some(signal_name(sig)),
300                }
301            } else {
302                TerminalExitStatus {
303                    exit_code: None,
304                    signal: None,
305                }
306            }
307        }
308    }
309}
310
311#[cfg(windows)]
312fn decode_status(status: Option<&std::process::ExitStatus>) -> TerminalExitStatus {
313    match status {
314        None => TerminalExitStatus {
315            exit_code: None,
316            signal: None,
317        },
318        Some(s) => TerminalExitStatus {
319            exit_code: s.code(),
320            signal: None,
321        },
322    }
323}
324
325#[cfg(unix)]
326fn signal_name(sig: i32) -> String {
327    match sig {
328        1 => "SIGHUP".into(),
329        2 => "SIGINT".into(),
330        3 => "SIGQUIT".into(),
331        6 => "SIGABRT".into(),
332        9 => "SIGKILL".into(),
333        13 => "SIGPIPE".into(),
334        14 => "SIGALRM".into(),
335        15 => "SIGTERM".into(),
336        other => format!("SIG#{other}"),
337    }
338}
339
340#[cfg(unix)]
341fn build_command(command: &str) -> Command {
342    let mut cmd = Command::new("/bin/sh");
343    cmd.arg("-c").arg(command);
344    cmd
345}
346
347#[cfg(windows)]
348fn build_command(command: &str) -> Command {
349    let mut cmd = Command::new("cmd");
350    cmd.arg("/C").arg(command);
351    cmd
352}
353
354/// An append-only buffer with a 1 MiB cap. Excess bytes are dropped but counted in
355/// `truncated`.
356struct OutputBuffer {
357    bytes: Vec<u8>,
358    truncated: u64,
359}
360
361impl OutputBuffer {
362    fn new() -> Self {
363        Self {
364            bytes: Vec::new(),
365            truncated: 0,
366        }
367    }
368
369    fn push(&mut self, chunk: &[u8]) {
370        let remaining = MAX_OUTPUT_BYTES.saturating_sub(self.bytes.len());
371        if remaining == 0 {
372            self.truncated += chunk.len() as u64;
373            return;
374        }
375        if chunk.len() <= remaining {
376            self.bytes.extend_from_slice(chunk);
377        } else {
378            self.bytes
379                .extend_from_slice(chunk.get(..remaining).unwrap_or(chunk));
380            self.truncated += (chunk.len() - remaining) as u64;
381        }
382    }
383
384    fn as_bytes(&self) -> &[u8] {
385        &self.bytes
386    }
387
388    fn truncated(&self) -> u64 {
389        self.truncated
390    }
391}
392
393/// A monotonically increasing terminal ID generator. The prefix includes the nanos at
394/// process start to avoid conflicts with old IDs from future persistence scenarios.
395fn next_terminal_id() -> TerminalId {
396    static COUNTER: AtomicU64 = AtomicU64::new(0);
397    static PREFIX: OnceLock<String> = OnceLock::new();
398    let prefix = PREFIX.get_or_init(|| {
399        let ts = std::time::SystemTime::now()
400            .duration_since(std::time::UNIX_EPOCH)
401            .map(|d| d.as_nanos())
402            .unwrap_or(0);
403        format!("local-{ts:x}")
404    });
405    let n = COUNTER.fetch_add(1, Ordering::Relaxed);
406    TerminalId::new(format!("{prefix}-{n:x}"))
407}
408
409#[cfg(test)]
410mod tests;