Skip to main content

gritty/
client.rs

1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::sys::termios::{self, SetArg, Termios};
5use std::io::{self, Read, Write};
6use std::os::fd::{AsFd, AsRawFd, BorrowedFd};
7use std::path::Path;
8use std::time::Duration;
9use tokio::io::unix::AsyncFd;
10use tokio::net::UnixStream;
11use tokio::signal::unix::{SignalKind, signal};
12use tokio::time::Instant;
13use tokio_util::codec::Framed;
14use tracing::{debug, info};
15
16// --- Escape sequence processing (SSH-style ~. detach, ~^Z suspend, ~? help) ---
17
18const ESCAPE_HELP: &[u8] = b"\r\nSupported escape sequences:\r\n\
19    ~.  - detach from session\r\n\
20    ~^Z - suspend client\r\n\
21    ~?  - this message\r\n\
22    ~~  - send the escape character by typing it twice\r\n\
23(Note that escapes are only recognized immediately after newline.)\r\n";
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26enum EscapeState {
27    Normal,
28    AfterNewline,
29    AfterTilde,
30}
31
32#[derive(Debug, PartialEq, Eq)]
33enum EscapeAction {
34    Data(Vec<u8>),
35    Detach,
36    Suspend,
37    Help,
38}
39
40struct EscapeProcessor {
41    state: EscapeState,
42}
43
44impl EscapeProcessor {
45    fn new() -> Self {
46        Self { state: EscapeState::AfterNewline }
47    }
48
49    fn process(&mut self, input: &[u8]) -> Vec<EscapeAction> {
50        let mut actions = Vec::new();
51        let mut data_buf = Vec::new();
52
53        for &b in input {
54            match self.state {
55                EscapeState::Normal => {
56                    if b == b'\n' || b == b'\r' {
57                        self.state = EscapeState::AfterNewline;
58                    }
59                    data_buf.push(b);
60                }
61                EscapeState::AfterNewline => {
62                    if b == b'~' {
63                        self.state = EscapeState::AfterTilde;
64                        // Buffer the tilde — don't send yet
65                        if !data_buf.is_empty() {
66                            actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
67                        }
68                    } else if b == b'\n' || b == b'\r' {
69                        // Stay in AfterNewline
70                        data_buf.push(b);
71                    } else {
72                        self.state = EscapeState::Normal;
73                        data_buf.push(b);
74                    }
75                }
76                EscapeState::AfterTilde => {
77                    match b {
78                        b'.' => {
79                            if !data_buf.is_empty() {
80                                actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
81                            }
82                            actions.push(EscapeAction::Detach);
83                            return actions; // Stop processing
84                        }
85                        0x1a => {
86                            // Ctrl-Z
87                            if !data_buf.is_empty() {
88                                actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
89                            }
90                            actions.push(EscapeAction::Suspend);
91                            self.state = EscapeState::Normal;
92                        }
93                        b'?' => {
94                            if !data_buf.is_empty() {
95                                actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
96                            }
97                            actions.push(EscapeAction::Help);
98                            self.state = EscapeState::Normal;
99                        }
100                        b'~' => {
101                            // Literal tilde
102                            data_buf.push(b'~');
103                            self.state = EscapeState::Normal;
104                        }
105                        b'\n' | b'\r' => {
106                            // Flush buffered tilde + this byte
107                            data_buf.push(b'~');
108                            data_buf.push(b);
109                            self.state = EscapeState::AfterNewline;
110                        }
111                        _ => {
112                            // Unknown — flush tilde + byte
113                            data_buf.push(b'~');
114                            data_buf.push(b);
115                            self.state = EscapeState::Normal;
116                        }
117                    }
118                }
119            }
120        }
121
122        if !data_buf.is_empty() {
123            actions.push(EscapeAction::Data(data_buf));
124        }
125        actions
126    }
127}
128
129fn suspend(raw_guard: &RawModeGuard, nb_guard: &NonBlockGuard) -> anyhow::Result<()> {
130    // Restore cooked mode and blocking stdin so the parent shell works normally
131    termios::tcsetattr(raw_guard.fd, SetArg::TCSAFLUSH, &raw_guard.original)?;
132    let _ = nix::fcntl::fcntl(nb_guard.fd, nix::fcntl::FcntlArg::F_SETFL(nb_guard.original_flags));
133
134    nix::sys::signal::kill(nix::unistd::Pid::from_raw(0), nix::sys::signal::Signal::SIGTSTP)?;
135
136    // After resume (fg): re-enter raw mode and non-blocking stdin
137    let _ = nix::fcntl::fcntl(
138        nb_guard.fd,
139        nix::fcntl::FcntlArg::F_SETFL(nb_guard.original_flags | nix::fcntl::OFlag::O_NONBLOCK),
140    );
141    let mut raw = raw_guard.original.clone();
142    termios::cfmakeraw(&mut raw);
143    termios::tcsetattr(raw_guard.fd, SetArg::TCSAFLUSH, &raw)?;
144    Ok(())
145}
146
147const SEND_TIMEOUT: Duration = Duration::from_secs(5);
148
149struct NonBlockGuard {
150    fd: BorrowedFd<'static>,
151    original_flags: nix::fcntl::OFlag,
152}
153
154impl NonBlockGuard {
155    fn set(fd: BorrowedFd<'static>) -> nix::Result<Self> {
156        let flags = nix::fcntl::fcntl(fd, nix::fcntl::FcntlArg::F_GETFL)?;
157        let original_flags = nix::fcntl::OFlag::from_bits_truncate(flags);
158        nix::fcntl::fcntl(
159            fd,
160            nix::fcntl::FcntlArg::F_SETFL(original_flags | nix::fcntl::OFlag::O_NONBLOCK),
161        )?;
162        Ok(Self { fd, original_flags })
163    }
164}
165
166impl Drop for NonBlockGuard {
167    fn drop(&mut self) {
168        let _ = nix::fcntl::fcntl(self.fd, nix::fcntl::FcntlArg::F_SETFL(self.original_flags));
169    }
170}
171
172struct RawModeGuard {
173    fd: BorrowedFd<'static>,
174    original: Termios,
175}
176
177impl RawModeGuard {
178    fn enter(fd: BorrowedFd<'static>) -> nix::Result<Self> {
179        let original = termios::tcgetattr(fd)?;
180        let mut raw = original.clone();
181        termios::cfmakeraw(&mut raw);
182        termios::tcsetattr(fd, SetArg::TCSAFLUSH, &raw)?;
183        Ok(Self { fd, original })
184    }
185}
186
187impl Drop for RawModeGuard {
188    fn drop(&mut self) {
189        let _ = termios::tcsetattr(self.fd, SetArg::TCSAFLUSH, &self.original);
190    }
191}
192
193/// Write all bytes to stdout, retrying on WouldBlock.
194/// Needed because setting O_NONBLOCK on stdin also affects stdout
195/// when they share the same terminal file description.
196fn write_stdout(data: &[u8]) -> io::Result<()> {
197    let mut stdout = io::stdout();
198    let mut written = 0;
199    while written < data.len() {
200        match stdout.write(&data[written..]) {
201            Ok(n) => written += n,
202            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
203                std::thread::yield_now();
204            }
205            Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
206            Err(e) => return Err(e),
207        }
208    }
209    loop {
210        match stdout.flush() {
211            Ok(()) => return Ok(()),
212            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
213                std::thread::yield_now();
214            }
215            Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
216            Err(e) => return Err(e),
217        }
218    }
219}
220
221fn get_terminal_size() -> (u16, u16) {
222    let mut ws: libc::winsize = unsafe { std::mem::zeroed() };
223    unsafe { libc::ioctl(libc::STDIN_FILENO, libc::TIOCGWINSZ, &mut ws) };
224    (ws.ws_col, ws.ws_row)
225}
226
227/// Send a frame with a timeout. Returns false if the send failed or timed out.
228async fn timed_send(framed: &mut Framed<UnixStream, FrameCodec>, frame: Frame) -> bool {
229    match tokio::time::timeout(SEND_TIMEOUT, framed.send(frame)).await {
230        Ok(Ok(())) => true,
231        Ok(Err(e)) => {
232            debug!("send error: {e}");
233            false
234        }
235        Err(_) => {
236            debug!("send timed out");
237            false
238        }
239    }
240}
241
242const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
243const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(15);
244
245/// Relay between stdin/stdout and the framed socket.
246/// Returns `Some(code)` on clean shell exit or detach, `None` on server disconnect / heartbeat timeout.
247async fn relay(
248    framed: &mut Framed<UnixStream, FrameCodec>,
249    async_stdin: &AsyncFd<io::Stdin>,
250    sigwinch: &mut tokio::signal::unix::Signal,
251    buf: &mut [u8],
252    redraw: bool,
253    env_vars: &[(String, String)],
254    mut escape: Option<&mut EscapeProcessor>,
255    raw_guard: &RawModeGuard,
256    nb_guard: &NonBlockGuard,
257) -> anyhow::Result<Option<i32>> {
258    // Send env vars before resize (server reads Env frame before spawning shell)
259    if !env_vars.is_empty() && !timed_send(framed, Frame::Env { vars: env_vars.to_vec() }).await {
260        return Ok(None);
261    }
262    // Send initial window size
263    let (cols, rows) = get_terminal_size();
264    if !timed_send(framed, Frame::Resize { cols, rows }).await {
265        return Ok(None);
266    }
267    // Inject Ctrl-L to force the shell/app to redraw
268    if redraw && !timed_send(framed, Frame::Data(Bytes::from_static(b"\x0c"))).await {
269        return Ok(None);
270    }
271
272    let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
273    heartbeat_interval.reset(); // first tick is immediate otherwise; delay it
274    let mut last_pong = Instant::now();
275
276    loop {
277        tokio::select! {
278            ready = async_stdin.readable() => {
279                let mut guard = ready?;
280                match guard.try_io(|inner| inner.get_ref().read(buf)) {
281                    Ok(Ok(0)) => {
282                        debug!("stdin EOF");
283                        return Ok(Some(0));
284                    }
285                    Ok(Ok(n)) => {
286                        debug!(len = n, "stdin → socket");
287                        if let Some(ref mut esc) = escape {
288                            for action in esc.process(&buf[..n]) {
289                                match action {
290                                    EscapeAction::Data(data) => {
291                                        if !timed_send(framed, Frame::Data(Bytes::from(data))).await {
292                                            return Ok(None);
293                                        }
294                                    }
295                                    EscapeAction::Detach => {
296                                        write_stdout(b"\r\n[detached]\r\n")?;
297                                        return Ok(Some(0));
298                                    }
299                                    EscapeAction::Suspend => {
300                                        suspend(raw_guard, nb_guard)?;
301                                        // Re-sync terminal size after resume
302                                        let (cols, rows) = get_terminal_size();
303                                        if !timed_send(framed, Frame::Resize { cols, rows }).await {
304                                            return Ok(None);
305                                        }
306                                    }
307                                    EscapeAction::Help => {
308                                        write_stdout(ESCAPE_HELP)?;
309                                    }
310                                }
311                            }
312                        } else if !timed_send(framed, Frame::Data(Bytes::copy_from_slice(&buf[..n]))).await {
313                            return Ok(None);
314                        }
315                    }
316                    Ok(Err(e)) => return Err(e.into()),
317                    Err(_would_block) => continue,
318                }
319            }
320
321            frame = framed.next() => {
322                match frame {
323                    Some(Ok(Frame::Data(data))) => {
324                        debug!(len = data.len(), "socket → stdout");
325                        write_stdout(&data)?;
326                    }
327                    Some(Ok(Frame::Pong)) => {
328                        debug!("pong received");
329                        last_pong = Instant::now();
330                    }
331                    Some(Ok(Frame::Exit { code })) => {
332                        info!(code, "server sent exit");
333                        return Ok(Some(code));
334                    }
335                    Some(Ok(Frame::Detached)) => {
336                        info!("detached by another client");
337                        write_stdout(b"[detached]\r\n")?;
338                        return Ok(Some(0));
339                    }
340                    Some(Ok(_)) => {} // ignore control/resize frames
341                    Some(Err(e)) => {
342                        debug!("server connection error: {e}");
343                        return Ok(None);
344                    }
345                    None => {
346                        debug!("server disconnected");
347                        return Ok(None);
348                    }
349                }
350            }
351
352            _ = sigwinch.recv() => {
353                let (cols, rows) = get_terminal_size();
354                debug!(cols, rows, "SIGWINCH → resize");
355                if !timed_send(framed, Frame::Resize { cols, rows }).await {
356                    return Ok(None);
357                }
358            }
359
360            _ = heartbeat_interval.tick() => {
361                if last_pong.elapsed() > HEARTBEAT_TIMEOUT {
362                    debug!("heartbeat timeout");
363                    return Ok(None);
364                }
365                if !timed_send(framed, Frame::Ping).await {
366                    return Ok(None);
367                }
368            }
369        }
370    }
371}
372
373pub async fn run(
374    session: &str,
375    mut framed: Framed<UnixStream, FrameCodec>,
376    redraw: bool,
377    ctl_path: &Path,
378    env_vars: Vec<(String, String)>,
379    no_escape: bool,
380) -> anyhow::Result<i32> {
381    let stdin = io::stdin();
382    let stdin_fd = stdin.as_fd();
383    // Safety: stdin lives for the duration of the program
384    let stdin_borrowed: BorrowedFd<'static> =
385        unsafe { BorrowedFd::borrow_raw(stdin_fd.as_raw_fd()) };
386    let raw_guard = RawModeGuard::enter(stdin_borrowed)?;
387
388    // Set stdin to non-blocking for AsyncFd — guard restores on drop.
389    // Declared BEFORE async_stdin so it drops AFTER AsyncFd (reverse drop order).
390    let nb_guard = NonBlockGuard::set(stdin_borrowed)?;
391    let async_stdin = AsyncFd::new(io::stdin())?;
392    let mut sigwinch = signal(SignalKind::window_change())?;
393    let mut buf = vec![0u8; 4096];
394    let mut current_redraw = redraw;
395    let mut current_env = env_vars;
396    let mut escape = if no_escape { None } else { Some(EscapeProcessor::new()) };
397
398    loop {
399        match relay(
400            &mut framed,
401            &async_stdin,
402            &mut sigwinch,
403            &mut buf,
404            current_redraw,
405            &current_env,
406            escape.as_mut(),
407            &raw_guard,
408            &nb_guard,
409        )
410        .await?
411        {
412            Some(code) => return Ok(code),
413            None => {
414                // Env vars only sent on first connection; clear for reconnect
415                current_env.clear();
416                // Disconnected — try to reconnect
417                write_stdout(b"[reconnecting...]\r\n")?;
418
419                loop {
420                    tokio::time::sleep(Duration::from_secs(1)).await;
421
422                    // Check for Ctrl-C (0x03) in raw mode
423                    {
424                        let mut peek = [0u8; 1];
425                        match io::stdin().read(&mut peek) {
426                            Ok(1) if peek[0] == 0x03 => {
427                                write_stdout(b"\r\n")?;
428                                return Ok(1);
429                            }
430                            _ => {}
431                        }
432                    }
433
434                    let stream = match UnixStream::connect(ctl_path).await {
435                        Ok(s) => s,
436                        Err(_) => continue,
437                    };
438
439                    let mut new_framed = Framed::new(stream, FrameCodec);
440                    if new_framed
441                        .send(Frame::Attach { session: session.to_string() })
442                        .await
443                        .is_err()
444                    {
445                        continue;
446                    }
447
448                    match new_framed.next().await {
449                        Some(Ok(Frame::Ok)) => {
450                            write_stdout(b"[reconnected]\r\n")?;
451                            framed = new_framed;
452                            current_redraw = true;
453                            break;
454                        }
455                        Some(Ok(Frame::Error { message })) => {
456                            let msg = format!("[session gone: {message}]\r\n");
457                            write_stdout(msg.as_bytes())?;
458                            return Ok(1);
459                        }
460                        _ => continue,
461                    }
462                }
463            }
464        }
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471
472    #[test]
473    fn normal_passthrough() {
474        let mut ep = EscapeProcessor::new();
475        // No newlines — after initial AfterNewline, 'h' transitions to Normal
476        let actions = ep.process(b"hello");
477        assert_eq!(actions, vec![EscapeAction::Data(b"hello".to_vec())]);
478    }
479
480    #[test]
481    fn tilde_after_newline_detach() {
482        let mut ep = EscapeProcessor { state: EscapeState::Normal };
483        let actions = ep.process(b"\n~.");
484        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
485    }
486
487    #[test]
488    fn tilde_after_cr_detach() {
489        let mut ep = EscapeProcessor { state: EscapeState::Normal };
490        let actions = ep.process(b"\r~.");
491        assert_eq!(actions, vec![EscapeAction::Data(b"\r".to_vec()), EscapeAction::Detach,]);
492    }
493
494    #[test]
495    fn tilde_not_after_newline() {
496        let mut ep = EscapeProcessor { state: EscapeState::Normal };
497        let actions = ep.process(b"a~.");
498        assert_eq!(actions, vec![EscapeAction::Data(b"a~.".to_vec())]);
499    }
500
501    #[test]
502    fn initial_state_detach() {
503        let mut ep = EscapeProcessor::new();
504        let actions = ep.process(b"~.");
505        assert_eq!(actions, vec![EscapeAction::Detach]);
506    }
507
508    #[test]
509    fn tilde_suspend() {
510        let mut ep = EscapeProcessor { state: EscapeState::Normal };
511        let actions = ep.process(b"\n~\x1a");
512        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Suspend,]);
513    }
514
515    #[test]
516    fn tilde_help() {
517        let mut ep = EscapeProcessor { state: EscapeState::Normal };
518        let actions = ep.process(b"\n~?");
519        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Help,]);
520    }
521
522    #[test]
523    fn double_tilde() {
524        let mut ep = EscapeProcessor { state: EscapeState::Normal };
525        let actions = ep.process(b"\n~~");
526        assert_eq!(
527            actions,
528            vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~".to_vec()),]
529        );
530        assert_eq!(ep.state, EscapeState::Normal);
531    }
532
533    #[test]
534    fn tilde_unknown_char() {
535        let mut ep = EscapeProcessor { state: EscapeState::Normal };
536        let actions = ep.process(b"\n~x");
537        assert_eq!(
538            actions,
539            vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~x".to_vec()),]
540        );
541    }
542
543    #[test]
544    fn split_across_reads() {
545        let mut ep = EscapeProcessor { state: EscapeState::Normal };
546        let a1 = ep.process(b"\n");
547        assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
548        let a2 = ep.process(b"~");
549        assert_eq!(a2, vec![]); // tilde buffered
550        let a3 = ep.process(b".");
551        assert_eq!(a3, vec![EscapeAction::Detach]);
552    }
553
554    #[test]
555    fn split_tilde_then_normal() {
556        let mut ep = EscapeProcessor { state: EscapeState::Normal };
557        let a1 = ep.process(b"\n");
558        assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
559        let a2 = ep.process(b"~");
560        assert_eq!(a2, vec![]);
561        let a3 = ep.process(b"a");
562        assert_eq!(a3, vec![EscapeAction::Data(b"~a".to_vec())]);
563    }
564
565    #[test]
566    fn multiple_escapes_one_buffer() {
567        let mut ep = EscapeProcessor { state: EscapeState::Normal };
568        let actions = ep.process(b"\n~?\n~.");
569        assert_eq!(
570            actions,
571            vec![
572                EscapeAction::Data(b"\n".to_vec()),
573                EscapeAction::Help,
574                EscapeAction::Data(b"\n".to_vec()),
575                EscapeAction::Detach,
576            ]
577        );
578    }
579
580    #[test]
581    fn consecutive_newlines() {
582        let mut ep = EscapeProcessor { state: EscapeState::Normal };
583        let actions = ep.process(b"\n\n\n~.");
584        assert_eq!(actions, vec![EscapeAction::Data(b"\n\n\n".to_vec()), EscapeAction::Detach,]);
585    }
586
587    #[test]
588    fn detach_stops_processing() {
589        let mut ep = EscapeProcessor { state: EscapeState::Normal };
590        let actions = ep.process(b"\n~.remaining");
591        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
592    }
593
594    #[test]
595    fn tilde_then_newline() {
596        let mut ep = EscapeProcessor { state: EscapeState::Normal };
597        let actions = ep.process(b"\n~\n");
598        assert_eq!(
599            actions,
600            vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~\n".to_vec()),]
601        );
602        assert_eq!(ep.state, EscapeState::AfterNewline);
603    }
604
605    #[test]
606    fn empty_input() {
607        let mut ep = EscapeProcessor::new();
608        let actions = ep.process(b"");
609        assert_eq!(actions, vec![]);
610    }
611
612    #[test]
613    fn only_tilde_buffered() {
614        let mut ep = EscapeProcessor { state: EscapeState::Normal };
615        let a1 = ep.process(b"\n~");
616        assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
617        assert_eq!(ep.state, EscapeState::AfterTilde);
618        let a2 = ep.process(b".");
619        assert_eq!(a2, vec![EscapeAction::Detach]);
620    }
621}