Skip to main content

gritty/
client.rs

1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::sys::termios::{self, SetArg, Termios};
5use std::io::{self, Read, Write};
6use std::os::fd::{AsFd, AsRawFd, BorrowedFd};
7use std::path::Path;
8use std::time::Duration;
9use tokio::io::unix::AsyncFd;
10use tokio::net::UnixStream;
11use tokio::signal::unix::{SignalKind, signal};
12use tokio::time::Instant;
13use tokio_util::codec::Framed;
14use tracing::{debug, info};
15
16// --- Escape sequence processing (SSH-style ~. detach, ~^Z suspend, ~? help) ---
17
18const ESCAPE_HELP: &[u8] = b"\r\nSupported escape sequences:\r\n\
19    ~.  - detach from session\r\n\
20    ~^Z - suspend client\r\n\
21    ~?  - this message\r\n\
22    ~~  - send the escape character by typing it twice\r\n\
23(Note that escapes are only recognized immediately after newline.)\r\n";
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26enum EscapeState {
27    Normal,
28    AfterNewline,
29    AfterTilde,
30}
31
32#[derive(Debug, PartialEq, Eq)]
33enum EscapeAction {
34    Data(Vec<u8>),
35    Detach,
36    Suspend,
37    Help,
38}
39
40struct EscapeProcessor {
41    state: EscapeState,
42}
43
44impl EscapeProcessor {
45    fn new() -> Self {
46        Self { state: EscapeState::AfterNewline }
47    }
48
49    fn process(&mut self, input: &[u8]) -> Vec<EscapeAction> {
50        let mut actions = Vec::new();
51        let mut data_buf = Vec::new();
52
53        for &b in input {
54            match self.state {
55                EscapeState::Normal => {
56                    if b == b'\n' || b == b'\r' {
57                        self.state = EscapeState::AfterNewline;
58                    }
59                    data_buf.push(b);
60                }
61                EscapeState::AfterNewline => {
62                    if b == b'~' {
63                        self.state = EscapeState::AfterTilde;
64                        // Buffer the tilde — don't send yet
65                        if !data_buf.is_empty() {
66                            actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
67                        }
68                    } else if b == b'\n' || b == b'\r' {
69                        // Stay in AfterNewline
70                        data_buf.push(b);
71                    } else {
72                        self.state = EscapeState::Normal;
73                        data_buf.push(b);
74                    }
75                }
76                EscapeState::AfterTilde => {
77                    match b {
78                        b'.' => {
79                            if !data_buf.is_empty() {
80                                actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
81                            }
82                            actions.push(EscapeAction::Detach);
83                            return actions; // Stop processing
84                        }
85                        0x1a => {
86                            // Ctrl-Z
87                            if !data_buf.is_empty() {
88                                actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
89                            }
90                            actions.push(EscapeAction::Suspend);
91                            self.state = EscapeState::Normal;
92                        }
93                        b'?' => {
94                            if !data_buf.is_empty() {
95                                actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
96                            }
97                            actions.push(EscapeAction::Help);
98                            self.state = EscapeState::Normal;
99                        }
100                        b'~' => {
101                            // Literal tilde
102                            data_buf.push(b'~');
103                            self.state = EscapeState::Normal;
104                        }
105                        b'\n' | b'\r' => {
106                            // Flush buffered tilde + this byte
107                            data_buf.push(b'~');
108                            data_buf.push(b);
109                            self.state = EscapeState::AfterNewline;
110                        }
111                        _ => {
112                            // Unknown — flush tilde + byte
113                            data_buf.push(b'~');
114                            data_buf.push(b);
115                            self.state = EscapeState::Normal;
116                        }
117                    }
118                }
119            }
120        }
121
122        if !data_buf.is_empty() {
123            actions.push(EscapeAction::Data(data_buf));
124        }
125        actions
126    }
127}
128
129fn suspend() -> anyhow::Result<()> {
130    nix::sys::signal::kill(nix::unistd::Pid::from_raw(0), nix::sys::signal::Signal::SIGTSTP)?;
131    Ok(())
132}
133
134const SEND_TIMEOUT: Duration = Duration::from_secs(5);
135
136struct NonBlockGuard {
137    fd: BorrowedFd<'static>,
138    original_flags: nix::fcntl::OFlag,
139}
140
141impl NonBlockGuard {
142    fn set(fd: BorrowedFd<'static>) -> nix::Result<Self> {
143        let flags = nix::fcntl::fcntl(fd, nix::fcntl::FcntlArg::F_GETFL)?;
144        let original_flags = nix::fcntl::OFlag::from_bits_truncate(flags);
145        nix::fcntl::fcntl(
146            fd,
147            nix::fcntl::FcntlArg::F_SETFL(original_flags | nix::fcntl::OFlag::O_NONBLOCK),
148        )?;
149        Ok(Self { fd, original_flags })
150    }
151}
152
153impl Drop for NonBlockGuard {
154    fn drop(&mut self) {
155        let _ = nix::fcntl::fcntl(self.fd, nix::fcntl::FcntlArg::F_SETFL(self.original_flags));
156    }
157}
158
159struct RawModeGuard {
160    fd: BorrowedFd<'static>,
161    original: Termios,
162}
163
164impl RawModeGuard {
165    fn enter(fd: BorrowedFd<'static>) -> nix::Result<Self> {
166        let original = termios::tcgetattr(fd)?;
167        let mut raw = original.clone();
168        termios::cfmakeraw(&mut raw);
169        termios::tcsetattr(fd, SetArg::TCSAFLUSH, &raw)?;
170        Ok(Self { fd, original })
171    }
172}
173
174impl Drop for RawModeGuard {
175    fn drop(&mut self) {
176        let _ = termios::tcsetattr(self.fd, SetArg::TCSAFLUSH, &self.original);
177    }
178}
179
180/// Write all bytes to stdout, retrying on WouldBlock.
181/// Needed because setting O_NONBLOCK on stdin also affects stdout
182/// when they share the same terminal file description.
183fn write_stdout(data: &[u8]) -> io::Result<()> {
184    let mut stdout = io::stdout();
185    let mut written = 0;
186    while written < data.len() {
187        match stdout.write(&data[written..]) {
188            Ok(n) => written += n,
189            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
190                std::thread::yield_now();
191            }
192            Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
193            Err(e) => return Err(e),
194        }
195    }
196    loop {
197        match stdout.flush() {
198            Ok(()) => return Ok(()),
199            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
200                std::thread::yield_now();
201            }
202            Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
203            Err(e) => return Err(e),
204        }
205    }
206}
207
208fn get_terminal_size() -> (u16, u16) {
209    let mut ws: libc::winsize = unsafe { std::mem::zeroed() };
210    unsafe { libc::ioctl(libc::STDIN_FILENO, libc::TIOCGWINSZ, &mut ws) };
211    (ws.ws_col, ws.ws_row)
212}
213
214/// Send a frame with a timeout. Returns false if the send failed or timed out.
215async fn timed_send(framed: &mut Framed<UnixStream, FrameCodec>, frame: Frame) -> bool {
216    match tokio::time::timeout(SEND_TIMEOUT, framed.send(frame)).await {
217        Ok(Ok(())) => true,
218        Ok(Err(e)) => {
219            debug!("send error: {e}");
220            false
221        }
222        Err(_) => {
223            debug!("send timed out");
224            false
225        }
226    }
227}
228
229const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
230const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(15);
231
232/// Relay between stdin/stdout and the framed socket.
233/// Returns `Some(code)` on clean shell exit or detach, `None` on server disconnect / heartbeat timeout.
234async fn relay(
235    framed: &mut Framed<UnixStream, FrameCodec>,
236    async_stdin: &AsyncFd<io::Stdin>,
237    sigwinch: &mut tokio::signal::unix::Signal,
238    buf: &mut [u8],
239    redraw: bool,
240    env_vars: &[(String, String)],
241    mut escape: Option<&mut EscapeProcessor>,
242) -> anyhow::Result<Option<i32>> {
243    // Send env vars before resize (server reads Env frame before spawning shell)
244    if !env_vars.is_empty() && !timed_send(framed, Frame::Env { vars: env_vars.to_vec() }).await {
245        return Ok(None);
246    }
247    // Send initial window size
248    let (cols, rows) = get_terminal_size();
249    if !timed_send(framed, Frame::Resize { cols, rows }).await {
250        return Ok(None);
251    }
252    // Inject Ctrl-L to force the shell/app to redraw
253    if redraw && !timed_send(framed, Frame::Data(Bytes::from_static(b"\x0c"))).await {
254        return Ok(None);
255    }
256
257    let mut heartbeat_interval = tokio::time::interval(HEARTBEAT_INTERVAL);
258    heartbeat_interval.reset(); // first tick is immediate otherwise; delay it
259    let mut last_pong = Instant::now();
260
261    loop {
262        tokio::select! {
263            ready = async_stdin.readable() => {
264                let mut guard = ready?;
265                match guard.try_io(|inner| inner.get_ref().read(buf)) {
266                    Ok(Ok(0)) => {
267                        debug!("stdin EOF");
268                        return Ok(Some(0));
269                    }
270                    Ok(Ok(n)) => {
271                        debug!(len = n, "stdin → socket");
272                        if let Some(ref mut esc) = escape {
273                            for action in esc.process(&buf[..n]) {
274                                match action {
275                                    EscapeAction::Data(data) => {
276                                        if !timed_send(framed, Frame::Data(Bytes::from(data))).await {
277                                            return Ok(None);
278                                        }
279                                    }
280                                    EscapeAction::Detach => {
281                                        write_stdout(b"\r\n[detached]\r\n")?;
282                                        return Ok(Some(0));
283                                    }
284                                    EscapeAction::Suspend => {
285                                        suspend()?;
286                                        // Re-sync terminal size after resume
287                                        let (cols, rows) = get_terminal_size();
288                                        if !timed_send(framed, Frame::Resize { cols, rows }).await {
289                                            return Ok(None);
290                                        }
291                                    }
292                                    EscapeAction::Help => {
293                                        write_stdout(ESCAPE_HELP)?;
294                                    }
295                                }
296                            }
297                        } else if !timed_send(framed, Frame::Data(Bytes::copy_from_slice(&buf[..n]))).await {
298                            return Ok(None);
299                        }
300                    }
301                    Ok(Err(e)) => return Err(e.into()),
302                    Err(_would_block) => continue,
303                }
304            }
305
306            frame = framed.next() => {
307                match frame {
308                    Some(Ok(Frame::Data(data))) => {
309                        debug!(len = data.len(), "socket → stdout");
310                        write_stdout(&data)?;
311                    }
312                    Some(Ok(Frame::Pong)) => {
313                        debug!("pong received");
314                        last_pong = Instant::now();
315                    }
316                    Some(Ok(Frame::Exit { code })) => {
317                        info!(code, "server sent exit");
318                        return Ok(Some(code));
319                    }
320                    Some(Ok(Frame::Detached)) => {
321                        info!("detached by another client");
322                        write_stdout(b"[detached]\r\n")?;
323                        return Ok(Some(0));
324                    }
325                    Some(Ok(_)) => {} // ignore control/resize frames
326                    Some(Err(e)) => {
327                        debug!("server connection error: {e}");
328                        return Ok(None);
329                    }
330                    None => {
331                        debug!("server disconnected");
332                        return Ok(None);
333                    }
334                }
335            }
336
337            _ = sigwinch.recv() => {
338                let (cols, rows) = get_terminal_size();
339                debug!(cols, rows, "SIGWINCH → resize");
340                if !timed_send(framed, Frame::Resize { cols, rows }).await {
341                    return Ok(None);
342                }
343            }
344
345            _ = heartbeat_interval.tick() => {
346                if last_pong.elapsed() > HEARTBEAT_TIMEOUT {
347                    debug!("heartbeat timeout");
348                    return Ok(None);
349                }
350                if !timed_send(framed, Frame::Ping).await {
351                    return Ok(None);
352                }
353            }
354        }
355    }
356}
357
358pub async fn run(
359    session: &str,
360    mut framed: Framed<UnixStream, FrameCodec>,
361    redraw: bool,
362    ctl_path: &Path,
363    env_vars: Vec<(String, String)>,
364    no_escape: bool,
365) -> anyhow::Result<i32> {
366    let stdin = io::stdin();
367    let stdin_fd = stdin.as_fd();
368    // Safety: stdin lives for the duration of the program
369    let stdin_borrowed: BorrowedFd<'static> =
370        unsafe { BorrowedFd::borrow_raw(stdin_fd.as_raw_fd()) };
371    let _guard = RawModeGuard::enter(stdin_borrowed)?;
372
373    // Set stdin to non-blocking for AsyncFd — guard restores on drop.
374    // Declared BEFORE async_stdin so it drops AFTER AsyncFd (reverse drop order).
375    let _nb_guard = NonBlockGuard::set(stdin_borrowed)?;
376    let async_stdin = AsyncFd::new(io::stdin())?;
377    let mut sigwinch = signal(SignalKind::window_change())?;
378    let mut buf = vec![0u8; 4096];
379    let mut current_redraw = redraw;
380    let mut current_env = env_vars;
381    let mut escape = if no_escape { None } else { Some(EscapeProcessor::new()) };
382
383    loop {
384        match relay(
385            &mut framed,
386            &async_stdin,
387            &mut sigwinch,
388            &mut buf,
389            current_redraw,
390            &current_env,
391            escape.as_mut(),
392        )
393        .await?
394        {
395            Some(code) => return Ok(code),
396            None => {
397                // Env vars only sent on first connection; clear for reconnect
398                current_env.clear();
399                // Disconnected — try to reconnect
400                write_stdout(b"[reconnecting...]\r\n")?;
401
402                loop {
403                    tokio::time::sleep(Duration::from_secs(1)).await;
404
405                    // Check for Ctrl-C (0x03) in raw mode
406                    {
407                        let mut peek = [0u8; 1];
408                        match io::stdin().read(&mut peek) {
409                            Ok(1) if peek[0] == 0x03 => {
410                                write_stdout(b"\r\n")?;
411                                return Ok(1);
412                            }
413                            _ => {}
414                        }
415                    }
416
417                    let stream = match UnixStream::connect(ctl_path).await {
418                        Ok(s) => s,
419                        Err(_) => continue,
420                    };
421
422                    let mut new_framed = Framed::new(stream, FrameCodec);
423                    if new_framed
424                        .send(Frame::Attach { session: session.to_string() })
425                        .await
426                        .is_err()
427                    {
428                        continue;
429                    }
430
431                    match new_framed.next().await {
432                        Some(Ok(Frame::Ok)) => {
433                            write_stdout(b"[reconnected]\r\n")?;
434                            framed = new_framed;
435                            current_redraw = true;
436                            break;
437                        }
438                        Some(Ok(Frame::Error { message })) => {
439                            let msg = format!("[session gone: {message}]\r\n");
440                            write_stdout(msg.as_bytes())?;
441                            return Ok(1);
442                        }
443                        _ => continue,
444                    }
445                }
446            }
447        }
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn normal_passthrough() {
457        let mut ep = EscapeProcessor::new();
458        // No newlines — after initial AfterNewline, 'h' transitions to Normal
459        let actions = ep.process(b"hello");
460        assert_eq!(actions, vec![EscapeAction::Data(b"hello".to_vec())]);
461    }
462
463    #[test]
464    fn tilde_after_newline_detach() {
465        let mut ep = EscapeProcessor { state: EscapeState::Normal };
466        let actions = ep.process(b"\n~.");
467        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
468    }
469
470    #[test]
471    fn tilde_after_cr_detach() {
472        let mut ep = EscapeProcessor { state: EscapeState::Normal };
473        let actions = ep.process(b"\r~.");
474        assert_eq!(actions, vec![EscapeAction::Data(b"\r".to_vec()), EscapeAction::Detach,]);
475    }
476
477    #[test]
478    fn tilde_not_after_newline() {
479        let mut ep = EscapeProcessor { state: EscapeState::Normal };
480        let actions = ep.process(b"a~.");
481        assert_eq!(actions, vec![EscapeAction::Data(b"a~.".to_vec())]);
482    }
483
484    #[test]
485    fn initial_state_detach() {
486        let mut ep = EscapeProcessor::new();
487        let actions = ep.process(b"~.");
488        assert_eq!(actions, vec![EscapeAction::Detach]);
489    }
490
491    #[test]
492    fn tilde_suspend() {
493        let mut ep = EscapeProcessor { state: EscapeState::Normal };
494        let actions = ep.process(b"\n~\x1a");
495        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Suspend,]);
496    }
497
498    #[test]
499    fn tilde_help() {
500        let mut ep = EscapeProcessor { state: EscapeState::Normal };
501        let actions = ep.process(b"\n~?");
502        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Help,]);
503    }
504
505    #[test]
506    fn double_tilde() {
507        let mut ep = EscapeProcessor { state: EscapeState::Normal };
508        let actions = ep.process(b"\n~~");
509        assert_eq!(
510            actions,
511            vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~".to_vec()),]
512        );
513        assert_eq!(ep.state, EscapeState::Normal);
514    }
515
516    #[test]
517    fn tilde_unknown_char() {
518        let mut ep = EscapeProcessor { state: EscapeState::Normal };
519        let actions = ep.process(b"\n~x");
520        assert_eq!(
521            actions,
522            vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~x".to_vec()),]
523        );
524    }
525
526    #[test]
527    fn split_across_reads() {
528        let mut ep = EscapeProcessor { state: EscapeState::Normal };
529        let a1 = ep.process(b"\n");
530        assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
531        let a2 = ep.process(b"~");
532        assert_eq!(a2, vec![]); // tilde buffered
533        let a3 = ep.process(b".");
534        assert_eq!(a3, vec![EscapeAction::Detach]);
535    }
536
537    #[test]
538    fn split_tilde_then_normal() {
539        let mut ep = EscapeProcessor { state: EscapeState::Normal };
540        let a1 = ep.process(b"\n");
541        assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
542        let a2 = ep.process(b"~");
543        assert_eq!(a2, vec![]);
544        let a3 = ep.process(b"a");
545        assert_eq!(a3, vec![EscapeAction::Data(b"~a".to_vec())]);
546    }
547
548    #[test]
549    fn multiple_escapes_one_buffer() {
550        let mut ep = EscapeProcessor { state: EscapeState::Normal };
551        let actions = ep.process(b"\n~?\n~.");
552        assert_eq!(
553            actions,
554            vec![
555                EscapeAction::Data(b"\n".to_vec()),
556                EscapeAction::Help,
557                EscapeAction::Data(b"\n".to_vec()),
558                EscapeAction::Detach,
559            ]
560        );
561    }
562
563    #[test]
564    fn consecutive_newlines() {
565        let mut ep = EscapeProcessor { state: EscapeState::Normal };
566        let actions = ep.process(b"\n\n\n~.");
567        assert_eq!(actions, vec![EscapeAction::Data(b"\n\n\n".to_vec()), EscapeAction::Detach,]);
568    }
569
570    #[test]
571    fn detach_stops_processing() {
572        let mut ep = EscapeProcessor { state: EscapeState::Normal };
573        let actions = ep.process(b"\n~.remaining");
574        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
575    }
576
577    #[test]
578    fn tilde_then_newline() {
579        let mut ep = EscapeProcessor { state: EscapeState::Normal };
580        let actions = ep.process(b"\n~\n");
581        assert_eq!(
582            actions,
583            vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~\n".to_vec()),]
584        );
585        assert_eq!(ep.state, EscapeState::AfterNewline);
586    }
587
588    #[test]
589    fn empty_input() {
590        let mut ep = EscapeProcessor::new();
591        let actions = ep.process(b"");
592        assert_eq!(actions, vec![]);
593    }
594
595    #[test]
596    fn only_tilde_buffered() {
597        let mut ep = EscapeProcessor { state: EscapeState::Normal };
598        let a1 = ep.process(b"\n~");
599        assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
600        assert_eq!(ep.state, EscapeState::AfterTilde);
601        let a2 = ep.process(b".");
602        assert_eq!(a2, vec![EscapeAction::Detach]);
603    }
604}