1use std::io::{self, Read, Write};
4use std::net::Shutdown;
5use std::os::fd::AsFd;
6use std::os::unix::net::UnixStream;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{mpsc, Arc};
9use std::thread;
10use std::time::Duration;
11
12use rmux_proto::{
13 encode_attach_message, AttachFrameDecoder, AttachMessage, AttachedKeystroke, RmuxError,
14 TerminalSize,
15};
16use rustix::event::{poll, PollFd, PollFlags, Timespec};
17use rustix::process::{kill_process, Signal};
18
19use crate::ClientError;
20
21#[path = "attach/resize.rs"]
22mod resize;
23#[path = "attach/screen.rs"]
24mod screen;
25#[path = "attach/terminal.rs"]
26mod terminal;
27#[path = "attach/terminal_cleanup.rs"]
28mod terminal_cleanup;
29
30use resize::{terminal_size_from_fd, ResizeWatcher, SignalMaskGuard};
31use screen::{
32 contains_subslice, AttachScreenTracker, AttachStopDetector, ALT_SCREEN_EXIT_FALLBACK,
33 DETACHED_BANNER_PREFIX, EXITED_BANNER,
34};
35use terminal::current_process_pid;
36pub use terminal::{AttachError, RawTerminal, Result};
37
38#[cfg(test)]
39use terminal_cleanup::fallback_attach_stop_sequence;
40
41const READ_BUFFER_SIZE: usize = 8192;
42const POLL_TIMEOUT: Timespec = Timespec {
43 tv_sec: 0,
44 tv_nsec: 100_000_000,
45};
46
47pub fn attach_terminal(stream: UnixStream) -> std::result::Result<(), ClientError> {
49 attach_terminal_with_initial_bytes(stream, Vec::new())
50}
51
52pub fn attach_terminal_with_initial_bytes(
54 stream: UnixStream,
55 initial_bytes: Vec<u8>,
56) -> std::result::Result<(), ClientError> {
57 let terminal = io::stdin();
58 let input = io::stdin();
59 let output = io::stdout();
60
61 attach_with_terminal_with_initial_bytes(stream, initial_bytes, &terminal, input, output)
62}
63
64pub fn attach_with_terminal<Terminal, Input, Output>(
69 stream: UnixStream,
70 terminal: &Terminal,
71 input: Input,
72 output: Output,
73) -> std::result::Result<(), ClientError>
74where
75 Terminal: AsFd,
76 Input: Read + AsFd + Send + 'static,
77 Output: Write + Send + 'static,
78{
79 attach_with_terminal_with_initial_bytes(stream, Vec::new(), terminal, input, output)
80}
81
82fn attach_with_terminal_with_initial_bytes<Terminal, Input, Output>(
83 stream: UnixStream,
84 initial_bytes: Vec<u8>,
85 terminal: &Terminal,
86 input: Input,
87 output: Output,
88) -> std::result::Result<(), ClientError>
89where
90 Terminal: AsFd,
91 Input: Read + AsFd + Send + 'static,
92 Output: Write + Send + 'static,
93{
94 let raw_terminal = RawTerminal::from_fd(terminal).map_err(ClientError::from)?;
95 let _ = raw_terminal.flush_pending_input();
96 let screen_tracker = AttachScreenTracker::default();
97 let result = drive_attach_with_terminal_state(
98 stream,
99 initial_bytes,
100 terminal,
101 &raw_terminal,
102 &screen_tracker,
103 input,
104 output,
105 );
106 if result.is_err() && !screen_tracker.was_stopped() {
107 let _ = raw_terminal.restore_attach_terminal_state();
108 }
109 let _ = raw_terminal.flush_pending_input();
110 drop(raw_terminal);
111 result
112}
113
114fn drive_attach_with_terminal_state<Terminal, Input, Output>(
115 stream: UnixStream,
116 initial_bytes: Vec<u8>,
117 terminal: &Terminal,
118 raw_terminal: &RawTerminal,
119 screen_tracker: &AttachScreenTracker,
120 input: Input,
121 output: Output,
122) -> std::result::Result<(), ClientError>
123where
124 Terminal: AsFd,
125 Input: Read + AsFd + Send + 'static,
126 Output: Write + Send + 'static,
127{
128 let _signal_mask = SignalMaskGuard::block_winch().map_err(ClientError::from)?;
131 let (resize_tx, resize_rx) = mpsc::channel();
132 let initial_size = terminal_size_from_fd(terminal).map_err(ClientError::from)?;
133 let terminal_fd = terminal
134 .as_fd()
135 .try_clone_to_owned()
136 .map_err(AttachError::from)?;
137
138 if let Some(initial_size) = initial_size {
139 resize_tx.send(initial_size).map_err(|_| {
140 ClientError::Io(io::Error::other(
141 "resize channel closed before attach start",
142 ))
143 })?;
144 }
145
146 let resize_watcher = ResizeWatcher::spawn(terminal_fd, resize_tx)?;
147 let attach_result = drive_attach_stream_with_locking(
148 stream,
149 initial_bytes,
150 raw_terminal,
151 screen_tracker,
152 input,
153 output,
154 resize_rx,
155 );
156 drop(resize_watcher);
157 attach_result
158}
159
160pub fn drive_attach_stream<Input, Output>(
162 stream: UnixStream,
163 input: Input,
164 output: Output,
165 resize_events: mpsc::Receiver<TerminalSize>,
166) -> std::result::Result<(), ClientError>
167where
168 Input: Read + AsFd + Send + 'static,
169 Output: Write + Send + 'static,
170{
171 drive_attach_stream_inner(
172 stream,
173 Vec::new(),
174 None,
175 AttachScreenTracker::default(),
176 input,
177 output,
178 resize_events,
179 )
180}
181
182fn drive_attach_stream_with_locking<Input, Output>(
183 stream: UnixStream,
184 initial_bytes: Vec<u8>,
185 raw_terminal: &RawTerminal,
186 screen_tracker: &AttachScreenTracker,
187 input: Input,
188 output: Output,
189 resize_events: mpsc::Receiver<TerminalSize>,
190) -> std::result::Result<(), ClientError>
191where
192 Input: Read + AsFd + Send + 'static,
193 Output: Write + Send + 'static,
194{
195 drive_attach_stream_inner(
196 stream,
197 initial_bytes,
198 Some(raw_terminal),
199 screen_tracker.clone(),
200 input,
201 output,
202 resize_events,
203 )
204}
205
206fn drive_attach_stream_inner<Input, Output>(
207 stream: UnixStream,
208 initial_bytes: Vec<u8>,
209 raw_terminal: Option<&RawTerminal>,
210 screen_tracker: AttachScreenTracker,
211 input: Input,
212 output: Output,
213 resize_events: mpsc::Receiver<TerminalSize>,
214) -> std::result::Result<(), ClientError>
215where
216 Input: Read + AsFd + Send + 'static,
217 Output: Write + Send + 'static,
218{
219 let control = stream.try_clone().map_err(ClientError::Io)?;
220 let mut lock_stream = stream.try_clone().map_err(ClientError::Io)?;
221 let input_stream = stream.try_clone().map_err(ClientError::Io)?;
222 let closed = Arc::new(AtomicBool::new(false));
223 let input_closed = Arc::clone(&closed);
224 let output_closed = Arc::clone(&closed);
225 let locked = Arc::new(AtomicBool::new(false));
226 let input_locked = Arc::clone(&locked);
227 let output_locked = Arc::clone(&locked);
228 let (action_tx, action_rx) = mpsc::channel();
229
230 let input_thread = thread::spawn(move || {
231 input_loop(
232 input_stream,
233 input,
234 resize_events,
235 input_closed,
236 input_locked,
237 )
238 });
239 let output_screen_tracker = screen_tracker.clone();
240 let output_thread = thread::spawn(move || {
241 output_loop(
242 stream,
243 initial_bytes,
244 output,
245 output_closed,
246 output_locked,
247 output_screen_tracker,
248 action_tx,
249 )
250 });
251
252 let output_result = wait_for_output_thread(
253 output_thread,
254 raw_terminal,
255 &mut lock_stream,
256 &locked,
257 action_rx,
258 )?;
259 closed.store(true, Ordering::SeqCst);
260 let _ = control.shutdown(Shutdown::Both);
261 let input_result = join_attach_thread(input_thread)?;
262
263 output_result?;
264 input_result
265}
266
267fn input_loop<Input>(
268 mut stream: UnixStream,
269 mut input: Input,
270 resize_events: mpsc::Receiver<TerminalSize>,
271 closed: Arc<AtomicBool>,
272 locked: Arc<AtomicBool>,
273) -> std::result::Result<(), ClientError>
274where
275 Input: Read + AsFd,
276{
277 let mut read_buffer = [0_u8; READ_BUFFER_SIZE];
278
279 loop {
280 if closed.load(Ordering::SeqCst) {
281 return Ok(());
282 }
283
284 drain_resize_events(&mut stream, &resize_events)?;
285 if locked.load(Ordering::SeqCst) {
286 thread::sleep(Duration::from_millis(20));
287 continue;
288 }
289
290 let mut fds = [PollFd::new(
291 &input,
292 PollFlags::IN | PollFlags::ERR | PollFlags::HUP,
293 )];
294 match poll(&mut fds, Some(&POLL_TIMEOUT)) {
295 Ok(0) => continue,
296 Ok(_) => {}
297 Err(rustix::io::Errno::INTR) => continue,
298 Err(error) => return Err(ClientError::Io(error.into())),
299 }
300
301 let ready = fds[0].revents();
302 if ready.is_empty() {
303 continue;
304 }
305 if closed.load(Ordering::SeqCst) {
306 return Ok(());
307 }
308 if !ready.contains(PollFlags::IN) {
309 if ready.contains(PollFlags::HUP) || ready.contains(PollFlags::ERR) {
310 shutdown_attach_writes(&stream)?;
311 return Ok(());
312 }
313 continue;
314 }
315
316 let bytes_read = match input.read(&mut read_buffer) {
317 Ok(0) => {
318 shutdown_attach_writes(&stream)?;
319 return Ok(());
320 }
321 Ok(bytes_read) => bytes_read,
322 Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
323 Err(error) => return Err(ClientError::Io(error)),
324 };
325
326 write_attach_message(
327 &mut stream,
328 AttachMessage::Keystroke(AttachedKeystroke::new(read_buffer[..bytes_read].to_vec())),
329 )?;
330 }
331}
332
333fn output_loop<Output>(
334 mut stream: UnixStream,
335 initial_bytes: Vec<u8>,
336 mut output: Output,
337 closed: Arc<AtomicBool>,
338 locked: Arc<AtomicBool>,
339 screen_tracker: AttachScreenTracker,
340 action_tx: mpsc::Sender<ClientAttachAction>,
341) -> std::result::Result<(), ClientError>
342where
343 Output: Write,
344{
345 let mut decoder = AttachFrameDecoder::new();
346 decoder.push_bytes(&initial_bytes);
347 let mut read_buffer = [0_u8; READ_BUFFER_SIZE];
348 let mut stop_detector = AttachStopDetector::new(screen_tracker.clone());
349
350 loop {
351 while let Some(message) = decoder.next_message().map_err(ClientError::from)? {
352 match message {
353 AttachMessage::Data(bytes) => {
354 if contains_subslice(&bytes, ALT_SCREEN_EXIT_FALLBACK)
355 || contains_subslice(&bytes, DETACHED_BANNER_PREFIX)
356 || contains_subslice(&bytes, EXITED_BANNER)
357 {
358 screen_tracker.mark_stopped();
359 }
360 stop_detector.observe(&bytes);
361 if locked.load(Ordering::SeqCst) {
362 continue;
363 }
364 output.write_all(&bytes).map_err(ClientError::Io)?;
365 output.flush().map_err(ClientError::Io)?;
366 }
367 AttachMessage::KeyDispatched(_) => {}
368 AttachMessage::Resize(_) => {
369 return Err(ClientError::Protocol(RmuxError::Decode(
370 "received unexpected resize message from attach stream".to_owned(),
371 )));
372 }
373 AttachMessage::Lock(command) => {
374 locked.store(true, Ordering::SeqCst);
375 action_tx
376 .send(ClientAttachAction::Lock(command))
377 .map_err(|_| {
378 ClientError::Io(io::Error::other("lock request receiver closed"))
379 })?;
380 }
381 AttachMessage::LockShellCommand(command) => {
382 locked.store(true, Ordering::SeqCst);
383 action_tx
384 .send(ClientAttachAction::Lock(command.command().to_owned()))
385 .map_err(|_| {
386 ClientError::Io(io::Error::other("lock request receiver closed"))
387 })?;
388 }
389 AttachMessage::Suspend => {
390 locked.store(true, Ordering::SeqCst);
391 action_tx.send(ClientAttachAction::Suspend).map_err(|_| {
392 ClientError::Io(io::Error::other("suspend request receiver closed"))
393 })?;
394 }
395 AttachMessage::DetachKill => {
396 closed.store(true, Ordering::SeqCst);
397 action_tx
398 .send(ClientAttachAction::DetachKill)
399 .map_err(|_| {
400 ClientError::Io(io::Error::other("detach request receiver closed"))
401 })?;
402 return Ok(());
403 }
404 AttachMessage::DetachExec(command) => {
405 closed.store(true, Ordering::SeqCst);
406 action_tx
407 .send(ClientAttachAction::DetachExec(command))
408 .map_err(|_| {
409 ClientError::Io(io::Error::other("detach request receiver closed"))
410 })?;
411 return Ok(());
412 }
413 AttachMessage::DetachExecShellCommand(command) => {
414 closed.store(true, Ordering::SeqCst);
415 action_tx
416 .send(ClientAttachAction::DetachExec(command.command().to_owned()))
417 .map_err(|_| {
418 ClientError::Io(io::Error::other("detach request receiver closed"))
419 })?;
420 return Ok(());
421 }
422 AttachMessage::Unlock => {
423 return Err(ClientError::Protocol(RmuxError::Decode(
424 "received unexpected unlock message from attach stream".to_owned(),
425 )));
426 }
427 AttachMessage::Keystroke(_) => {
428 return Err(ClientError::Protocol(RmuxError::Decode(
429 "received unexpected keystroke message from attach stream".to_owned(),
430 )));
431 }
432 }
433 }
434
435 let bytes_read = match stream.read(&mut read_buffer) {
436 Ok(0) => {
437 closed.store(true, Ordering::SeqCst);
438 if screen_tracker.was_stopped() {
439 return Ok(());
440 }
441 return Err(ClientError::Io(io::Error::new(
442 io::ErrorKind::UnexpectedEof,
443 "attach stream closed before attach-stop sequence",
444 )));
445 }
446 Ok(bytes_read) => bytes_read,
447 Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
448 Err(error)
449 if screen_tracker.was_stopped()
450 && matches!(
451 error.kind(),
452 io::ErrorKind::ConnectionReset | io::ErrorKind::BrokenPipe
453 ) =>
454 {
455 return Ok(());
456 }
457 Err(error) => return Err(ClientError::Io(error)),
458 };
459
460 decoder.push_bytes(&read_buffer[..bytes_read]);
461 }
462}
463
464fn wait_for_output_thread(
465 output_thread: thread::JoinHandle<std::result::Result<(), ClientError>>,
466 raw_terminal: Option<&RawTerminal>,
467 lock_stream: &mut UnixStream,
468 locked: &Arc<AtomicBool>,
469 action_rx: mpsc::Receiver<ClientAttachAction>,
470) -> std::result::Result<std::result::Result<(), ClientError>, ClientError> {
471 loop {
472 match action_rx.recv_timeout(Duration::from_millis(20)) {
473 Ok(action) => handle_attach_action(raw_terminal, lock_stream, locked, action)?,
474 Err(mpsc::RecvTimeoutError::Timeout) if output_thread.is_finished() => break,
475 Err(mpsc::RecvTimeoutError::Timeout) => {}
476 Err(mpsc::RecvTimeoutError::Disconnected) => break,
477 }
478 }
479
480 while let Ok(action) = action_rx.try_recv() {
481 handle_attach_action(raw_terminal, lock_stream, locked, action)?;
482 }
483
484 join_attach_thread(output_thread)
485}
486
487fn handle_attach_action(
488 raw_terminal: Option<&RawTerminal>,
489 lock_stream: &mut UnixStream,
490 locked: &Arc<AtomicBool>,
491 action: ClientAttachAction,
492) -> std::result::Result<(), ClientError> {
493 match action {
494 ClientAttachAction::Lock(command) => {
495 let Some(raw_terminal) = raw_terminal else {
496 locked.store(false, Ordering::SeqCst);
497 return Err(ClientError::Protocol(RmuxError::Decode(
498 "received unexpected lock request without a managed terminal".to_owned(),
499 )));
500 };
501 raw_terminal
502 .run_lock_command(&command)
503 .map_err(ClientError::from)?;
504 write_attach_message(lock_stream, AttachMessage::Unlock)?;
505 locked.store(false, Ordering::SeqCst);
506 Ok(())
507 }
508 ClientAttachAction::Suspend => {
509 let Some(raw_terminal) = raw_terminal else {
510 locked.store(false, Ordering::SeqCst);
511 return Err(ClientError::Protocol(RmuxError::Decode(
512 "received unexpected suspend request without a managed terminal".to_owned(),
513 )));
514 };
515 raw_terminal.suspend_self().map_err(ClientError::from)?;
516 write_attach_message(lock_stream, AttachMessage::Unlock)?;
517 locked.store(false, Ordering::SeqCst);
518 Ok(())
519 }
520 ClientAttachAction::DetachKill => {
521 if let Some(raw_terminal) = raw_terminal {
522 raw_terminal.restore().map_err(ClientError::from)?;
523 }
524 kill_process(current_process_pid().map_err(ClientError::Io)?, Signal::HUP)
525 .map_err(|error| ClientError::Io(error.into()))?;
526 Ok(())
527 }
528 ClientAttachAction::DetachExec(command) => {
529 let Some(raw_terminal) = raw_terminal else {
530 return Err(ClientError::Protocol(RmuxError::Decode(
531 "received unexpected detach exec request without a managed terminal".to_owned(),
532 )));
533 };
534 raw_terminal
535 .run_detach_exec_command(&command)
536 .map_err(ClientError::from)
537 }
538 }
539}
540
541fn drain_resize_events(
542 stream: &mut UnixStream,
543 resize_events: &mpsc::Receiver<TerminalSize>,
544) -> std::result::Result<(), ClientError> {
545 while let Ok(size) = resize_events.try_recv() {
546 write_attach_message(stream, AttachMessage::Resize(size))?;
547 }
548
549 Ok(())
550}
551
552fn write_attach_message(
553 stream: &mut UnixStream,
554 message: AttachMessage,
555) -> std::result::Result<(), ClientError> {
556 let frame = encode_attach_message(&message).map_err(ClientError::from)?;
557 stream.write_all(&frame).map_err(ClientError::Io)
558}
559
560fn join_attach_thread(
561 thread: thread::JoinHandle<std::result::Result<(), ClientError>>,
562) -> std::result::Result<std::result::Result<(), ClientError>, ClientError> {
563 thread
564 .join()
565 .map_err(|_| ClientError::Io(io::Error::other("attach thread panicked")))
566}
567
568fn shutdown_attach_writes(stream: &UnixStream) -> std::result::Result<(), ClientError> {
569 match stream.shutdown(Shutdown::Write) {
570 Ok(()) => Ok(()),
571 Err(error) if error.kind() == io::ErrorKind::NotConnected => Ok(()),
572 Err(error) => Err(ClientError::Io(error)),
573 }
574}
575
576#[derive(Debug)]
577enum ClientAttachAction {
578 Lock(String),
579 Suspend,
580 DetachKill,
581 DetachExec(String),
582}
583
584#[cfg(test)]
585mod tests;