1use std::collections::HashMap;
2use std::fs::{self, OpenOptions};
3use std::io::{self, Read, Write};
4use std::path::Path;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::{Arc, Mutex};
7use std::thread;
8
9use portable_pty::{CommandBuilder, PtySize};
10
11use super::persistence::{atomic_write, ExitMarker, TaskPaths};
12use super::pty_runtime::{CompletionCoordinator, PtyRuntime};
13
14#[allow(clippy::too_many_arguments)]
15pub(crate) fn spawn_pty_for_command(
16 task_id: &str,
17 session_id: &str,
18 user_command: &str,
19 paths: &TaskPaths,
20 workdir: &Path,
21 env: &HashMap<String, String>,
22 rows: u16,
23 cols: u16,
24 wake_tx: crossbeam_channel::Sender<()>,
25) -> Result<PtyRuntime, String> {
26 #[cfg(unix)]
27 {
28 let mut command = CommandBuilder::new("/bin/sh");
29 command.arg("-c");
30 command.arg(user_command);
31 command.cwd(workdir.as_os_str());
32 for (key, value) in env {
33 command.env(key, value);
34 }
35 try_spawn_pty(task_id, session_id, command, paths, rows, cols, wake_tx)
36 }
37 #[cfg(windows)]
38 {
39 use crate::windows_shell::shell_candidates;
40
41 let candidates = shell_candidates();
42 let mut last_err = String::from("no Windows shell candidates available");
43
44 for shell in candidates {
45 let wrapper_body = shell.wrapper_script(user_command, &paths.exit);
46 let wrapper_path = windows_wrapper_path(paths, &shell);
47 if let Err(error) = fs::write(&wrapper_path, wrapper_body) {
48 last_err = format!("write wrapper {wrapper_path:?}: {error}");
49 continue;
50 }
51
52 let mut command = CommandBuilder::new(shell.binary().as_ref());
53 for arg in shell.pty_wrapper_args(&wrapper_path) {
54 command.arg(arg);
55 }
56 command.cwd(workdir.as_os_str());
57 for (key, value) in env {
58 command.env(key, value);
59 }
60
61 match try_spawn_pty(
62 task_id,
63 session_id,
64 command,
65 paths,
66 rows,
67 cols,
68 wake_tx.clone(),
69 ) {
70 Ok(runtime) => return Ok(runtime),
71 Err(error) => {
72 let msg = format!("{shell:?}: {error}");
73 if msg.contains("NotFound") || msg.contains("not recognized") {
74 last_err = msg;
75 continue;
76 }
77 return Err(msg);
78 }
79 }
80 }
81
82 Err(last_err)
83 }
84}
85
86#[cfg(windows)]
87fn windows_wrapper_path(
88 paths: &TaskPaths,
89 shell: &crate::windows_shell::WindowsShell,
90) -> std::path::PathBuf {
91 let extension = match shell {
92 crate::windows_shell::WindowsShell::Pwsh
93 | crate::windows_shell::WindowsShell::Powershell => "ps1",
94 crate::windows_shell::WindowsShell::Cmd => "bat",
95 crate::windows_shell::WindowsShell::Posix(_) => "sh",
96 };
97 let stem = paths
98 .json
99 .file_stem()
100 .and_then(|stem| stem.to_str())
101 .unwrap_or("wrapper");
102 paths.dir.join(format!("{stem}.{extension}"))
103}
104
105#[allow(clippy::too_many_arguments)]
106fn try_spawn_pty(
107 task_id: &str,
108 session_id: &str,
109 command: CommandBuilder,
110 paths: &TaskPaths,
111 rows: u16,
112 cols: u16,
113 wake_tx: crossbeam_channel::Sender<()>,
114) -> Result<PtyRuntime, String> {
115 let pty_system = portable_pty::native_pty_system();
116 let pair = pty_system
117 .openpty(PtySize {
118 rows,
119 cols,
120 pixel_width: 0,
121 pixel_height: 0,
122 })
123 .map_err(|error| format!("open PTY failed: {error}"))?;
124 let child = pair
125 .slave
126 .spawn_command(command)
127 .map_err(|error| format!("spawn PTY command failed: {error}"))?;
128 let child_pid = child.process_id();
129 let killer = child.clone_killer();
130 let reader = pair
131 .master
132 .try_clone_reader()
133 .map_err(|error| format!("clone PTY reader failed: {error}"))?;
134 let writer = pair
135 .master
136 .take_writer()
137 .map_err(|error| format!("take PTY writer failed: {error}"))?;
138
139 let reader_done = Arc::new(AtomicBool::new(false));
140 let exit_observed = Arc::new(AtomicBool::new(false));
141 let was_killed = Arc::new(AtomicBool::new(false));
142 let coordinator = Arc::new(CompletionCoordinator::new(
143 task_id.to_string(),
144 session_id.to_string(),
145 wake_tx,
146 ));
147
148 spawn_reader(
149 reader,
150 paths.pty.clone(),
151 Arc::clone(&reader_done),
152 Arc::clone(&coordinator),
153 );
154 spawn_waiter(
155 child,
156 paths.exit.clone(),
157 Arc::clone(&was_killed),
158 Arc::clone(&exit_observed),
159 Arc::clone(&coordinator),
160 );
161
162 Ok(PtyRuntime {
163 master: Some(pair.master),
164 writer: Arc::new(Mutex::new(writer)),
165 killer,
166 child_pid,
167 reader_done,
168 exit_observed,
169 was_killed,
170 coordinator,
171 })
172}
173
174pub(crate) fn spawn_reader(
175 mut reader: Box<dyn Read + Send>,
176 spill_path: std::path::PathBuf,
177 reader_done: Arc<AtomicBool>,
178 coordinator: Arc<CompletionCoordinator>,
179) {
180 thread::spawn(move || {
181 let result = (|| -> io::Result<()> {
182 if let Some(parent) = spill_path.parent() {
183 fs::create_dir_all(parent)?;
184 }
185 let mut file = OpenOptions::new()
186 .create(true)
187 .append(true)
188 .open(&spill_path)?;
189 let mut buf = [0_u8; 8192];
190 loop {
191 match reader.read(&mut buf) {
192 Ok(0) => break,
193 Ok(n) => {
194 file.write_all(&buf[..n])?;
195 file.flush()?;
196 }
197 Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
198 Err(error) => return Err(error),
199 }
200 }
201 Ok(())
202 })();
203 if let Err(error) = result {
204 crate::slog_warn!(
205 "PTY reader for {}:{} stopped with error: {error}",
206 coordinator.session_id,
207 coordinator.task_id
208 );
209 }
210 reader_done.store(true, Ordering::SeqCst);
211 coordinator.signal_one_done();
212 });
213}
214
215pub(crate) fn spawn_waiter(
216 mut child: Box<dyn portable_pty::Child + Send + Sync>,
217 exit_path: std::path::PathBuf,
218 was_killed: Arc<AtomicBool>,
219 exit_observed: Arc<AtomicBool>,
220 coordinator: Arc<CompletionCoordinator>,
221) {
222 thread::spawn(move || {
223 let marker = loop {
224 match child.wait() {
225 Ok(status) => {
226 if was_killed.load(Ordering::SeqCst) {
227 break ExitMarker::Killed;
228 }
229 let code = i32::try_from(status.exit_code()).unwrap_or(i32::MAX);
230 break ExitMarker::Code(code);
231 }
232 Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
233 Err(error) => {
234 crate::slog_warn!(
235 "PTY waiter for {}:{} failed: {error}",
236 coordinator.session_id,
237 coordinator.task_id
238 );
239 break ExitMarker::Killed;
240 }
241 }
242 };
243
244 if let Err(error) = write_exit_marker(&exit_path, &marker, &coordinator.task_id) {
245 crate::slog_warn!(
246 "PTY waiter for {}:{} failed to write exit marker: {error}",
247 coordinator.session_id,
248 coordinator.task_id
249 );
250 }
251 exit_observed.store(true, Ordering::SeqCst);
252 coordinator.signal_one_done();
253 });
254}
255
256fn write_exit_marker(path: &Path, marker: &ExitMarker, task_id: &str) -> io::Result<()> {
257 let content = match marker {
258 ExitMarker::Code(code) => code.to_string(),
259 ExitMarker::Killed => "killed".to_string(),
260 };
261 atomic_write(path, content.as_bytes(), task_id)
262}
263
264#[cfg(test)]
265mod tests {
266 use std::io;
267 use std::sync::atomic::{AtomicBool, Ordering};
268 use std::sync::Arc;
269 use std::time::{Duration, Instant};
270
271 use portable_pty::{Child, ChildKiller, ExitStatus};
272
273 use super::*;
274
275 #[derive(Debug)]
276 struct FakeKiller;
277
278 impl ChildKiller for FakeKiller {
279 fn kill(&mut self) -> io::Result<()> {
280 Ok(())
281 }
282
283 fn clone_killer(&self) -> Box<dyn ChildKiller + Send + Sync> {
284 Box::new(FakeKiller)
285 }
286 }
287
288 #[derive(Debug)]
289 struct InterruptedOnceChild {
290 waits: usize,
291 }
292
293 impl ChildKiller for InterruptedOnceChild {
294 fn kill(&mut self) -> io::Result<()> {
295 Ok(())
296 }
297
298 fn clone_killer(&self) -> Box<dyn ChildKiller + Send + Sync> {
299 Box::new(FakeKiller)
300 }
301 }
302
303 impl Child for InterruptedOnceChild {
304 fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> {
305 Ok(None)
306 }
307
308 fn wait(&mut self) -> io::Result<ExitStatus> {
309 self.waits += 1;
310 if self.waits == 1 {
311 Err(io::Error::from(io::ErrorKind::Interrupted))
312 } else {
313 Ok(ExitStatus::with_exit_code(0))
314 }
315 }
316
317 fn process_id(&self) -> Option<u32> {
318 None
319 }
320
321 #[cfg(windows)]
322 fn as_raw_handle(&self) -> Option<std::os::windows::io::RawHandle> {
323 None
324 }
325 }
326
327 #[cfg(unix)]
328 #[test]
329 fn pty_waiter_retries_wait_on_interrupted() {
330 let temp = tempfile::tempdir().unwrap();
331 let exit_path = temp.path().join("task.exit");
332 let (wake_tx, wake_rx) = crossbeam_channel::bounded(1);
333 let coordinator = Arc::new(CompletionCoordinator::new(
334 "task".to_string(),
335 "session".to_string(),
336 wake_tx,
337 ));
338 let was_killed = Arc::new(AtomicBool::new(false));
339 let exit_observed = Arc::new(AtomicBool::new(false));
340
341 spawn_waiter(
342 Box::new(InterruptedOnceChild { waits: 0 }),
343 exit_path.clone(),
344 was_killed,
345 Arc::clone(&exit_observed),
346 Arc::clone(&coordinator),
347 );
348 coordinator.signal_one_done();
349
350 let started = Instant::now();
351 while !exit_observed.load(Ordering::SeqCst) {
352 assert!(started.elapsed() < Duration::from_secs(2));
353 std::thread::sleep(Duration::from_millis(10));
354 }
355 wake_rx.recv_timeout(Duration::from_secs(1)).unwrap();
356 assert_eq!(fs::read_to_string(exit_path).unwrap(), "0");
357 }
358}