Skip to main content

aft/bash_background/
pty_process.rs

1use std::collections::HashMap;
2use std::fs::{self, OpenOptions};
3use std::io::{self, Read, Write};
4use std::path::Path;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::{Arc, Mutex};
7use std::thread;
8
9use portable_pty::{CommandBuilder, PtySize};
10
11use super::persistence::{atomic_write, ExitMarker, TaskPaths};
12use super::pty_runtime::{CompletionCoordinator, PtyRuntime};
13
14#[allow(clippy::too_many_arguments)]
15pub(crate) fn spawn_pty_for_command(
16    task_id: &str,
17    session_id: &str,
18    user_command: &str,
19    paths: &TaskPaths,
20    workdir: &Path,
21    env: &HashMap<String, String>,
22    rows: u16,
23    cols: u16,
24    wake_tx: crossbeam_channel::Sender<()>,
25) -> Result<PtyRuntime, String> {
26    #[cfg(unix)]
27    {
28        let mut command = CommandBuilder::new("/bin/sh");
29        command.arg("-c");
30        command.arg(user_command);
31        command.cwd(workdir.as_os_str());
32        for (key, value) in env {
33            command.env(key, value);
34        }
35        try_spawn_pty(task_id, session_id, command, paths, rows, cols, wake_tx)
36    }
37    #[cfg(windows)]
38    {
39        use crate::windows_shell::shell_candidates;
40
41        let candidates = shell_candidates();
42        let mut last_err = String::from("no Windows shell candidates available");
43
44        for shell in candidates {
45            let wrapper_body = shell.wrapper_script(user_command, &paths.exit);
46            let wrapper_path = windows_wrapper_path(paths, &shell);
47            if let Err(error) = fs::write(&wrapper_path, wrapper_body) {
48                last_err = format!("write wrapper {wrapper_path:?}: {error}");
49                continue;
50            }
51
52            let mut command = CommandBuilder::new(shell.binary().as_ref());
53            for arg in shell.pty_wrapper_args(&wrapper_path) {
54                command.arg(arg);
55            }
56            command.cwd(workdir.as_os_str());
57            for (key, value) in env {
58                command.env(key, value);
59            }
60
61            match try_spawn_pty(
62                task_id,
63                session_id,
64                command,
65                paths,
66                rows,
67                cols,
68                wake_tx.clone(),
69            ) {
70                Ok(runtime) => return Ok(runtime),
71                Err(error) => {
72                    let msg = format!("{shell:?}: {error}");
73                    if msg.contains("NotFound") || msg.contains("not recognized") {
74                        last_err = msg;
75                        continue;
76                    }
77                    return Err(msg);
78                }
79            }
80        }
81
82        Err(last_err)
83    }
84}
85
86#[cfg(windows)]
87fn windows_wrapper_path(
88    paths: &TaskPaths,
89    shell: &crate::windows_shell::WindowsShell,
90) -> std::path::PathBuf {
91    let extension = match shell {
92        crate::windows_shell::WindowsShell::Pwsh
93        | crate::windows_shell::WindowsShell::Powershell => "ps1",
94        crate::windows_shell::WindowsShell::Cmd => "bat",
95        crate::windows_shell::WindowsShell::Posix(_) => "sh",
96    };
97    let stem = paths
98        .json
99        .file_stem()
100        .and_then(|stem| stem.to_str())
101        .unwrap_or("wrapper");
102    paths.dir.join(format!("{stem}.{extension}"))
103}
104
105#[allow(clippy::too_many_arguments)]
106fn try_spawn_pty(
107    task_id: &str,
108    session_id: &str,
109    command: CommandBuilder,
110    paths: &TaskPaths,
111    rows: u16,
112    cols: u16,
113    wake_tx: crossbeam_channel::Sender<()>,
114) -> Result<PtyRuntime, String> {
115    let pty_system = portable_pty::native_pty_system();
116    let pair = pty_system
117        .openpty(PtySize {
118            rows,
119            cols,
120            pixel_width: 0,
121            pixel_height: 0,
122        })
123        .map_err(|error| format!("open PTY failed: {error}"))?;
124    let child = pair
125        .slave
126        .spawn_command(command)
127        .map_err(|error| format!("spawn PTY command failed: {error}"))?;
128    let child_pid = child.process_id();
129    let killer = child.clone_killer();
130    let reader = pair
131        .master
132        .try_clone_reader()
133        .map_err(|error| format!("clone PTY reader failed: {error}"))?;
134    let writer = pair
135        .master
136        .take_writer()
137        .map_err(|error| format!("take PTY writer failed: {error}"))?;
138
139    let reader_done = Arc::new(AtomicBool::new(false));
140    let exit_observed = Arc::new(AtomicBool::new(false));
141    let was_killed = Arc::new(AtomicBool::new(false));
142    let coordinator = Arc::new(CompletionCoordinator::new(
143        task_id.to_string(),
144        session_id.to_string(),
145        wake_tx,
146    ));
147
148    spawn_reader(
149        reader,
150        paths.pty.clone(),
151        Arc::clone(&reader_done),
152        Arc::clone(&coordinator),
153    );
154    spawn_waiter(
155        child,
156        paths.exit.clone(),
157        Arc::clone(&was_killed),
158        Arc::clone(&exit_observed),
159        Arc::clone(&coordinator),
160    );
161
162    Ok(PtyRuntime {
163        master: Some(pair.master),
164        writer: Arc::new(Mutex::new(writer)),
165        killer,
166        child_pid,
167        reader_done,
168        exit_observed,
169        was_killed,
170        coordinator,
171    })
172}
173
174pub(crate) fn spawn_reader(
175    mut reader: Box<dyn Read + Send>,
176    spill_path: std::path::PathBuf,
177    reader_done: Arc<AtomicBool>,
178    coordinator: Arc<CompletionCoordinator>,
179) {
180    thread::spawn(move || {
181        let result = (|| -> io::Result<()> {
182            if let Some(parent) = spill_path.parent() {
183                fs::create_dir_all(parent)?;
184            }
185            let mut file = OpenOptions::new()
186                .create(true)
187                .append(true)
188                .open(&spill_path)?;
189            let mut buf = [0_u8; 8192];
190            loop {
191                match reader.read(&mut buf) {
192                    Ok(0) => break,
193                    Ok(n) => {
194                        file.write_all(&buf[..n])?;
195                        file.flush()?;
196                    }
197                    Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
198                    Err(error) => return Err(error),
199                }
200            }
201            Ok(())
202        })();
203        if let Err(error) = result {
204            crate::slog_warn!(
205                "PTY reader for {}:{} stopped with error: {error}",
206                coordinator.session_id,
207                coordinator.task_id
208            );
209        }
210        reader_done.store(true, Ordering::SeqCst);
211        coordinator.signal_one_done();
212    });
213}
214
215pub(crate) fn spawn_waiter(
216    mut child: Box<dyn portable_pty::Child + Send + Sync>,
217    exit_path: std::path::PathBuf,
218    was_killed: Arc<AtomicBool>,
219    exit_observed: Arc<AtomicBool>,
220    coordinator: Arc<CompletionCoordinator>,
221) {
222    thread::spawn(move || {
223        let marker = loop {
224            match child.wait() {
225                Ok(status) => {
226                    if was_killed.load(Ordering::SeqCst) {
227                        break ExitMarker::Killed;
228                    }
229                    let code = i32::try_from(status.exit_code()).unwrap_or(i32::MAX);
230                    break ExitMarker::Code(code);
231                }
232                Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
233                Err(error) => {
234                    crate::slog_warn!(
235                        "PTY waiter for {}:{} failed: {error}",
236                        coordinator.session_id,
237                        coordinator.task_id
238                    );
239                    break ExitMarker::Killed;
240                }
241            }
242        };
243
244        if let Err(error) = write_exit_marker(&exit_path, &marker, &coordinator.task_id) {
245            crate::slog_warn!(
246                "PTY waiter for {}:{} failed to write exit marker: {error}",
247                coordinator.session_id,
248                coordinator.task_id
249            );
250        }
251        exit_observed.store(true, Ordering::SeqCst);
252        coordinator.signal_one_done();
253    });
254}
255
256fn write_exit_marker(path: &Path, marker: &ExitMarker, task_id: &str) -> io::Result<()> {
257    let content = match marker {
258        ExitMarker::Code(code) => code.to_string(),
259        ExitMarker::Killed => "killed".to_string(),
260    };
261    atomic_write(path, content.as_bytes(), task_id)
262}
263
264#[cfg(test)]
265mod tests {
266    use std::io;
267    use std::sync::atomic::{AtomicBool, Ordering};
268    use std::sync::Arc;
269    use std::time::{Duration, Instant};
270
271    use portable_pty::{Child, ChildKiller, ExitStatus};
272
273    use super::*;
274
275    #[derive(Debug)]
276    struct FakeKiller;
277
278    impl ChildKiller for FakeKiller {
279        fn kill(&mut self) -> io::Result<()> {
280            Ok(())
281        }
282
283        fn clone_killer(&self) -> Box<dyn ChildKiller + Send + Sync> {
284            Box::new(FakeKiller)
285        }
286    }
287
288    #[derive(Debug)]
289    struct InterruptedOnceChild {
290        waits: usize,
291    }
292
293    impl ChildKiller for InterruptedOnceChild {
294        fn kill(&mut self) -> io::Result<()> {
295            Ok(())
296        }
297
298        fn clone_killer(&self) -> Box<dyn ChildKiller + Send + Sync> {
299            Box::new(FakeKiller)
300        }
301    }
302
303    impl Child for InterruptedOnceChild {
304        fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> {
305            Ok(None)
306        }
307
308        fn wait(&mut self) -> io::Result<ExitStatus> {
309            self.waits += 1;
310            if self.waits == 1 {
311                Err(io::Error::from(io::ErrorKind::Interrupted))
312            } else {
313                Ok(ExitStatus::with_exit_code(0))
314            }
315        }
316
317        fn process_id(&self) -> Option<u32> {
318            None
319        }
320
321        #[cfg(windows)]
322        fn as_raw_handle(&self) -> Option<std::os::windows::io::RawHandle> {
323            None
324        }
325    }
326
327    #[cfg(unix)]
328    #[test]
329    fn pty_waiter_retries_wait_on_interrupted() {
330        let temp = tempfile::tempdir().unwrap();
331        let exit_path = temp.path().join("task.exit");
332        let (wake_tx, wake_rx) = crossbeam_channel::bounded(1);
333        let coordinator = Arc::new(CompletionCoordinator::new(
334            "task".to_string(),
335            "session".to_string(),
336            wake_tx,
337        ));
338        let was_killed = Arc::new(AtomicBool::new(false));
339        let exit_observed = Arc::new(AtomicBool::new(false));
340
341        spawn_waiter(
342            Box::new(InterruptedOnceChild { waits: 0 }),
343            exit_path.clone(),
344            was_killed,
345            Arc::clone(&exit_observed),
346            Arc::clone(&coordinator),
347        );
348        coordinator.signal_one_done();
349
350        let started = Instant::now();
351        while !exit_observed.load(Ordering::SeqCst) {
352            assert!(started.elapsed() < Duration::from_secs(2));
353            std::thread::sleep(Duration::from_millis(10));
354        }
355        wake_rx.recv_timeout(Duration::from_secs(1)).unwrap();
356        assert_eq!(fs::read_to_string(exit_path).unwrap(), "0");
357    }
358}