1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::sys::termios::{self, SetArg, Termios};
5use std::io::{self, Read, Write};
6use std::os::fd::{AsFd, AsRawFd, BorrowedFd};
7use std::path::Path;
8use std::time::Duration;
9use tokio::io::unix::AsyncFd;
10use tokio::net::UnixStream;
11use tokio::signal::unix::{SignalKind, signal};
12use tokio::time::Instant;
13use tokio_util::codec::Framed;
14use tracing::{debug, info};
15
16const ESCAPE_HELP: &[u8] = b"\r\nSupported escape sequences:\r\n\
19 ~. - detach from session\r\n\
20 ~^Z - suspend client\r\n\
21 ~? - this message\r\n\
22 ~~ - send the escape character by typing it twice\r\n\
23(Note that escapes are only recognized immediately after newline.)\r\n";
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26enum EscapeState {
27 Normal,
28 AfterNewline,
29 AfterTilde,
30}
31
32#[derive(Debug, PartialEq, Eq)]
33enum EscapeAction {
34 Data(Vec<u8>),
35 Detach,
36 Suspend,
37 Help,
38}
39
40struct EscapeProcessor {
41 state: EscapeState,
42}
43
44impl EscapeProcessor {
45 fn new() -> Self {
46 Self { state: EscapeState::AfterNewline }
47 }
48
49 fn process(&mut self, input: &[u8]) -> Vec<EscapeAction> {
50 let mut actions = Vec::new();
51 let mut data_buf = Vec::new();
52
53 for &b in input {
54 match self.state {
55 EscapeState::Normal => {
56 if b == b'\n' || b == b'\r' {
57 self.state = EscapeState::AfterNewline;
58 }
59 data_buf.push(b);
60 }
61 EscapeState::AfterNewline => {
62 if b == b'~' {
63 self.state = EscapeState::AfterTilde;
64 if !data_buf.is_empty() {
66 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
67 }
68 } else if b == b'\n' || b == b'\r' {
69 data_buf.push(b);
71 } else {
72 self.state = EscapeState::Normal;
73 data_buf.push(b);
74 }
75 }
76 EscapeState::AfterTilde => {
77 match b {
78 b'.' => {
79 if !data_buf.is_empty() {
80 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
81 }
82 actions.push(EscapeAction::Detach);
83 return actions; }
85 0x1a => {
86 if !data_buf.is_empty() {
88 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
89 }
90 actions.push(EscapeAction::Suspend);
91 self.state = EscapeState::Normal;
92 }
93 b'?' => {
94 if !data_buf.is_empty() {
95 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
96 }
97 actions.push(EscapeAction::Help);
98 self.state = EscapeState::Normal;
99 }
100 b'~' => {
101 data_buf.push(b'~');
103 self.state = EscapeState::Normal;
104 }
105 b'\n' | b'\r' => {
106 data_buf.push(b'~');
108 data_buf.push(b);
109 self.state = EscapeState::AfterNewline;
110 }
111 _ => {
112 data_buf.push(b'~');
114 data_buf.push(b);
115 self.state = EscapeState::Normal;
116 }
117 }
118 }
119 }
120 }
121
122 if !data_buf.is_empty() {
123 actions.push(EscapeAction::Data(data_buf));
124 }
125 actions
126 }
127}
128
129fn suspend(raw_guard: &RawModeGuard, nb_guard: &NonBlockGuard) -> anyhow::Result<()> {
130 termios::tcsetattr(raw_guard.fd, SetArg::TCSAFLUSH, &raw_guard.original)?;
132 let _ = nix::fcntl::fcntl(nb_guard.fd, nix::fcntl::FcntlArg::F_SETFL(nb_guard.original_flags));
133
134 nix::sys::signal::kill(nix::unistd::Pid::from_raw(0), nix::sys::signal::Signal::SIGTSTP)?;
135
136 let _ = nix::fcntl::fcntl(
138 nb_guard.fd,
139 nix::fcntl::FcntlArg::F_SETFL(nb_guard.original_flags | nix::fcntl::OFlag::O_NONBLOCK),
140 );
141 let mut raw = raw_guard.original.clone();
142 termios::cfmakeraw(&mut raw);
143 termios::tcsetattr(raw_guard.fd, SetArg::TCSAFLUSH, &raw)?;
144 Ok(())
145}
146
147const SEND_TIMEOUT: Duration = Duration::from_secs(5);
148
149struct NonBlockGuard {
150 fd: BorrowedFd<'static>,
151 original_flags: nix::fcntl::OFlag,
152}
153
154impl NonBlockGuard {
155 fn set(fd: BorrowedFd<'static>) -> nix::Result<Self> {
156 let flags = nix::fcntl::fcntl(fd, nix::fcntl::FcntlArg::F_GETFL)?;
157 let original_flags = nix::fcntl::OFlag::from_bits_truncate(flags);
158 nix::fcntl::fcntl(
159 fd,
160 nix::fcntl::FcntlArg::F_SETFL(original_flags | nix::fcntl::OFlag::O_NONBLOCK),
161 )?;
162 Ok(Self { fd, original_flags })
163 }
164}
165
166impl Drop for NonBlockGuard {
167 fn drop(&mut self) {
168 let _ = nix::fcntl::fcntl(self.fd, nix::fcntl::FcntlArg::F_SETFL(self.original_flags));
169 }
170}
171
172struct RawModeGuard {
173 fd: BorrowedFd<'static>,
174 original: Termios,
175}
176
177impl RawModeGuard {
178 fn enter(fd: BorrowedFd<'static>) -> nix::Result<Self> {
179 let original = termios::tcgetattr(fd)?;
180 let mut raw = original.clone();
181 termios::cfmakeraw(&mut raw);
182 termios::tcsetattr(fd, SetArg::TCSAFLUSH, &raw)?;
183 Ok(Self { fd, original })
184 }
185}
186
187impl Drop for RawModeGuard {
188 fn drop(&mut self) {
189 let _ = termios::tcsetattr(self.fd, SetArg::TCSAFLUSH, &self.original);
190 }
191}
192
193fn write_stdout(data: &[u8]) -> io::Result<()> {
197 let mut stdout = io::stdout();
198 let mut written = 0;
199 while written < data.len() {
200 match stdout.write(&data[written..]) {
201 Ok(n) => written += n,
202 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
203 std::thread::yield_now();
204 }
205 Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
206 Err(e) => return Err(e),
207 }
208 }
209 loop {
210 match stdout.flush() {
211 Ok(()) => return Ok(()),
212 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
213 std::thread::yield_now();
214 }
215 Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
216 Err(e) => return Err(e),
217 }
218 }
219}
220
221fn get_terminal_size() -> (u16, u16) {
222 let mut ws: libc::winsize = unsafe { std::mem::zeroed() };
223 unsafe { libc::ioctl(libc::STDIN_FILENO, libc::TIOCGWINSZ, &mut ws) };
224 (ws.ws_col, ws.ws_row)
225}
226
227async fn timed_send(framed: &mut Framed<UnixStream, FrameCodec>, frame: Frame) -> bool {
229 match tokio::time::timeout(SEND_TIMEOUT, framed.send(frame)).await {
230 Ok(Ok(())) => true,
231 Ok(Err(e)) => {
232 debug!("send error: {e}");
233 false
234 }
235 Err(_) => {
236 debug!("send timed out");
237 false
238 }
239 }
240}
241
242const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
243const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(15);
244
245async fn relay(
248 framed: &mut Framed<UnixStream, FrameCodec>,
249 async_stdin: &AsyncFd<io::Stdin>,
250 sigwinch: &mut tokio::signal::unix::Signal,
251 buf: &mut [u8],
252 redraw: bool,
253 env_vars: &[(String, String)],
254 mut escape: Option<&mut EscapeProcessor>,
255 raw_guard: &RawModeGuard,
256 nb_guard: &NonBlockGuard,
257) -> anyhow::Result<Option<i32>> {
258 if !env_vars.is_empty() && !timed_send(framed, Frame::Env { vars: env_vars.to_vec() }).await {
260 return Ok(None);
261 }
262 let (cols, rows) = get_terminal_size();
264 if !timed_send(framed, Frame::Resize { cols, rows }).await {
265 return Ok(None);
266 }
267 if redraw && !timed_send(framed, Frame::Data(Bytes::from_static(b"\x0c"))).await {
269 return Ok(None);
270 }
271
272 let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
273 heartbeat_interval.reset(); let mut last_pong = Instant::now();
275
276 loop {
277 tokio::select! {
278 ready = async_stdin.readable() => {
279 let mut guard = ready?;
280 match guard.try_io(|inner| inner.get_ref().read(buf)) {
281 Ok(Ok(0)) => {
282 debug!("stdin EOF");
283 return Ok(Some(0));
284 }
285 Ok(Ok(n)) => {
286 debug!(len = n, "stdin → socket");
287 if let Some(ref mut esc) = escape {
288 for action in esc.process(&buf[..n]) {
289 match action {
290 EscapeAction::Data(data) => {
291 if !timed_send(framed, Frame::Data(Bytes::from(data))).await {
292 return Ok(None);
293 }
294 }
295 EscapeAction::Detach => {
296 write_stdout(b"\r\n[detached]\r\n")?;
297 return Ok(Some(0));
298 }
299 EscapeAction::Suspend => {
300 suspend(raw_guard, nb_guard)?;
301 let (cols, rows) = get_terminal_size();
303 if !timed_send(framed, Frame::Resize { cols, rows }).await {
304 return Ok(None);
305 }
306 }
307 EscapeAction::Help => {
308 write_stdout(ESCAPE_HELP)?;
309 }
310 }
311 }
312 } else if !timed_send(framed, Frame::Data(Bytes::copy_from_slice(&buf[..n]))).await {
313 return Ok(None);
314 }
315 }
316 Ok(Err(e)) => return Err(e.into()),
317 Err(_would_block) => continue,
318 }
319 }
320
321 frame = framed.next() => {
322 match frame {
323 Some(Ok(Frame::Data(data))) => {
324 debug!(len = data.len(), "socket → stdout");
325 write_stdout(&data)?;
326 }
327 Some(Ok(Frame::Pong)) => {
328 debug!("pong received");
329 last_pong = Instant::now();
330 }
331 Some(Ok(Frame::Exit { code })) => {
332 info!(code, "server sent exit");
333 return Ok(Some(code));
334 }
335 Some(Ok(Frame::Detached)) => {
336 info!("detached by another client");
337 write_stdout(b"[detached]\r\n")?;
338 return Ok(Some(0));
339 }
340 Some(Ok(_)) => {} Some(Err(e)) => {
342 debug!("server connection error: {e}");
343 return Ok(None);
344 }
345 None => {
346 debug!("server disconnected");
347 return Ok(None);
348 }
349 }
350 }
351
352 _ = sigwinch.recv() => {
353 let (cols, rows) = get_terminal_size();
354 debug!(cols, rows, "SIGWINCH → resize");
355 if !timed_send(framed, Frame::Resize { cols, rows }).await {
356 return Ok(None);
357 }
358 }
359
360 _ = heartbeat_interval.tick() => {
361 if last_pong.elapsed() > HEARTBEAT_TIMEOUT {
362 debug!("heartbeat timeout");
363 return Ok(None);
364 }
365 if !timed_send(framed, Frame::Ping).await {
366 return Ok(None);
367 }
368 }
369 }
370 }
371}
372
373pub async fn run(
374 session: &str,
375 mut framed: Framed<UnixStream, FrameCodec>,
376 redraw: bool,
377 ctl_path: &Path,
378 env_vars: Vec<(String, String)>,
379 no_escape: bool,
380) -> anyhow::Result<i32> {
381 let stdin = io::stdin();
382 let stdin_fd = stdin.as_fd();
383 let stdin_borrowed: BorrowedFd<'static> =
385 unsafe { BorrowedFd::borrow_raw(stdin_fd.as_raw_fd()) };
386 let raw_guard = RawModeGuard::enter(stdin_borrowed)?;
387
388 let nb_guard = NonBlockGuard::set(stdin_borrowed)?;
391 let async_stdin = AsyncFd::new(io::stdin())?;
392 let mut sigwinch = signal(SignalKind::window_change())?;
393 let mut buf = vec![0u8; 4096];
394 let mut current_redraw = redraw;
395 let mut current_env = env_vars;
396 let mut escape = if no_escape { None } else { Some(EscapeProcessor::new()) };
397
398 loop {
399 match relay(
400 &mut framed,
401 &async_stdin,
402 &mut sigwinch,
403 &mut buf,
404 current_redraw,
405 ¤t_env,
406 escape.as_mut(),
407 &raw_guard,
408 &nb_guard,
409 )
410 .await?
411 {
412 Some(code) => return Ok(code),
413 None => {
414 current_env.clear();
416 write_stdout(b"[reconnecting...]\r\n")?;
418
419 loop {
420 tokio::time::sleep(Duration::from_secs(1)).await;
421
422 {
424 let mut peek = [0u8; 1];
425 match io::stdin().read(&mut peek) {
426 Ok(1) if peek[0] == 0x03 => {
427 write_stdout(b"\r\n")?;
428 return Ok(1);
429 }
430 _ => {}
431 }
432 }
433
434 let stream = match UnixStream::connect(ctl_path).await {
435 Ok(s) => s,
436 Err(_) => continue,
437 };
438
439 let mut new_framed = Framed::new(stream, FrameCodec);
440 if new_framed
441 .send(Frame::Attach { session: session.to_string() })
442 .await
443 .is_err()
444 {
445 continue;
446 }
447
448 match new_framed.next().await {
449 Some(Ok(Frame::Ok)) => {
450 write_stdout(b"[reconnected]\r\n")?;
451 framed = new_framed;
452 current_redraw = true;
453 break;
454 }
455 Some(Ok(Frame::Error { message })) => {
456 let msg = format!("[session gone: {message}]\r\n");
457 write_stdout(msg.as_bytes())?;
458 return Ok(1);
459 }
460 _ => continue,
461 }
462 }
463 }
464 }
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471
472 #[test]
473 fn normal_passthrough() {
474 let mut ep = EscapeProcessor::new();
475 let actions = ep.process(b"hello");
477 assert_eq!(actions, vec![EscapeAction::Data(b"hello".to_vec())]);
478 }
479
480 #[test]
481 fn tilde_after_newline_detach() {
482 let mut ep = EscapeProcessor { state: EscapeState::Normal };
483 let actions = ep.process(b"\n~.");
484 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
485 }
486
487 #[test]
488 fn tilde_after_cr_detach() {
489 let mut ep = EscapeProcessor { state: EscapeState::Normal };
490 let actions = ep.process(b"\r~.");
491 assert_eq!(actions, vec![EscapeAction::Data(b"\r".to_vec()), EscapeAction::Detach,]);
492 }
493
494 #[test]
495 fn tilde_not_after_newline() {
496 let mut ep = EscapeProcessor { state: EscapeState::Normal };
497 let actions = ep.process(b"a~.");
498 assert_eq!(actions, vec![EscapeAction::Data(b"a~.".to_vec())]);
499 }
500
501 #[test]
502 fn initial_state_detach() {
503 let mut ep = EscapeProcessor::new();
504 let actions = ep.process(b"~.");
505 assert_eq!(actions, vec![EscapeAction::Detach]);
506 }
507
508 #[test]
509 fn tilde_suspend() {
510 let mut ep = EscapeProcessor { state: EscapeState::Normal };
511 let actions = ep.process(b"\n~\x1a");
512 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Suspend,]);
513 }
514
515 #[test]
516 fn tilde_help() {
517 let mut ep = EscapeProcessor { state: EscapeState::Normal };
518 let actions = ep.process(b"\n~?");
519 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Help,]);
520 }
521
522 #[test]
523 fn double_tilde() {
524 let mut ep = EscapeProcessor { state: EscapeState::Normal };
525 let actions = ep.process(b"\n~~");
526 assert_eq!(
527 actions,
528 vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~".to_vec()),]
529 );
530 assert_eq!(ep.state, EscapeState::Normal);
531 }
532
533 #[test]
534 fn tilde_unknown_char() {
535 let mut ep = EscapeProcessor { state: EscapeState::Normal };
536 let actions = ep.process(b"\n~x");
537 assert_eq!(
538 actions,
539 vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~x".to_vec()),]
540 );
541 }
542
543 #[test]
544 fn split_across_reads() {
545 let mut ep = EscapeProcessor { state: EscapeState::Normal };
546 let a1 = ep.process(b"\n");
547 assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
548 let a2 = ep.process(b"~");
549 assert_eq!(a2, vec![]); let a3 = ep.process(b".");
551 assert_eq!(a3, vec![EscapeAction::Detach]);
552 }
553
554 #[test]
555 fn split_tilde_then_normal() {
556 let mut ep = EscapeProcessor { state: EscapeState::Normal };
557 let a1 = ep.process(b"\n");
558 assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
559 let a2 = ep.process(b"~");
560 assert_eq!(a2, vec![]);
561 let a3 = ep.process(b"a");
562 assert_eq!(a3, vec![EscapeAction::Data(b"~a".to_vec())]);
563 }
564
565 #[test]
566 fn multiple_escapes_one_buffer() {
567 let mut ep = EscapeProcessor { state: EscapeState::Normal };
568 let actions = ep.process(b"\n~?\n~.");
569 assert_eq!(
570 actions,
571 vec![
572 EscapeAction::Data(b"\n".to_vec()),
573 EscapeAction::Help,
574 EscapeAction::Data(b"\n".to_vec()),
575 EscapeAction::Detach,
576 ]
577 );
578 }
579
580 #[test]
581 fn consecutive_newlines() {
582 let mut ep = EscapeProcessor { state: EscapeState::Normal };
583 let actions = ep.process(b"\n\n\n~.");
584 assert_eq!(actions, vec![EscapeAction::Data(b"\n\n\n".to_vec()), EscapeAction::Detach,]);
585 }
586
587 #[test]
588 fn detach_stops_processing() {
589 let mut ep = EscapeProcessor { state: EscapeState::Normal };
590 let actions = ep.process(b"\n~.remaining");
591 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
592 }
593
594 #[test]
595 fn tilde_then_newline() {
596 let mut ep = EscapeProcessor { state: EscapeState::Normal };
597 let actions = ep.process(b"\n~\n");
598 assert_eq!(
599 actions,
600 vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~\n".to_vec()),]
601 );
602 assert_eq!(ep.state, EscapeState::AfterNewline);
603 }
604
605 #[test]
606 fn empty_input() {
607 let mut ep = EscapeProcessor::new();
608 let actions = ep.process(b"");
609 assert_eq!(actions, vec![]);
610 }
611
612 #[test]
613 fn only_tilde_buffered() {
614 let mut ep = EscapeProcessor { state: EscapeState::Normal };
615 let a1 = ep.process(b"\n~");
616 assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
617 assert_eq!(ep.state, EscapeState::AfterTilde);
618 let a2 = ep.process(b".");
619 assert_eq!(a2, vec![EscapeAction::Detach]);
620 }
621}