1use std::collections::HashMap;
20use std::path::PathBuf;
21use std::sync::Mutex;
22use std::sync::atomic::{AtomicU64, Ordering};
23use std::sync::{Arc, OnceLock};
24
25use defect_agent::error::BoxError;
26use defect_agent::shell::{ShellBackend, ShellError, ShellOutput, TerminalExitStatus, TerminalId};
27use futures::future::BoxFuture;
28use tokio::io::{AsyncBufReadExt, BufReader};
29use tokio::process::{Child, Command};
30use tokio::sync::Notify;
31
32pub const DEFAULT_MAX_OUTPUT_BYTES: usize = 1024 * 1024;
35
36pub struct LocalShellBackend {
39 terminals: Mutex<HashMap<TerminalId, Arc<TerminalState>>>,
40 max_output_bytes: usize,
42}
43
44impl LocalShellBackend {
45 pub fn new() -> Self {
46 Self::with_max_output_bytes(DEFAULT_MAX_OUTPUT_BYTES)
47 }
48
49 pub fn with_max_output_bytes(max_output_bytes: usize) -> Self {
52 Self {
53 terminals: Mutex::new(HashMap::new()),
54 max_output_bytes: max_output_bytes.max(1),
55 }
56 }
57
58 fn lookup(&self, id: &TerminalId) -> Result<Arc<TerminalState>, ShellError> {
59 let guard = self
60 .terminals
61 .lock()
62 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
63 guard
64 .get(id)
65 .cloned()
66 .ok_or_else(|| ShellError::NotFound(id.clone()))
67 }
68}
69
70impl Default for LocalShellBackend {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76struct TerminalState {
79 output: Mutex<OutputBuffer>,
80 exit: Mutex<Option<TerminalExitStatus>>,
81 exit_notify: Notify,
82 kill_notify: Notify,
88}
89
90#[derive(Debug, thiserror::Error)]
91#[error("local shell backend mutex poisoned")]
92struct PoisonedTable;
93
94impl ShellBackend for LocalShellBackend {
95 fn create(
96 &self,
97 command: String,
98 cwd: PathBuf,
99 ) -> BoxFuture<'_, Result<TerminalId, ShellError>> {
100 Box::pin(async move {
101 let mut cmd = build_command(&command);
102 cmd.current_dir(&cwd)
103 .stdin(std::process::Stdio::null())
104 .stdout(std::process::Stdio::piped())
105 .stderr(std::process::Stdio::piped())
106 .kill_on_drop(true);
107
108 let mut child = cmd
109 .spawn()
110 .map_err(|err| ShellError::Backend(BoxError::new(err)))?;
111
112 let stdout = child.stdout.take().expect("piped stdout");
113 let stderr = child.stderr.take().expect("piped stderr");
114
115 let id = next_terminal_id();
116 let state = Arc::new(TerminalState {
117 output: Mutex::new(OutputBuffer::new(self.max_output_bytes)),
118 exit: Mutex::new(None),
119 exit_notify: Notify::new(),
120 kill_notify: Notify::new(),
121 });
122
123 {
124 let mut guard = self
125 .terminals
126 .lock()
127 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
128 guard.insert(id.clone(), state.clone());
129 }
130
131 tokio::spawn(reader_task(state, child, stdout, stderr));
132
133 Ok(id)
134 })
135 }
136
137 fn output(&self, id: &TerminalId) -> BoxFuture<'_, Result<ShellOutput, ShellError>> {
138 let id = id.clone();
139 Box::pin(async move {
140 let state = self.lookup(&id)?;
141 let (text, truncated) = {
142 let buf = state
143 .output
144 .lock()
145 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
146 (
147 String::from_utf8_lossy(buf.as_bytes()).into_owned(),
148 buf.truncated() > 0,
149 )
150 };
151 let exit_status = {
152 let exit = state
153 .exit
154 .lock()
155 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
156 exit.clone()
157 };
158 Ok(ShellOutput {
159 text,
160 truncated,
161 exit_status,
162 })
163 })
164 }
165
166 fn wait_for_exit(
167 &self,
168 id: &TerminalId,
169 ) -> BoxFuture<'_, Result<TerminalExitStatus, ShellError>> {
170 let id = id.clone();
171 Box::pin(async move {
172 let state = self.lookup(&id)?;
173 loop {
174 {
175 let exit = state
176 .exit
177 .lock()
178 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
179 if let Some(status) = exit.as_ref() {
180 return Ok(status.clone());
181 }
182 }
183 let notified = state.exit_notify.notified();
187 tokio::pin!(notified);
188 {
189 let exit = state
190 .exit
191 .lock()
192 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
193 if let Some(status) = exit.as_ref() {
194 return Ok(status.clone());
195 }
196 }
197 notified.await;
198 }
199 })
200 }
201
202 fn release(&self, id: &TerminalId) -> BoxFuture<'_, Result<(), ShellError>> {
203 let id = id.clone();
204 Box::pin(async move {
205 let removed = {
206 let mut guard = self
207 .terminals
208 .lock()
209 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
210 guard.remove(&id)
211 };
212 if let Some(state) = removed {
216 state.kill_notify.notify_one();
217 }
218 Ok(())
219 })
220 }
221
222 fn kill(&self, id: &TerminalId) -> BoxFuture<'_, Result<(), ShellError>> {
223 let id = id.clone();
224 Box::pin(async move {
225 let state = self.lookup(&id)?;
226 state.kill_notify.notify_one();
227 Ok(())
228 })
229 }
230}
231
232async fn reader_task(
233 state: Arc<TerminalState>,
234 mut child: Child,
235 stdout: tokio::process::ChildStdout,
236 stderr: tokio::process::ChildStderr,
237) {
238 let mut stdout_lines = BufReader::new(stdout).lines();
239 let mut stderr_lines = BufReader::new(stderr).lines();
240 let mut stdout_open = true;
241 let mut stderr_open = true;
242 let mut killed = false;
243
244 while stdout_open || stderr_open {
245 tokio::select! {
246 _ = state.kill_notify.notified(), if !killed => {
247 killed = true;
248 let _ = child.start_kill();
249 }
256 line = stdout_lines.next_line(), if stdout_open => {
257 match line {
258 Ok(Some(mut l)) => {
259 l.push('\n');
260 if let Ok(mut buf) = state.output.lock() {
261 buf.push(l.as_bytes());
262 }
263 }
264 _ => stdout_open = false,
265 }
266 }
267 line = stderr_lines.next_line(), if stderr_open => {
268 match line {
269 Ok(Some(mut l)) => {
270 l.push('\n');
271 if let Ok(mut buf) = state.output.lock() {
272 buf.push(l.as_bytes());
273 }
274 }
275 _ => stderr_open = false,
276 }
277 }
278 }
279 }
280 let _ = killed;
284
285 let wait_result = child.wait().await;
286 let status = decode_status(wait_result.ok().as_ref());
287 if let Ok(mut exit) = state.exit.lock() {
288 *exit = Some(status);
289 }
290 state.exit_notify.notify_waiters();
291}
292
293#[cfg(unix)]
294fn decode_status(status: Option<&std::process::ExitStatus>) -> TerminalExitStatus {
295 use std::os::unix::process::ExitStatusExt;
296 match status {
297 None => TerminalExitStatus {
298 exit_code: None,
299 signal: None,
300 },
301 Some(s) => {
302 if let Some(code) = s.code() {
303 TerminalExitStatus {
304 exit_code: Some(code),
305 signal: None,
306 }
307 } else if let Some(sig) = s.signal() {
308 TerminalExitStatus {
309 exit_code: None,
310 signal: Some(signal_name(sig)),
311 }
312 } else {
313 TerminalExitStatus {
314 exit_code: None,
315 signal: None,
316 }
317 }
318 }
319 }
320}
321
322#[cfg(windows)]
323fn decode_status(status: Option<&std::process::ExitStatus>) -> TerminalExitStatus {
324 match status {
325 None => TerminalExitStatus {
326 exit_code: None,
327 signal: None,
328 },
329 Some(s) => TerminalExitStatus {
330 exit_code: s.code(),
331 signal: None,
332 },
333 }
334}
335
336#[cfg(unix)]
337fn signal_name(sig: i32) -> String {
338 match sig {
339 1 => "SIGHUP".into(),
340 2 => "SIGINT".into(),
341 3 => "SIGQUIT".into(),
342 6 => "SIGABRT".into(),
343 9 => "SIGKILL".into(),
344 13 => "SIGPIPE".into(),
345 14 => "SIGALRM".into(),
346 15 => "SIGTERM".into(),
347 other => format!("SIG#{other}"),
348 }
349}
350
351#[cfg(unix)]
352fn build_command(command: &str) -> Command {
353 let mut cmd = Command::new("/bin/sh");
354 cmd.arg("-c").arg(command);
355 cmd
356}
357
358#[cfg(windows)]
359fn build_command(command: &str) -> Command {
360 let mut cmd = Command::new("cmd");
361 cmd.arg("/C").arg(command);
362 cmd
363}
364
365struct OutputBuffer {
368 bytes: Vec<u8>,
369 truncated: u64,
370 max_bytes: usize,
371}
372
373impl OutputBuffer {
374 fn new(max_bytes: usize) -> Self {
375 Self {
376 bytes: Vec::new(),
377 truncated: 0,
378 max_bytes,
379 }
380 }
381
382 fn push(&mut self, chunk: &[u8]) {
383 let remaining = self.max_bytes.saturating_sub(self.bytes.len());
384 if remaining == 0 {
385 self.truncated += chunk.len() as u64;
386 return;
387 }
388 if chunk.len() <= remaining {
389 self.bytes.extend_from_slice(chunk);
390 } else {
391 self.bytes
392 .extend_from_slice(chunk.get(..remaining).unwrap_or(chunk));
393 self.truncated += (chunk.len() - remaining) as u64;
394 }
395 }
396
397 fn as_bytes(&self) -> &[u8] {
398 &self.bytes
399 }
400
401 fn truncated(&self) -> u64 {
402 self.truncated
403 }
404}
405
406fn next_terminal_id() -> TerminalId {
409 static COUNTER: AtomicU64 = AtomicU64::new(0);
410 static PREFIX: OnceLock<String> = OnceLock::new();
411 let prefix = PREFIX.get_or_init(|| {
412 let ts = std::time::SystemTime::now()
413 .duration_since(std::time::UNIX_EPOCH)
414 .map(|d| d.as_nanos())
415 .unwrap_or(0);
416 format!("local-{ts:x}")
417 });
418 let n = COUNTER.fetch_add(1, Ordering::Relaxed);
419 TerminalId::new(format!("{prefix}-{n:x}"))
420}
421
422#[cfg(test)]
423mod tests;