1use std::collections::HashMap;
20use std::path::PathBuf;
21use std::sync::Mutex;
22use std::sync::atomic::{AtomicU64, Ordering};
23use std::sync::{Arc, OnceLock};
24
25use defect_agent::error::BoxError;
26use defect_agent::shell::{ShellBackend, ShellError, ShellOutput, TerminalExitStatus, TerminalId};
27use futures::future::BoxFuture;
28use tokio::io::{AsyncBufReadExt, BufReader};
29use tokio::process::{Child, Command};
30use tokio::sync::Notify;
31
32const MAX_OUTPUT_BYTES: usize = 1024 * 1024;
33
34pub struct LocalShellBackend {
37 terminals: Mutex<HashMap<TerminalId, Arc<TerminalState>>>,
38}
39
40impl LocalShellBackend {
41 pub fn new() -> Self {
42 Self {
43 terminals: Mutex::new(HashMap::new()),
44 }
45 }
46
47 fn lookup(&self, id: &TerminalId) -> Result<Arc<TerminalState>, ShellError> {
48 let guard = self
49 .terminals
50 .lock()
51 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
52 guard
53 .get(id)
54 .cloned()
55 .ok_or_else(|| ShellError::NotFound(id.clone()))
56 }
57}
58
59impl Default for LocalShellBackend {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65struct TerminalState {
68 output: Mutex<OutputBuffer>,
69 exit: Mutex<Option<TerminalExitStatus>>,
70 exit_notify: Notify,
71 kill_notify: Notify,
77}
78
79#[derive(Debug, thiserror::Error)]
80#[error("local shell backend mutex poisoned")]
81struct PoisonedTable;
82
83impl ShellBackend for LocalShellBackend {
84 fn create(
85 &self,
86 command: String,
87 cwd: PathBuf,
88 ) -> BoxFuture<'_, Result<TerminalId, ShellError>> {
89 Box::pin(async move {
90 let mut cmd = build_command(&command);
91 cmd.current_dir(&cwd)
92 .stdin(std::process::Stdio::null())
93 .stdout(std::process::Stdio::piped())
94 .stderr(std::process::Stdio::piped())
95 .kill_on_drop(true);
96
97 let mut child = cmd
98 .spawn()
99 .map_err(|err| ShellError::Backend(BoxError::new(err)))?;
100
101 let stdout = child.stdout.take().expect("piped stdout");
102 let stderr = child.stderr.take().expect("piped stderr");
103
104 let id = next_terminal_id();
105 let state = Arc::new(TerminalState {
106 output: Mutex::new(OutputBuffer::new()),
107 exit: Mutex::new(None),
108 exit_notify: Notify::new(),
109 kill_notify: Notify::new(),
110 });
111
112 {
113 let mut guard = self
114 .terminals
115 .lock()
116 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
117 guard.insert(id.clone(), state.clone());
118 }
119
120 tokio::spawn(reader_task(state, child, stdout, stderr));
121
122 Ok(id)
123 })
124 }
125
126 fn output(&self, id: &TerminalId) -> BoxFuture<'_, Result<ShellOutput, ShellError>> {
127 let id = id.clone();
128 Box::pin(async move {
129 let state = self.lookup(&id)?;
130 let (text, truncated) = {
131 let buf = state
132 .output
133 .lock()
134 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
135 (
136 String::from_utf8_lossy(buf.as_bytes()).into_owned(),
137 buf.truncated() > 0,
138 )
139 };
140 let exit_status = {
141 let exit = state
142 .exit
143 .lock()
144 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
145 exit.clone()
146 };
147 Ok(ShellOutput {
148 text,
149 truncated,
150 exit_status,
151 })
152 })
153 }
154
155 fn wait_for_exit(
156 &self,
157 id: &TerminalId,
158 ) -> BoxFuture<'_, Result<TerminalExitStatus, ShellError>> {
159 let id = id.clone();
160 Box::pin(async move {
161 let state = self.lookup(&id)?;
162 loop {
163 {
164 let exit = state
165 .exit
166 .lock()
167 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
168 if let Some(status) = exit.as_ref() {
169 return Ok(status.clone());
170 }
171 }
172 let notified = state.exit_notify.notified();
176 tokio::pin!(notified);
177 {
178 let exit = state
179 .exit
180 .lock()
181 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
182 if let Some(status) = exit.as_ref() {
183 return Ok(status.clone());
184 }
185 }
186 notified.await;
187 }
188 })
189 }
190
191 fn release(&self, id: &TerminalId) -> BoxFuture<'_, Result<(), ShellError>> {
192 let id = id.clone();
193 Box::pin(async move {
194 let removed = {
195 let mut guard = self
196 .terminals
197 .lock()
198 .map_err(|_| ShellError::Backend(BoxError::new(PoisonedTable)))?;
199 guard.remove(&id)
200 };
201 if let Some(state) = removed {
205 state.kill_notify.notify_one();
206 }
207 Ok(())
208 })
209 }
210
211 fn kill(&self, id: &TerminalId) -> BoxFuture<'_, Result<(), ShellError>> {
212 let id = id.clone();
213 Box::pin(async move {
214 let state = self.lookup(&id)?;
215 state.kill_notify.notify_one();
216 Ok(())
217 })
218 }
219}
220
221async fn reader_task(
222 state: Arc<TerminalState>,
223 mut child: Child,
224 stdout: tokio::process::ChildStdout,
225 stderr: tokio::process::ChildStderr,
226) {
227 let mut stdout_lines = BufReader::new(stdout).lines();
228 let mut stderr_lines = BufReader::new(stderr).lines();
229 let mut stdout_open = true;
230 let mut stderr_open = true;
231 let mut killed = false;
232
233 while stdout_open || stderr_open {
234 tokio::select! {
235 _ = state.kill_notify.notified(), if !killed => {
236 killed = true;
237 let _ = child.start_kill();
238 }
245 line = stdout_lines.next_line(), if stdout_open => {
246 match line {
247 Ok(Some(mut l)) => {
248 l.push('\n');
249 if let Ok(mut buf) = state.output.lock() {
250 buf.push(l.as_bytes());
251 }
252 }
253 _ => stdout_open = false,
254 }
255 }
256 line = stderr_lines.next_line(), if stderr_open => {
257 match line {
258 Ok(Some(mut l)) => {
259 l.push('\n');
260 if let Ok(mut buf) = state.output.lock() {
261 buf.push(l.as_bytes());
262 }
263 }
264 _ => stderr_open = false,
265 }
266 }
267 }
268 }
269 let _ = killed;
273
274 let wait_result = child.wait().await;
275 let status = decode_status(wait_result.ok().as_ref());
276 if let Ok(mut exit) = state.exit.lock() {
277 *exit = Some(status);
278 }
279 state.exit_notify.notify_waiters();
280}
281
282#[cfg(unix)]
283fn decode_status(status: Option<&std::process::ExitStatus>) -> TerminalExitStatus {
284 use std::os::unix::process::ExitStatusExt;
285 match status {
286 None => TerminalExitStatus {
287 exit_code: None,
288 signal: None,
289 },
290 Some(s) => {
291 if let Some(code) = s.code() {
292 TerminalExitStatus {
293 exit_code: Some(code),
294 signal: None,
295 }
296 } else if let Some(sig) = s.signal() {
297 TerminalExitStatus {
298 exit_code: None,
299 signal: Some(signal_name(sig)),
300 }
301 } else {
302 TerminalExitStatus {
303 exit_code: None,
304 signal: None,
305 }
306 }
307 }
308 }
309}
310
311#[cfg(windows)]
312fn decode_status(status: Option<&std::process::ExitStatus>) -> TerminalExitStatus {
313 match status {
314 None => TerminalExitStatus {
315 exit_code: None,
316 signal: None,
317 },
318 Some(s) => TerminalExitStatus {
319 exit_code: s.code(),
320 signal: None,
321 },
322 }
323}
324
325#[cfg(unix)]
326fn signal_name(sig: i32) -> String {
327 match sig {
328 1 => "SIGHUP".into(),
329 2 => "SIGINT".into(),
330 3 => "SIGQUIT".into(),
331 6 => "SIGABRT".into(),
332 9 => "SIGKILL".into(),
333 13 => "SIGPIPE".into(),
334 14 => "SIGALRM".into(),
335 15 => "SIGTERM".into(),
336 other => format!("SIG#{other}"),
337 }
338}
339
340#[cfg(unix)]
341fn build_command(command: &str) -> Command {
342 let mut cmd = Command::new("/bin/sh");
343 cmd.arg("-c").arg(command);
344 cmd
345}
346
347#[cfg(windows)]
348fn build_command(command: &str) -> Command {
349 let mut cmd = Command::new("cmd");
350 cmd.arg("/C").arg(command);
351 cmd
352}
353
354struct OutputBuffer {
357 bytes: Vec<u8>,
358 truncated: u64,
359}
360
361impl OutputBuffer {
362 fn new() -> Self {
363 Self {
364 bytes: Vec::new(),
365 truncated: 0,
366 }
367 }
368
369 fn push(&mut self, chunk: &[u8]) {
370 let remaining = MAX_OUTPUT_BYTES.saturating_sub(self.bytes.len());
371 if remaining == 0 {
372 self.truncated += chunk.len() as u64;
373 return;
374 }
375 if chunk.len() <= remaining {
376 self.bytes.extend_from_slice(chunk);
377 } else {
378 self.bytes
379 .extend_from_slice(chunk.get(..remaining).unwrap_or(chunk));
380 self.truncated += (chunk.len() - remaining) as u64;
381 }
382 }
383
384 fn as_bytes(&self) -> &[u8] {
385 &self.bytes
386 }
387
388 fn truncated(&self) -> u64 {
389 self.truncated
390 }
391}
392
393fn next_terminal_id() -> TerminalId {
396 static COUNTER: AtomicU64 = AtomicU64::new(0);
397 static PREFIX: OnceLock<String> = OnceLock::new();
398 let prefix = PREFIX.get_or_init(|| {
399 let ts = std::time::SystemTime::now()
400 .duration_since(std::time::UNIX_EPOCH)
401 .map(|d| d.as_nanos())
402 .unwrap_or(0);
403 format!("local-{ts:x}")
404 });
405 let n = COUNTER.fetch_add(1, Ordering::Relaxed);
406 TerminalId::new(format!("{prefix}-{n:x}"))
407}
408
409#[cfg(test)]
410mod tests;