Skip to main content

aft/bash_background/
pty_process.rs

1use std::collections::HashMap;
2use std::fs::{self, OpenOptions};
3use std::io::{self, Read, Write};
4#[cfg(unix)]
5use std::os::unix::fs::PermissionsExt;
6use std::path::Path;
7#[cfg(unix)]
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Mutex};
11use std::thread;
12
13use portable_pty::{CommandBuilder, PtySize};
14
15use super::persistence::{atomic_write, ExitMarker, TaskPaths};
16use super::pty_runtime::{CompletionCoordinator, PtyRuntime};
17
18#[allow(clippy::too_many_arguments)]
19pub(crate) fn spawn_pty_for_command(
20    task_id: &str,
21    session_id: &str,
22    user_command: &str,
23    paths: &TaskPaths,
24    workdir: &Path,
25    env: &HashMap<String, String>,
26    rows: u16,
27    cols: u16,
28    wake_tx: crossbeam_channel::Sender<()>,
29) -> Result<PtyRuntime, String> {
30    #[cfg(unix)]
31    {
32        let shell = resolve_posix_shell();
33        let mut command = CommandBuilder::new(shell.as_os_str());
34        command.arg("-c");
35        command.arg(user_command);
36        command.cwd(workdir.as_os_str());
37        for (key, value) in env {
38            command.env(key, value);
39        }
40        try_spawn_pty(task_id, session_id, command, paths, rows, cols, wake_tx)
41    }
42    #[cfg(windows)]
43    {
44        use crate::windows_shell::shell_candidates;
45
46        let candidates = shell_candidates();
47        let mut last_err = String::from("no Windows shell candidates available");
48
49        for shell in candidates {
50            let wrapper_body = shell.wrapper_script_bytes(user_command, &paths.exit);
51            let wrapper_path = windows_wrapper_path(paths, &shell);
52            if let Err(error) = fs::write(&wrapper_path, wrapper_body) {
53                last_err = format!("write wrapper {wrapper_path:?}: {error}");
54                continue;
55            }
56
57            let mut command = CommandBuilder::new(shell.binary().as_ref());
58            for arg in shell.pty_wrapper_args(&wrapper_path) {
59                command.arg(arg);
60            }
61            command.cwd(workdir.as_os_str());
62            for (key, value) in env {
63                command.env(key, value);
64            }
65
66            match try_spawn_pty(
67                task_id,
68                session_id,
69                command,
70                paths,
71                rows,
72                cols,
73                wake_tx.clone(),
74            ) {
75                Ok(runtime) => return Ok(runtime),
76                Err(error) => {
77                    let msg = format!("{shell:?}: {error}");
78                    if msg.contains("NotFound") || msg.contains("not recognized") {
79                        last_err = msg;
80                        continue;
81                    }
82                    return Err(msg);
83                }
84            }
85        }
86
87        Err(last_err)
88    }
89}
90
91#[cfg(unix)]
92fn resolve_posix_shell() -> PathBuf {
93    resolve_posix_shell_with(
94        || std::env::var_os("SHELL").map(PathBuf::from),
95        is_executable_file,
96    )
97}
98
99#[cfg(unix)]
100fn resolve_posix_shell_with<S, X>(shell_env: S, is_executable: X) -> PathBuf
101where
102    S: FnOnce() -> Option<PathBuf>,
103    X: Fn(&Path) -> bool,
104{
105    if let Some(shell) =
106        shell_env().filter(|path| !path.as_os_str().is_empty() && is_executable(path.as_path()))
107    {
108        return shell;
109    }
110
111    for fallback in ["/bin/bash", "/bin/sh", "/bin/zsh"] {
112        let path = PathBuf::from(fallback);
113        if is_executable(&path) {
114            return path;
115        }
116    }
117
118    PathBuf::from("/bin/sh")
119}
120
121#[cfg(unix)]
122fn is_executable_file(path: &Path) -> bool {
123    fs::metadata(path)
124        .map(|metadata| metadata.is_file() && metadata.permissions().mode() & 0o111 != 0)
125        .unwrap_or(false)
126}
127
128#[cfg(windows)]
129fn windows_wrapper_path(
130    paths: &TaskPaths,
131    shell: &crate::windows_shell::WindowsShell,
132) -> std::path::PathBuf {
133    let extension = match shell {
134        crate::windows_shell::WindowsShell::Pwsh
135        | crate::windows_shell::WindowsShell::Powershell => "ps1",
136        crate::windows_shell::WindowsShell::Cmd => "bat",
137        crate::windows_shell::WindowsShell::Posix(_) => "sh",
138    };
139    let stem = paths
140        .json
141        .file_stem()
142        .and_then(|stem| stem.to_str())
143        .unwrap_or("wrapper");
144    paths.dir.join(format!("{stem}.{extension}"))
145}
146
147#[allow(clippy::too_many_arguments)]
148fn try_spawn_pty(
149    task_id: &str,
150    session_id: &str,
151    command: CommandBuilder,
152    paths: &TaskPaths,
153    rows: u16,
154    cols: u16,
155    wake_tx: crossbeam_channel::Sender<()>,
156) -> Result<PtyRuntime, String> {
157    let pty_system = portable_pty::native_pty_system();
158    let pair = pty_system
159        .openpty(PtySize {
160            rows,
161            cols,
162            pixel_width: 0,
163            pixel_height: 0,
164        })
165        .map_err(|error| format!("open PTY failed: {error}"))?;
166    let child = pair
167        .slave
168        .spawn_command(command)
169        .map_err(|error| format!("spawn PTY command failed: {error}"))?;
170    let child_pid = child.process_id();
171    let killer = child.clone_killer();
172    let reader = pair
173        .master
174        .try_clone_reader()
175        .map_err(|error| format!("clone PTY reader failed: {error}"))?;
176    let writer = pair
177        .master
178        .take_writer()
179        .map_err(|error| format!("take PTY writer failed: {error}"))?;
180
181    let reader_done = Arc::new(AtomicBool::new(false));
182    let exit_observed = Arc::new(AtomicBool::new(false));
183    let was_killed = Arc::new(AtomicBool::new(false));
184    let coordinator = Arc::new(CompletionCoordinator::new(
185        task_id.to_string(),
186        session_id.to_string(),
187        wake_tx,
188    ));
189
190    let writer = Arc::new(Mutex::new(writer));
191    spawn_reader(
192        reader,
193        paths.pty.clone(),
194        Arc::clone(&reader_done),
195        Arc::clone(&coordinator),
196        Some(Arc::clone(&writer)),
197    );
198    spawn_waiter(
199        child,
200        paths.exit.clone(),
201        Arc::clone(&was_killed),
202        Arc::clone(&exit_observed),
203        Arc::clone(&coordinator),
204    );
205
206    Ok(PtyRuntime {
207        master: Some(pair.master),
208        writer,
209        killer,
210        child_pid,
211        reader_done,
212        exit_observed,
213        was_killed,
214        coordinator,
215    })
216}
217
218pub(crate) fn spawn_reader(
219    mut reader: Box<dyn Read + Send>,
220    spill_path: std::path::PathBuf,
221    reader_done: Arc<AtomicBool>,
222    coordinator: Arc<CompletionCoordinator>,
223    writer: Option<Arc<Mutex<Box<dyn Write + Send>>>>,
224) {
225    thread::spawn(move || {
226        let result = (|| -> io::Result<()> {
227            if let Some(parent) = spill_path.parent() {
228                fs::create_dir_all(parent)?;
229            }
230            let mut file = OpenOptions::new()
231                .create(true)
232                .append(true)
233                .open(&spill_path)?;
234            let mut buf = [0_u8; 8192];
235            loop {
236                match reader.read(&mut buf) {
237                    Ok(0) => break,
238                    Ok(n) => {
239                        file.write_all(&buf[..n])?;
240                        file.flush()?;
241                        if buf[..n].windows(4).any(|window| window == b"\x1b[6n") {
242                            // Some Windows console hosts/apps query the
243                            // terminal cursor position with DSR (ESC[6n)
244                            // before accepting input. A real terminal answers
245                            // with ESC[row;colR; without that response the
246                            // process can sit forever after emitting only the
247                            // query. We own both ends of the PTY, so provide a
248                            // conservative 1;1 response.
249                            if let Some(writer) = writer.as_ref() {
250                                if let Ok(mut writer) = writer.lock() {
251                                    let _ = writer.write_all(b"\x1b[1;1R");
252                                    let _ = writer.flush();
253                                }
254                            }
255                        }
256                    }
257                    Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
258                    Err(error) => return Err(error),
259                }
260            }
261            Ok(())
262        })();
263        if let Err(error) = result {
264            crate::slog_warn!(
265                "PTY reader for {}:{} stopped with error: {error}",
266                coordinator.session_id,
267                coordinator.task_id
268            );
269        }
270        reader_done.store(true, Ordering::SeqCst);
271        coordinator.signal_one_done();
272    });
273}
274
275pub(crate) fn spawn_waiter(
276    mut child: Box<dyn portable_pty::Child + Send + Sync>,
277    exit_path: std::path::PathBuf,
278    was_killed: Arc<AtomicBool>,
279    exit_observed: Arc<AtomicBool>,
280    coordinator: Arc<CompletionCoordinator>,
281) {
282    thread::spawn(move || {
283        let marker = loop {
284            match child.wait() {
285                Ok(status) => {
286                    if was_killed.load(Ordering::SeqCst) {
287                        break ExitMarker::Killed;
288                    }
289                    let code = i32::try_from(status.exit_code()).unwrap_or(i32::MAX);
290                    break ExitMarker::Code(code);
291                }
292                Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
293                Err(error) => {
294                    crate::slog_warn!(
295                        "PTY waiter for {}:{} failed: {error}",
296                        coordinator.session_id,
297                        coordinator.task_id
298                    );
299                    break ExitMarker::Killed;
300                }
301            }
302        };
303
304        if let Err(error) = write_exit_marker(&exit_path, &marker, &coordinator.task_id) {
305            crate::slog_warn!(
306                "PTY waiter for {}:{} failed to write exit marker: {error}",
307                coordinator.session_id,
308                coordinator.task_id
309            );
310        }
311        exit_observed.store(true, Ordering::SeqCst);
312        coordinator.signal_one_done();
313    });
314}
315
316fn write_exit_marker(path: &Path, marker: &ExitMarker, task_id: &str) -> io::Result<()> {
317    let content = match marker {
318        ExitMarker::Code(code) => code.to_string(),
319        ExitMarker::Killed => "killed".to_string(),
320    };
321    atomic_write(path, content.as_bytes(), task_id)
322}
323
324// Every test in this module exercises Unix-only PTY paths (`#[cfg(unix)]`
325// shell resolution + the spawn_waiter), so gate the whole module on `unix` to
326// avoid unused-import / dead-code warnings when cross-compiling for Windows.
327#[cfg(all(test, unix))]
328mod tests {
329    use std::io;
330    use std::sync::atomic::{AtomicBool, Ordering};
331    use std::sync::Arc;
332    use std::time::{Duration, Instant};
333
334    use portable_pty::{Child, ChildKiller, ExitStatus};
335
336    use super::*;
337
338    #[derive(Debug)]
339    struct FakeKiller;
340
341    impl ChildKiller for FakeKiller {
342        fn kill(&mut self) -> io::Result<()> {
343            Ok(())
344        }
345
346        fn clone_killer(&self) -> Box<dyn ChildKiller + Send + Sync> {
347            Box::new(FakeKiller)
348        }
349    }
350
351    #[derive(Debug)]
352    struct InterruptedOnceChild {
353        waits: usize,
354    }
355
356    impl ChildKiller for InterruptedOnceChild {
357        fn kill(&mut self) -> io::Result<()> {
358            Ok(())
359        }
360
361        fn clone_killer(&self) -> Box<dyn ChildKiller + Send + Sync> {
362            Box::new(FakeKiller)
363        }
364    }
365
366    impl Child for InterruptedOnceChild {
367        fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> {
368            Ok(None)
369        }
370
371        fn wait(&mut self) -> io::Result<ExitStatus> {
372            self.waits += 1;
373            if self.waits == 1 {
374                Err(io::Error::from(io::ErrorKind::Interrupted))
375            } else {
376                Ok(ExitStatus::with_exit_code(0))
377            }
378        }
379
380        fn process_id(&self) -> Option<u32> {
381            None
382        }
383
384        #[cfg(windows)]
385        fn as_raw_handle(&self) -> Option<std::os::windows::io::RawHandle> {
386            None
387        }
388    }
389
390    #[cfg(unix)]
391    #[test]
392    fn pty_shell_prefers_executable_shell_env() {
393        let shell = PathBuf::from("/custom/zsh");
394        let resolved =
395            resolve_posix_shell_with(|| Some(shell.clone()), |path| path == shell.as_path());
396
397        assert_eq!(resolved, shell);
398    }
399
400    #[cfg(unix)]
401    #[test]
402    fn pty_shell_ignores_unusable_shell_env_and_uses_fallback_order() {
403        let resolved = resolve_posix_shell_with(
404            || Some(PathBuf::from("/missing/fish")),
405            |path| path == Path::new("/bin/sh") || path == Path::new("/bin/zsh"),
406        );
407
408        assert_eq!(resolved, PathBuf::from("/bin/sh"));
409    }
410
411    #[cfg(unix)]
412    #[test]
413    fn pty_shell_uses_bin_bash_before_later_fallbacks() {
414        let resolved = resolve_posix_shell_with(
415            || None,
416            |path| path == Path::new("/bin/bash") || path == Path::new("/bin/sh"),
417        );
418
419        assert_eq!(resolved, PathBuf::from("/bin/bash"));
420    }
421
422    #[cfg(unix)]
423    #[test]
424    fn pty_waiter_retries_wait_on_interrupted() {
425        let temp = tempfile::tempdir().unwrap();
426        let exit_path = temp.path().join("task.exit");
427        let (wake_tx, wake_rx) = crossbeam_channel::bounded(1);
428        let coordinator = Arc::new(CompletionCoordinator::new(
429            "task".to_string(),
430            "session".to_string(),
431            wake_tx,
432        ));
433        let was_killed = Arc::new(AtomicBool::new(false));
434        let exit_observed = Arc::new(AtomicBool::new(false));
435
436        spawn_waiter(
437            Box::new(InterruptedOnceChild { waits: 0 }),
438            exit_path.clone(),
439            was_killed,
440            Arc::clone(&exit_observed),
441            Arc::clone(&coordinator),
442        );
443        coordinator.signal_one_done();
444
445        let started = Instant::now();
446        while !exit_observed.load(Ordering::SeqCst) {
447            assert!(started.elapsed() < Duration::from_secs(2));
448            std::thread::sleep(Duration::from_millis(10));
449        }
450        wake_rx.recv_timeout(Duration::from_secs(1)).unwrap();
451        assert_eq!(fs::read_to_string(exit_path).unwrap(), "0");
452    }
453}