Skip to main content

aft/bash_background/
pty_process.rs

1use std::collections::HashMap;
2use std::fs::{self, OpenOptions};
3use std::io::{self, Read, Write};
4#[cfg(unix)]
5use std::os::unix::fs::PermissionsExt;
6use std::path::{Path, PathBuf};
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{Arc, Mutex};
9use std::thread;
10
11use portable_pty::{CommandBuilder, PtySize};
12
13use super::persistence::{atomic_write, ExitMarker, TaskPaths};
14use super::pty_runtime::{CompletionCoordinator, PtyRuntime};
15
16#[allow(clippy::too_many_arguments)]
17pub(crate) fn spawn_pty_for_command(
18    task_id: &str,
19    session_id: &str,
20    user_command: &str,
21    paths: &TaskPaths,
22    workdir: &Path,
23    env: &HashMap<String, String>,
24    rows: u16,
25    cols: u16,
26    wake_tx: crossbeam_channel::Sender<()>,
27) -> Result<PtyRuntime, String> {
28    #[cfg(unix)]
29    {
30        let shell = resolve_posix_shell();
31        let mut command = CommandBuilder::new(shell.as_os_str());
32        command.arg("-c");
33        command.arg(user_command);
34        command.cwd(workdir.as_os_str());
35        for (key, value) in env {
36            command.env(key, value);
37        }
38        try_spawn_pty(task_id, session_id, command, paths, rows, cols, wake_tx)
39    }
40    #[cfg(windows)]
41    {
42        use crate::windows_shell::shell_candidates;
43
44        let candidates = shell_candidates();
45        let mut last_err = String::from("no Windows shell candidates available");
46
47        for shell in candidates {
48            let wrapper_body = shell.wrapper_script(user_command, &paths.exit);
49            let wrapper_path = windows_wrapper_path(paths, &shell);
50            if let Err(error) = fs::write(&wrapper_path, wrapper_body) {
51                last_err = format!("write wrapper {wrapper_path:?}: {error}");
52                continue;
53            }
54
55            let mut command = CommandBuilder::new(shell.binary().as_ref());
56            for arg in shell.pty_wrapper_args(&wrapper_path) {
57                command.arg(arg);
58            }
59            command.cwd(workdir.as_os_str());
60            for (key, value) in env {
61                command.env(key, value);
62            }
63
64            match try_spawn_pty(
65                task_id,
66                session_id,
67                command,
68                paths,
69                rows,
70                cols,
71                wake_tx.clone(),
72            ) {
73                Ok(runtime) => return Ok(runtime),
74                Err(error) => {
75                    let msg = format!("{shell:?}: {error}");
76                    if msg.contains("NotFound") || msg.contains("not recognized") {
77                        last_err = msg;
78                        continue;
79                    }
80                    return Err(msg);
81                }
82            }
83        }
84
85        Err(last_err)
86    }
87}
88
89#[cfg(unix)]
90fn resolve_posix_shell() -> PathBuf {
91    resolve_posix_shell_with(
92        || std::env::var_os("SHELL").map(PathBuf::from),
93        is_executable_file,
94    )
95}
96
97#[cfg(unix)]
98fn resolve_posix_shell_with<S, X>(shell_env: S, is_executable: X) -> PathBuf
99where
100    S: FnOnce() -> Option<PathBuf>,
101    X: Fn(&Path) -> bool,
102{
103    if let Some(shell) =
104        shell_env().filter(|path| !path.as_os_str().is_empty() && is_executable(path.as_path()))
105    {
106        return shell;
107    }
108
109    for fallback in ["/bin/bash", "/bin/sh", "/bin/zsh"] {
110        let path = PathBuf::from(fallback);
111        if is_executable(&path) {
112            return path;
113        }
114    }
115
116    PathBuf::from("/bin/sh")
117}
118
119#[cfg(unix)]
120fn is_executable_file(path: &Path) -> bool {
121    fs::metadata(path)
122        .map(|metadata| metadata.is_file() && metadata.permissions().mode() & 0o111 != 0)
123        .unwrap_or(false)
124}
125
126#[cfg(windows)]
127fn windows_wrapper_path(
128    paths: &TaskPaths,
129    shell: &crate::windows_shell::WindowsShell,
130) -> std::path::PathBuf {
131    let extension = match shell {
132        crate::windows_shell::WindowsShell::Pwsh
133        | crate::windows_shell::WindowsShell::Powershell => "ps1",
134        crate::windows_shell::WindowsShell::Cmd => "bat",
135        crate::windows_shell::WindowsShell::Posix(_) => "sh",
136    };
137    let stem = paths
138        .json
139        .file_stem()
140        .and_then(|stem| stem.to_str())
141        .unwrap_or("wrapper");
142    paths.dir.join(format!("{stem}.{extension}"))
143}
144
145#[allow(clippy::too_many_arguments)]
146fn try_spawn_pty(
147    task_id: &str,
148    session_id: &str,
149    command: CommandBuilder,
150    paths: &TaskPaths,
151    rows: u16,
152    cols: u16,
153    wake_tx: crossbeam_channel::Sender<()>,
154) -> Result<PtyRuntime, String> {
155    let pty_system = portable_pty::native_pty_system();
156    let pair = pty_system
157        .openpty(PtySize {
158            rows,
159            cols,
160            pixel_width: 0,
161            pixel_height: 0,
162        })
163        .map_err(|error| format!("open PTY failed: {error}"))?;
164    let child = pair
165        .slave
166        .spawn_command(command)
167        .map_err(|error| format!("spawn PTY command failed: {error}"))?;
168    let child_pid = child.process_id();
169    let killer = child.clone_killer();
170    let reader = pair
171        .master
172        .try_clone_reader()
173        .map_err(|error| format!("clone PTY reader failed: {error}"))?;
174    let writer = pair
175        .master
176        .take_writer()
177        .map_err(|error| format!("take PTY writer failed: {error}"))?;
178
179    let reader_done = Arc::new(AtomicBool::new(false));
180    let exit_observed = Arc::new(AtomicBool::new(false));
181    let was_killed = Arc::new(AtomicBool::new(false));
182    let coordinator = Arc::new(CompletionCoordinator::new(
183        task_id.to_string(),
184        session_id.to_string(),
185        wake_tx,
186    ));
187
188    spawn_reader(
189        reader,
190        paths.pty.clone(),
191        Arc::clone(&reader_done),
192        Arc::clone(&coordinator),
193    );
194    spawn_waiter(
195        child,
196        paths.exit.clone(),
197        Arc::clone(&was_killed),
198        Arc::clone(&exit_observed),
199        Arc::clone(&coordinator),
200    );
201
202    Ok(PtyRuntime {
203        master: Some(pair.master),
204        writer: Arc::new(Mutex::new(writer)),
205        killer,
206        child_pid,
207        reader_done,
208        exit_observed,
209        was_killed,
210        coordinator,
211    })
212}
213
214pub(crate) fn spawn_reader(
215    mut reader: Box<dyn Read + Send>,
216    spill_path: std::path::PathBuf,
217    reader_done: Arc<AtomicBool>,
218    coordinator: Arc<CompletionCoordinator>,
219) {
220    thread::spawn(move || {
221        let result = (|| -> io::Result<()> {
222            if let Some(parent) = spill_path.parent() {
223                fs::create_dir_all(parent)?;
224            }
225            let mut file = OpenOptions::new()
226                .create(true)
227                .append(true)
228                .open(&spill_path)?;
229            let mut buf = [0_u8; 8192];
230            loop {
231                match reader.read(&mut buf) {
232                    Ok(0) => break,
233                    Ok(n) => {
234                        file.write_all(&buf[..n])?;
235                        file.flush()?;
236                    }
237                    Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
238                    Err(error) => return Err(error),
239                }
240            }
241            Ok(())
242        })();
243        if let Err(error) = result {
244            crate::slog_warn!(
245                "PTY reader for {}:{} stopped with error: {error}",
246                coordinator.session_id,
247                coordinator.task_id
248            );
249        }
250        reader_done.store(true, Ordering::SeqCst);
251        coordinator.signal_one_done();
252    });
253}
254
255pub(crate) fn spawn_waiter(
256    mut child: Box<dyn portable_pty::Child + Send + Sync>,
257    exit_path: std::path::PathBuf,
258    was_killed: Arc<AtomicBool>,
259    exit_observed: Arc<AtomicBool>,
260    coordinator: Arc<CompletionCoordinator>,
261) {
262    thread::spawn(move || {
263        let marker = loop {
264            match child.wait() {
265                Ok(status) => {
266                    if was_killed.load(Ordering::SeqCst) {
267                        break ExitMarker::Killed;
268                    }
269                    let code = i32::try_from(status.exit_code()).unwrap_or(i32::MAX);
270                    break ExitMarker::Code(code);
271                }
272                Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
273                Err(error) => {
274                    crate::slog_warn!(
275                        "PTY waiter for {}:{} failed: {error}",
276                        coordinator.session_id,
277                        coordinator.task_id
278                    );
279                    break ExitMarker::Killed;
280                }
281            }
282        };
283
284        if let Err(error) = write_exit_marker(&exit_path, &marker, &coordinator.task_id) {
285            crate::slog_warn!(
286                "PTY waiter for {}:{} failed to write exit marker: {error}",
287                coordinator.session_id,
288                coordinator.task_id
289            );
290        }
291        exit_observed.store(true, Ordering::SeqCst);
292        coordinator.signal_one_done();
293    });
294}
295
296fn write_exit_marker(path: &Path, marker: &ExitMarker, task_id: &str) -> io::Result<()> {
297    let content = match marker {
298        ExitMarker::Code(code) => code.to_string(),
299        ExitMarker::Killed => "killed".to_string(),
300    };
301    atomic_write(path, content.as_bytes(), task_id)
302}
303
304#[cfg(test)]
305mod tests {
306    use std::io;
307    use std::sync::atomic::{AtomicBool, Ordering};
308    use std::sync::Arc;
309    use std::time::{Duration, Instant};
310
311    use portable_pty::{Child, ChildKiller, ExitStatus};
312
313    use super::*;
314
315    #[derive(Debug)]
316    struct FakeKiller;
317
318    impl ChildKiller for FakeKiller {
319        fn kill(&mut self) -> io::Result<()> {
320            Ok(())
321        }
322
323        fn clone_killer(&self) -> Box<dyn ChildKiller + Send + Sync> {
324            Box::new(FakeKiller)
325        }
326    }
327
328    #[derive(Debug)]
329    struct InterruptedOnceChild {
330        waits: usize,
331    }
332
333    impl ChildKiller for InterruptedOnceChild {
334        fn kill(&mut self) -> io::Result<()> {
335            Ok(())
336        }
337
338        fn clone_killer(&self) -> Box<dyn ChildKiller + Send + Sync> {
339            Box::new(FakeKiller)
340        }
341    }
342
343    impl Child for InterruptedOnceChild {
344        fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> {
345            Ok(None)
346        }
347
348        fn wait(&mut self) -> io::Result<ExitStatus> {
349            self.waits += 1;
350            if self.waits == 1 {
351                Err(io::Error::from(io::ErrorKind::Interrupted))
352            } else {
353                Ok(ExitStatus::with_exit_code(0))
354            }
355        }
356
357        fn process_id(&self) -> Option<u32> {
358            None
359        }
360
361        #[cfg(windows)]
362        fn as_raw_handle(&self) -> Option<std::os::windows::io::RawHandle> {
363            None
364        }
365    }
366
367    #[cfg(unix)]
368    #[test]
369    fn pty_shell_prefers_executable_shell_env() {
370        let shell = PathBuf::from("/custom/zsh");
371        let resolved =
372            resolve_posix_shell_with(|| Some(shell.clone()), |path| path == shell.as_path());
373
374        assert_eq!(resolved, shell);
375    }
376
377    #[cfg(unix)]
378    #[test]
379    fn pty_shell_ignores_unusable_shell_env_and_uses_fallback_order() {
380        let resolved = resolve_posix_shell_with(
381            || Some(PathBuf::from("/missing/fish")),
382            |path| path == Path::new("/bin/sh") || path == Path::new("/bin/zsh"),
383        );
384
385        assert_eq!(resolved, PathBuf::from("/bin/sh"));
386    }
387
388    #[cfg(unix)]
389    #[test]
390    fn pty_shell_uses_bin_bash_before_later_fallbacks() {
391        let resolved = resolve_posix_shell_with(
392            || None,
393            |path| path == Path::new("/bin/bash") || path == Path::new("/bin/sh"),
394        );
395
396        assert_eq!(resolved, PathBuf::from("/bin/bash"));
397    }
398
399    #[cfg(unix)]
400    #[test]
401    fn pty_waiter_retries_wait_on_interrupted() {
402        let temp = tempfile::tempdir().unwrap();
403        let exit_path = temp.path().join("task.exit");
404        let (wake_tx, wake_rx) = crossbeam_channel::bounded(1);
405        let coordinator = Arc::new(CompletionCoordinator::new(
406            "task".to_string(),
407            "session".to_string(),
408            wake_tx,
409        ));
410        let was_killed = Arc::new(AtomicBool::new(false));
411        let exit_observed = Arc::new(AtomicBool::new(false));
412
413        spawn_waiter(
414            Box::new(InterruptedOnceChild { waits: 0 }),
415            exit_path.clone(),
416            was_killed,
417            Arc::clone(&exit_observed),
418            Arc::clone(&coordinator),
419        );
420        coordinator.signal_one_done();
421
422        let started = Instant::now();
423        while !exit_observed.load(Ordering::SeqCst) {
424            assert!(started.elapsed() < Duration::from_secs(2));
425            std::thread::sleep(Duration::from_millis(10));
426        }
427        wake_rx.recv_timeout(Duration::from_secs(1)).unwrap();
428        assert_eq!(fs::read_to_string(exit_path).unwrap(), "0");
429    }
430}