Skip to main content

gritty/
client.rs

1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::sys::termios::{self, FlushArg, LocalFlags, SetArg, SpecialCharacterIndices, Termios};
5use std::collections::HashMap;
6use std::io::{self, Read, Write};
7use std::ops::ControlFlow;
8use std::os::fd::{AsFd, AsRawFd, BorrowedFd};
9use std::path::Path;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicU32, Ordering};
12use std::time::Duration;
13use tokio::io::unix::AsyncFd;
14use tokio::net::UnixStream;
15use tokio::signal::unix::{SignalKind, signal};
16use tokio::sync::mpsc;
17use tokio::time::Instant;
18
19/// Outcome from a client relay loop iteration.
20enum RelayExit {
21    /// Shell or server reported an exit code (or detach/signal).
22    Exit(i32),
23    /// Server disconnected -- caller should reconnect.
24    Disconnected,
25}
26use tokio_util::codec::Framed;
27use tracing::{debug, info};
28
29// --- Escape sequence processing (SSH-style ~. detach, ~^Z suspend, ~? help) ---
30
31const ESCAPE_HELP: &[u8] = b"\r\nSupported escape sequences:\r\n\
32    ~.  - detach from session\r\n\
33    ~R  - force reconnect\r\n\
34    ~^Z - suspend client\r\n\
35    ~#  - session status and RTT\r\n\
36    ~?  - this message\r\n\
37    ~~  - send the escape character by typing it twice\r\n\
38(Note that escapes are only recognized immediately after newline.)\r\n";
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41enum EscapeState {
42    Normal,
43    AfterNewline,
44    AfterTilde,
45}
46
47#[derive(Debug, PartialEq, Eq)]
48enum EscapeAction {
49    Data(Vec<u8>),
50    Detach,
51    Reconnect,
52    Suspend,
53    Status,
54    Help,
55}
56
57struct EscapeProcessor {
58    state: EscapeState,
59}
60
61impl EscapeProcessor {
62    fn new() -> Self {
63        Self { state: EscapeState::AfterNewline }
64    }
65
66    fn process(&mut self, input: &[u8]) -> Vec<EscapeAction> {
67        let mut actions = Vec::new();
68        let mut data_buf = Vec::new();
69
70        for &b in input {
71            match self.state {
72                EscapeState::Normal => {
73                    if b == b'\n' || b == b'\r' {
74                        self.state = EscapeState::AfterNewline;
75                    }
76                    data_buf.push(b);
77                }
78                EscapeState::AfterNewline => {
79                    if b == b'~' {
80                        self.state = EscapeState::AfterTilde;
81                        // Buffer the tilde — don't send yet
82                        if !data_buf.is_empty() {
83                            actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
84                        }
85                    } else if b == b'\n' || b == b'\r' {
86                        // Stay in AfterNewline
87                        data_buf.push(b);
88                    } else {
89                        self.state = EscapeState::Normal;
90                        data_buf.push(b);
91                    }
92                }
93                EscapeState::AfterTilde => {
94                    match b {
95                        b'.' => {
96                            if !data_buf.is_empty() {
97                                actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
98                            }
99                            actions.push(EscapeAction::Detach);
100                            return actions; // Stop processing
101                        }
102                        b'R' => {
103                            if !data_buf.is_empty() {
104                                actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
105                            }
106                            actions.push(EscapeAction::Reconnect);
107                            return actions; // Stop processing
108                        }
109                        0x1a => {
110                            // Ctrl-Z
111                            if !data_buf.is_empty() {
112                                actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
113                            }
114                            actions.push(EscapeAction::Suspend);
115                            self.state = EscapeState::Normal;
116                        }
117                        b'#' => {
118                            if !data_buf.is_empty() {
119                                actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
120                            }
121                            actions.push(EscapeAction::Status);
122                            self.state = EscapeState::Normal;
123                        }
124                        b'?' => {
125                            if !data_buf.is_empty() {
126                                actions.push(EscapeAction::Data(std::mem::take(&mut data_buf)));
127                            }
128                            actions.push(EscapeAction::Help);
129                            self.state = EscapeState::Normal;
130                        }
131                        b'~' => {
132                            // Literal tilde
133                            data_buf.push(b'~');
134                            self.state = EscapeState::Normal;
135                        }
136                        b'\n' | b'\r' => {
137                            // Flush buffered tilde + this byte
138                            data_buf.push(b'~');
139                            data_buf.push(b);
140                            self.state = EscapeState::AfterNewline;
141                        }
142                        _ => {
143                            // Unknown — flush tilde + byte
144                            data_buf.push(b'~');
145                            data_buf.push(b);
146                            self.state = EscapeState::Normal;
147                        }
148                    }
149                }
150            }
151        }
152
153        if !data_buf.is_empty() {
154            actions.push(EscapeAction::Data(data_buf));
155        }
156        actions
157    }
158}
159
160fn suspend(raw_guard: &RawModeGuard, nb_guard: &NonBlockGuard) -> anyhow::Result<()> {
161    // Restore cooked mode and blocking stdin so the parent shell works normally
162    termios::tcsetattr(raw_guard.fd, SetArg::TCSAFLUSH, &raw_guard.original)?;
163    let _ = nix::fcntl::fcntl(nb_guard.fd, nix::fcntl::FcntlArg::F_SETFL(nb_guard.original_flags));
164
165    nix::sys::signal::kill(nix::unistd::Pid::from_raw(0), nix::sys::signal::Signal::SIGTSTP)?;
166
167    // After resume (fg): re-enter raw mode and non-blocking stdin
168    let _ = nix::fcntl::fcntl(
169        nb_guard.fd,
170        nix::fcntl::FcntlArg::F_SETFL(nb_guard.original_flags | nix::fcntl::OFlag::O_NONBLOCK),
171    );
172    let mut raw = raw_guard.original.clone();
173    termios::cfmakeraw(&mut raw);
174    termios::tcsetattr(raw_guard.fd, SetArg::TCSAFLUSH, &raw)?;
175    Ok(())
176}
177
178const SEND_TIMEOUT: Duration = Duration::from_secs(5);
179
180struct NonBlockGuard {
181    fd: BorrowedFd<'static>,
182    original_flags: nix::fcntl::OFlag,
183}
184
185impl NonBlockGuard {
186    fn set(fd: BorrowedFd<'static>) -> nix::Result<Self> {
187        let flags = nix::fcntl::fcntl(fd, nix::fcntl::FcntlArg::F_GETFL)?;
188        let original_flags = nix::fcntl::OFlag::from_bits_truncate(flags);
189        nix::fcntl::fcntl(
190            fd,
191            nix::fcntl::FcntlArg::F_SETFL(original_flags | nix::fcntl::OFlag::O_NONBLOCK),
192        )?;
193        Ok(Self { fd, original_flags })
194    }
195}
196
197impl Drop for NonBlockGuard {
198    fn drop(&mut self) {
199        let _ = nix::fcntl::fcntl(self.fd, nix::fcntl::FcntlArg::F_SETFL(self.original_flags));
200    }
201}
202
203struct RawModeGuard {
204    fd: BorrowedFd<'static>,
205    original: Termios,
206}
207
208impl RawModeGuard {
209    fn enter(fd: BorrowedFd<'static>) -> nix::Result<Self> {
210        let original = termios::tcgetattr(fd)?;
211        let mut raw = original.clone();
212        termios::cfmakeraw(&mut raw);
213        termios::tcsetattr(fd, SetArg::TCSAFLUSH, &raw)?;
214        Ok(Self { fd, original })
215    }
216}
217
218impl Drop for RawModeGuard {
219    fn drop(&mut self) {
220        let _ = termios::tcsetattr(self.fd, SetArg::TCSAFLUSH, &self.original);
221    }
222}
223
224/// Suppresses stdin echo for tail mode: disables ECHO and ICANON but keeps
225/// ISIG so Ctrl-C still generates SIGINT. Flushes pending input on drop.
226struct SuppressInputGuard {
227    fd: BorrowedFd<'static>,
228    original: Termios,
229}
230
231impl SuppressInputGuard {
232    fn enter(fd: BorrowedFd<'static>) -> nix::Result<Self> {
233        let original = termios::tcgetattr(fd)?;
234        let mut modified = original.clone();
235        modified.local_flags.remove(LocalFlags::ECHO | LocalFlags::ICANON);
236        modified.control_chars[SpecialCharacterIndices::VMIN as usize] = 1;
237        modified.control_chars[SpecialCharacterIndices::VTIME as usize] = 0;
238        termios::tcsetattr(fd, SetArg::TCSAFLUSH, &modified)?;
239        Ok(Self { fd, original })
240    }
241}
242
243impl Drop for SuppressInputGuard {
244    fn drop(&mut self) {
245        let _ = termios::tcflush(self.fd, FlushArg::TCIFLUSH);
246        let _ = termios::tcsetattr(self.fd, SetArg::TCSAFLUSH, &self.original);
247    }
248}
249
250/// Write all bytes to stdout asynchronously via AsyncFd.
251/// Used in relay mode where stdout is non-blocking (shares fd with stdin).
252async fn write_stdout_async(fd: &AsyncFd<std::os::fd::OwnedFd>, data: &[u8]) -> io::Result<()> {
253    let mut written = 0;
254    while written < data.len() {
255        let mut guard = fd.writable().await?;
256        match guard
257            .try_io(|inner| nix::unistd::write(inner, &data[written..]).map_err(io::Error::from))
258        {
259            Ok(Ok(0)) => {
260                return Err(io::Error::new(io::ErrorKind::WriteZero, "stdout closed"));
261            }
262            Ok(Ok(n)) => written += n,
263            Ok(Err(e)) => return Err(e),
264            Err(_would_block) => continue,
265        }
266    }
267    Ok(())
268}
269
270/// Format a byte count as a human-readable size string.
271pub fn format_size(bytes: u64) -> String {
272    humansize::format_size(bytes, humansize::BINARY)
273}
274
275fn status_msg(text: &str) -> String {
276    format!("\r\n\x1b[2;33m[{text}]\x1b[0m\r\n")
277}
278
279fn success_msg(text: &str) -> String {
280    format!("\r\n\x1b[32m[{text}]\x1b[0m\r\n")
281}
282
283fn error_msg(text: &str) -> String {
284    format!("\r\n\x1b[31m[{text}]\x1b[0m\r\n")
285}
286
287fn get_terminal_size() -> (u16, u16) {
288    terminal_size::terminal_size().map(|(w, h)| (w.0, h.0)).unwrap_or((80, 24))
289}
290
291/// Send a frame with a timeout. Returns false if the send failed or timed out.
292async fn timed_send(framed: &mut Framed<UnixStream, FrameCodec>, frame: Frame) -> bool {
293    match tokio::time::timeout(SEND_TIMEOUT, framed.send(frame)).await {
294        Ok(Ok(())) => true,
295        Ok(Err(e)) => {
296            debug!("send error: {e}");
297            false
298        }
299        Err(_) => {
300            debug!("send timed out");
301            false
302        }
303    }
304}
305
306const DEFAULT_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
307const DEFAULT_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(15);
308
309/// Events from local agent connection tasks to the relay loop.
310enum AgentEvent {
311    Data { channel_id: u32, data: Bytes },
312    Closed { channel_id: u32 },
313}
314
315/// Events from the tunnel TCP listener/connection to the relay loop.
316enum ClientTunnelEvent {
317    Accepted { channel_id: u32, writer_tx: mpsc::Sender<Bytes> },
318    Data { channel_id: u32, data: Bytes },
319    Closed { channel_id: u32 },
320}
321
322/// Events from port forward TCP acceptors/connections on the client side.
323enum ClientPortForwardEvent {
324    Accepted { forward_id: u32, channel_id: u32, writer_tx: mpsc::Sender<Bytes> },
325    Data { channel_id: u32, data: Bytes },
326    Closed { channel_id: u32 },
327}
328
329/// Per-forward state on the client side.
330struct ClientPortForwardState {
331    listener_handle: Option<tokio::task::JoinHandle<()>>,
332    target_port: u16,
333}
334
335/// Grouped state for agent channel management on the client side.
336struct ClientAgentState {
337    channels: HashMap<u32, mpsc::Sender<Bytes>>,
338}
339
340impl ClientAgentState {
341    fn new() -> Self {
342        Self { channels: HashMap::new() }
343    }
344
345    fn teardown(&mut self) {
346        self.channels.clear();
347    }
348}
349
350/// Grouped state for the OAuth callback tunnel on the client side (multi-channel).
351struct ClientTunnelState {
352    listener: Option<tokio::task::JoinHandle<()>>,
353    channels: HashMap<u32, mpsc::Sender<Bytes>>,
354    next_channel_id: Arc<AtomicU32>,
355}
356
357impl ClientTunnelState {
358    fn new() -> Self {
359        Self {
360            listener: None,
361            channels: HashMap::new(),
362            next_channel_id: Arc::new(AtomicU32::new(0)),
363        }
364    }
365
366    fn teardown(&mut self) {
367        self.channels.clear();
368        if let Some(handle) = self.listener.take() {
369            handle.abort();
370        }
371    }
372}
373
374/// Grouped state for TCP port forwarding on the client side.
375struct ClientPortForwardTable {
376    forwards: HashMap<u32, ClientPortForwardState>,
377    channels: HashMap<u32, (u32, mpsc::Sender<Bytes>)>,
378    next_channel_id: std::sync::Arc<std::sync::atomic::AtomicU32>,
379}
380
381impl ClientPortForwardTable {
382    fn new() -> Self {
383        Self {
384            forwards: HashMap::new(),
385            channels: HashMap::new(),
386            next_channel_id: std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)),
387        }
388    }
389
390    fn teardown(&mut self) {
391        for (_, fwd) in self.forwards.drain() {
392            if let Some(h) = fwd.listener_handle {
393                h.abort();
394            }
395        }
396        self.channels.clear();
397    }
398}
399
400/// Send session setup frames (env, agent/open forwarding, resize, redraw).
401/// Returns false if the connection dropped during setup.
402async fn send_init_frames(
403    framed: &mut Framed<UnixStream, FrameCodec>,
404    env_vars: &[(String, String)],
405    forward_agent: bool,
406    agent_socket: Option<&str>,
407    forward_open: bool,
408    redraw: bool,
409) -> bool {
410    if !env_vars.is_empty() && !timed_send(framed, Frame::Env { vars: env_vars.to_vec() }).await {
411        return false;
412    }
413    if forward_agent && agent_socket.is_some() && !timed_send(framed, Frame::AgentForward).await {
414        return false;
415    }
416    if forward_open && !timed_send(framed, Frame::OpenForward).await {
417        return false;
418    }
419    let (cols, rows) = get_terminal_size();
420    if !timed_send(framed, Frame::Resize { cols, rows }).await {
421        return false;
422    }
423    if redraw && !timed_send(framed, Frame::Data(Bytes::from_static(b"\x0c"))).await {
424        return false;
425    }
426    true
427}
428
429/// `framed` is kept outside (passed to handlers) so `tokio::select!` can
430/// poll `framed.next()` independently without conflicting borrows.
431struct ClientRelay<'a> {
432    async_stdout: &'a AsyncFd<std::os::fd::OwnedFd>,
433    agent: &'a mut ClientAgentState,
434    agent_event_tx: &'a mpsc::UnboundedSender<AgentEvent>,
435    agent_socket: Option<&'a str>,
436    tunnel: &'a mut ClientTunnelState,
437    tunnel_event_tx: &'a mpsc::UnboundedSender<ClientTunnelEvent>,
438    oauth_redirect: bool,
439    oauth_timeout: u64,
440    pf: &'a mut ClientPortForwardTable,
441    pf_event_tx: &'a mpsc::UnboundedSender<ClientPortForwardEvent>,
442    last_pong: &'a mut Instant,
443    last_ping_sent: &'a mut Instant,
444    last_rtt: &'a mut Option<Duration>,
445    connected_at: Instant,
446    bytes_relayed: &'a mut u64,
447}
448
449impl ClientRelay<'_> {
450    async fn handle_server_frame(
451        &mut self,
452        framed: &mut Framed<UnixStream, FrameCodec>,
453        frame: Option<Result<Frame, io::Error>>,
454    ) -> Result<ControlFlow<RelayExit>, anyhow::Error> {
455        match frame {
456            Some(Ok(Frame::Data(data))) => {
457                debug!(len = data.len(), "socket → stdout");
458                *self.bytes_relayed += data.len() as u64;
459                write_stdout_async(self.async_stdout, &data).await?;
460            }
461            Some(Ok(Frame::Pong)) => {
462                *self.last_rtt = Some(self.last_ping_sent.elapsed());
463                debug!(rtt_ms = self.last_rtt.unwrap().as_secs_f64() * 1000.0, "pong received");
464                *self.last_pong = Instant::now();
465            }
466            Some(Ok(Frame::Exit { code })) => {
467                debug!(code, "server sent exit");
468                return Ok(ControlFlow::Break(RelayExit::Exit(code)));
469            }
470            Some(Ok(Frame::Detached)) => {
471                info!("detached by another client");
472                self.agent.teardown();
473                self.tunnel.teardown();
474                self.pf.teardown();
475                write_stdout_async(self.async_stdout, status_msg("detached").as_bytes()).await?;
476                return Ok(ControlFlow::Break(RelayExit::Exit(0)));
477            }
478            Some(Ok(Frame::AgentOpen { channel_id })) => {
479                if let Some(sock_path) = self.agent_socket {
480                    match tokio::net::UnixStream::connect(sock_path).await {
481                        Ok(stream) => {
482                            let (read_half, write_half) = stream.into_split();
483                            let data_tx = self.agent_event_tx.clone();
484                            let close_tx = self.agent_event_tx.clone();
485                            let writer_tx = crate::spawn_channel_relay(
486                                channel_id,
487                                read_half,
488                                write_half,
489                                move |id, data| {
490                                    data_tx.send(AgentEvent::Data { channel_id: id, data }).is_ok()
491                                },
492                                move |id| {
493                                    let _ = close_tx.send(AgentEvent::Closed { channel_id: id });
494                                },
495                            );
496                            self.agent.channels.insert(channel_id, writer_tx);
497                        }
498                        Err(e) => {
499                            debug!("failed to connect to local agent: {e}");
500                            let _ = timed_send(framed, Frame::AgentClose { channel_id }).await;
501                        }
502                    }
503                } else {
504                    let _ = timed_send(framed, Frame::AgentClose { channel_id }).await;
505                }
506            }
507            Some(Ok(Frame::AgentData { channel_id, data })) => {
508                if let Some(tx) = self.agent.channels.get(&channel_id) {
509                    let _ = tx.send(data).await;
510                }
511            }
512            Some(Ok(Frame::AgentClose { channel_id })) => {
513                self.agent.channels.remove(&channel_id);
514            }
515            Some(Ok(Frame::OpenUrl { url })) => {
516                if url.starts_with("http://") || url.starts_with("https://") {
517                    debug!("opening URL locally: {url}");
518                    tokio::task::spawn_blocking(move || {
519                        let _ = opener::open(&url);
520                    });
521                } else {
522                    debug!("rejected non-http(s) URL: {url}");
523                }
524            }
525            Some(Ok(Frame::TunnelListen { port })) => {
526                if !self.oauth_redirect {
527                    debug!(port, "tunnel: oauth-redirect disabled, declining");
528                    let _ = timed_send(framed, Frame::TunnelClose { channel_id: 0 }).await;
529                } else {
530                    // Bind synchronously to guarantee port is ready before OpenUrl
531                    match std::net::TcpListener::bind(("127.0.0.1", port)) {
532                        Ok(std_listener) => {
533                            debug!(port, "tunnel: bound local port");
534                            std_listener.set_nonblocking(true).ok();
535                            let listener = tokio::net::TcpListener::from_std(std_listener).unwrap();
536                            let tx = self.tunnel_event_tx.clone();
537                            let timeout = self.oauth_timeout;
538                            let next_id = Arc::clone(&self.tunnel.next_channel_id);
539                            self.tunnel.listener = Some(tokio::spawn(async move {
540                                let deadline =
541                                    tokio::time::Instant::now() + Duration::from_secs(timeout);
542                                loop {
543                                    let accept =
544                                        tokio::time::timeout_at(deadline, listener.accept()).await;
545                                    match accept {
546                                        Ok(Ok((stream, _))) => {
547                                            let channel_id =
548                                                next_id.fetch_add(1, Ordering::Relaxed);
549                                            let (read_half, write_half) = stream.into_split();
550                                            let (writer_tx, mut writer_rx) =
551                                                mpsc::channel::<Bytes>(crate::CHANNEL_RELAY_BUFFER);
552
553                                            // Writer task: channel -> TCP
554                                            tokio::spawn(async move {
555                                                use tokio::io::AsyncWriteExt;
556                                                let mut writer = write_half;
557                                                while let Some(data) = writer_rx.recv().await {
558                                                    if writer.write_all(&data).await.is_err() {
559                                                        break;
560                                                    }
561                                                }
562                                            });
563
564                                            let _ = tx.send(ClientTunnelEvent::Accepted {
565                                                channel_id,
566                                                writer_tx,
567                                            });
568
569                                            // Reader task: TCP -> events (spawned so we
570                                            // can keep accepting new connections)
571                                            let reader_tx = tx.clone();
572                                            tokio::spawn(async move {
573                                                use tokio::io::AsyncReadExt;
574                                                let mut read_half = read_half;
575                                                let mut buf = vec![0u8; 4096];
576                                                loop {
577                                                    match read_half.read(&mut buf).await {
578                                                        Ok(0) | Err(_) => {
579                                                            let _ = reader_tx.send(
580                                                                ClientTunnelEvent::Closed {
581                                                                    channel_id,
582                                                                },
583                                                            );
584                                                            break;
585                                                        }
586                                                        Ok(n) => {
587                                                            let data =
588                                                                Bytes::copy_from_slice(&buf[..n]);
589                                                            if reader_tx
590                                                                .send(ClientTunnelEvent::Data {
591                                                                    channel_id,
592                                                                    data,
593                                                                })
594                                                                .is_err()
595                                                            {
596                                                                break;
597                                                            }
598                                                        }
599                                                    }
600                                                }
601                                            });
602                                        }
603                                        _ => {
604                                            debug!(port, "tunnel: accept timed out or failed");
605                                            break;
606                                        }
607                                    }
608                                }
609                            }));
610                        }
611                        Err(e) => {
612                            debug!(port, "tunnel: bind failed: {e}");
613                            let _ = timed_send(framed, Frame::TunnelClose { channel_id: 0 }).await;
614                        }
615                    }
616                }
617            }
618            Some(Ok(Frame::SendOffer { file_count, total_bytes })) => {
619                let size_str = format_size(total_bytes);
620                let s = if file_count == 1 { "" } else { "s" };
621                write_stdout_async(
622                    self.async_stdout,
623                    status_msg(&format!("gritty: receiving {file_count} file{s} ({size_str})"))
624                        .as_bytes(),
625                )
626                .await?;
627            }
628            Some(Ok(Frame::SendDone)) => {
629                write_stdout_async(
630                    self.async_stdout,
631                    success_msg("gritty: transfer complete").as_bytes(),
632                )
633                .await?;
634            }
635            Some(Ok(Frame::SendCancel { reason })) => {
636                write_stdout_async(
637                    self.async_stdout,
638                    error_msg(&format!("gritty: transfer cancelled: {reason}")).as_bytes(),
639                )
640                .await?;
641            }
642            Some(Ok(Frame::TunnelData { channel_id, data })) => {
643                if let Some(tx) = self.tunnel.channels.get(&channel_id) {
644                    let _ = tx.send(data).await;
645                }
646            }
647            Some(Ok(Frame::TunnelClose { channel_id })) => {
648                self.tunnel.channels.remove(&channel_id);
649            }
650            // Port forward: server asks client to bind a port (remote-fwd)
651            Some(Ok(Frame::PortForwardListen { forward_id, listen_port, target_port })) => {
652                match std::net::TcpListener::bind(("127.0.0.1", listen_port)) {
653                    Ok(std_listener) => {
654                        debug!(forward_id, listen_port, "port forward: bound local port");
655                        std_listener.set_nonblocking(true).ok();
656                        let listener = tokio::net::TcpListener::from_std(std_listener).unwrap();
657                        let tx = self.pf_event_tx.clone();
658                        let nid = self.pf.next_channel_id.clone();
659                        let handle = tokio::spawn(async move {
660                            loop {
661                                let (stream, _) = match listener.accept().await {
662                                    Ok(conn) => conn,
663                                    Err(_) => break,
664                                };
665                                let channel_id =
666                                    nid.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
667                                let (read_half, write_half) = stream.into_split();
668                                let data_tx = tx.clone();
669                                let close_tx = tx.clone();
670                                let writer_tx = crate::spawn_channel_relay(
671                                    channel_id,
672                                    read_half,
673                                    write_half,
674                                    move |id, data| {
675                                        data_tx
676                                            .send(ClientPortForwardEvent::Data {
677                                                channel_id: id,
678                                                data,
679                                            })
680                                            .is_ok()
681                                    },
682                                    move |id| {
683                                        let _ = close_tx.send(ClientPortForwardEvent::Closed {
684                                            channel_id: id,
685                                        });
686                                    },
687                                );
688                                if tx
689                                    .send(ClientPortForwardEvent::Accepted {
690                                        forward_id,
691                                        channel_id,
692                                        writer_tx,
693                                    })
694                                    .is_err()
695                                {
696                                    break;
697                                }
698                            }
699                        });
700                        self.pf.forwards.insert(
701                            forward_id,
702                            ClientPortForwardState { listener_handle: Some(handle), target_port },
703                        );
704                        if !timed_send(framed, Frame::PortForwardReady { forward_id }).await {
705                            return Ok(ControlFlow::Break(RelayExit::Disconnected));
706                        }
707                    }
708                    Err(e) => {
709                        debug!(forward_id, listen_port, "port forward: bind failed: {e}");
710                        let _ = timed_send(framed, Frame::PortForwardStop { forward_id }).await;
711                    }
712                }
713            }
714            // Port forward: new TCP connection from server side (local-fwd)
715            Some(Ok(Frame::PortForwardOpen { forward_id, channel_id, target_port })) => {
716                if self.pf.forwards.contains_key(&forward_id) || forward_id == u32::MAX {
717                    // forward_id == u32::MAX is a "don't track" sentinel for local-fwd
718                    match tokio::net::TcpStream::connect(("127.0.0.1", target_port)).await {
719                        Ok(stream) => {
720                            let (read_half, write_half) = stream.into_split();
721                            let data_tx = self.pf_event_tx.clone();
722                            let close_tx = self.pf_event_tx.clone();
723                            let writer_tx = crate::spawn_channel_relay(
724                                channel_id,
725                                read_half,
726                                write_half,
727                                move |id, data| {
728                                    data_tx
729                                        .send(ClientPortForwardEvent::Data { channel_id: id, data })
730                                        .is_ok()
731                                },
732                                move |id| {
733                                    let _ = close_tx
734                                        .send(ClientPortForwardEvent::Closed { channel_id: id });
735                                },
736                            );
737                            self.pf.channels.insert(channel_id, (forward_id, writer_tx));
738                        }
739                        Err(e) => {
740                            debug!(channel_id, target_port, "pf connect failed: {e}");
741                            let _ =
742                                timed_send(framed, Frame::PortForwardClose { channel_id }).await;
743                        }
744                    }
745                }
746            }
747            // Port forward: channel data from server
748            Some(Ok(Frame::PortForwardData { channel_id, data })) => {
749                if let Some((_, tx)) = self.pf.channels.get(&channel_id) {
750                    let _ = tx.send(data).await;
751                }
752            }
753            // Port forward: channel closed by server
754            Some(Ok(Frame::PortForwardClose { channel_id })) => {
755                self.pf.channels.remove(&channel_id);
756            }
757            // Port forward: teardown from server
758            Some(Ok(Frame::PortForwardStop { forward_id })) => {
759                if let Some(fwd) = self.pf.forwards.remove(&forward_id) {
760                    if let Some(h) = fwd.listener_handle {
761                        h.abort();
762                    }
763                }
764                // Remove channels belonging to this forward
765                self.pf.channels.retain(|_, (fid, _)| *fid != forward_id);
766            }
767            Some(Ok(_)) => {} // ignore control/resize frames
768            Some(Err(e)) => {
769                debug!("server connection error: {e}");
770                return Ok(ControlFlow::Break(RelayExit::Disconnected));
771            }
772            None => {
773                debug!("server disconnected");
774                return Ok(ControlFlow::Break(RelayExit::Disconnected));
775            }
776        }
777        Ok(ControlFlow::Continue(()))
778    }
779
780    async fn handle_agent_event(
781        &mut self,
782        framed: &mut Framed<UnixStream, FrameCodec>,
783        event: Option<AgentEvent>,
784    ) -> bool {
785        match event {
786            Some(AgentEvent::Data { channel_id, data }) => {
787                if self.agent.channels.contains_key(&channel_id)
788                    && !timed_send(framed, Frame::AgentData { channel_id, data }).await
789                {
790                    return false;
791                }
792            }
793            Some(AgentEvent::Closed { channel_id }) => {
794                if self.agent.channels.remove(&channel_id).is_some()
795                    && !timed_send(framed, Frame::AgentClose { channel_id }).await
796                {
797                    return false;
798                }
799            }
800            None => {} // no more agent tasks
801        }
802        true
803    }
804
805    async fn handle_tunnel_event(
806        &mut self,
807        framed: &mut Framed<UnixStream, FrameCodec>,
808        event: Option<ClientTunnelEvent>,
809    ) -> bool {
810        match event {
811            Some(ClientTunnelEvent::Accepted { channel_id, writer_tx }) => {
812                self.tunnel.channels.insert(channel_id, writer_tx);
813                if !timed_send(framed, Frame::TunnelOpen { channel_id }).await {
814                    return false;
815                }
816            }
817            Some(ClientTunnelEvent::Data { channel_id, data }) => {
818                if !timed_send(framed, Frame::TunnelData { channel_id, data }).await {
819                    return false;
820                }
821            }
822            Some(ClientTunnelEvent::Closed { channel_id }) => {
823                self.tunnel.channels.remove(&channel_id);
824                if !timed_send(framed, Frame::TunnelClose { channel_id }).await {
825                    return false;
826                }
827            }
828            None => {}
829        }
830        true
831    }
832
833    async fn handle_pf_event(
834        &mut self,
835        framed: &mut Framed<UnixStream, FrameCodec>,
836        event: Option<ClientPortForwardEvent>,
837    ) -> bool {
838        match event {
839            Some(ClientPortForwardEvent::Accepted { forward_id, channel_id, writer_tx }) => {
840                if let Some(fwd) = self.pf.forwards.get(&forward_id) {
841                    let target_port = fwd.target_port;
842                    self.pf.channels.insert(channel_id, (forward_id, writer_tx));
843                    if !timed_send(
844                        framed,
845                        Frame::PortForwardOpen { forward_id, channel_id, target_port },
846                    )
847                    .await
848                    {
849                        return false;
850                    }
851                }
852            }
853            Some(ClientPortForwardEvent::Data { channel_id, data }) => {
854                if self.pf.channels.contains_key(&channel_id)
855                    && !timed_send(framed, Frame::PortForwardData { channel_id, data }).await
856                {
857                    return false;
858                }
859            }
860            Some(ClientPortForwardEvent::Closed { channel_id }) => {
861                if self.pf.channels.remove(&channel_id).is_some()
862                    && !timed_send(framed, Frame::PortForwardClose { channel_id }).await
863                {
864                    return false;
865                }
866            }
867            None => {}
868        }
869        true
870    }
871}
872
873/// Relay between stdin/stdout and the framed socket.
874#[allow(clippy::too_many_arguments)]
875async fn relay(
876    framed: &mut Framed<UnixStream, FrameCodec>,
877    async_stdin: &AsyncFd<io::Stdin>,
878    async_stdout: &AsyncFd<std::os::fd::OwnedFd>,
879    sigwinch: &mut tokio::signal::unix::Signal,
880    buf: &mut [u8],
881    mut escape: Option<&mut EscapeProcessor>,
882    raw_guard: &RawModeGuard,
883    nb_guard: &NonBlockGuard,
884    agent_socket: Option<&str>,
885    oauth_redirect: bool,
886    oauth_timeout: u64,
887    session: &str,
888    hb_interval: Duration,
889    hb_timeout: Duration,
890) -> anyhow::Result<RelayExit> {
891    let mut sigterm = signal(SignalKind::terminate())?;
892    let mut sighup = signal(SignalKind::hangup())?;
893
894    let mut heartbeat_interval = tokio::time::interval(hb_interval);
895    heartbeat_interval.reset(); // first tick is immediate otherwise; delay it
896    let mut last_pong = Instant::now();
897    let mut last_ping_sent = Instant::now();
898    let mut last_rtt: Option<Duration> = None;
899
900    // Agent channel management
901    let mut agent = ClientAgentState::new();
902    let (agent_event_tx, mut agent_event_rx) = mpsc::unbounded_channel::<AgentEvent>();
903
904    // Tunnel state (reverse TCP tunnel for OAuth callbacks)
905    let mut tunnel = ClientTunnelState::new();
906    let (tunnel_event_tx, mut tunnel_event_rx) = mpsc::unbounded_channel::<ClientTunnelEvent>();
907
908    // Port forward state
909    let (pf_event_tx, mut pf_event_rx) = mpsc::unbounded_channel::<ClientPortForwardEvent>();
910    let mut pf = ClientPortForwardTable::new();
911
912    let mut bytes_relayed = 0u64;
913    let mut relay = ClientRelay {
914        async_stdout,
915        agent: &mut agent,
916        agent_event_tx: &agent_event_tx,
917        agent_socket,
918        tunnel: &mut tunnel,
919        tunnel_event_tx: &tunnel_event_tx,
920        oauth_redirect,
921        oauth_timeout,
922        pf: &mut pf,
923        pf_event_tx: &pf_event_tx,
924        last_pong: &mut last_pong,
925        last_ping_sent: &mut last_ping_sent,
926        last_rtt: &mut last_rtt,
927        connected_at: Instant::now(),
928        bytes_relayed: &mut bytes_relayed,
929    };
930
931    loop {
932        tokio::select! {
933            ready = async_stdin.readable() => {
934                let mut guard = ready?;
935                match guard.try_io(|inner| inner.get_ref().read(buf)) {
936                    Ok(Ok(0)) => {
937                        debug!("stdin EOF");
938                        return Ok(RelayExit::Exit(0));
939                    }
940                    Ok(Ok(n)) => {
941                        debug!(len = n, "stdin → socket");
942                        if let Some(ref mut esc) = escape {
943                            for action in esc.process(&buf[..n]) {
944                                match action {
945                                    EscapeAction::Data(data) => {
946                                        if !timed_send(framed, Frame::Data(Bytes::from(data))).await {
947                                            return Ok(RelayExit::Disconnected);
948                                        }
949                                    }
950                                    EscapeAction::Detach => {
951                                        write_stdout_async(async_stdout, status_msg("detached").as_bytes()).await?;
952                                        return Ok(RelayExit::Exit(0));
953                                    }
954                                    EscapeAction::Reconnect => {
955                                        write_stdout_async(async_stdout, status_msg("force reconnect").as_bytes()).await?;
956                                        return Ok(RelayExit::Disconnected);
957                                    }
958                                    EscapeAction::Suspend => {
959                                        suspend(raw_guard, nb_guard)?;
960                                        // Re-sync terminal size after resume
961                                        let (cols, rows) = get_terminal_size();
962                                        if !timed_send(framed, Frame::Resize { cols, rows }).await {
963                                            return Ok(RelayExit::Disconnected);
964                                        }
965                                    }
966                                    EscapeAction::Status => {
967                                        let rtt_str = match *relay.last_rtt {
968                                            Some(d) => format!("{:.1}ms", d.as_secs_f64() * 1000.0),
969                                            None => "n/a".to_string(),
970                                        };
971                                        let uptime = relay.connected_at.elapsed();
972                                        let uptime_str = if uptime.as_secs() >= 3600 {
973                                            format!(
974                                                "{}h {}m {}s",
975                                                uptime.as_secs() / 3600,
976                                                (uptime.as_secs() % 3600) / 60,
977                                                uptime.as_secs() % 60,
978                                            )
979                                        } else if uptime.as_secs() >= 60 {
980                                            format!(
981                                                "{}m {}s",
982                                                uptime.as_secs() / 60,
983                                                uptime.as_secs() % 60,
984                                            )
985                                        } else {
986                                            format!("{}s", uptime.as_secs())
987                                        };
988                                        let bytes_str = format_size(*relay.bytes_relayed);
989                                        let agent_info = if relay.agent_socket.is_some() {
990                                            format!(
991                                                "on ({} channels)",
992                                                relay.agent.channels.len()
993                                            )
994                                        } else {
995                                            "off".to_string()
996                                        };
997                                        let open_str = if relay.oauth_redirect { "on" } else { "off" };
998                                        let mut pf_lines = Vec::new();
999                                        for (&fwd_id, fwd) in &relay.pf.forwards {
1000                                            let ch_count = relay.pf.channels.values()
1001                                                .filter(|(fid, _)| *fid == fwd_id)
1002                                                .count();
1003                                            pf_lines.push(format!(
1004                                                "    :{} ({} connections)",
1005                                                fwd.target_port,
1006                                                ch_count,
1007                                            ));
1008                                        }
1009                                        let tunnel_str = if !relay.tunnel.channels.is_empty() {
1010                                            format!("active ({} channels)", relay.tunnel.channels.len())
1011                                        } else if relay.tunnel.listener.is_some() {
1012                                            "listening".to_string()
1013                                        } else {
1014                                            "idle".to_string()
1015                                        };
1016                                        let mut status = format!(
1017                                            "\r\n\x1b[2;33m[gritty status]\r\n\
1018                                             \x1b[0m\x1b[2m  session: {session}\r\n\
1019                                             \x1b[0m\x1b[2m  rtt: {rtt_str}\r\n\
1020                                             \x1b[0m\x1b[2m  connected: {uptime_str}\r\n\
1021                                             \x1b[0m\x1b[2m  bytes relayed: {bytes_str}\r\n\
1022                                             \x1b[0m\x1b[2m  agent forwarding: {agent_info}\r\n\
1023                                             \x1b[0m\x1b[2m  open forwarding: {open_str}\r\n\
1024                                             \x1b[0m\x1b[2m  oauth tunnel: {tunnel_str}\r\n",
1025                                        );
1026                                        for line in &pf_lines {
1027                                            status.push_str(&format!(
1028                                                "\x1b[0m\x1b[2m  port forward{line}\r\n"
1029                                            ));
1030                                        }
1031                                        status.push_str("\x1b[0m");
1032                                        write_stdout_async(
1033                                            async_stdout,
1034                                            status.as_bytes(),
1035                                        ).await?;
1036                                    }
1037                                    EscapeAction::Help => {
1038                                        write_stdout_async(async_stdout, ESCAPE_HELP).await?;
1039                                    }
1040                                }
1041                            }
1042                        } else if !timed_send(framed, Frame::Data(Bytes::copy_from_slice(&buf[..n]))).await {
1043                            return Ok(RelayExit::Disconnected);
1044                        }
1045                    }
1046                    Ok(Err(e)) => return Err(e.into()),
1047                    Err(_would_block) => continue,
1048                }
1049            }
1050
1051            frame = framed.next() => {
1052                if let ControlFlow::Break(exit) = relay.handle_server_frame(framed, frame).await? {
1053                    return Ok(exit);
1054                }
1055            }
1056
1057            event = agent_event_rx.recv() => {
1058                if !relay.handle_agent_event(framed, event).await {
1059                    return Ok(RelayExit::Disconnected);
1060                }
1061            }
1062
1063            event = tunnel_event_rx.recv() => {
1064                if !relay.handle_tunnel_event(framed, event).await {
1065                    return Ok(RelayExit::Disconnected);
1066                }
1067            }
1068
1069            event = pf_event_rx.recv() => {
1070                if !relay.handle_pf_event(framed, event).await {
1071                    return Ok(RelayExit::Disconnected);
1072                }
1073            }
1074
1075            _ = sigwinch.recv() => {
1076                let (cols, rows) = get_terminal_size();
1077                debug!(cols, rows, "SIGWINCH → resize");
1078                if !timed_send(framed, Frame::Resize { cols, rows }).await {
1079                    return Ok(RelayExit::Disconnected);
1080                }
1081            }
1082
1083            _ = heartbeat_interval.tick() => {
1084                if relay.last_pong.elapsed() > hb_timeout {
1085                    debug!("heartbeat timeout");
1086                    return Ok(RelayExit::Disconnected);
1087                }
1088                *relay.last_ping_sent = Instant::now();
1089                if !timed_send(framed, Frame::Ping).await {
1090                    return Ok(RelayExit::Disconnected);
1091                }
1092            }
1093
1094            _ = sigterm.recv() => {
1095                debug!("SIGTERM received, exiting");
1096                return Ok(RelayExit::Exit(1));
1097            }
1098
1099            _ = sighup.recv() => {
1100                debug!("SIGHUP received, exiting");
1101                return Ok(RelayExit::Exit(1));
1102            }
1103        }
1104    }
1105}
1106
1107#[allow(clippy::too_many_arguments)]
1108pub async fn run(
1109    session: &str,
1110    mut framed: Framed<UnixStream, FrameCodec>,
1111    redraw: bool,
1112    ctl_path: &Path,
1113    env_vars: Vec<(String, String)>,
1114    no_escape: bool,
1115    forward_agent: bool,
1116    forward_open: bool,
1117    oauth_redirect: bool,
1118    oauth_timeout: u64,
1119    heartbeat_interval: u64,
1120    heartbeat_timeout: u64,
1121) -> anyhow::Result<i32> {
1122    let stdin = io::stdin();
1123    let stdin_fd = stdin.as_fd();
1124    // Safety: stdin lives for the duration of the program
1125    let stdin_borrowed: BorrowedFd<'static> =
1126        unsafe { BorrowedFd::borrow_raw(stdin_fd.as_raw_fd()) };
1127    let raw_guard = RawModeGuard::enter(stdin_borrowed)?;
1128
1129    // Set stdin to non-blocking for AsyncFd — guard restores on drop.
1130    // Declared BEFORE async_stdin so it drops AFTER AsyncFd (reverse drop order).
1131    let nb_guard = NonBlockGuard::set(stdin_borrowed)?;
1132    let async_stdin = AsyncFd::new(io::stdin())?;
1133    // dup() stdout so we get an independent fd for AsyncFd (stdin may share the same fd).
1134    let stdout_fd = crate::security::checked_dup(io::stdout().as_raw_fd())?;
1135    let async_stdout = AsyncFd::new(stdout_fd)?;
1136    let mut sigwinch = signal(SignalKind::window_change())?;
1137    let mut buf = vec![0u8; 4096];
1138    let mut current_redraw = redraw;
1139    let mut current_env = env_vars;
1140    let mut escape = if no_escape { None } else { Some(EscapeProcessor::new()) };
1141    let agent_socket = if forward_agent { std::env::var("SSH_AUTH_SOCK").ok() } else { None };
1142
1143    loop {
1144        let result = if send_init_frames(
1145            &mut framed,
1146            &current_env,
1147            forward_agent,
1148            agent_socket.as_deref(),
1149            forward_open,
1150            current_redraw,
1151        )
1152        .await
1153        {
1154            relay(
1155                &mut framed,
1156                &async_stdin,
1157                &async_stdout,
1158                &mut sigwinch,
1159                &mut buf,
1160                escape.as_mut(),
1161                &raw_guard,
1162                &nb_guard,
1163                agent_socket.as_deref(),
1164                oauth_redirect,
1165                oauth_timeout,
1166                session,
1167                Duration::from_secs(heartbeat_interval),
1168                Duration::from_secs(heartbeat_timeout),
1169            )
1170            .await?
1171        } else {
1172            RelayExit::Disconnected
1173        };
1174        match result {
1175            RelayExit::Exit(code) => return Ok(code),
1176            RelayExit::Disconnected => {
1177                // Env vars only sent on first connection; clear for reconnect
1178                current_env.clear();
1179                // Disconnected — try to reconnect
1180                let reconnect_started = Instant::now();
1181                write_stdout_async(&async_stdout, status_msg("reconnecting...").as_bytes()).await?;
1182
1183                loop {
1184                    // Race sleep against stdin so Ctrl-C is instant
1185                    tokio::select! {
1186                        _ = tokio::time::sleep(Duration::from_secs(1)) => {}
1187                        _ = async_stdin.readable() => {
1188                            let mut peek = [0u8; 1];
1189                            match async_stdin.get_ref().read(&mut peek) {
1190                                Ok(1) if peek[0] == 0x03 => {
1191                                    write_stdout_async(&async_stdout, b"\r\n").await?;
1192                                    return Ok(1);
1193                                }
1194                                _ => {}
1195                            }
1196                            continue;
1197                        }
1198                    }
1199
1200                    // Update elapsed time on each retry
1201                    let elapsed = reconnect_started.elapsed().as_secs();
1202                    write_stdout_async(
1203                        &async_stdout,
1204                        format!("\r{}", status_msg(&format!("reconnecting... {elapsed}s")))
1205                            .as_bytes(),
1206                    )
1207                    .await?;
1208
1209                    let stream = match UnixStream::connect(ctl_path).await {
1210                        Ok(s) => s,
1211                        Err(_) => continue,
1212                    };
1213
1214                    let mut new_framed = Framed::new(stream, FrameCodec);
1215                    if crate::handshake(&mut new_framed).await.is_err() {
1216                        continue;
1217                    }
1218                    if new_framed
1219                        .send(Frame::Attach { session: session.to_string() })
1220                        .await
1221                        .is_err()
1222                    {
1223                        continue;
1224                    }
1225
1226                    match new_framed.next().await {
1227                        Some(Ok(Frame::Ok)) => {
1228                            write_stdout_async(
1229                                &async_stdout,
1230                                success_msg("reconnected").as_bytes(),
1231                            )
1232                            .await?;
1233                            framed = new_framed;
1234                            current_redraw = true;
1235                            break;
1236                        }
1237                        Some(Ok(Frame::Error { message })) => {
1238                            write_stdout_async(
1239                                &async_stdout,
1240                                error_msg(&format!("session gone: {message}")).as_bytes(),
1241                            )
1242                            .await?;
1243                            return Ok(1);
1244                        }
1245                        _ => continue,
1246                    }
1247                }
1248            }
1249        }
1250    }
1251}
1252
1253/// Read-only tail of a session's PTY output.
1254/// No raw mode, no stdin, no escape processing, no forwarding.
1255/// Ctrl-C triggers clean exit with terminal reset.
1256pub async fn tail(
1257    session: &str,
1258    mut framed: Framed<UnixStream, FrameCodec>,
1259    ctl_path: &Path,
1260) -> anyhow::Result<i32> {
1261    // Suppress stdin echo — tail is read-only. Guard restores on drop.
1262    let stdin_fd = unsafe { BorrowedFd::borrow_raw(libc::STDIN_FILENO) };
1263    let _input_guard = SuppressInputGuard::enter(stdin_fd).ok();
1264
1265    // Drain stdin in background, ring bell on first keystroke
1266    tokio::task::spawn_blocking(|| {
1267        let mut buf = [0u8; 64];
1268        let mut belled = false;
1269        loop {
1270            match io::stdin().read(&mut buf) {
1271                Ok(0) | Err(_) => break,
1272                Ok(_) if !belled => {
1273                    let _ = io::stderr().write_all(b"\x07");
1274                    let _ = io::stderr().flush();
1275                    belled = true;
1276                }
1277                _ => {}
1278            }
1279        }
1280    });
1281
1282    let mut heartbeat_interval = tokio::time::interval(DEFAULT_HEARTBEAT_INTERVAL);
1283    heartbeat_interval.reset();
1284    let mut last_pong = Instant::now();
1285    let mut sigint = signal(SignalKind::interrupt())?;
1286    let mut sigterm = signal(SignalKind::terminate())?;
1287    let mut sighup = signal(SignalKind::hangup())?;
1288    let mut stdout = tokio::io::stdout();
1289
1290    let code = 'outer: loop {
1291        let result = 'relay: loop {
1292            tokio::select! {
1293                frame = framed.next() => {
1294                    match frame {
1295                        Some(Ok(Frame::Data(data))) => {
1296                            use tokio::io::AsyncWriteExt;
1297                            stdout.write_all(&data).await?;
1298                        }
1299                        Some(Ok(Frame::Pong)) => {
1300                            last_pong = Instant::now();
1301                        }
1302                        Some(Ok(Frame::Exit { code })) => {
1303                            break 'relay Some(code);
1304                        }
1305                        Some(Ok(_)) => {}
1306                        Some(Err(e)) => {
1307                            debug!("tail connection error: {e}");
1308                            break 'relay None;
1309                        }
1310                        None => {
1311                            debug!("tail server disconnected");
1312                            break 'relay None;
1313                        }
1314                    }
1315                }
1316                _ = heartbeat_interval.tick() => {
1317                    if last_pong.elapsed() > DEFAULT_HEARTBEAT_TIMEOUT {
1318                        debug!("tail heartbeat timeout");
1319                        break 'relay None;
1320                    }
1321                    if framed.send(Frame::Ping).await.is_err() {
1322                        break 'relay None;
1323                    }
1324                }
1325                _ = sigint.recv() => {
1326                    break 'outer 0;
1327                }
1328                _ = sigterm.recv() => {
1329                    break 'outer 1;
1330                }
1331                _ = sighup.recv() => {
1332                    break 'outer 1;
1333                }
1334            }
1335        };
1336
1337        match result {
1338            Some(code) => break code,
1339            None => {
1340                let reconnect_started = Instant::now();
1341                eprintln!("\x1b[2;33m[reconnecting...]\x1b[0m");
1342                loop {
1343                    tokio::time::sleep(Duration::from_secs(1)).await;
1344                    let elapsed = reconnect_started.elapsed().as_secs();
1345                    eprint!("\r\x1b[2;33m[reconnecting... {elapsed}s]\x1b[0m");
1346
1347                    let stream = match UnixStream::connect(ctl_path).await {
1348                        Ok(s) => s,
1349                        Err(_) => continue,
1350                    };
1351
1352                    let mut new_framed = Framed::new(stream, FrameCodec);
1353                    if crate::handshake(&mut new_framed).await.is_err() {
1354                        continue;
1355                    }
1356                    if new_framed.send(Frame::Tail { session: session.to_string() }).await.is_err()
1357                    {
1358                        continue;
1359                    }
1360
1361                    match new_framed.next().await {
1362                        Some(Ok(Frame::Ok)) => {
1363                            eprintln!("\x1b[32m[reconnected]\x1b[0m");
1364                            framed = new_framed;
1365                            heartbeat_interval.reset();
1366                            last_pong = Instant::now();
1367                            break;
1368                        }
1369                        Some(Ok(Frame::Error { message })) => {
1370                            eprintln!("\x1b[31m[session gone: {message}]\x1b[0m");
1371                            break 'outer 1;
1372                        }
1373                        _ => continue,
1374                    }
1375                }
1376            }
1377        }
1378    };
1379
1380    // Reset terminal state: clear attributes and show cursor.
1381    // PTY output may have left colors/bold set or cursor hidden.
1382    {
1383        use tokio::io::AsyncWriteExt;
1384        let _ = stdout.write_all(b"\x1b[0m\x1b[?25h").await;
1385    }
1386    Ok(code)
1387}
1388
1389#[cfg(test)]
1390mod tests {
1391    use super::*;
1392
1393    #[test]
1394    fn normal_passthrough() {
1395        let mut ep = EscapeProcessor::new();
1396        // No newlines — after initial AfterNewline, 'h' transitions to Normal
1397        let actions = ep.process(b"hello");
1398        assert_eq!(actions, vec![EscapeAction::Data(b"hello".to_vec())]);
1399    }
1400
1401    #[test]
1402    fn tilde_after_newline_detach() {
1403        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1404        let actions = ep.process(b"\n~.");
1405        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
1406    }
1407
1408    #[test]
1409    fn tilde_after_cr_detach() {
1410        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1411        let actions = ep.process(b"\r~.");
1412        assert_eq!(actions, vec![EscapeAction::Data(b"\r".to_vec()), EscapeAction::Detach,]);
1413    }
1414
1415    #[test]
1416    fn tilde_not_after_newline() {
1417        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1418        let actions = ep.process(b"a~.");
1419        assert_eq!(actions, vec![EscapeAction::Data(b"a~.".to_vec())]);
1420    }
1421
1422    #[test]
1423    fn initial_state_detach() {
1424        let mut ep = EscapeProcessor::new();
1425        let actions = ep.process(b"~.");
1426        assert_eq!(actions, vec![EscapeAction::Detach]);
1427    }
1428
1429    #[test]
1430    fn tilde_suspend() {
1431        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1432        let actions = ep.process(b"\n~\x1a");
1433        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Suspend,]);
1434    }
1435
1436    #[test]
1437    fn tilde_status() {
1438        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1439        let actions = ep.process(b"\n~#");
1440        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Status,]);
1441    }
1442
1443    #[test]
1444    fn tilde_reconnect() {
1445        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1446        let actions = ep.process(b"\n~R");
1447        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Reconnect,]);
1448    }
1449
1450    #[test]
1451    fn tilde_reconnect_stops_processing() {
1452        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1453        let actions = ep.process(b"\n~Rremaining");
1454        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Reconnect,]);
1455    }
1456
1457    #[test]
1458    fn tilde_help() {
1459        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1460        let actions = ep.process(b"\n~?");
1461        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Help,]);
1462    }
1463
1464    #[test]
1465    fn double_tilde() {
1466        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1467        let actions = ep.process(b"\n~~");
1468        assert_eq!(
1469            actions,
1470            vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~".to_vec()),]
1471        );
1472        assert_eq!(ep.state, EscapeState::Normal);
1473    }
1474
1475    #[test]
1476    fn tilde_unknown_char() {
1477        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1478        let actions = ep.process(b"\n~x");
1479        assert_eq!(
1480            actions,
1481            vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~x".to_vec()),]
1482        );
1483    }
1484
1485    #[test]
1486    fn split_across_reads() {
1487        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1488        let a1 = ep.process(b"\n");
1489        assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
1490        let a2 = ep.process(b"~");
1491        assert_eq!(a2, vec![]); // tilde buffered
1492        let a3 = ep.process(b".");
1493        assert_eq!(a3, vec![EscapeAction::Detach]);
1494    }
1495
1496    #[test]
1497    fn split_tilde_then_normal() {
1498        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1499        let a1 = ep.process(b"\n");
1500        assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
1501        let a2 = ep.process(b"~");
1502        assert_eq!(a2, vec![]);
1503        let a3 = ep.process(b"a");
1504        assert_eq!(a3, vec![EscapeAction::Data(b"~a".to_vec())]);
1505    }
1506
1507    #[test]
1508    fn multiple_escapes_one_buffer() {
1509        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1510        let actions = ep.process(b"\n~?\n~.");
1511        assert_eq!(
1512            actions,
1513            vec![
1514                EscapeAction::Data(b"\n".to_vec()),
1515                EscapeAction::Help,
1516                EscapeAction::Data(b"\n".to_vec()),
1517                EscapeAction::Detach,
1518            ]
1519        );
1520    }
1521
1522    #[test]
1523    fn consecutive_newlines() {
1524        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1525        let actions = ep.process(b"\n\n\n~.");
1526        assert_eq!(actions, vec![EscapeAction::Data(b"\n\n\n".to_vec()), EscapeAction::Detach,]);
1527    }
1528
1529    #[test]
1530    fn detach_stops_processing() {
1531        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1532        let actions = ep.process(b"\n~.remaining");
1533        assert_eq!(actions, vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Detach,]);
1534    }
1535
1536    #[test]
1537    fn tilde_then_newline() {
1538        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1539        let actions = ep.process(b"\n~\n");
1540        assert_eq!(
1541            actions,
1542            vec![EscapeAction::Data(b"\n".to_vec()), EscapeAction::Data(b"~\n".to_vec()),]
1543        );
1544        assert_eq!(ep.state, EscapeState::AfterNewline);
1545    }
1546
1547    #[test]
1548    fn empty_input() {
1549        let mut ep = EscapeProcessor::new();
1550        let actions = ep.process(b"");
1551        assert_eq!(actions, vec![]);
1552    }
1553
1554    #[test]
1555    fn only_tilde_buffered() {
1556        let mut ep = EscapeProcessor { state: EscapeState::Normal };
1557        let a1 = ep.process(b"\n~");
1558        assert_eq!(a1, vec![EscapeAction::Data(b"\n".to_vec())]);
1559        assert_eq!(ep.state, EscapeState::AfterTilde);
1560        let a2 = ep.process(b".");
1561        assert_eq!(a2, vec![EscapeAction::Detach]);
1562    }
1563}