1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::sys::termios::{self, SetArg, Termios};
5use std::io::{self, Read, Write};
6use std::os::fd::{AsFd, AsRawFd, BorrowedFd};
7use std::path::Path;
8use std::time::Duration;
9use tokio::io::unix::AsyncFd;
10use tokio::net::UnixStream;
11use tokio::signal::unix::{SignalKind, signal};
12use tokio::time::Instant;
13use tokio_util::codec::Framed;
14use tracing::{debug, info};
15
16const ESCAPE_HELP: &[u8] = b"\r\nSupported escape sequences:\r\n\
19 ~. - detach from session\r\n\
20 ~^Z - suspend client\r\n\
21 ~? - this message\r\n\
22 ~~ - send the escape character by typing it twice\r\n\
23(Note that escapes are only recognized immediately after newline.)\r\n";
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26enum EscapeState {
27 Normal,
28 AfterNewline,
29 AfterTilde,
30}
31
32#[derive(Debug, PartialEq, Eq)]
33enum EscapeAction {
34 Data(Vec<u8>),
35 Detach,
36 Suspend,
37 Help,
38}
39
40struct EscapeProcessor {
41 state: EscapeState,
42}
43
44impl EscapeProcessor {
45 fn new() -> Self {
46 Self { state: EscapeState::AfterNewline }
47 }
48
49 fn process(&mut self, input: &[u8]) -> Vec<EscapeAction> {
50 let mut actions = Vec::new();
51 let mut data_buf = Vec::new();
52
53 for &b in input {
54 match self.state {
55 EscapeState::Normal => {
56 if b == b'\n' || b == b'\r' {
57 self.state = EscapeState::AfterNewline;
58 }
59 data_buf.push(b);
60 }
61 EscapeState::AfterNewline => {
62 if b == b'~' {
63 self.state = EscapeState::AfterTilde;
64 if !data_buf.is_empty() {
66 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
67 }
68 } else if b == b'\n' || b == b'\r' {
69 data_buf.push(b);
71 } else {
72 self.state = EscapeState::Normal;
73 data_buf.push(b);
74 }
75 }
76 EscapeState::AfterTilde => {
77 match b {
78 b'.' => {
79 if !data_buf.is_empty() {
80 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
81 }
82 actions.push(EscapeAction::Detach);
83 return actions; }
85 0x1a => {
86 if !data_buf.is_empty() {
88 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
89 }
90 actions.push(EscapeAction::Suspend);
91 self.state = EscapeState::Normal;
92 }
93 b'?' => {
94 if !data_buf.is_empty() {
95 actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
96 }
97 actions.push(EscapeAction::Help);
98 self.state = EscapeState::Normal;
99 }
100 b'~' => {
101 data_buf.push(b'~');
103 self.state = EscapeState::Normal;
104 }
105 b'\n' | b'\r' => {
106 data_buf.push(b'~');
108 data_buf.push(b);
109 self.state = EscapeState::AfterNewline;
110 }
111 _ => {
112 data_buf.push(b'~');
114 data_buf.push(b);
115 self.state = EscapeState::Normal;
116 }
117 }
118 }
119 }
120 }
121
122 if !data_buf.is_empty() {
123 actions.push(EscapeAction::Data(data_buf));
124 }
125 actions
126 }
127}
128
129fn suspend() -> anyhow::Result<()> {
130 nix::sys::signal::kill(nix::unistd::Pid::from_raw(0), nix::sys::signal::Signal::SIGTSTP)?;
131 Ok(())
132}
133
134const SEND_TIMEOUT: Duration = Duration::from_secs(5);
135
136struct NonBlockGuard {
137 fd: BorrowedFd<'static>,
138 original_flags: nix::fcntl::OFlag,
139}
140
141impl NonBlockGuard {
142 fn set(fd: BorrowedFd<'static>) -> nix::Result<Self> {
143 let flags = nix::fcntl::fcntl(fd, nix::fcntl::FcntlArg::F_GETFL)?;
144 let original_flags = nix::fcntl::OFlag::from_bits_truncate(flags);
145 nix::fcntl::fcntl(
146 fd,
147 nix::fcntl::FcntlArg::F_SETFL(original_flags | nix::fcntl::OFlag::O_NONBLOCK),
148 )?;
149 Ok(Self { fd, original_flags })
150 }
151}
152
153impl Drop for NonBlockGuard {
154 fn drop(&mut self) {
155 let _ = nix::fcntl::fcntl(self.fd, nix::fcntl::FcntlArg::F_SETFL(self.original_flags));
156 }
157}
158
159struct RawModeGuard {
160 fd: BorrowedFd<'static>,
161 original: Termios,
162}
163
164impl RawModeGuard {
165 fn enter(fd: BorrowedFd<'static>) -> nix::Result<Self> {
166 let original = termios::tcgetattr(fd)?;
167 let mut raw = original.clone();
168 termios::cfmakeraw(&mut raw);
169 termios::tcsetattr(fd, SetArg::TCSAFLUSH, &raw)?;
170 Ok(Self { fd, original })
171 }
172}
173
174impl Drop for RawModeGuard {
175 fn drop(&mut self) {
176 let _ = termios::tcsetattr(self.fd, SetArg::TCSAFLUSH, &self.original);
177 }
178}
179
180fn write_stdout(data: &[u8]) -> io::Result<()> {
184 let mut stdout = io::stdout();
185 let mut written = 0;
186 while written < data.len() {
187 match stdout.write(&data[written..]) {
188 Ok(n) => written += n,
189 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
190 std::thread::yield_now();
191 }
192 Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
193 Err(e) => return Err(e),
194 }
195 }
196 loop {
197 match stdout.flush() {
198 Ok(()) => return Ok(()),
199 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
200 std::thread::yield_now();
201 }
202 Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
203 Err(e) => return Err(e),
204 }
205 }
206}
207
208fn get_terminal_size() -> (u16, u16) {
209 let mut ws: libc::winsize = unsafe { std::mem::zeroed() };
210 unsafe { libc::ioctl(libc::STDIN_FILENO, libc::TIOCGWINSZ, &mut ws) };
211 (ws.ws_col, ws.ws_row)
212}
213
214async fn timed_send(framed: &mut Framed<UnixStream, FrameCodec>, frame: Frame) -> bool {
216 match tokio::time::timeout(SEND_TIMEOUT, framed.send(frame)).await {
217 Ok(Ok(())) => true,
218 Ok(Err(e)) => {
219 debug!("send error: {e}");
220 false
221 }
222 Err(_) => {
223 debug!("send timed out");
224 false
225 }
226 }
227}
228
229const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
230const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(15);
231
232async fn relay(
235 framed: &mut Framed<UnixStream, FrameCodec>,
236 async_stdin: &AsyncFd<io::Stdin>,
237 sigwinch: &mut tokio::signal::unix::Signal,
238 buf: &mut [u8],
239 redraw: bool,
240 env_vars: &[(String, String)],
241 mut escape: Option<&mut EscapeProcessor>,
242) -> anyhow::Result<Option<i32>> {
243 if !env_vars.is_empty() && !timed_send(framed, Frame::Env { vars: env_vars.to_vec() }).await {
245 return Ok(None);
246 }
247 let (cols, rows) = get_terminal_size();
249 if !timed_send(framed, Frame::Resize { cols, rows }).await {
250 return Ok(None);
251 }
252 if redraw && !timed_send(framed, Frame::Data(Bytes::from_static(b"\x0c"))).await {
254 return Ok(None);
255 }
256
257 let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
258 heartbeat_interval.reset(); let mut last_pong = Instant::now();
260
261 loop {
262 tokio::select! {
263 ready = async_stdin.readable() => {
264 let mut guard = ready?;
265 match guard.try_io(|inner| inner.get_ref().read(buf)) {
266 Ok(Ok(0)) => {
267 debug!("stdin EOF");
268 return Ok(Some(0));
269 }
270 Ok(Ok(n)) => {
271 debug!(len = n, "stdin → socket");
272 if let Some(ref mut esc) = escape {
273 for action in esc.process(&buf[..n]) {
274 match action {
275 EscapeAction::Data(data) => {
276 if !timed_send(framed, Frame::Data(Bytes::from(data))).await {
277 return Ok(None);
278 }
279 }
280 EscapeAction::Detach => {
281 write_stdout(b"\r\n[detached]\r\n")?;
282 return Ok(Some(0));
283 }
284 EscapeAction::Suspend => {
285 suspend()?;
286 let (cols, rows) = get_terminal_size();
288 if !timed_send(framed, Frame::Resize { cols, rows }).await {
289 return Ok(None);
290 }
291 }
292 EscapeAction::Help => {
293 write_stdout(ESCAPE_HELP)?;
294 }
295 }
296 }
297 } else if !timed_send(framed, Frame::Data(Bytes::copy_from_slice(&buf[..n]))).await {
298 return Ok(None);
299 }
300 }
301 Ok(Err(e)) => return Err(e.into()),
302 Err(_would_block) => continue,
303 }
304 }
305
306 frame = framed.next() => {
307 match frame {
308 Some(Ok(Frame::Data(data))) => {
309 debug!(len = data.len(), "socket → stdout");
310 write_stdout(&data)?;
311 }
312 Some(Ok(Frame::Pong)) => {
313 debug!("pong received");
314 last_pong = Instant::now();
315 }
316 Some(Ok(Frame::Exit { code })) => {
317 info!(code, "server sent exit");
318 return Ok(Some(code));
319 }
320 Some(Ok(Frame::Detached)) => {
321 info!("detached by another client");
322 write_stdout(b"[detached]\r\n")?;
323 return Ok(Some(0));
324 }
325 Some(Ok(_)) => {} Some(Err(e)) => {
327 debug!("server connection error: {e}");
328 return Ok(None);
329 }
330 None => {
331 debug!("server disconnected");
332 return Ok(None);
333 }
334 }
335 }
336
337 _ = sigwinch.recv() => {
338 let (cols, rows) = get_terminal_size();
339 debug!(cols, rows, "SIGWINCH → resize");
340 if !timed_send(framed, Frame::Resize { cols, rows }).await {
341 return Ok(None);
342 }
343 }
344
345 _ = heartbeat_interval.tick() => {
346 if last_pong.elapsed() > HEARTBEAT_TIMEOUT {
347 debug!("heartbeat timeout");
348 return Ok(None);
349 }
350 if !timed_send(framed, Frame::Ping).await {
351 return Ok(None);
352 }
353 }
354 }
355 }
356}
357
358pub async fn run(
359 session: &str,
360 mut framed: Framed<UnixStream, FrameCodec>,
361 redraw: bool,
362 ctl_path: &Path,
363 env_vars: Vec<(String, String)>,
364 no_escape: bool,
365) -> anyhow::Result<i32> {
366 let stdin = io::stdin();
367 let stdin_fd = stdin.as_fd();
368 let stdin_borrowed: BorrowedFd<'static> =
370 unsafe { BorrowedFd::borrow_raw(stdin_fd.as_raw_fd()) };
371 let _guard = RawModeGuard::enter(stdin_borrowed)?;
372
373 let _nb_guard = NonBlockGuard::set(stdin_borrowed)?;
376 let async_stdin = AsyncFd::new(io::stdin())?;
377 let mut sigwinch = signal(SignalKind::window_change())?;
378 let mut buf = vec![0u8; 4096];
379 let mut current_redraw = redraw;
380 let mut current_env = env_vars;
381 let mut escape = if no_escape { None } else { Some(EscapeProcessor::new()) };
382
383 loop {
384 match relay(
385 &mut framed,
386 &async_stdin,
387 &mut sigwinch,
388 &mut buf,
389 current_redraw,
390 ¤t_env,
391 escape.as_mut(),
392 )
393 .await?
394 {
395 Some(code) => return Ok(code),
396 None => {
397 current_env.clear();
399 write_stdout(b"[reconnecting...]\r\n")?;
401
402 loop {
403 tokio::time::sleep(Duration::from_secs(1)).await;
404
405 {
407 let mut peek = [0u8; 1];
408 match io::stdin().read(&mut peek) {
409 Ok(1) if peek[0] == 0x03 => {
410 write_stdout(b"\r\n")?;
411 return Ok(1);
412 }
413 _ => {}
414 }
415 }
416
417 let stream = match UnixStream::connect(ctl_path).await {
418 Ok(s) => s,
419 Err(_) => continue,
420 };
421
422 let mut new_framed = Framed::new(stream, FrameCodec);
423 if new_framed
424 .send(Frame::Attach { session: session.to_string() })
425 .await
426 .is_err()
427 {
428 continue;
429 }
430
431 match new_framed.next().await {
432 Some(Ok(Frame::Ok)) => {
433 write_stdout(b"[reconnected]\r\n")?;
434 framed = new_framed;
435 current_redraw = true;
436 break;
437 }
438 Some(Ok(Frame::Error { message })) => {
439 let msg = format!("[session gone: {message}]\r\n");
440 write_stdout(msg.as_bytes())?;
441 return Ok(1);
442 }
443 _ => continue,
444 }
445 }
446 }
447 }
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn normal_passthrough() {
457 let mut ep = EscapeProcessor::new();
458 let actions = ep.process(b"hello");
460 assert_eq!(actions, vec![EscapeAction::Data(b"hello".to_vec())]);
461 }
462
463 #[test]
464 fn tilde_after_newline_detach() {
465 let mut ep = EscapeProcessor { state: EscapeState::Normal };
466 let actions = ep.process(b"\n~.");
467 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
468 }
469
470 #[test]
471 fn tilde_after_cr_detach() {
472 let mut ep = EscapeProcessor { state: EscapeState::Normal };
473 let actions = ep.process(b"\r~.");
474 assert_eq!(actions, vec![EscapeAction::Data(b"\r".to_vec()), EscapeAction::Detach,]);
475 }
476
477 #[test]
478 fn tilde_not_after_newline() {
479 let mut ep = EscapeProcessor { state: EscapeState::Normal };
480 let actions = ep.process(b"a~.");
481 assert_eq!(actions, vec![EscapeAction::Data(b"a~.".to_vec())]);
482 }
483
484 #[test]
485 fn initial_state_detach() {
486 let mut ep = EscapeProcessor::new();
487 let actions = ep.process(b"~.");
488 assert_eq!(actions, vec![EscapeAction::Detach]);
489 }
490
491 #[test]
492 fn tilde_suspend() {
493 let mut ep = EscapeProcessor { state: EscapeState::Normal };
494 let actions = ep.process(b"\n~\x1a");
495 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Suspend,]);
496 }
497
498 #[test]
499 fn tilde_help() {
500 let mut ep = EscapeProcessor { state: EscapeState::Normal };
501 let actions = ep.process(b"\n~?");
502 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Help,]);
503 }
504
505 #[test]
506 fn double_tilde() {
507 let mut ep = EscapeProcessor { state: EscapeState::Normal };
508 let actions = ep.process(b"\n~~");
509 assert_eq!(
510 actions,
511 vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~".to_vec()),]
512 );
513 assert_eq!(ep.state, EscapeState::Normal);
514 }
515
516 #[test]
517 fn tilde_unknown_char() {
518 let mut ep = EscapeProcessor { state: EscapeState::Normal };
519 let actions = ep.process(b"\n~x");
520 assert_eq!(
521 actions,
522 vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~x".to_vec()),]
523 );
524 }
525
526 #[test]
527 fn split_across_reads() {
528 let mut ep = EscapeProcessor { state: EscapeState::Normal };
529 let a1 = ep.process(b"\n");
530 assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
531 let a2 = ep.process(b"~");
532 assert_eq!(a2, vec![]); let a3 = ep.process(b".");
534 assert_eq!(a3, vec![EscapeAction::Detach]);
535 }
536
537 #[test]
538 fn split_tilde_then_normal() {
539 let mut ep = EscapeProcessor { state: EscapeState::Normal };
540 let a1 = ep.process(b"\n");
541 assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
542 let a2 = ep.process(b"~");
543 assert_eq!(a2, vec![]);
544 let a3 = ep.process(b"a");
545 assert_eq!(a3, vec![EscapeAction::Data(b"~a".to_vec())]);
546 }
547
548 #[test]
549 fn multiple_escapes_one_buffer() {
550 let mut ep = EscapeProcessor { state: EscapeState::Normal };
551 let actions = ep.process(b"\n~?\n~.");
552 assert_eq!(
553 actions,
554 vec![
555 EscapeAction::Data(b"\n".to_vec()),
556 EscapeAction::Help,
557 EscapeAction::Data(b"\n".to_vec()),
558 EscapeAction::Detach,
559 ]
560 );
561 }
562
563 #[test]
564 fn consecutive_newlines() {
565 let mut ep = EscapeProcessor { state: EscapeState::Normal };
566 let actions = ep.process(b"\n\n\n~.");
567 assert_eq!(actions, vec![EscapeAction::Data(b"\n\n\n".to_vec()), EscapeAction::Detach,]);
568 }
569
570 #[test]
571 fn detach_stops_processing() {
572 let mut ep = EscapeProcessor { state: EscapeState::Normal };
573 let actions = ep.process(b"\n~.remaining");
574 assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
575 }
576
577 #[test]
578 fn tilde_then_newline() {
579 let mut ep = EscapeProcessor { state: EscapeState::Normal };
580 let actions = ep.process(b"\n~\n");
581 assert_eq!(
582 actions,
583 vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~\n".to_vec()),]
584 );
585 assert_eq!(ep.state, EscapeState::AfterNewline);
586 }
587
588 #[test]
589 fn empty_input() {
590 let mut ep = EscapeProcessor::new();
591 let actions = ep.process(b"");
592 assert_eq!(actions, vec![]);
593 }
594
595 #[test]
596 fn only_tilde_buffered() {
597 let mut ep = EscapeProcessor { state: EscapeState::Normal };
598 let a1 = ep.process(b"\n~");
599 assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
600 assert_eq!(ep.state, EscapeState::AfterTilde);
601 let a2 = ep.process(b".");
602 assert_eq!(a2, vec![EscapeAction::Detach]);
603 }
604}