Skip to main content

aft/bash_background/
pty_process.rs

1use std::collections::HashMap;
2use std::fs::{self, OpenOptions};
3use std::io::{self, Read, Write};
4#[cfg(unix)]
5use std::os::unix::fs::PermissionsExt;
6use std::path::Path;
7#[cfg(unix)]
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Mutex};
11use std::thread;
12
13use portable_pty::{CommandBuilder, PtySize};
14
15use super::persistence::{atomic_write, ExitMarker, TaskPaths};
16use super::pty_runtime::{CompletionCoordinator, PtyRuntime};
17
18#[allow(clippy::too_many_arguments)]
19pub(crate) fn spawn_pty_for_command(
20    task_id: &str,
21    session_id: &str,
22    user_command: &str,
23    paths: &TaskPaths,
24    workdir: &Path,
25    env: &HashMap<String, String>,
26    rows: u16,
27    cols: u16,
28    wake_tx: crossbeam_channel::Sender<()>,
29) -> Result<PtyRuntime, String> {
30    #[cfg(unix)]
31    {
32        let shell = resolve_posix_shell();
33        let mut command = CommandBuilder::new(shell.as_os_str());
34        command.arg("-c");
35        command.arg(user_command);
36        command.cwd(workdir.as_os_str());
37        for (key, value) in env {
38            command.env(key, value);
39        }
40        try_spawn_pty(task_id, session_id, command, paths, rows, cols, wake_tx)
41    }
42    #[cfg(windows)]
43    {
44        use crate::windows_shell::shell_candidates;
45
46        let candidates = shell_candidates();
47        let mut last_err = String::from("no Windows shell candidates available");
48
49        for shell in candidates {
50            let wrapper_body = shell.wrapper_script_bytes(user_command, &paths.exit);
51            let wrapper_path = windows_wrapper_path(paths, &shell);
52            if let Err(error) = fs::write(&wrapper_path, wrapper_body) {
53                last_err = format!("write wrapper {wrapper_path:?}: {error}");
54                continue;
55            }
56
57            let mut command = CommandBuilder::new(shell.binary().as_ref());
58            for arg in shell.pty_wrapper_args(&wrapper_path) {
59                command.arg(arg);
60            }
61            command.cwd(workdir.as_os_str());
62            for (key, value) in env {
63                command.env(key, value);
64            }
65
66            match try_spawn_pty(
67                task_id,
68                session_id,
69                command,
70                paths,
71                rows,
72                cols,
73                wake_tx.clone(),
74            ) {
75                Ok(runtime) => return Ok(runtime),
76                Err(error) => {
77                    let msg = format!("{shell:?}: {error}");
78                    if msg.contains("NotFound") || msg.contains("not recognized") {
79                        last_err = msg;
80                        continue;
81                    }
82                    return Err(msg);
83                }
84            }
85        }
86
87        Err(last_err)
88    }
89}
90
91#[cfg(unix)]
92fn resolve_posix_shell() -> PathBuf {
93    resolve_posix_shell_with(
94        || std::env::var_os("SHELL").map(PathBuf::from),
95        is_executable_file,
96    )
97}
98
99#[cfg(unix)]
100fn resolve_posix_shell_with<S, X>(shell_env: S, is_executable: X) -> PathBuf
101where
102    S: FnOnce() -> Option<PathBuf>,
103    X: Fn(&Path) -> bool,
104{
105    if let Some(shell) =
106        shell_env().filter(|path| !path.as_os_str().is_empty() && is_executable(path.as_path()))
107    {
108        return shell;
109    }
110
111    for fallback in ["/bin/bash", "/bin/sh", "/bin/zsh"] {
112        let path = PathBuf::from(fallback);
113        if is_executable(&path) {
114            return path;
115        }
116    }
117
118    PathBuf::from("/bin/sh")
119}
120
121#[cfg(unix)]
122fn is_executable_file(path: &Path) -> bool {
123    fs::metadata(path)
124        .map(|metadata| metadata.is_file() && metadata.permissions().mode() & 0o111 != 0)
125        .unwrap_or(false)
126}
127
128#[cfg(windows)]
129fn windows_wrapper_path(
130    paths: &TaskPaths,
131    shell: &crate::windows_shell::WindowsShell,
132) -> std::path::PathBuf {
133    let extension = match shell {
134        crate::windows_shell::WindowsShell::Pwsh
135        | crate::windows_shell::WindowsShell::Powershell => "ps1",
136        crate::windows_shell::WindowsShell::Cmd => "bat",
137        crate::windows_shell::WindowsShell::Posix(_) => "sh",
138    };
139    let stem = paths
140        .json
141        .file_stem()
142        .and_then(|stem| stem.to_str())
143        .unwrap_or("wrapper");
144    paths.dir.join(format!("{stem}.{extension}"))
145}
146
147#[allow(clippy::too_many_arguments)]
148fn try_spawn_pty(
149    task_id: &str,
150    session_id: &str,
151    command: CommandBuilder,
152    paths: &TaskPaths,
153    rows: u16,
154    cols: u16,
155    wake_tx: crossbeam_channel::Sender<()>,
156) -> Result<PtyRuntime, String> {
157    let pty_system = portable_pty::native_pty_system();
158    let pair = pty_system
159        .openpty(PtySize {
160            rows,
161            cols,
162            pixel_width: 0,
163            pixel_height: 0,
164        })
165        .map_err(|error| format!("open PTY failed: {error}"))?;
166    let child = pair
167        .slave
168        .spawn_command(command)
169        .map_err(|error| format!("spawn PTY command failed: {error}"))?;
170    let child_pid = child.process_id();
171    let killer = child.clone_killer();
172    let reader = pair
173        .master
174        .try_clone_reader()
175        .map_err(|error| format!("clone PTY reader failed: {error}"))?;
176    let writer = pair
177        .master
178        .take_writer()
179        .map_err(|error| format!("take PTY writer failed: {error}"))?;
180
181    let reader_done = Arc::new(AtomicBool::new(false));
182    let exit_observed = Arc::new(AtomicBool::new(false));
183    let was_killed = Arc::new(AtomicBool::new(false));
184    let coordinator = Arc::new(CompletionCoordinator::new(
185        task_id.to_string(),
186        session_id.to_string(),
187        wake_tx,
188    ));
189
190    let writer = Arc::new(Mutex::new(writer));
191    spawn_reader(
192        reader,
193        paths.pty.clone(),
194        Arc::clone(&reader_done),
195        Arc::clone(&coordinator),
196        Some(Arc::clone(&writer)),
197    );
198    spawn_waiter(
199        child,
200        paths.exit.clone(),
201        Arc::clone(&was_killed),
202        Arc::clone(&exit_observed),
203        Arc::clone(&coordinator),
204    );
205
206    Ok(PtyRuntime {
207        master: Some(pair.master),
208        writer,
209        killer,
210        child_pid,
211        reader_done,
212        exit_observed,
213        was_killed,
214        coordinator,
215    })
216}
217
218/// DSR escape sequence `\x1b[6n` is 4 bytes, so a carry of the last 3 bytes of
219/// the running stream is enough to detect a needle straddling any read boundary
220/// (see `DsrScanner`).
221const DSR_CARRY_OVER: usize = 3;
222
223pub(crate) fn spawn_reader(
224    mut reader: Box<dyn Read + Send>,
225    spill_path: std::path::PathBuf,
226    reader_done: Arc<AtomicBool>,
227    coordinator: Arc<CompletionCoordinator>,
228    writer: Option<Arc<Mutex<Box<dyn Write + Send>>>>,
229) {
230    thread::spawn(move || {
231        let result = (|| -> io::Result<()> {
232            if let Some(parent) = spill_path.parent() {
233                fs::create_dir_all(parent)?;
234            }
235            let mut file = OpenOptions::new()
236                .create(true)
237                .append(true)
238                .open(&spill_path)?;
239            let mut buf = [0_u8; 8192];
240            let mut dsr = DsrScanner::default();
241            loop {
242                match reader.read(&mut buf) {
243                    Ok(0) => break,
244                    Ok(n) => {
245                        file.write_all(&buf[..n])?;
246                        file.flush()?;
247                        if dsr.scan(&buf[..n]) {
248                            // Some Windows console hosts/apps query the
249                            // terminal cursor position with DSR (ESC[6n)
250                            // before accepting input. A real terminal answers
251                            // with ESC[row;colR; without that response the
252                            // process can sit forever after emitting only the
253                            // query. We own both ends of the PTY, so provide a
254                            // conservative 1;1 response.
255                            if let Some(writer) = writer.as_ref() {
256                                if let Ok(mut writer) = writer.lock() {
257                                    let _ = writer.write_all(b"\x1b[1;1R");
258                                    let _ = writer.flush();
259                                }
260                            }
261                        }
262                    }
263                    Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
264                    Err(error) => return Err(error),
265                }
266            }
267            Ok(())
268        })();
269        if let Err(error) = result {
270            crate::slog_warn!(
271                "PTY reader for {}:{} stopped with error: {error}",
272                coordinator.session_id,
273                coordinator.task_id
274            );
275        }
276        reader_done.store(true, Ordering::SeqCst);
277        coordinator.signal_one_done();
278    });
279}
280
281/// Detects the DSR cursor-position query `\x1b[6n` (4 bytes) in a byte stream
282/// delivered in arbitrary chunks. A `read()` may return as little as one byte,
283/// so the 4-byte needle can split across ANY number of reads. We keep a rolling
284/// carry of the last 3 bytes seen and prepend it to each new chunk before
285/// scanning. Crucially the carry is taken from the COMBINED (carry + chunk)
286/// buffer, not just the chunk's own tail — that is what lets detection survive
287/// more than two reads (e.g. `\x1b`, `[`, `6`, `n` arriving as four reads).
288#[derive(Default)]
289struct DsrScanner {
290    carry: Vec<u8>,
291}
292
293impl DsrScanner {
294    fn scan(&mut self, chunk: &[u8]) -> bool {
295        let mut combined = Vec::with_capacity(self.carry.len() + chunk.len());
296        combined.extend_from_slice(&self.carry);
297        combined.extend_from_slice(chunk);
298        let detected = combined.windows(4).any(|w| w == b"\x1b[6n");
299        // Carry the last 3 bytes of the combined stream forward: a 4-byte
300        // needle straddling this boundary keeps at most 3 bytes on this side.
301        let start = combined.len().saturating_sub(DSR_CARRY_OVER);
302        self.carry.clear();
303        self.carry.extend_from_slice(&combined[start..]);
304        detected
305    }
306}
307
308pub(crate) fn spawn_waiter(
309    mut child: Box<dyn portable_pty::Child + Send + Sync>,
310    exit_path: std::path::PathBuf,
311    was_killed: Arc<AtomicBool>,
312    exit_observed: Arc<AtomicBool>,
313    coordinator: Arc<CompletionCoordinator>,
314) {
315    thread::spawn(move || {
316        let marker = loop {
317            match child.wait() {
318                Ok(status) => {
319                    if was_killed.load(Ordering::SeqCst) {
320                        break ExitMarker::Killed;
321                    }
322                    let code = i32::try_from(status.exit_code()).unwrap_or(i32::MAX);
323                    break ExitMarker::Code(code);
324                }
325                Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
326                Err(error) => {
327                    crate::slog_warn!(
328                        "PTY waiter for {}:{} failed: {error}",
329                        coordinator.session_id,
330                        coordinator.task_id
331                    );
332                    break ExitMarker::Killed;
333                }
334            }
335        };
336
337        if let Err(error) = write_exit_marker(&exit_path, &marker, &coordinator.task_id) {
338            crate::slog_warn!(
339                "PTY waiter for {}:{} failed to write exit marker: {error}",
340                coordinator.session_id,
341                coordinator.task_id
342            );
343        }
344        exit_observed.store(true, Ordering::SeqCst);
345        coordinator.signal_one_done();
346    });
347}
348
349fn write_exit_marker(path: &Path, marker: &ExitMarker, task_id: &str) -> io::Result<()> {
350    let content = match marker {
351        ExitMarker::Code(code) => code.to_string(),
352        ExitMarker::Killed => "killed".to_string(),
353    };
354    atomic_write(path, content.as_bytes(), task_id)
355}
356
357// Every test in this module exercises Unix-only PTY paths (`#[cfg(unix)]`
358// shell resolution + the spawn_waiter), so gate the whole module on `unix` to
359// avoid unused-import / dead-code warnings when cross-compiling for Windows.
360#[cfg(all(test, unix))]
361mod tests {
362    use std::io;
363    use std::sync::atomic::{AtomicBool, Ordering};
364    use std::sync::Arc;
365    use std::time::{Duration, Instant};
366
367    use portable_pty::{Child, ChildKiller, ExitStatus};
368
369    use super::*;
370
371    #[derive(Debug)]
372    struct FakeKiller;
373
374    impl ChildKiller for FakeKiller {
375        fn kill(&mut self) -> io::Result<()> {
376            Ok(())
377        }
378
379        fn clone_killer(&self) -> Box<dyn ChildKiller + Send + Sync> {
380            Box::new(FakeKiller)
381        }
382    }
383
384    #[derive(Debug)]
385    struct InterruptedOnceChild {
386        waits: usize,
387    }
388
389    impl ChildKiller for InterruptedOnceChild {
390        fn kill(&mut self) -> io::Result<()> {
391            Ok(())
392        }
393
394        fn clone_killer(&self) -> Box<dyn ChildKiller + Send + Sync> {
395            Box::new(FakeKiller)
396        }
397    }
398
399    impl Child for InterruptedOnceChild {
400        fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> {
401            Ok(None)
402        }
403
404        fn wait(&mut self) -> io::Result<ExitStatus> {
405            self.waits += 1;
406            if self.waits == 1 {
407                Err(io::Error::from(io::ErrorKind::Interrupted))
408            } else {
409                Ok(ExitStatus::with_exit_code(0))
410            }
411        }
412
413        fn process_id(&self) -> Option<u32> {
414            None
415        }
416
417        #[cfg(windows)]
418        fn as_raw_handle(&self) -> Option<std::os::windows::io::RawHandle> {
419            None
420        }
421    }
422
423    #[cfg(unix)]
424    #[test]
425    fn pty_shell_prefers_executable_shell_env() {
426        let shell = PathBuf::from("/custom/zsh");
427        let resolved =
428            resolve_posix_shell_with(|| Some(shell.clone()), |path| path == shell.as_path());
429
430        assert_eq!(resolved, shell);
431    }
432
433    #[cfg(unix)]
434    #[test]
435    fn pty_shell_ignores_unusable_shell_env_and_uses_fallback_order() {
436        let resolved = resolve_posix_shell_with(
437            || Some(PathBuf::from("/missing/fish")),
438            |path| path == Path::new("/bin/sh") || path == Path::new("/bin/zsh"),
439        );
440
441        assert_eq!(resolved, PathBuf::from("/bin/sh"));
442    }
443
444    #[cfg(unix)]
445    #[test]
446    fn pty_shell_uses_bin_bash_before_later_fallbacks() {
447        let resolved = resolve_posix_shell_with(
448            || None,
449            |path| path == Path::new("/bin/bash") || path == Path::new("/bin/sh"),
450        );
451
452        assert_eq!(resolved, PathBuf::from("/bin/bash"));
453    }
454
455    #[cfg(unix)]
456    #[test]
457    fn pty_waiter_retries_wait_on_interrupted() {
458        let temp = tempfile::tempdir().unwrap();
459        let exit_path = temp.path().join("task.exit");
460        let (wake_tx, wake_rx) = crossbeam_channel::bounded(1);
461        let coordinator = Arc::new(CompletionCoordinator::new(
462            "task".to_string(),
463            "session".to_string(),
464            wake_tx,
465        ));
466        let was_killed = Arc::new(AtomicBool::new(false));
467        let exit_observed = Arc::new(AtomicBool::new(false));
468
469        spawn_waiter(
470            Box::new(InterruptedOnceChild { waits: 0 }),
471            exit_path.clone(),
472            was_killed,
473            Arc::clone(&exit_observed),
474            Arc::clone(&coordinator),
475        );
476        coordinator.signal_one_done();
477
478        let started = Instant::now();
479        while !exit_observed.load(Ordering::SeqCst) {
480            assert!(started.elapsed() < Duration::from_secs(2));
481            std::thread::sleep(Duration::from_millis(10));
482        }
483        wake_rx.recv_timeout(Duration::from_secs(1)).unwrap();
484        assert_eq!(fs::read_to_string(exit_path).unwrap(), "0");
485    }
486
487    /// Feed `chunks` to a fresh scanner and return how many chunks reported a
488    /// detection. The needle appears exactly once across all chunks, so a
489    /// correct scanner returns exactly 1 (detect once, never double-fire).
490    fn scan_chunks(chunks: &[&[u8]]) -> usize {
491        let mut scanner = DsrScanner::default();
492        chunks.iter().filter(|chunk| scanner.scan(chunk)).count()
493    }
494
495    #[test]
496    fn dsr_detected_within_single_read() {
497        assert_eq!(scan_chunks(&[b"\x1b[6n"]), 1);
498    }
499
500    #[test]
501    fn dsr_detected_two_read_splits() {
502        assert_eq!(scan_chunks(&[b"\x1b[6", b"n"]), 1); // 3/1
503        assert_eq!(scan_chunks(&[b"\x1b[", b"6n"]), 1); // 2/2
504        assert_eq!(scan_chunks(&[b"\x1b", b"[6n"]), 1); // 1/3
505    }
506
507    #[test]
508    fn dsr_detected_across_more_than_two_reads() {
509        // The original carry-over (keep only the chunk's own last 3 bytes)
510        // missed these because by the time `n` arrives the `\x1b` had aged out.
511        assert_eq!(scan_chunks(&[b"\x1b", b"[", b"6", b"n"]), 1); // 1/1/1/1
512        assert_eq!(scan_chunks(&[b"\x1b", b"[", b"6n"]), 1); // 1/1/2
513        assert_eq!(scan_chunks(&[b"\x1b[", b"6", b"n"]), 1); // 2/1/1
514                                                             // With unrelated leading noise that pushes the needle across reads.
515        assert_eq!(scan_chunks(&[b"junk\x1b", b"[6", b"n more"]), 1);
516    }
517
518    #[test]
519    fn dsr_detected_once_with_surrounding_output() {
520        // Single sequence embedded in a larger single read fires exactly once.
521        assert_eq!(scan_chunks(&[b"hello\x1b[6nworld"]), 1);
522    }
523
524    #[test]
525    fn dsr_not_detected_no_match() {
526        assert_eq!(scan_chunks(&[b"abc", b"def"]), 0);
527        assert_eq!(scan_chunks(&[b"hello"]), 0);
528        // Partial-but-never-completed sequence must not fire.
529        assert_eq!(scan_chunks(&[b"\x1b[6", b"x"]), 0);
530    }
531}