Skip to main content

rmux_client/
attach.rs

1//! Raw terminal lifecycle and attach-stream helpers for attach-mode clients.
2
3use std::io::{self, Read, Write};
4use std::net::Shutdown;
5use std::os::fd::AsFd;
6use std::os::unix::net::UnixStream;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{mpsc, Arc};
9use std::thread;
10use std::time::Duration;
11
12use rmux_proto::{
13    encode_attach_message, AttachFrameDecoder, AttachMessage, AttachedKeystroke, RmuxError,
14    TerminalSize,
15};
16use rustix::event::{poll, PollFd, PollFlags, Timespec};
17use rustix::process::{kill_process, Signal};
18
19use crate::ClientError;
20
21#[path = "attach/resize.rs"]
22mod resize;
23#[path = "attach/screen.rs"]
24mod screen;
25#[path = "attach/terminal.rs"]
26mod terminal;
27#[path = "attach/terminal_cleanup.rs"]
28mod terminal_cleanup;
29
30use resize::{terminal_size_from_fd, ResizeWatcher, SignalMaskGuard};
31use screen::{
32    contains_subslice, AttachScreenTracker, AttachStopDetector, ALT_SCREEN_EXIT_FALLBACK,
33    DETACHED_BANNER_PREFIX, EXITED_BANNER,
34};
35use terminal::current_process_pid;
36pub use terminal::{AttachError, RawTerminal, Result};
37
38#[cfg(test)]
39use terminal_cleanup::fallback_attach_stop_sequence;
40
41const READ_BUFFER_SIZE: usize = 8192;
42const POLL_TIMEOUT: Timespec = Timespec {
43    tv_sec: 0,
44    tv_nsec: 100_000_000,
45};
46
47/// Runs the attach loop using the process stdin/stdout streams.
48pub fn attach_terminal(stream: UnixStream) -> std::result::Result<(), ClientError> {
49    attach_terminal_with_initial_bytes(stream, Vec::new())
50}
51
52/// Runs the attach loop using process stdin/stdout and pre-read stream bytes.
53pub fn attach_terminal_with_initial_bytes(
54    stream: UnixStream,
55    initial_bytes: Vec<u8>,
56) -> std::result::Result<(), ClientError> {
57    let terminal = io::stdin();
58    let input = io::stdin();
59    let output = io::stdout();
60
61    attach_with_terminal_with_initial_bytes(stream, initial_bytes, &terminal, input, output)
62}
63
64/// Runs the attach loop with an explicit terminal file descriptor.
65///
66/// The `terminal` handle is used for raw-mode lifecycle and resize discovery,
67/// while `input` and `output` carry the byte stream.
68pub fn attach_with_terminal<Terminal, Input, Output>(
69    stream: UnixStream,
70    terminal: &Terminal,
71    input: Input,
72    output: Output,
73) -> std::result::Result<(), ClientError>
74where
75    Terminal: AsFd,
76    Input: Read + AsFd + Send + 'static,
77    Output: Write + Send + 'static,
78{
79    attach_with_terminal_with_initial_bytes(stream, Vec::new(), terminal, input, output)
80}
81
82fn attach_with_terminal_with_initial_bytes<Terminal, Input, Output>(
83    stream: UnixStream,
84    initial_bytes: Vec<u8>,
85    terminal: &Terminal,
86    input: Input,
87    output: Output,
88) -> std::result::Result<(), ClientError>
89where
90    Terminal: AsFd,
91    Input: Read + AsFd + Send + 'static,
92    Output: Write + Send + 'static,
93{
94    let raw_terminal = RawTerminal::from_fd(terminal).map_err(ClientError::from)?;
95    let _ = raw_terminal.flush_pending_input();
96    let screen_tracker = AttachScreenTracker::default();
97    let result = drive_attach_with_terminal_state(
98        stream,
99        initial_bytes,
100        terminal,
101        &raw_terminal,
102        &screen_tracker,
103        input,
104        output,
105    );
106    if result.is_err() && !screen_tracker.was_stopped() {
107        let _ = raw_terminal.restore_attach_terminal_state();
108    }
109    let _ = raw_terminal.flush_pending_input();
110    drop(raw_terminal);
111    result
112}
113
114fn drive_attach_with_terminal_state<Terminal, Input, Output>(
115    stream: UnixStream,
116    initial_bytes: Vec<u8>,
117    terminal: &Terminal,
118    raw_terminal: &RawTerminal,
119    screen_tracker: &AttachScreenTracker,
120    input: Input,
121    output: Output,
122) -> std::result::Result<(), ClientError>
123where
124    Terminal: AsFd,
125    Input: Read + AsFd + Send + 'static,
126    Output: Write + Send + 'static,
127{
128    // This helper runs while the caller's `RawTerminal` guard is still alive,
129    // which keeps termios restoration as the last drop on every return path.
130    let _signal_mask = SignalMaskGuard::block_winch().map_err(ClientError::from)?;
131    let (resize_tx, resize_rx) = mpsc::channel();
132    let initial_size = terminal_size_from_fd(terminal).map_err(ClientError::from)?;
133    let terminal_fd = terminal
134        .as_fd()
135        .try_clone_to_owned()
136        .map_err(AttachError::from)?;
137
138    if let Some(initial_size) = initial_size {
139        resize_tx.send(initial_size).map_err(|_| {
140            ClientError::Io(io::Error::other(
141                "resize channel closed before attach start",
142            ))
143        })?;
144    }
145
146    let resize_watcher = ResizeWatcher::spawn(terminal_fd, resize_tx)?;
147    let attach_result = drive_attach_stream_with_locking(
148        stream,
149        initial_bytes,
150        raw_terminal,
151        screen_tracker,
152        input,
153        output,
154        resize_rx,
155    );
156    drop(resize_watcher);
157    attach_result
158}
159
160/// Drives raw attach-stream byte forwarding over an upgraded Unix socket.
161pub fn drive_attach_stream<Input, Output>(
162    stream: UnixStream,
163    input: Input,
164    output: Output,
165    resize_events: mpsc::Receiver<TerminalSize>,
166) -> std::result::Result<(), ClientError>
167where
168    Input: Read + AsFd + Send + 'static,
169    Output: Write + Send + 'static,
170{
171    drive_attach_stream_inner(
172        stream,
173        Vec::new(),
174        None,
175        AttachScreenTracker::default(),
176        input,
177        output,
178        resize_events,
179    )
180}
181
182fn drive_attach_stream_with_locking<Input, Output>(
183    stream: UnixStream,
184    initial_bytes: Vec<u8>,
185    raw_terminal: &RawTerminal,
186    screen_tracker: &AttachScreenTracker,
187    input: Input,
188    output: Output,
189    resize_events: mpsc::Receiver<TerminalSize>,
190) -> std::result::Result<(), ClientError>
191where
192    Input: Read + AsFd + Send + 'static,
193    Output: Write + Send + 'static,
194{
195    drive_attach_stream_inner(
196        stream,
197        initial_bytes,
198        Some(raw_terminal),
199        screen_tracker.clone(),
200        input,
201        output,
202        resize_events,
203    )
204}
205
206fn drive_attach_stream_inner<Input, Output>(
207    stream: UnixStream,
208    initial_bytes: Vec<u8>,
209    raw_terminal: Option<&RawTerminal>,
210    screen_tracker: AttachScreenTracker,
211    input: Input,
212    output: Output,
213    resize_events: mpsc::Receiver<TerminalSize>,
214) -> std::result::Result<(), ClientError>
215where
216    Input: Read + AsFd + Send + 'static,
217    Output: Write + Send + 'static,
218{
219    let control = stream.try_clone().map_err(ClientError::Io)?;
220    let mut lock_stream = stream.try_clone().map_err(ClientError::Io)?;
221    let input_stream = stream.try_clone().map_err(ClientError::Io)?;
222    let closed = Arc::new(AtomicBool::new(false));
223    let input_closed = Arc::clone(&closed);
224    let output_closed = Arc::clone(&closed);
225    let locked = Arc::new(AtomicBool::new(false));
226    let input_locked = Arc::clone(&locked);
227    let output_locked = Arc::clone(&locked);
228    let (action_tx, action_rx) = mpsc::channel();
229
230    let input_thread = thread::spawn(move || {
231        input_loop(
232            input_stream,
233            input,
234            resize_events,
235            input_closed,
236            input_locked,
237        )
238    });
239    let output_screen_tracker = screen_tracker.clone();
240    let output_thread = thread::spawn(move || {
241        output_loop(
242            stream,
243            initial_bytes,
244            output,
245            output_closed,
246            output_locked,
247            output_screen_tracker,
248            action_tx,
249        )
250    });
251
252    let output_result = wait_for_output_thread(
253        output_thread,
254        raw_terminal,
255        &mut lock_stream,
256        &locked,
257        action_rx,
258    )?;
259    closed.store(true, Ordering::SeqCst);
260    let _ = control.shutdown(Shutdown::Both);
261    let input_result = join_attach_thread(input_thread)?;
262
263    output_result?;
264    input_result
265}
266
267fn input_loop<Input>(
268    mut stream: UnixStream,
269    mut input: Input,
270    resize_events: mpsc::Receiver<TerminalSize>,
271    closed: Arc<AtomicBool>,
272    locked: Arc<AtomicBool>,
273) -> std::result::Result<(), ClientError>
274where
275    Input: Read + AsFd,
276{
277    let mut read_buffer = [0_u8; READ_BUFFER_SIZE];
278
279    loop {
280        if closed.load(Ordering::SeqCst) {
281            return Ok(());
282        }
283
284        drain_resize_events(&mut stream, &resize_events)?;
285        if locked.load(Ordering::SeqCst) {
286            thread::sleep(Duration::from_millis(20));
287            continue;
288        }
289
290        let mut fds = [PollFd::new(
291            &input,
292            PollFlags::IN | PollFlags::ERR | PollFlags::HUP,
293        )];
294        match poll(&mut fds, Some(&POLL_TIMEOUT)) {
295            Ok(0) => continue,
296            Ok(_) => {}
297            Err(rustix::io::Errno::INTR) => continue,
298            Err(error) => return Err(ClientError::Io(error.into())),
299        }
300
301        let ready = fds[0].revents();
302        if ready.is_empty() {
303            continue;
304        }
305        if closed.load(Ordering::SeqCst) {
306            return Ok(());
307        }
308        if !ready.contains(PollFlags::IN) {
309            if ready.contains(PollFlags::HUP) || ready.contains(PollFlags::ERR) {
310                shutdown_attach_writes(&stream)?;
311                return Ok(());
312            }
313            continue;
314        }
315
316        let bytes_read = match input.read(&mut read_buffer) {
317            Ok(0) => {
318                shutdown_attach_writes(&stream)?;
319                return Ok(());
320            }
321            Ok(bytes_read) => bytes_read,
322            Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
323            Err(error) => return Err(ClientError::Io(error)),
324        };
325
326        write_attach_message(
327            &mut stream,
328            AttachMessage::Keystroke(AttachedKeystroke::new(read_buffer[..bytes_read].to_vec())),
329        )?;
330    }
331}
332
333fn output_loop<Output>(
334    mut stream: UnixStream,
335    initial_bytes: Vec<u8>,
336    mut output: Output,
337    closed: Arc<AtomicBool>,
338    locked: Arc<AtomicBool>,
339    screen_tracker: AttachScreenTracker,
340    action_tx: mpsc::Sender<ClientAttachAction>,
341) -> std::result::Result<(), ClientError>
342where
343    Output: Write,
344{
345    let mut decoder = AttachFrameDecoder::new();
346    decoder.push_bytes(&initial_bytes);
347    let mut read_buffer = [0_u8; READ_BUFFER_SIZE];
348    let mut stop_detector = AttachStopDetector::new(screen_tracker.clone());
349
350    loop {
351        while let Some(message) = decoder.next_message().map_err(ClientError::from)? {
352            match message {
353                AttachMessage::Data(bytes) => {
354                    if contains_subslice(&bytes, ALT_SCREEN_EXIT_FALLBACK)
355                        || contains_subslice(&bytes, DETACHED_BANNER_PREFIX)
356                        || contains_subslice(&bytes, EXITED_BANNER)
357                    {
358                        screen_tracker.mark_stopped();
359                    }
360                    stop_detector.observe(&bytes);
361                    if locked.load(Ordering::SeqCst) {
362                        continue;
363                    }
364                    output.write_all(&bytes).map_err(ClientError::Io)?;
365                    output.flush().map_err(ClientError::Io)?;
366                }
367                AttachMessage::KeyDispatched(_) => {}
368                AttachMessage::Resize(_) => {
369                    return Err(ClientError::Protocol(RmuxError::Decode(
370                        "received unexpected resize message from attach stream".to_owned(),
371                    )));
372                }
373                AttachMessage::Lock(command) => {
374                    locked.store(true, Ordering::SeqCst);
375                    action_tx
376                        .send(ClientAttachAction::Lock(command))
377                        .map_err(|_| {
378                            ClientError::Io(io::Error::other("lock request receiver closed"))
379                        })?;
380                }
381                AttachMessage::LockShellCommand(command) => {
382                    locked.store(true, Ordering::SeqCst);
383                    action_tx
384                        .send(ClientAttachAction::Lock(command.command().to_owned()))
385                        .map_err(|_| {
386                            ClientError::Io(io::Error::other("lock request receiver closed"))
387                        })?;
388                }
389                AttachMessage::Suspend => {
390                    locked.store(true, Ordering::SeqCst);
391                    action_tx.send(ClientAttachAction::Suspend).map_err(|_| {
392                        ClientError::Io(io::Error::other("suspend request receiver closed"))
393                    })?;
394                }
395                AttachMessage::DetachKill => {
396                    closed.store(true, Ordering::SeqCst);
397                    action_tx
398                        .send(ClientAttachAction::DetachKill)
399                        .map_err(|_| {
400                            ClientError::Io(io::Error::other("detach request receiver closed"))
401                        })?;
402                    return Ok(());
403                }
404                AttachMessage::DetachExec(command) => {
405                    closed.store(true, Ordering::SeqCst);
406                    action_tx
407                        .send(ClientAttachAction::DetachExec(command))
408                        .map_err(|_| {
409                            ClientError::Io(io::Error::other("detach request receiver closed"))
410                        })?;
411                    return Ok(());
412                }
413                AttachMessage::DetachExecShellCommand(command) => {
414                    closed.store(true, Ordering::SeqCst);
415                    action_tx
416                        .send(ClientAttachAction::DetachExec(command.command().to_owned()))
417                        .map_err(|_| {
418                            ClientError::Io(io::Error::other("detach request receiver closed"))
419                        })?;
420                    return Ok(());
421                }
422                AttachMessage::Unlock => {
423                    return Err(ClientError::Protocol(RmuxError::Decode(
424                        "received unexpected unlock message from attach stream".to_owned(),
425                    )));
426                }
427                AttachMessage::Keystroke(_) => {
428                    return Err(ClientError::Protocol(RmuxError::Decode(
429                        "received unexpected keystroke message from attach stream".to_owned(),
430                    )));
431                }
432            }
433        }
434
435        let bytes_read = match stream.read(&mut read_buffer) {
436            Ok(0) => {
437                closed.store(true, Ordering::SeqCst);
438                if screen_tracker.was_stopped() {
439                    return Ok(());
440                }
441                return Err(ClientError::Io(io::Error::new(
442                    io::ErrorKind::UnexpectedEof,
443                    "attach stream closed before attach-stop sequence",
444                )));
445            }
446            Ok(bytes_read) => bytes_read,
447            Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
448            Err(error)
449                if screen_tracker.was_stopped()
450                    && matches!(
451                        error.kind(),
452                        io::ErrorKind::ConnectionReset | io::ErrorKind::BrokenPipe
453                    ) =>
454            {
455                return Ok(());
456            }
457            Err(error) => return Err(ClientError::Io(error)),
458        };
459
460        decoder.push_bytes(&read_buffer[..bytes_read]);
461    }
462}
463
464fn wait_for_output_thread(
465    output_thread: thread::JoinHandle<std::result::Result<(), ClientError>>,
466    raw_terminal: Option<&RawTerminal>,
467    lock_stream: &mut UnixStream,
468    locked: &Arc<AtomicBool>,
469    action_rx: mpsc::Receiver<ClientAttachAction>,
470) -> std::result::Result<std::result::Result<(), ClientError>, ClientError> {
471    loop {
472        match action_rx.recv_timeout(Duration::from_millis(20)) {
473            Ok(action) => handle_attach_action(raw_terminal, lock_stream, locked, action)?,
474            Err(mpsc::RecvTimeoutError::Timeout) if output_thread.is_finished() => break,
475            Err(mpsc::RecvTimeoutError::Timeout) => {}
476            Err(mpsc::RecvTimeoutError::Disconnected) => break,
477        }
478    }
479
480    while let Ok(action) = action_rx.try_recv() {
481        handle_attach_action(raw_terminal, lock_stream, locked, action)?;
482    }
483
484    join_attach_thread(output_thread)
485}
486
487fn handle_attach_action(
488    raw_terminal: Option<&RawTerminal>,
489    lock_stream: &mut UnixStream,
490    locked: &Arc<AtomicBool>,
491    action: ClientAttachAction,
492) -> std::result::Result<(), ClientError> {
493    match action {
494        ClientAttachAction::Lock(command) => {
495            let Some(raw_terminal) = raw_terminal else {
496                locked.store(false, Ordering::SeqCst);
497                return Err(ClientError::Protocol(RmuxError::Decode(
498                    "received unexpected lock request without a managed terminal".to_owned(),
499                )));
500            };
501            raw_terminal
502                .run_lock_command(&command)
503                .map_err(ClientError::from)?;
504            write_attach_message(lock_stream, AttachMessage::Unlock)?;
505            locked.store(false, Ordering::SeqCst);
506            Ok(())
507        }
508        ClientAttachAction::Suspend => {
509            let Some(raw_terminal) = raw_terminal else {
510                locked.store(false, Ordering::SeqCst);
511                return Err(ClientError::Protocol(RmuxError::Decode(
512                    "received unexpected suspend request without a managed terminal".to_owned(),
513                )));
514            };
515            raw_terminal.suspend_self().map_err(ClientError::from)?;
516            write_attach_message(lock_stream, AttachMessage::Unlock)?;
517            locked.store(false, Ordering::SeqCst);
518            Ok(())
519        }
520        ClientAttachAction::DetachKill => {
521            if let Some(raw_terminal) = raw_terminal {
522                raw_terminal.restore().map_err(ClientError::from)?;
523            }
524            kill_process(current_process_pid().map_err(ClientError::Io)?, Signal::HUP)
525                .map_err(|error| ClientError::Io(error.into()))?;
526            Ok(())
527        }
528        ClientAttachAction::DetachExec(command) => {
529            let Some(raw_terminal) = raw_terminal else {
530                return Err(ClientError::Protocol(RmuxError::Decode(
531                    "received unexpected detach exec request without a managed terminal".to_owned(),
532                )));
533            };
534            raw_terminal
535                .run_detach_exec_command(&command)
536                .map_err(ClientError::from)
537        }
538    }
539}
540
541fn drain_resize_events(
542    stream: &mut UnixStream,
543    resize_events: &mpsc::Receiver<TerminalSize>,
544) -> std::result::Result<(), ClientError> {
545    while let Ok(size) = resize_events.try_recv() {
546        write_attach_message(stream, AttachMessage::Resize(size))?;
547    }
548
549    Ok(())
550}
551
552fn write_attach_message(
553    stream: &mut UnixStream,
554    message: AttachMessage,
555) -> std::result::Result<(), ClientError> {
556    let frame = encode_attach_message(&message).map_err(ClientError::from)?;
557    stream.write_all(&frame).map_err(ClientError::Io)
558}
559
560fn join_attach_thread(
561    thread: thread::JoinHandle<std::result::Result<(), ClientError>>,
562) -> std::result::Result<std::result::Result<(), ClientError>, ClientError> {
563    thread
564        .join()
565        .map_err(|_| ClientError::Io(io::Error::other("attach thread panicked")))
566}
567
568fn shutdown_attach_writes(stream: &UnixStream) -> std::result::Result<(), ClientError> {
569    match stream.shutdown(Shutdown::Write) {
570        Ok(()) => Ok(()),
571        Err(error) if error.kind() == io::ErrorKind::NotConnected => Ok(()),
572        Err(error) => Err(ClientError::Io(error)),
573    }
574}
575
576#[derive(Debug)]
577enum ClientAttachAction {
578    Lock(String),
579    Suspend,
580    DetachKill,
581    DetachExec(String),
582}
583
584#[cfg(test)]
585mod tests;