1use std::{ffi::OsString, io::Read as _};
7
8use snafu::{OptionExt as _, ResultExt as _};
9use tokio::sync::mpsc;
10use tracing::Instrument as _;
11
12pub type BytesFromPTY = [u8; 4096];
15pub type BytesFromSTDIN = [u8; 128];
17
18#[non_exhaustive]
20pub(crate) struct PTY {
21 pub command: Vec<OsString>,
23 pub width: u16,
25 pub height: u16,
27 pub control_tx: tokio::sync::broadcast::Sender<crate::Protocol>,
29 pub output_tx: tokio::sync::mpsc::Sender<crate::pty::BytesFromPTY>,
31}
32
33impl PTY {
34 fn setup_pty(&self) -> Result<portable_pty::PtyPair, crate::errors::PTYError> {
36 tracing::debug!("Setting up PTY");
37 let pty_system = portable_pty::native_pty_system();
38 let pair = pty_system
39 .openpty(Self::pty_size(self.width, self.height))
40 .with_whatever_context(|_| "Error opening PTY")?;
41
42 tracing::debug!("Launching `{:?}` on PTY", self.command);
43 let mut cmd = portable_pty::CommandBuilder::from_argv(self.command.clone());
44 cmd.cwd(
45 std::env::current_dir()
46 .with_whatever_context(|_| "Couldn't get user's current directory")?,
47 );
48 let spawn = pair
49 .slave
50 .spawn_command(cmd)
51 .with_whatever_context(|_| "Error spawning PTY command")?;
52 let killer = spawn.clone_killer();
53 Self::wait_for_pty_end(self.control_tx.clone(), spawn);
54 Self::kill_on_protocol_end(self.control_tx.subscribe(), killer);
55
56 tracing::trace!("Returning PTY pair");
57 Ok(pair)
58 }
59
60 fn pty_reader_loop(
63 pty_reader: std::boxed::Box<dyn std::io::Read + std::marker::Send>,
64 pty_reader_tx: mpsc::Sender<BytesFromPTY>,
65 ) -> tokio::task::JoinHandle<()> {
66 tokio::task::spawn_blocking(move || {
67 let mut reader = std::io::BufReader::new(pty_reader);
68 loop {
69 let mut buffer: BytesFromPTY = [0; 4096];
70
71 let now = std::time::Instant::now();
72 let read_result = reader.read(&mut buffer);
73 let elapsed = now.elapsed();
74
75 match read_result {
76 Ok(0) => {
77 tracing::debug!("PTY reader loop received 0 bytes, exiting...");
78 break;
79 }
80 Ok(n) => {
81 tracing::trace!(
82 "Read {} PTY bytes. Time since last output {:?}",
83 n,
84 elapsed
85 );
86 let send_result = pty_reader_tx.blocking_send(buffer);
87 if let Err(error) = send_result {
88 tracing::error!("Broadcasting PTY output: {error:?}");
89 break;
90 }
91 }
92 Err(error) => tracing::error!("PTY reader: {error:?}"),
93 }
94 }
95 tracing::trace!("Leaving PTY reader loop");
96 })
97 }
98
99 fn wait_for_pty_end(
101 protocol_out: tokio::sync::broadcast::Sender<crate::Protocol>,
102 mut spawn: Box<dyn portable_pty::Child + Send + Sync>,
103 ) {
104 tokio::task::spawn_blocking(move || {
105 tracing::debug!("Starting to wait for PTY end");
106 let waiter_result = spawn.wait();
107 if let Err(error) = waiter_result {
108 tracing::error!("Waiting for PTY: {error:?}");
109 }
110
111 std::thread::sleep(std::time::Duration::from_millis(10));
114
115 let sender_result = protocol_out.send(crate::Protocol::End);
116 if let Err(error) = sender_result {
117 tracing::error!("Sending `Protocol::End` after: {error:?} ");
118 }
119 tracing::info!("PTY ended by its own accord");
120 });
121 }
122
123 fn kill_on_protocol_end(
125 mut protocol_in: tokio::sync::broadcast::Receiver<crate::Protocol>,
126 mut spawn: Box<dyn portable_pty::ChildKiller + Send + Sync>,
127 ) {
128 let current_span = tracing::Span::current();
129 tokio::spawn(
130 async move {
131 tracing::debug!("Starting loop for PTY spawn to receive protocol messages");
132 loop {
133 match protocol_in.recv().await {
134 Ok(message) => {
135 if matches!(message, crate::Protocol::End) {
136 tracing::debug!("PTY received Tattoy message {message:?}");
137 let result = spawn.kill();
138 if let Err(error) = result {
139 let pty_exit = "No such process";
142 if error.to_string().contains(pty_exit) {
143 tracing::debug!("Tried killing PTY that was already gone.");
144 break;
145 }
146
147 tracing::error!("Couldn't kill PTY: {error:?}");
148 }
150
151 tracing::debug!(
152 "`kill()` (which includes OS kill signals) sent to PTY spawn process"
153 );
154 break;
155 }
156 }
157 Err(error) => {
158 tracing::error!("Reading protocol from PTY loop: {error:?}");
159 }
160 }
161 }
162 tracing::debug!("Leaving spawn shutdown listener loop.");
163 }
164 .instrument(current_span),
165 );
166 }
167
168 pub(crate) async fn run(
170 self,
171 user_input_rx: mpsc::Receiver<BytesFromSTDIN>,
172 internal_input_rx: mpsc::Receiver<BytesFromSTDIN>,
173 ) -> Result<(), crate::errors::PTYError> {
174 let (pty_reader_tx, mut pty_reader_rx) = tokio::sync::mpsc::channel(1);
175
176 let mut protocol_for_main_loop = self.control_tx.subscribe();
180
181 let pty_pair = self.setup_pty()?;
182 let pty_writer = pty_pair
183 .master
184 .take_writer()
185 .with_whatever_context(|err| format!("Getting PTY writer: {err:?}"))?;
186 let pty_reader = pty_pair
187 .master
188 .try_clone_reader()
189 .with_whatever_context(|err| format!("Getting PTY reader: {err:?}"))?;
190
191 Self::pty_reader_loop(pty_reader, pty_reader_tx);
192
193 drop(pty_pair.slave);
195
196 let protocol_for_input_loop = self.control_tx.subscribe();
198 let current_span = tracing::Span::current();
199 tokio::spawn(async move {
200 let result = Self::forward_input(
201 user_input_rx,
202 internal_input_rx,
203 pty_writer,
204 pty_pair.master,
205 protocol_for_input_loop,
206 )
207 .instrument(current_span)
208 .await;
209 if let Err(err) = result {
210 tracing::error!("Writing to PTY stream: {err}");
211 }
212 });
213
214 tracing::debug!("Starting PTY reader loop");
215 #[expect(
216 clippy::integer_division_remainder_used,
217 reason = "`tokio::select!` generates this."
218 )]
219 loop {
220 tokio::select! {
221 result = self.read_stream(&mut pty_reader_rx) => {
222 if let Err(error) = result {
223 tracing::error!("{error:?}");
225 snafu::whatever!("{error:?}");
226 }
227 }
228 result = protocol_for_main_loop.recv() => {
229 match result {
230 Ok(message) => {
231 if matches!(message, crate::Protocol::End) {
232 break;
233 }
234 }
235 Err(err) => {
236 tracing::error!("{err:?}");
238 snafu::whatever!("{err:?}");
239 },
240
241 }
242 }
243
244 }
245 }
246
247 tracing::debug!("PTY reader loop finished");
248 Ok(())
249 }
250
251 async fn read_stream(
253 &self,
254 pty_reader_rx: &mut mpsc::Receiver<BytesFromPTY>,
255 ) -> Result<(), crate::errors::PTYError> {
256 let Some(bytes) = pty_reader_rx.recv().await else {
257 return Ok(());
258 };
259
260 let result = self.output_tx.send(bytes).await;
261 if let Err(err) = result {
262 tracing::error!("Sending bytes on PTY output channel: {err}");
263 }
264
265 let output = String::from_utf8_lossy(&bytes)
266 .to_string()
267 .replace('\x1b', "^");
268 tracing::trace!("Sent PTY output, sample:\n{:.500}...", output);
269
270 Ok(())
271 }
272
273 async fn forward_input(
275 mut user_input: mpsc::Receiver<BytesFromSTDIN>,
276 mut internal_input: mpsc::Receiver<BytesFromSTDIN>,
277 mut pty_writer: std::boxed::Box<dyn std::io::Write + std::marker::Send>,
278 pty_master: std::boxed::Box<(dyn portable_pty::MasterPty + std::marker::Send + 'static)>,
279 mut protocol: tokio::sync::broadcast::Receiver<crate::Protocol>,
280 ) -> Result<(), crate::errors::PTYError> {
281 tracing::debug!("Starting `forward_input` loop");
282
283 #[expect(
284 clippy::integer_division_remainder_used,
285 reason = "This is generated by the `tokio::select!`"
286 )]
287 loop {
288 tokio::select! {
289 message = protocol.recv() => {
290 Self::handle_protocol_message_for_input_loop(&message, &pty_master)?;
291 if matches!(message, Ok(crate::Protocol::End)) {
292 break;
293 }
294 }
295 Some(some_bytes) = user_input.recv() => {
296 Self::handle_input_bytes(some_bytes, &mut pty_writer)?;
297 }
298 Some(some_bytes) = internal_input.recv() => {
299 Self::handle_input_bytes(some_bytes, &mut pty_writer)?;
300 }
301 }
302 }
303
304 tracing::debug!("`forward_input` loop finished");
305 Ok(())
306 }
307
308 fn handle_protocol_message_for_input_loop(
310 message: &std::result::Result<crate::Protocol, tokio::sync::broadcast::error::RecvError>,
311 pty_master: &std::boxed::Box<(dyn portable_pty::MasterPty + std::marker::Send + 'static)>,
312 ) -> Result<(), crate::errors::PTYError> {
313 match message {
314 Ok(crate::Protocol::End) => {
315 tracing::trace!("PTY input forwarder task received {message:?}");
316 return Ok(());
317 }
318 Ok(crate::Protocol::Resize { width, height }) => {
319 tracing::debug!("Resize event received on PTY input loop {message:?}");
320
321 let result = pty_master.resize(Self::pty_size(*width, *height));
322 if result.is_err() {
323 tracing::error!("Couldn't resize underlying PTY subprocesss: {result:?}");
324 }
325 }
326 Ok(_) => (),
327 Err(err) => snafu::whatever!("{err:?}"),
328 }
329
330 Ok(())
331 }
332
333 fn handle_input_bytes(
335 bytes: BytesFromSTDIN,
336 pty_stdin: &mut std::boxed::Box<dyn std::io::Write + std::marker::Send>,
337 ) -> Result<(), crate::errors::PTYError> {
338 tracing::trace!(
339 "Forwarding input to PTY: '{}'",
340 String::from_utf8_lossy(&bytes)
341 .replace('\n', "\\n")
342 .replace('\x1b', "^")
343 );
344
345 let maybe_size = bytes.iter().position(|byte| byte == &0);
346 let size = maybe_size.unwrap_or(128);
347 let byte_slice = bytes.get(0..size).with_whatever_context(|| {
348 "Couldn't get slice of input payload. Should be impossible."
349 })?;
350
351 pty_stdin
352 .write_all(byte_slice)
353 .with_whatever_context(|err| {
354 format!("`handle_input_bytes()`: couldn't write bytes into PTY's STDIN: {err:?}")
355 })?;
356 pty_stdin
357 .flush()
358 .with_whatever_context(|err| format!("Couldn't flush STDIN stream to PTY: {err:?}"))?;
359
360 Ok(())
361 }
362
363 const fn pty_size(width: u16, height: u16) -> portable_pty::PtySize {
365 portable_pty::PtySize {
366 cols: width,
367 rows: height,
368 pixel_width: 0,
372 pixel_height: 0,
373 }
374 }
375
376 pub(crate) fn add_bytes_to_buffer(
378 buffer: &mut BytesFromSTDIN,
379 bytes: &[u8],
380 ) -> Result<(), crate::errors::PTYError> {
381 if bytes.len() > buffer.len() {
382 snafu::whatever!(
383 "Bytes ({}) to add to buffer are more than the buffer size ({}).",
384 bytes.len(),
385 buffer.len()
386 );
387 }
388 for (i, chunk_byte) in bytes.iter().enumerate() {
389 let buffer_byte = buffer
390 .get_mut(i)
391 .with_whatever_context(|| "Couldn't get byte from buffer")?;
392 *buffer_byte = *chunk_byte;
393 }
394
395 Ok(())
396 }
397}
398
399impl Drop for PTY {
400 fn drop(&mut self) {
401 tracing::debug!("PTY dropped, broadcasting `End` signal.");
402
403 let result: Result<_, crate::errors::PTYError> = self
404 .control_tx
405 .send(crate::Protocol::End)
406 .with_whatever_context(|err| {
407 format!("Couldn't send shutdown signal after PTY finished: {err:?}")
408 });
409
410 if let Err(err) = result {
411 tracing::error!("{err:?}");
412 }
413 }
414}
415
416#[cfg(test)]
417#[expect(clippy::print_stderr, reason = "Tests aren't so strict")]
418mod test {
419 use super::*;
420
421 fn run(
422 command: Vec<OsString>,
423 ) -> (
424 tokio::task::JoinHandle<std::string::String>,
425 mpsc::Sender<BytesFromSTDIN>,
426 ) {
427 let (pty_output_tx, mut pty_output_rx) = mpsc::channel::<BytesFromPTY>(8);
431 let (pty_input_tx, pty_input_rx) = mpsc::channel::<BytesFromSTDIN>(1);
432 let (_, internal_input_rx) = mpsc::channel::<BytesFromSTDIN>(8);
433 let (protocol_tx, _) = tokio::sync::broadcast::channel(16);
434
435 let output_task = tokio::spawn(async move {
436 tracing::debug!("TEST: Output listener loop starting...");
437 let mut result: Vec<u8> = vec![];
438
439 while let Some(bytes) = pty_output_rx.recv().await {
442 result.extend(bytes.iter().copied());
443 }
444
445 let output = String::from_utf8_lossy(&result).into_owned();
446 tracing::debug!("TEST: `interactive()` output: {output:?}");
447 output
448 });
449
450 tokio::spawn(async move {
451 tracing::debug!("TEST: PTY.run() starting...");
452 let pty = PTY {
453 command,
454 width: 10,
455 height: 10,
456 output_tx: pty_output_tx,
457 control_tx: protocol_tx.clone(),
458 };
459 let result = pty.run(pty_input_rx, internal_input_rx).await;
460 if let Err(err) = result {
461 tracing::warn!("PTY (for tests) handle: {err:?}");
462 }
463 tracing::debug!("Test PTY.run() done");
464 });
465
466 tracing::debug!("TEST: Leaving run helper...");
467 (output_task, pty_input_tx)
468 }
469
470 fn cat_earth_command() -> String {
472 let cat_command = "cat";
473 let path = crate::tests::helpers::workspace_dir()
474 .join("shadow-terminal")
475 .join("src")
476 .join("tests")
477 .join("cat_me.txt");
478
479 #[cfg(not(target_os = "windows"))]
480 let sleep = "&& sleep 0.5";
481 #[cfg(target_os = "windows")]
482 let sleep = "; Start-Sleep -Milliseconds 5";
483
484 format!("{cat_command} {} {sleep}", path.display())
485 }
486
487 fn stdin_bytes(input: &str) -> BytesFromSTDIN {
488 let mut buffer: BytesFromSTDIN = [0; 128];
489 #[expect(
490 clippy::indexing_slicing,
491 reason = "How do I do a range slice with []?"
492 )]
493 buffer[..input.len()].copy_from_slice(input.as_bytes());
494 buffer
495 }
496
497 #[tokio::test(flavor = "multi_thread")]
498 async fn basic_output() {
499 let mut command = crate::tests::helpers::get_canonical_shell();
500
501 #[cfg(not(target_os = "windows"))]
502 command.push("-c".into());
503 #[cfg(target_os = "windows")]
504 command.push("-Command".into());
505
506 command.push(cat_earth_command().into());
507
508 let (output_task, _) = run(command);
509 let result = output_task.await.unwrap();
510 eprintln!("{result}");
511
512 assert!(result.contains("earth"));
513 }
514
515 #[cfg(not(target_os = "windows"))]
516 #[tokio::test(flavor = "multi_thread")]
517 async fn interactive() {
518 let (output_task, input_channel) = run(crate::tests::helpers::get_canonical_shell());
519 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
520
521 #[cfg(not(target_os = "windows"))]
522 let exit = "&& exit";
523 #[cfg(target_os = "windows")]
524 let exit = "; exit";
525 let command = format!("{} {exit}\n", cat_earth_command());
526
527 input_channel
528 .send(stdin_bytes(command.as_ref()))
529 .await
530 .unwrap();
531 tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
532 let result = output_task.await.unwrap();
533 eprintln!("{result}");
534
535 assert!(result.contains("earth"));
536 }
537}