Skip to main content

gritty/
server.rs

1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::pty::openpty;
5use std::collections::{HashMap, VecDeque};
6use std::io;
7use std::ops::ControlFlow;
8use std::os::fd::{AsRawFd, OwnedFd};
9use std::path::{Path, PathBuf};
10use std::process::Stdio;
11use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
12use std::sync::{Arc, OnceLock};
13use std::time::Duration;
14use tokio::io::AsyncReadExt;
15use tokio::io::unix::AsyncFd;
16use tokio::net::{UnixListener, UnixStream};
17use tokio::process::Command;
18use tokio::sync::{broadcast, mpsc};
19use tokio_util::codec::Framed;
20use tracing::{debug, info, warn};
21
22/// Wrapper to distinguish active, tail, and send connections arriving via channel.
23pub enum ClientConn {
24    Active(Framed<UnixStream, FrameCodec>),
25    Tail(Framed<UnixStream, FrameCodec>),
26    /// Raw stream for file transfer (local-side commands routed through daemon).
27    Send(UnixStream),
28}
29
30/// Events broadcast to tail clients.
31#[derive(Clone)]
32enum TailEvent {
33    Data(Bytes),
34    Exit { code: i32 },
35}
36
37pub struct SessionMetadata {
38    pub pty_path: String,
39    pub shell_pid: u32,
40    pub created_at: u64,
41    pub attached: AtomicBool,
42    pub last_heartbeat: AtomicU64,
43}
44
45/// Wraps a child process and its process group ID.
46/// On drop, sends SIGHUP to the entire process group.
47struct ManagedChild {
48    child: tokio::process::Child,
49    pgid: nix::unistd::Pid,
50}
51
52impl ManagedChild {
53    fn new(child: tokio::process::Child) -> Self {
54        let pid = child.id().expect("child should have pid") as i32;
55        Self { child, pgid: nix::unistd::Pid::from_raw(pid) }
56    }
57}
58
59impl Drop for ManagedChild {
60    fn drop(&mut self) {
61        let _ = nix::sys::signal::killpg(self.pgid, nix::sys::signal::Signal::SIGHUP);
62        let _ = self.child.try_wait();
63    }
64}
65
66/// Why the relay loop exited.
67enum RelayExit {
68    /// Client disconnected — re-accept.
69    ClientGone,
70    /// Shell exited with a code — we're done.
71    ShellExited(i32),
72}
73
74/// Events from agent connection tasks to the main relay loop.
75enum AgentEvent {
76    Accepted { channel_id: u32, writer_tx: mpsc::Sender<Bytes> },
77    Data { channel_id: u32, data: Bytes },
78    Closed { channel_id: u32 },
79}
80
81/// Events from open socket acceptor to the main relay loop.
82enum OpenEvent {
83    Url(String),
84}
85
86/// Events from tunnel TCP connection tasks to the main relay loop.
87enum TunnelEvent {
88    Connected { channel_id: u32, stream: tokio::net::TcpStream },
89    ConnectFailed { channel_id: u32 },
90    Data { channel_id: u32, data: Bytes },
91    Closed { channel_id: u32 },
92}
93
94/// Events from port forward TCP acceptors and connections to the main relay loop.
95enum PortForwardEvent {
96    /// Svc socket requested a port forward.
97    Requested { stream: UnixStream, direction: u8, listen_port: u16, target_port: u16 },
98    /// TCP connection accepted on a listening port.
99    Accepted { forward_id: u32, channel_id: u32, writer_tx: mpsc::Sender<Bytes> },
100    /// Background TCP connect completed (remote-fwd PortForwardOpen).
101    Connected { forward_id: u32, channel_id: u32, stream: tokio::net::TcpStream },
102    /// Background TCP connect failed.
103    ConnectFailed { channel_id: u32 },
104    /// Data from a TCP connection.
105    Data { channel_id: u32, data: Bytes },
106    /// TCP connection closed.
107    Closed { channel_id: u32 },
108    /// Svc socket dropped -- teardown forward.
109    Stopped { forward_id: u32 },
110}
111
112/// Per-forward state tracked by the server.
113struct PortForwardState {
114    /// Handle for the TCP listener task (aborted on teardown).
115    listener_handle: Option<tokio::task::JoinHandle<()>>,
116    /// Active TCP relay channels.
117    channels: HashMap<u32, mpsc::Sender<Bytes>>,
118    /// Handle for the svc stop watcher (aborted on teardown).
119    stop_handle: Option<tokio::task::JoinHandle<()>>,
120    target_port: u16,
121}
122
123/// Grouped state for SSH agent forwarding within a session.
124struct AgentForwardState {
125    enabled: bool,
126    channels: HashMap<u32, mpsc::Sender<Bytes>>,
127    acceptor: Option<tokio::task::JoinHandle<()>>,
128    next_channel_id: Arc<AtomicU32>,
129    socket_path: PathBuf,
130}
131
132impl AgentForwardState {
133    fn new(socket_path: PathBuf) -> Self {
134        Self {
135            enabled: false,
136            channels: HashMap::new(),
137            acceptor: None,
138            next_channel_id: Arc::new(AtomicU32::new(0)),
139            socket_path,
140        }
141    }
142
143    fn teardown(&mut self) {
144        self.channels.clear();
145        self.enabled = false;
146        if let Some(handle) = self.acceptor.take() {
147            handle.abort();
148        }
149        cleanup_socket(&self.socket_path);
150    }
151}
152
153/// Grouped state for the OAuth callback reverse tunnel (multi-channel).
154struct TunnelRelayState {
155    port: Option<u16>,
156    channels: HashMap<u32, mpsc::Sender<Bytes>>,
157    idle_deadline: Option<tokio::time::Instant>,
158    idle_timeout: Duration,
159}
160
161impl TunnelRelayState {
162    fn new(idle_timeout: Duration) -> Self {
163        Self { port: None, channels: HashMap::new(), idle_deadline: None, idle_timeout }
164    }
165
166    fn teardown(&mut self) {
167        self.channels.clear();
168        self.port = None;
169        self.idle_deadline = None;
170    }
171}
172
173/// Grouped state for TCP port forwarding (local and remote).
174struct PortForwardTable {
175    forwards: HashMap<u32, PortForwardState>,
176    channels: HashMap<u32, (u32, mpsc::Sender<Bytes>)>,
177    pending_remote: HashMap<u32, UnixStream>,
178    next_forward_id: u32,
179    next_channel_id: Arc<AtomicU32>,
180}
181
182impl PortForwardTable {
183    fn new() -> Self {
184        Self {
185            forwards: HashMap::new(),
186            channels: HashMap::new(),
187            pending_remote: HashMap::new(),
188            next_forward_id: 0,
189            next_channel_id: Arc::new(AtomicU32::new(0)),
190        }
191    }
192
193    fn teardown(&mut self) {
194        for (_, pf) in self.forwards.drain() {
195            if let Some(h) = pf.listener_handle {
196                h.abort();
197            }
198            if let Some(h) = pf.stop_handle {
199                h.abort();
200            }
201        }
202        self.channels.clear();
203        self.pending_remote.clear();
204    }
205}
206
207/// File manifest entry parsed from sender protocol.
208struct FileManifest {
209    files: Vec<(String, u64)>, // (basename, size)
210}
211
212impl FileManifest {
213    fn total_bytes(&self) -> u64 {
214        self.files.iter().map(|(_, s)| s).sum()
215    }
216}
217
218/// Events from the send socket acceptor to the main relay loop.
219enum SendEvent {
220    SenderArrived { stream: UnixStream, manifest: FileManifest },
221    ReceiverArrived { stream: UnixStream },
222}
223
224/// State machine for file transfer rendezvous.
225enum TransferState {
226    Idle,
227    WaitingForReceiver { sender_stream: UnixStream, manifest: FileManifest },
228    WaitingForSender { receiver_stream: UnixStream },
229    Active { relay_handle: tokio::task::JoinHandle<()> },
230}
231
232/// Sanitize a filename: strip path separators, reject ".." and empty names.
233fn sanitize_filename(name: &str) -> Option<String> {
234    let basename = std::path::Path::new(name).file_name().and_then(|n| n.to_str()).unwrap_or(name);
235    if basename.is_empty() || basename == ".." || basename == "." {
236        return None;
237    }
238    if basename.contains('\0') || basename.contains('\\') {
239        return None;
240    }
241    Some(basename.to_string())
242}
243
244/// Extract redirect port from a URL's redirect_uri/redirect_url query parameter.
245/// Returns Some(port) if the redirect target is localhost or 127.0.0.1 with a port.
246fn extract_redirect_port(url_str: &str) -> Option<u16> {
247    let parsed = url::Url::parse(url_str).ok()?;
248    for (key, value) in parsed.query_pairs() {
249        if key != "redirect_uri" && key != "redirect_url" {
250            continue;
251        }
252        let redirect = url::Url::parse(&value).ok()?;
253        match redirect.host_str()? {
254            "localhost" | "127.0.0.1" => return redirect.port(),
255            _ => {}
256        }
257    }
258    None
259}
260
261/// Check if a TCP port is in use by attempting to bind it.
262fn port_in_use(port: u16) -> bool {
263    std::net::TcpListener::bind(("127.0.0.1", port)).is_err()
264}
265
266/// Spawn the agent acceptor task that accepts connections on the agent socket
267/// and creates per-connection relay tasks.
268fn spawn_agent_acceptor(
269    listener: UnixListener,
270    event_tx: mpsc::UnboundedSender<AgentEvent>,
271    next_channel_id: Arc<AtomicU32>,
272) -> tokio::task::JoinHandle<()> {
273    tokio::spawn(async move {
274        loop {
275            let (stream, _) = match listener.accept().await {
276                Ok(conn) => conn,
277                Err(e) => {
278                    debug!("agent listener accept error: {e}");
279                    break;
280                }
281            };
282
283            if let Err(e) = crate::security::verify_peer_uid(&stream) {
284                warn!("agent socket: {e}");
285                continue;
286            }
287
288            let channel_id = next_channel_id.fetch_add(1, Ordering::Relaxed);
289
290            let (read_half, write_half) = stream.into_split();
291            let data_tx = event_tx.clone();
292            let close_tx = event_tx.clone();
293            let writer_tx = crate::spawn_channel_relay(
294                channel_id,
295                read_half,
296                write_half,
297                move |id, data| data_tx.send(AgentEvent::Data { channel_id: id, data }).is_ok(),
298                move |id| {
299                    let _ = close_tx.send(AgentEvent::Closed { channel_id: id });
300                },
301            );
302
303            // Notify the relay loop about the new connection
304            if event_tx.send(AgentEvent::Accepted { channel_id, writer_tx }).is_err() {
305                break; // relay loop is gone
306            }
307        }
308    })
309}
310
311/// Spawn a TCP acceptor for port forwarding. Each accepted connection assigns
312/// a channel_id and spawns a bidirectional relay via `spawn_channel_relay`.
313fn spawn_pf_tcp_acceptor(
314    listener: tokio::net::TcpListener,
315    forward_id: u32,
316    next_channel_id: Arc<AtomicU32>,
317    event_tx: mpsc::UnboundedSender<PortForwardEvent>,
318) -> tokio::task::JoinHandle<()> {
319    tokio::spawn(async move {
320        loop {
321            let (stream, _) = match listener.accept().await {
322                Ok(conn) => conn,
323                Err(e) => {
324                    debug!(forward_id, "pf tcp listener accept error: {e}");
325                    break;
326                }
327            };
328
329            let channel_id = next_channel_id.fetch_add(1, Ordering::Relaxed);
330            let (read_half, write_half) = stream.into_split();
331            let data_tx = event_tx.clone();
332            let close_tx = event_tx.clone();
333            let writer_tx = crate::spawn_channel_relay(
334                channel_id,
335                read_half,
336                write_half,
337                move |id, data| {
338                    data_tx.send(PortForwardEvent::Data { channel_id: id, data }).is_ok()
339                },
340                move |id| {
341                    let _ = close_tx.send(PortForwardEvent::Closed { channel_id: id });
342                },
343            );
344
345            if event_tx
346                .send(PortForwardEvent::Accepted { forward_id, channel_id, writer_tx })
347                .is_err()
348            {
349                break;
350            }
351        }
352    })
353}
354
355/// Spawn a task that watches a svc stream for EOF and sends a Stopped event.
356fn spawn_pf_svc_watcher(
357    stream: UnixStream,
358    forward_id: u32,
359    event_tx: mpsc::UnboundedSender<PortForwardEvent>,
360) -> tokio::task::JoinHandle<()> {
361    tokio::spawn(async move {
362        let mut stream = stream;
363        let mut buf = [0u8; 1];
364        // Block until EOF or error
365        let _ = stream.read(&mut buf).await;
366        let _ = event_tx.send(PortForwardEvent::Stopped { forward_id });
367    })
368}
369
370/// Maximum URL length accepted on the service socket.
371const URL_MAX_LEN: usize = 4096;
372
373/// Parse sender manifest from the send socket stream.
374/// Expected format after 'S' byte: file_count(u32 BE), then for each file:
375///   filename_len(u16 BE), filename(UTF-8 bytes), file_size(u64 BE)
376async fn parse_sender_manifest(stream: &mut UnixStream) -> io::Result<FileManifest> {
377    let mut buf4 = [0u8; 4];
378    stream.read_exact(&mut buf4).await?;
379    let file_count = u32::from_be_bytes(buf4);
380    if file_count == 0 || file_count > 10_000 {
381        return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid file count"));
382    }
383    let mut files = Vec::with_capacity(file_count as usize);
384    for _ in 0..file_count {
385        let mut buf2 = [0u8; 2];
386        stream.read_exact(&mut buf2).await?;
387        let name_len = u16::from_be_bytes(buf2) as usize;
388        if name_len == 0 || name_len > 4096 {
389            return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid filename length"));
390        }
391        let mut name_buf = vec![0u8; name_len];
392        stream.read_exact(&mut name_buf).await?;
393        let name = String::from_utf8(name_buf)
394            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
395        let name = match sanitize_filename(&name) {
396            Some(n) => n,
397            None => {
398                return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid filename"));
399            }
400        };
401        let mut buf8 = [0u8; 8];
402        stream.read_exact(&mut buf8).await?;
403        let file_size = u64::from_be_bytes(buf8);
404        files.push((name, file_size));
405    }
406    Ok(FileManifest { files })
407}
408
409/// Parse receiver dest_dir from the send socket stream.
410/// Expected format after 'R' byte: UTF-8 string, newline-terminated.
411async fn parse_receiver_dest(stream: &mut UnixStream) -> io::Result<String> {
412    let mut buf = Vec::with_capacity(256);
413    let mut byte = [0u8; 1];
414    loop {
415        match stream.read_exact(&mut byte).await {
416            Ok(_) => {
417                if byte[0] == b'\n' {
418                    break;
419                }
420                buf.push(byte[0]);
421                if buf.len() > 4096 {
422                    return Err(io::Error::new(io::ErrorKind::InvalidData, "dest dir too long"));
423                }
424            }
425            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
426            Err(e) => return Err(e),
427        }
428    }
429    String::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
430}
431
432/// Spawn the unified service socket acceptor. Reads a SvcRequest discriminator
433/// byte then dispatches to the appropriate handler.
434fn spawn_svc_acceptor(
435    listener: UnixListener,
436    open_event_tx: mpsc::UnboundedSender<OpenEvent>,
437    send_event_tx: mpsc::UnboundedSender<SendEvent>,
438    pf_event_tx: mpsc::UnboundedSender<PortForwardEvent>,
439) -> tokio::task::JoinHandle<()> {
440    tokio::spawn(async move {
441        loop {
442            let (mut stream, _) = match listener.accept().await {
443                Ok(conn) => conn,
444                Err(e) => {
445                    debug!("svc listener accept error: {e}");
446                    break;
447                }
448            };
449
450            // Lenient peer UID check: reject known-bad UIDs but tolerate
451            // OS-level errors (fire-and-forget open connections on macOS may
452            // disconnect before getpeereid returns).
453            match crate::security::verify_peer_uid(&stream) {
454                Ok(()) => {}
455                Err(e) if e.kind() == std::io::ErrorKind::PermissionDenied => {
456                    warn!("svc socket: {e}");
457                    continue;
458                }
459                Err(e) => {
460                    debug!("svc socket peer_cred unavailable: {e}");
461                }
462            }
463
464            let otx = open_event_tx.clone();
465            let stx = send_event_tx.clone();
466            let ptx = pf_event_tx.clone();
467            tokio::spawn(async move {
468                // Read discriminator byte
469                let mut disc = [0u8; 1];
470                if stream.read_exact(&mut disc).await.is_err() {
471                    return;
472                }
473                match crate::protocol::SvcRequest::from_byte(disc[0]) {
474                    Some(crate::protocol::SvcRequest::OpenUrl) => {
475                        let mut buf = vec![0u8; URL_MAX_LEN];
476                        let mut total = 0;
477                        loop {
478                            match stream.read(&mut buf[total..]).await {
479                                Ok(0) => break,
480                                Ok(n) => {
481                                    total += n;
482                                    if buf[..total].contains(&b'\n') || total >= buf.len() {
483                                        break;
484                                    }
485                                }
486                                Err(_) => return,
487                            }
488                        }
489                        let s = String::from_utf8_lossy(&buf[..total]);
490                        let url = s.trim();
491                        if !url.is_empty() {
492                            let _ = otx.send(OpenEvent::Url(url.to_string()));
493                        }
494                    }
495                    Some(crate::protocol::SvcRequest::Send) => {
496                        match parse_sender_manifest(&mut stream).await {
497                            Ok(manifest) => {
498                                let _ = stx.send(SendEvent::SenderArrived { stream, manifest });
499                            }
500                            Err(e) => debug!("svc socket: bad sender manifest: {e}"),
501                        }
502                    }
503                    Some(crate::protocol::SvcRequest::Receive) => {
504                        match parse_receiver_dest(&mut stream).await {
505                            Ok(_) => {
506                                let _ = stx.send(SendEvent::ReceiverArrived { stream });
507                            }
508                            Err(e) => debug!("svc socket: bad receiver dest: {e}"),
509                        }
510                    }
511                    Some(crate::protocol::SvcRequest::PortForward) => {
512                        // Wire: [direction: u8][listen_port: u16 BE][target_port: u16 BE]
513                        let mut hdr = [0u8; 5];
514                        if stream.read_exact(&mut hdr).await.is_err() {
515                            return;
516                        }
517                        let direction = hdr[0];
518                        let listen_port = u16::from_be_bytes([hdr[1], hdr[2]]);
519                        let target_port = u16::from_be_bytes([hdr[3], hdr[4]]);
520                        let _ = ptx.send(PortForwardEvent::Requested {
521                            stream,
522                            direction,
523                            listen_port,
524                            target_port,
525                        });
526                    }
527                    None => {
528                        debug!("svc socket: unknown request byte: 0x{:02x}", disc[0]);
529                    }
530                }
531            });
532        }
533    })
534}
535
536/// Spawn the transfer relay task. Reads file data from sender, writes to receiver.
537/// Sends SendOffer/SendDone/SendCancel notification frames to the active client.
538fn spawn_transfer_relay(
539    mut sender: UnixStream,
540    mut receiver: UnixStream,
541    manifest: FileManifest,
542    notify_tx: mpsc::UnboundedSender<Frame>,
543) -> tokio::task::JoinHandle<()> {
544    tokio::spawn(async move {
545        use tokio::io::AsyncWriteExt;
546
547        let file_count = manifest.files.len() as u32;
548        let total_bytes = manifest.total_bytes();
549
550        // Notify active client about transfer start
551        let _ = notify_tx.send(Frame::SendOffer { file_count, total_bytes });
552
553        // Signal sender to start streaming (write 0x01 go byte)
554        if sender.write_all(&[0x01]).await.is_err() {
555            let _ = notify_tx.send(Frame::SendCancel { reason: "sender disconnected".into() });
556            return;
557        }
558
559        // Write file_count to receiver
560        if receiver.write_all(&file_count.to_be_bytes()).await.is_err() {
561            let _ = notify_tx.send(Frame::SendCancel { reason: "receiver disconnected".into() });
562            return;
563        }
564
565        // For each file: write metadata to receiver, then relay file data
566        let mut buf = vec![0u8; 64 * 1024];
567        for (name, size) in &manifest.files {
568            // Write per-file header to receiver
569            let name_bytes = name.as_bytes();
570            let mut hdr = Vec::with_capacity(2 + name_bytes.len() + 8);
571            hdr.extend_from_slice(&(name_bytes.len() as u16).to_be_bytes());
572            hdr.extend_from_slice(name_bytes);
573            hdr.extend_from_slice(&size.to_be_bytes());
574            if receiver.write_all(&hdr).await.is_err() {
575                let _ =
576                    notify_tx.send(Frame::SendCancel { reason: "receiver disconnected".into() });
577                return;
578            }
579
580            // Relay exactly file_size bytes from sender to receiver
581            let mut remaining = *size;
582            while remaining > 0 {
583                let to_read = (remaining as usize).min(buf.len());
584                match sender.read_exact(&mut buf[..to_read]).await {
585                    Ok(_) => {
586                        if receiver.write_all(&buf[..to_read]).await.is_err() {
587                            let _ = notify_tx
588                                .send(Frame::SendCancel { reason: "receiver disconnected".into() });
589                            return;
590                        }
591                        remaining -= to_read as u64;
592                    }
593                    Err(_) => {
594                        let _ = notify_tx
595                            .send(Frame::SendCancel { reason: "sender disconnected".into() });
596                        return;
597                    }
598                }
599            }
600        }
601
602        // Write sentinel: filename_len = 0
603        if receiver.write_all(&[0u8; 2]).await.is_err() {
604            let _ = notify_tx.send(Frame::SendCancel { reason: "receiver disconnected".into() });
605            return;
606        }
607
608        let _ = notify_tx.send(Frame::SendDone);
609    })
610}
611
612/// Relay broadcast events to a tail client. Handles Ping/Pong for keepalive.
613async fn tail_relay(
614    mut framed: Framed<UnixStream, FrameCodec>,
615    mut rx: broadcast::Receiver<TailEvent>,
616) {
617    loop {
618        tokio::select! {
619            event = rx.recv() => match event {
620                Ok(TailEvent::Data(chunk)) => {
621                    if framed.send(Frame::Data(chunk)).await.is_err() { break; }
622                }
623                Ok(TailEvent::Exit { code }) => {
624                    let _ = framed.send(Frame::Exit { code }).await;
625                    break;
626                }
627                Err(broadcast::error::RecvError::Lagged(_)) => continue,
628                Err(broadcast::error::RecvError::Closed) => break,
629            },
630            frame = framed.next() => match frame {
631                Some(Ok(Frame::Ping)) => { let _ = framed.send(Frame::Pong).await; }
632                _ => break,
633            },
634        }
635    }
636}
637
638/// Drain ring buffer contents to a tail client, then subscribe to broadcast and spawn relay.
639fn spawn_tail(
640    mut framed: Framed<UnixStream, FrameCodec>,
641    ring_buf: &VecDeque<Bytes>,
642    tail_tx: &broadcast::Sender<TailEvent>,
643) {
644    let rx = tail_tx.subscribe();
645    let chunks: Vec<Bytes> = ring_buf.iter().cloned().collect();
646    tokio::spawn(async move {
647        for chunk in chunks {
648            if framed.send(Frame::Data(chunk)).await.is_err() {
649                return;
650            }
651        }
652        tail_relay(framed, rx).await;
653    });
654}
655
656/// Groups references needed by inner relay handler methods.
657/// `framed` is kept outside (passed to handlers) so `tokio::select!` can
658/// poll `framed.next()` independently without conflicting borrows.
659struct ServerRelay<'a> {
660    async_master: &'a AsyncFd<OwnedFd>,
661    agent: &'a mut AgentForwardState,
662    tunnel: &'a mut TunnelRelayState,
663    pf: &'a mut PortForwardTable,
664    transfer_state: &'a mut TransferState,
665    open_forward_enabled: &'a mut bool,
666    tail_tx: &'a broadcast::Sender<TailEvent>,
667    metadata_slot: &'a Arc<OnceLock<SessionMetadata>>,
668    agent_event_tx: &'a mpsc::UnboundedSender<AgentEvent>,
669    tunnel_event_tx: &'a mpsc::UnboundedSender<TunnelEvent>,
670    pf_event_tx: &'a mpsc::UnboundedSender<PortForwardEvent>,
671    send_notify_tx: &'a mpsc::UnboundedSender<Frame>,
672}
673
674impl ServerRelay<'_> {
675    async fn handle_client_frame(
676        &mut self,
677        framed: &mut Framed<UnixStream, FrameCodec>,
678        frame: Option<Result<Frame, io::Error>>,
679    ) -> Result<ControlFlow<RelayExit>, anyhow::Error> {
680        match frame {
681            Some(Ok(Frame::Data(data))) => {
682                debug!(len = data.len(), "socket -> pty");
683                let mut written = 0;
684                while written < data.len() {
685                    let mut guard = self.async_master.writable().await?;
686                    match guard.try_io(|inner| {
687                        nix::unistd::write(inner, &data[written..]).map_err(io::Error::from)
688                    }) {
689                        Ok(Ok(n)) => written += n,
690                        Ok(Err(e)) => return Err(e.into()),
691                        Err(_would_block) => continue,
692                    }
693                }
694            }
695            Some(Ok(Frame::Resize { cols, rows })) => {
696                let (cols, rows) = crate::security::clamp_winsize(cols, rows);
697                debug!(cols, rows, "resize pty");
698                let ws = libc::winsize { ws_row: rows, ws_col: cols, ws_xpixel: 0, ws_ypixel: 0 };
699                unsafe {
700                    libc::ioctl(self.async_master.as_raw_fd(), libc::TIOCSWINSZ, &ws as *const _);
701                }
702                if let Ok(pgid) = nix::unistd::tcgetpgrp(self.async_master) {
703                    let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGWINCH);
704                }
705            }
706            Some(Ok(Frame::Ping)) => {
707                if let Some(meta) = self.metadata_slot.get() {
708                    let now = std::time::SystemTime::now()
709                        .duration_since(std::time::UNIX_EPOCH)
710                        .unwrap_or_default()
711                        .as_secs();
712                    meta.last_heartbeat.store(now, Ordering::Relaxed);
713                }
714                let _ = framed.send(Frame::Pong).await;
715            }
716            Some(Ok(Frame::AgentForward)) => {
717                debug!("agent forwarding enabled by client");
718                self.agent.enabled = true;
719                if self.agent.acceptor.is_none() {
720                    if let Some(listener) = bind_agent_listener(&self.agent.socket_path) {
721                        self.agent.acceptor = Some(spawn_agent_acceptor(
722                            listener,
723                            self.agent_event_tx.clone(),
724                            self.agent.next_channel_id.clone(),
725                        ));
726                    }
727                }
728            }
729            Some(Ok(Frame::AgentData { channel_id, data })) => {
730                if let Some(tx) = self.agent.channels.get(&channel_id) {
731                    let _ = tx.send(data).await;
732                }
733            }
734            Some(Ok(Frame::AgentClose { channel_id })) => {
735                self.agent.channels.remove(&channel_id);
736            }
737            Some(Ok(Frame::OpenForward)) => {
738                debug!("open forwarding enabled by client");
739                *self.open_forward_enabled = true;
740            }
741            Some(Ok(Frame::TunnelOpen { channel_id })) => {
742                if let Some(port) = self.tunnel.port {
743                    self.tunnel.idle_deadline = None;
744                    let tx = self.tunnel_event_tx.clone();
745                    tokio::spawn(async move {
746                        match tokio::net::TcpStream::connect(("127.0.0.1", port)).await {
747                            Ok(stream) => {
748                                let _ = tx.send(TunnelEvent::Connected { channel_id, stream });
749                            }
750                            Err(_) => {
751                                let _ = tx.send(TunnelEvent::ConnectFailed { channel_id });
752                            }
753                        }
754                    });
755                }
756            }
757            Some(Ok(Frame::TunnelData { channel_id, data })) => {
758                if let Some(tx) = self.tunnel.channels.get(&channel_id) {
759                    let _ = tx.send(data).await;
760                }
761            }
762            Some(Ok(Frame::TunnelClose { channel_id })) => {
763                self.tunnel.channels.remove(&channel_id);
764                if self.tunnel.channels.is_empty() && self.tunnel.port.is_some() {
765                    self.tunnel.idle_deadline =
766                        Some(tokio::time::Instant::now() + self.tunnel.idle_timeout);
767                }
768            }
769            Some(Ok(Frame::PortForwardReady { forward_id })) => {
770                if let Some(mut svc_stream) = self.pf.pending_remote.remove(&forward_id) {
771                    use tokio::io::AsyncWriteExt;
772                    let _ = svc_stream.write_all(&[0x01]).await;
773                    let stop_handle =
774                        spawn_pf_svc_watcher(svc_stream, forward_id, self.pf_event_tx.clone());
775                    if let Some(fwd) = self.pf.forwards.get_mut(&forward_id) {
776                        fwd.stop_handle = Some(stop_handle);
777                    }
778                }
779            }
780            Some(Ok(Frame::PortForwardOpen { forward_id, channel_id, target_port })) => {
781                if self.pf.forwards.contains_key(&forward_id) {
782                    let tx = self.pf_event_tx.clone();
783                    tokio::spawn(async move {
784                        match tokio::net::TcpStream::connect(("127.0.0.1", target_port)).await {
785                            Ok(stream) => {
786                                let _ = tx.send(PortForwardEvent::Connected {
787                                    forward_id,
788                                    channel_id,
789                                    stream,
790                                });
791                            }
792                            Err(_) => {
793                                let _ = tx.send(PortForwardEvent::ConnectFailed { channel_id });
794                            }
795                        }
796                    });
797                }
798            }
799            Some(Ok(Frame::PortForwardData { channel_id, data })) => {
800                if let Some((_, tx)) = self.pf.channels.get(&channel_id) {
801                    let _ = tx.send(data).await;
802                }
803            }
804            Some(Ok(Frame::PortForwardClose { channel_id })) => {
805                if let Some((forward_id, _)) = self.pf.channels.remove(&channel_id) {
806                    if let Some(fwd) = self.pf.forwards.get_mut(&forward_id) {
807                        fwd.channels.remove(&channel_id);
808                    }
809                }
810            }
811            Some(Ok(Frame::PortForwardStop { forward_id })) => {
812                if let Some(mut svc_stream) = self.pf.pending_remote.remove(&forward_id) {
813                    use tokio::io::AsyncWriteExt;
814                    let _ = svc_stream.write_all(&[0x02]).await;
815                    let _ = svc_stream.write_all(b"client declined forward").await;
816                }
817                if let Some(fwd) = self.pf.forwards.remove(&forward_id) {
818                    if let Some(h) = fwd.listener_handle {
819                        h.abort();
820                    }
821                    if let Some(h) = fwd.stop_handle {
822                        h.abort();
823                    }
824                    for ch_id in fwd.channels.keys() {
825                        self.pf.channels.remove(ch_id);
826                    }
827                }
828            }
829            Some(Ok(Frame::Exit { .. })) | None => {
830                return Ok(ControlFlow::Break(RelayExit::ClientGone));
831            }
832            Some(Ok(_)) => {}
833            Some(Err(e)) => return Err(e.into()),
834        }
835        Ok(ControlFlow::Continue(()))
836    }
837
838    async fn handle_agent_event(
839        &mut self,
840        framed: &mut Framed<UnixStream, FrameCodec>,
841        event: Option<AgentEvent>,
842    ) {
843        match event {
844            Some(AgentEvent::Accepted { channel_id, writer_tx }) => {
845                if self.agent.enabled {
846                    self.agent.channels.insert(channel_id, writer_tx);
847                    let _ = framed.send(Frame::AgentOpen { channel_id }).await;
848                }
849            }
850            Some(AgentEvent::Data { channel_id, data }) => {
851                if self.agent.enabled && self.agent.channels.contains_key(&channel_id) {
852                    let _ = framed.send(Frame::AgentData { channel_id, data }).await;
853                }
854            }
855            Some(AgentEvent::Closed { channel_id }) => {
856                if self.agent.channels.remove(&channel_id).is_some() {
857                    let _ = framed.send(Frame::AgentClose { channel_id }).await;
858                }
859            }
860            None => {
861                debug!("agent event channel closed");
862            }
863        }
864    }
865
866    async fn handle_open_event(
867        &mut self,
868        framed: &mut Framed<UnixStream, FrameCodec>,
869        event: Option<OpenEvent>,
870    ) {
871        match event {
872            Some(OpenEvent::Url(url)) => {
873                if *self.open_forward_enabled {
874                    if let Some(port) = extract_redirect_port(&url) {
875                        if port_in_use(port) {
876                            debug!(port, "setting up reverse tunnel for OAuth callback");
877                            self.tunnel.port = Some(port);
878                            let _ = framed.send(Frame::TunnelListen { port }).await;
879                        }
880                    }
881                    let _ = framed.send(Frame::OpenUrl { url }).await;
882                }
883            }
884            None => {
885                debug!("open event channel closed");
886            }
887        }
888    }
889
890    async fn handle_tunnel_event(
891        &mut self,
892        framed: &mut Framed<UnixStream, FrameCodec>,
893        event: Option<TunnelEvent>,
894    ) {
895        match event {
896            Some(TunnelEvent::Connected { channel_id, stream }) => {
897                if self.tunnel.port.is_some() {
898                    debug!(channel_id, "tunnel channel connected");
899                    self.tunnel.idle_deadline = None;
900                    let (mut read_half, write_half) = stream.into_split();
901                    let (writer_tx, mut writer_rx) =
902                        mpsc::channel::<Bytes>(crate::CHANNEL_RELAY_BUFFER);
903                    self.tunnel.channels.insert(channel_id, writer_tx);
904
905                    // Writer task: channel -> TCP
906                    tokio::spawn(async move {
907                        use tokio::io::AsyncWriteExt;
908                        let mut writer = write_half;
909                        while let Some(data) = writer_rx.recv().await {
910                            if writer.write_all(&data).await.is_err() {
911                                break;
912                            }
913                        }
914                    });
915
916                    // Reader task: TCP -> TunnelEvent
917                    let tx = self.tunnel_event_tx.clone();
918                    tokio::spawn(async move {
919                        let mut buf = vec![0u8; 4096];
920                        loop {
921                            match read_half.read(&mut buf).await {
922                                Ok(0) | Err(_) => {
923                                    let _ = tx.send(TunnelEvent::Closed { channel_id });
924                                    break;
925                                }
926                                Ok(n) => {
927                                    let data = Bytes::copy_from_slice(&buf[..n]);
928                                    if tx.send(TunnelEvent::Data { channel_id, data }).is_err() {
929                                        break;
930                                    }
931                                }
932                            }
933                        }
934                    });
935                }
936            }
937            Some(TunnelEvent::ConnectFailed { channel_id }) => {
938                let _ = framed.send(Frame::TunnelClose { channel_id }).await;
939            }
940            Some(TunnelEvent::Data { channel_id, data }) => {
941                let _ = framed.send(Frame::TunnelData { channel_id, data }).await;
942            }
943            Some(TunnelEvent::Closed { channel_id }) => {
944                self.tunnel.channels.remove(&channel_id);
945                let _ = framed.send(Frame::TunnelClose { channel_id }).await;
946                if self.tunnel.channels.is_empty() && self.tunnel.port.is_some() {
947                    self.tunnel.idle_deadline =
948                        Some(tokio::time::Instant::now() + self.tunnel.idle_timeout);
949                }
950            }
951            None => {}
952        }
953    }
954
955    async fn handle_send_notification(
956        &mut self,
957        framed: &mut Framed<UnixStream, FrameCodec>,
958        notification: Option<Frame>,
959    ) {
960        if let Some(frame) = notification {
961            if matches!(frame, Frame::SendDone | Frame::SendCancel { .. }) {
962                *self.transfer_state = TransferState::Idle;
963            }
964            let _ = framed.send(frame).await;
965        }
966    }
967
968    async fn handle_pf_event(
969        &mut self,
970        framed: &mut Framed<UnixStream, FrameCodec>,
971        event: Option<PortForwardEvent>,
972    ) {
973        match event {
974            Some(PortForwardEvent::Requested { stream, direction, listen_port, target_port }) => {
975                use tokio::io::AsyncWriteExt;
976                let fwd_id = self.pf.next_forward_id;
977                self.pf.next_forward_id += 1;
978                if direction == 0 {
979                    // Local-forward: server binds TCP, forwards to client
980                    match tokio::net::TcpListener::bind(("127.0.0.1", listen_port)).await {
981                        Ok(listener) => {
982                            debug!(fwd_id, listen_port, target_port, "local-forward: bound");
983                            let handle = spawn_pf_tcp_acceptor(
984                                listener,
985                                fwd_id,
986                                self.pf.next_channel_id.clone(),
987                                self.pf_event_tx.clone(),
988                            );
989                            let mut s = stream;
990                            let _ = s.write_all(&[0x01]).await;
991                            let stream = s;
992                            let stop_handle =
993                                spawn_pf_svc_watcher(stream, fwd_id, self.pf_event_tx.clone());
994                            self.pf.forwards.insert(
995                                fwd_id,
996                                PortForwardState {
997                                    listener_handle: Some(handle),
998                                    channels: HashMap::new(),
999                                    stop_handle: Some(stop_handle),
1000                                    target_port,
1001                                },
1002                            );
1003                        }
1004                        Err(e) => {
1005                            debug!(listen_port, "local-forward: bind failed: {e}");
1006                            let mut s = stream;
1007                            let msg = format!("bind failed: {e}");
1008                            let _ = s.write_all(&[0x02]).await;
1009                            let _ = s.write_all(msg.as_bytes()).await;
1010                        }
1011                    }
1012                } else {
1013                    // Remote-forward: tell client to bind, wait for Ready
1014                    let _ = framed
1015                        .send(Frame::PortForwardListen {
1016                            forward_id: fwd_id,
1017                            listen_port,
1018                            target_port,
1019                        })
1020                        .await;
1021                    self.pf.pending_remote.insert(fwd_id, stream);
1022                    self.pf.forwards.insert(
1023                        fwd_id,
1024                        PortForwardState {
1025                            listener_handle: None,
1026                            channels: HashMap::new(),
1027                            stop_handle: None,
1028                            target_port,
1029                        },
1030                    );
1031                }
1032            }
1033            Some(PortForwardEvent::Accepted { forward_id, channel_id, writer_tx }) => {
1034                if let Some(fwd) = self.pf.forwards.get_mut(&forward_id) {
1035                    fwd.channels.insert(channel_id, writer_tx.clone());
1036                    self.pf.channels.insert(channel_id, (forward_id, writer_tx));
1037                    let _ = framed
1038                        .send(Frame::PortForwardOpen {
1039                            forward_id,
1040                            channel_id,
1041                            target_port: fwd.target_port,
1042                        })
1043                        .await;
1044                }
1045            }
1046            Some(PortForwardEvent::Connected { forward_id, channel_id, stream }) => {
1047                if self.pf.forwards.contains_key(&forward_id) {
1048                    let (read_half, write_half) = stream.into_split();
1049                    let data_tx = self.pf_event_tx.clone();
1050                    let close_tx = self.pf_event_tx.clone();
1051                    let writer_tx = crate::spawn_channel_relay(
1052                        channel_id,
1053                        read_half,
1054                        write_half,
1055                        move |id, data| {
1056                            data_tx.send(PortForwardEvent::Data { channel_id: id, data }).is_ok()
1057                        },
1058                        move |id| {
1059                            let _ = close_tx.send(PortForwardEvent::Closed { channel_id: id });
1060                        },
1061                    );
1062                    self.pf.channels.insert(channel_id, (forward_id, writer_tx.clone()));
1063                    if let Some(fwd) = self.pf.forwards.get_mut(&forward_id) {
1064                        fwd.channels.insert(channel_id, writer_tx);
1065                    }
1066                }
1067            }
1068            Some(PortForwardEvent::ConnectFailed { channel_id }) => {
1069                let _ = framed.send(Frame::PortForwardClose { channel_id }).await;
1070            }
1071            Some(PortForwardEvent::Data { channel_id, data }) => {
1072                if self.pf.channels.contains_key(&channel_id) {
1073                    let _ = framed.send(Frame::PortForwardData { channel_id, data }).await;
1074                }
1075            }
1076            Some(PortForwardEvent::Closed { channel_id }) => {
1077                if let Some((forward_id, _)) = self.pf.channels.remove(&channel_id) {
1078                    let _ = framed.send(Frame::PortForwardClose { channel_id }).await;
1079                    if let Some(fwd) = self.pf.forwards.get_mut(&forward_id) {
1080                        fwd.channels.remove(&channel_id);
1081                    }
1082                }
1083            }
1084            Some(PortForwardEvent::Stopped { forward_id }) => {
1085                debug!(forward_id, "port forward stopped (svc socket dropped)");
1086                if let Some(fwd) = self.pf.forwards.remove(&forward_id) {
1087                    if let Some(h) = fwd.listener_handle {
1088                        h.abort();
1089                    }
1090                    if let Some(h) = fwd.stop_handle {
1091                        h.abort();
1092                    }
1093                    for ch_id in fwd.channels.keys() {
1094                        self.pf.channels.remove(ch_id);
1095                    }
1096                    let _ = framed.send(Frame::PortForwardStop { forward_id }).await;
1097                }
1098            }
1099            None => {}
1100        }
1101    }
1102}
1103
1104#[allow(clippy::too_many_arguments)]
1105pub async fn run(
1106    mut client_rx: mpsc::UnboundedReceiver<ClientConn>,
1107    metadata_slot: Arc<OnceLock<SessionMetadata>>,
1108    agent_socket_path: PathBuf,
1109    svc_socket_path: PathBuf,
1110    session_id: u32,
1111    session_name: Option<String>,
1112    command: Option<String>,
1113    ring_buffer_cap: usize,
1114    oauth_tunnel_idle_timeout: u64,
1115) -> anyhow::Result<()> {
1116    // Allocate PTY (once, before accept loop)
1117    let pty = openpty(None, None)?;
1118    let master: OwnedFd = pty.master;
1119    let slave: OwnedFd = pty.slave;
1120
1121    // Get PTY slave name before we drop the slave fd
1122    let pty_path =
1123        nix::unistd::ttyname(&slave).map(|p| p.display().to_string()).unwrap_or_default();
1124
1125    // Dup slave fds for shell stdio (before dropping slave)
1126    let slave_fd = slave.as_raw_fd();
1127    let stdin_fd = crate::security::checked_dup(slave_fd)?;
1128    let stdout_fd = crate::security::checked_dup(slave_fd)?;
1129    let stderr_fd = crate::security::checked_dup(slave_fd)?;
1130    let raw_stdin = stdin_fd.as_raw_fd();
1131    drop(slave);
1132
1133    // Set master to non-blocking for AsyncFd
1134    let flags = nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_GETFL)?;
1135    let mut oflags = nix::fcntl::OFlag::from_bits_truncate(flags);
1136    oflags |= nix::fcntl::OFlag::O_NONBLOCK;
1137    nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_SETFL(oflags))?;
1138
1139    let async_master = AsyncFd::new(master)?;
1140    let mut buf = vec![0u8; 4096];
1141    let mut ring_buf: VecDeque<Bytes> = VecDeque::new();
1142    let mut ring_buf_size: usize = 0;
1143    let mut ring_buf_dropped: usize = 0;
1144
1145    // Agent forwarding state
1146    let mut agent = AgentForwardState::new(agent_socket_path);
1147    let (agent_event_tx, mut agent_event_rx) = mpsc::unbounded_channel::<AgentEvent>();
1148
1149    // Open forwarding state (no acceptor -- svc_acceptor handles the socket)
1150    let mut open_forward_enabled = false;
1151
1152    // Tunnel state (reverse TCP tunnel for OAuth callbacks)
1153    let (tunnel_event_tx, mut tunnel_event_rx) = mpsc::unbounded_channel::<TunnelEvent>();
1154    let mut tunnel = TunnelRelayState::new(Duration::from_secs(oauth_tunnel_idle_timeout));
1155
1156    // Broadcast channel for tail clients
1157    let (tail_tx, _) = broadcast::channel::<TailEvent>(256);
1158
1159    // Open and send event channels (created at session start, persist across clients)
1160    let (open_event_tx, mut open_event_rx) = mpsc::unbounded_channel::<OpenEvent>();
1161    let (send_event_tx, mut send_event_rx) = mpsc::unbounded_channel::<SendEvent>();
1162    let (send_notify_tx, mut send_notify_rx) = mpsc::unbounded_channel::<Frame>();
1163    let mut transfer_state = TransferState::Idle;
1164    let mut svc_acceptor: Option<tokio::task::JoinHandle<()>> = None;
1165
1166    // Port forward event channel and state
1167    let (pf_event_tx, mut pf_event_rx) = mpsc::unbounded_channel::<PortForwardEvent>();
1168    let mut pf = PortForwardTable::new();
1169
1170    // Bind unified service socket immediately (always available)
1171    if let Some(listener) = bind_agent_listener(&svc_socket_path) {
1172        svc_acceptor = Some(spawn_svc_acceptor(
1173            listener,
1174            open_event_tx.clone(),
1175            send_event_tx.clone(),
1176            pf_event_tx.clone(),
1177        ));
1178    }
1179
1180    // Wait for first active client before spawning shell (so we can read Env frame).
1181    // Tail and send clients that arrive before the first active client get handled
1182    // appropriately (tail subscribed to broadcast, send queued for rendezvous).
1183    let mut framed = loop {
1184        tokio::select! {
1185            client = client_rx.recv() => match client {
1186                Some(ClientConn::Active(f)) => {
1187                    info!("first client connected via channel");
1188                    break f;
1189                }
1190                Some(ClientConn::Tail(f)) => {
1191                    info!("tail client connected before shell spawn");
1192                    spawn_tail(f, &ring_buf, &tail_tx);
1193                    continue;
1194                }
1195                Some(ClientConn::Send(stream)) => {
1196                    handle_send_stream(stream, &send_event_tx);
1197                    continue;
1198                }
1199                None => {
1200                    info!("client channel closed before first client");
1201                    cleanup_socket(&agent.socket_path);
1202                    cleanup_socket(&svc_socket_path);
1203                    return Ok(());
1204                }
1205            },
1206            event = send_event_rx.recv() => {
1207                if let Some(event) = event {
1208                    handle_send_event(event, &mut transfer_state, &send_notify_tx);
1209                }
1210                continue;
1211            }
1212        }
1213    };
1214
1215    // Read optional Env frame from first client (2s timeout -- generous for slow SSH tunnels)
1216    let env_vars =
1217        match tokio::time::timeout(std::time::Duration::from_secs(2), framed.next()).await {
1218            Ok(Some(Ok(Frame::Env { vars }))) => {
1219                debug!(count = vars.len(), "received env vars from client");
1220                vars
1221            }
1222            _ => Vec::new(),
1223        };
1224
1225    // Spawn shell (or custom command) on slave PTY
1226    let shell = std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string());
1227    let home = std::env::var("HOME").ok();
1228    let mut cmd = Command::new(&shell);
1229    if let Some(ref user_cmd) = command {
1230        cmd.arg("-c").arg(user_cmd);
1231    } else {
1232        cmd.arg("-l");
1233    }
1234    if let Some(ref dir) = home {
1235        cmd.current_dir(dir);
1236    }
1237    const ALLOWED_ENV_KEYS: &[&str] = &["TERM", "LANG", "COLORTERM"];
1238    for (k, v) in &env_vars {
1239        if ALLOWED_ENV_KEYS.contains(&k.as_str()) {
1240            cmd.env(k, v);
1241        } else if k == "BROWSER" {
1242            // Client signals open forwarding desired; resolve to server-side binary
1243            let exe = std::env::current_exe()
1244                .ok()
1245                .and_then(|p| p.to_str().map(String::from))
1246                .unwrap_or_else(|| "gritty".into());
1247            cmd.env("BROWSER", format!("{exe} open"));
1248        } else {
1249            warn!(key = k, "ignoring disallowed env var from client");
1250        }
1251    }
1252    // Set SSH_AUTH_SOCK to the agent socket path
1253    cmd.env("SSH_AUTH_SOCK", &agent.socket_path);
1254    // Set GRITTY_SOCK so `gritty open`/`gritty send`/`gritty receive` find the service socket
1255    cmd.env("GRITTY_SOCK", &svc_socket_path);
1256    // Session context env vars
1257    cmd.env("GRITTY_SESSION", session_id.to_string());
1258    if let Some(ref name) = session_name {
1259        cmd.env("GRITTY_SESSION_NAME", name);
1260    }
1261    let mut managed = ManagedChild::new(unsafe {
1262        cmd.pre_exec(move || {
1263            nix::unistd::setsid().map_err(io::Error::other)?;
1264            libc::ioctl(raw_stdin, libc::TIOCSCTTY as libc::c_ulong, 0);
1265            Ok(())
1266        })
1267        .stdin(Stdio::from(stdin_fd))
1268        .stdout(Stdio::from(stdout_fd))
1269        .stderr(Stdio::from(stderr_fd))
1270        .spawn()?
1271    });
1272
1273    let shell_pid = managed.child.id().unwrap_or(0);
1274    let created_at = std::time::SystemTime::now()
1275        .duration_since(std::time::UNIX_EPOCH)
1276        .unwrap_or_default()
1277        .as_secs();
1278
1279    let _ = metadata_slot.set(SessionMetadata {
1280        pty_path,
1281        shell_pid,
1282        created_at,
1283        attached: AtomicBool::new(false),
1284        last_heartbeat: AtomicU64::new(0),
1285    });
1286
1287    // First client is already connected — enter relay directly
1288    metadata_slot.get().unwrap().attached.store(true, Ordering::Relaxed);
1289
1290    // Outer loop: accept clients via channel. PTY persists across reconnects.
1291    // First iteration skips client-wait (first client already connected above).
1292    let mut first_client = true;
1293    loop {
1294        if !first_client {
1295            let got_client = 'drain: loop {
1296                tokio::select! {
1297                    client = client_rx.recv() => {
1298                        match client {
1299                            Some(ClientConn::Active(f)) => {
1300                                info!("client connected via channel");
1301                                framed = f;
1302                                break 'drain true;
1303                            }
1304                            Some(ClientConn::Tail(f)) => {
1305                                info!("tail client connected while disconnected");
1306                                spawn_tail(f, &ring_buf, &tail_tx);
1307                                continue;
1308                            }
1309                            Some(ClientConn::Send(stream)) => {
1310                                handle_send_stream(stream, &send_event_tx);
1311                                continue;
1312                            }
1313                            None => {
1314                                info!("client channel closed");
1315                                break 'drain false;
1316                            }
1317                        }
1318                    }
1319                    status = managed.child.wait() => {
1320                        let code = status?.code().unwrap_or(1);
1321                        info!(code, "shell exited while awaiting client");
1322                        break 'drain false;
1323                    }
1324                    ready = async_master.readable() => {
1325                        let mut guard = ready?;
1326                        match guard.try_io(|inner| {
1327                            nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
1328                        }) {
1329                            Ok(Ok(0)) => {
1330                                debug!("pty EOF while disconnected");
1331                                break 'drain false;
1332                            }
1333                            Ok(Ok(n)) => {
1334                                let chunk = Bytes::copy_from_slice(&buf[..n]);
1335                                if tail_tx.receiver_count() > 0 {
1336                                    let _ = tail_tx.send(TailEvent::Data(chunk.clone()));
1337                                }
1338                                ring_buf_size += chunk.len();
1339                                ring_buf.push_back(chunk);
1340                                while ring_buf_size > ring_buffer_cap {
1341                                    if let Some(old) = ring_buf.pop_front() {
1342                                        ring_buf_size -= old.len();
1343                                        ring_buf_dropped += old.len();
1344                                    }
1345                                }
1346                            }
1347                            Ok(Err(e)) => {
1348                                if e.raw_os_error() == Some(libc::EIO) {
1349                                    debug!("pty EIO while disconnected");
1350                                    break 'drain false;
1351                                }
1352                                return Err(e.into());
1353                            }
1354                            Err(_would_block) => continue,
1355                        }
1356                    }
1357                    event = send_event_rx.recv() => {
1358                        if let Some(event) = event {
1359                            handle_send_event(event, &mut transfer_state, &send_notify_tx);
1360                        }
1361                        continue;
1362                    }
1363                }
1364            };
1365            if !got_client {
1366                break;
1367            }
1368
1369            if let Some(meta) = metadata_slot.get() {
1370                meta.attached.store(true, Ordering::Relaxed);
1371            }
1372        }
1373        first_client = false;
1374
1375        // Flush any buffered PTY output to the new client
1376        if !ring_buf.is_empty() {
1377            debug!(
1378                chunks = ring_buf.len(),
1379                bytes = ring_buf_size,
1380                dropped = ring_buf_dropped,
1381                "flushing ring buffer"
1382            );
1383            if ring_buf_dropped > 0 {
1384                let msg = format!(
1385                    "\r\n\x1b[2;33m[gritty: {} bytes of output dropped]\x1b[0m\r\n",
1386                    ring_buf_dropped
1387                );
1388                framed.send(Frame::Data(Bytes::from(msg))).await?;
1389                ring_buf_dropped = 0;
1390            }
1391            while let Some(chunk) = ring_buf.pop_front() {
1392                framed.send(Frame::Data(chunk)).await?;
1393            }
1394            ring_buf_size = 0;
1395        }
1396
1397        // Inner loop: relay between socket and PTY.
1398        // Scoped block so ServerRelay borrows are released before
1399        // the post-loop code accesses the underlying state directly.
1400        let exit = {
1401            let mut relay = ServerRelay {
1402                async_master: &async_master,
1403                agent: &mut agent,
1404                tunnel: &mut tunnel,
1405                pf: &mut pf,
1406                transfer_state: &mut transfer_state,
1407                open_forward_enabled: &mut open_forward_enabled,
1408                tail_tx: &tail_tx,
1409                metadata_slot: &metadata_slot,
1410                agent_event_tx: &agent_event_tx,
1411                tunnel_event_tx: &tunnel_event_tx,
1412                pf_event_tx: &pf_event_tx,
1413                send_notify_tx: &send_notify_tx,
1414            };
1415            loop {
1416                tokio::select! {
1417                    frame = framed.next() => {
1418                        if let ControlFlow::Break(exit) = relay.handle_client_frame(&mut framed, frame).await? {
1419                            break exit;
1420                        }
1421                    }
1422
1423                    ready = relay.async_master.readable() => {
1424                        let mut guard = ready?;
1425                        match guard.try_io(|inner| {
1426                            nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
1427                        }) {
1428                            Ok(Ok(0)) => {
1429                                debug!("pty EOF");
1430                                break RelayExit::ShellExited(0);
1431                            }
1432                            Ok(Ok(n)) => {
1433                                debug!(len = n, "pty -> socket");
1434                                let chunk = Bytes::copy_from_slice(&buf[..n]);
1435                                if relay.tail_tx.receiver_count() > 0 {
1436                                    let _ = relay.tail_tx.send(TailEvent::Data(chunk.clone()));
1437                                }
1438                                framed.send(Frame::Data(chunk)).await?;
1439                            }
1440                            Ok(Err(e)) => {
1441                                if e.raw_os_error() == Some(libc::EIO) {
1442                                    debug!("pty EIO (shell exited)");
1443                                    break RelayExit::ShellExited(0);
1444                                }
1445                                return Err(e.into());
1446                            }
1447                            Err(_would_block) => continue,
1448                        }
1449                    }
1450
1451                    new_client = client_rx.recv() => {
1452                        match new_client {
1453                            Some(ClientConn::Active(new_framed)) => {
1454                                info!("new client via channel, detaching old client");
1455                                let _ = framed.send(Frame::Detached).await;
1456                                relay.agent.teardown();
1457                                relay.tunnel.teardown();
1458                                relay.pf.teardown();
1459                                *relay.open_forward_enabled = false;
1460                                // Inform the new client about the takeover
1461                                let was_attached = relay.metadata_slot.get()
1462                                    .map(|m| m.attached.load(Ordering::Relaxed))
1463                                    .unwrap_or(false);
1464                                framed = new_framed;
1465                                if was_attached {
1466                                    let hb_age = relay.metadata_slot.get()
1467                                        .and_then(|m| {
1468                                            let hb = m.last_heartbeat.load(Ordering::Relaxed);
1469                                            if hb == 0 { return None; }
1470                                            let now = std::time::SystemTime::now()
1471                                                .duration_since(std::time::UNIX_EPOCH)
1472                                                .unwrap_or_default()
1473                                                .as_secs();
1474                                            Some(now.saturating_sub(hb))
1475                                        });
1476                                    let hb_str = match hb_age {
1477                                        Some(s) => format!("{s}s ago"),
1478                                        None => "n/a".to_string(),
1479                                    };
1480                                    let msg = format!(
1481                                        "\r\n\x1b[2;33m[gritty: took over session (was active, heartbeat {hb_str})]\x1b[0m\r\n"
1482                                    );
1483                                    let _ = framed.send(Frame::Data(Bytes::from(msg))).await;
1484                                }
1485                            }
1486                            Some(ClientConn::Tail(f)) => {
1487                                info!("tail client connected while active");
1488                                spawn_tail(f, &ring_buf, relay.tail_tx);
1489                            }
1490                            Some(ClientConn::Send(stream)) => {
1491                                handle_send_stream(stream, &send_event_tx);
1492                            }
1493                            None => {}
1494                        }
1495                    }
1496
1497                    event = agent_event_rx.recv() => {
1498                        relay.handle_agent_event(&mut framed, event).await;
1499                    }
1500
1501                    event = open_event_rx.recv() => {
1502                        relay.handle_open_event(&mut framed, event).await;
1503                    }
1504
1505                    event = tunnel_event_rx.recv() => {
1506                        relay.handle_tunnel_event(&mut framed, event).await;
1507                    }
1508
1509                    _ = async {
1510                        match relay.tunnel.idle_deadline {
1511                            Some(deadline) => tokio::time::sleep_until(deadline).await,
1512                            None => std::future::pending().await,
1513                        }
1514                    } => {
1515                        if relay.tunnel.channels.is_empty() {
1516                            debug!("tunnel idle timeout, tearing down");
1517                            relay.tunnel.teardown();
1518                        }
1519                    }
1520
1521                    event = send_event_rx.recv() => {
1522                        if let Some(event) = event {
1523                            handle_send_event(event, relay.transfer_state, relay.send_notify_tx);
1524                        }
1525                    }
1526
1527                    notification = send_notify_rx.recv() => {
1528                        relay.handle_send_notification(&mut framed, notification).await;
1529                    }
1530
1531                    event = pf_event_rx.recv() => {
1532                        relay.handle_pf_event(&mut framed, event).await;
1533                    }
1534
1535                    status = managed.child.wait() => {
1536                        let code = status?.code().unwrap_or(1);
1537                        info!(code, "shell exited");
1538                        let _ = relay.tail_tx.send(TailEvent::Exit { code });
1539                        break RelayExit::ShellExited(code);
1540                    }
1541                }
1542            }
1543        };
1544
1545        match exit {
1546            RelayExit::ClientGone => {
1547                if let Some(meta) = metadata_slot.get() {
1548                    meta.attached.store(false, Ordering::Relaxed);
1549                }
1550                agent.teardown();
1551                tunnel.teardown();
1552                pf.teardown();
1553                open_forward_enabled = false;
1554                info!("client disconnected, waiting for reconnect");
1555                continue;
1556            }
1557            RelayExit::ShellExited(mut code) => {
1558                // PTY EOF/EIO may fire before child.wait() -- especially on
1559                // macOS where the race window is wider. Give the child a
1560                // moment to actually exit so we can capture the real code.
1561                if let Ok(Ok(status)) = tokio::time::timeout(
1562                    std::time::Duration::from_millis(500),
1563                    managed.child.wait(),
1564                )
1565                .await
1566                {
1567                    code = status.code().unwrap_or(code);
1568                }
1569                let _ = tail_tx.send(TailEvent::Exit { code });
1570                let _ = framed.send(Frame::Exit { code }).await;
1571                info!(code, "session ended");
1572                break;
1573            }
1574        }
1575    }
1576
1577    cleanup_socket(&agent.socket_path);
1578    cleanup_socket(&svc_socket_path);
1579    if let Some(handle) = svc_acceptor.take() {
1580        handle.abort();
1581    }
1582    Ok(())
1583}
1584
1585/// Handle a raw send stream from ClientConn::Send (local-side commands).
1586/// Spawns a task to read the SvcRequest discriminator and manifest/dest, then sends the event.
1587fn handle_send_stream(mut stream: UnixStream, send_event_tx: &mpsc::UnboundedSender<SendEvent>) {
1588    let etx = send_event_tx.clone();
1589    tokio::spawn(async move {
1590        let mut disc = [0u8; 1];
1591        if stream.read_exact(&mut disc).await.is_err() {
1592            return;
1593        }
1594        match crate::protocol::SvcRequest::from_byte(disc[0]) {
1595            Some(crate::protocol::SvcRequest::Send) => {
1596                match parse_sender_manifest(&mut stream).await {
1597                    Ok(manifest) => {
1598                        let _ = etx.send(SendEvent::SenderArrived { stream, manifest });
1599                    }
1600                    Err(e) => debug!("send stream: bad sender manifest: {e}"),
1601                }
1602            }
1603            Some(crate::protocol::SvcRequest::Receive) => {
1604                match parse_receiver_dest(&mut stream).await {
1605                    Ok(_) => {
1606                        let _ = etx.send(SendEvent::ReceiverArrived { stream });
1607                    }
1608                    Err(e) => debug!("send stream: bad receiver dest: {e}"),
1609                }
1610            }
1611            _ => {}
1612        }
1613    });
1614}
1615
1616/// Check if a stream's peer has disconnected (EOF or error).
1617/// Returns true if the stream is dead and should be discarded.
1618fn stream_is_dead(stream: &UnixStream) -> bool {
1619    let mut probe = [0u8; 1];
1620    match stream.try_read(&mut probe) {
1621        Ok(0) => true,                                                     // EOF
1622        Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => false, // alive
1623        Err(_) => true,                                                    // error
1624        Ok(_) => false, // unexpected data, treat as alive
1625    }
1626}
1627
1628/// Handle a send event: pair sender and receiver for rendezvous.
1629fn handle_send_event(
1630    event: SendEvent,
1631    state: &mut TransferState,
1632    notify_tx: &mpsc::UnboundedSender<Frame>,
1633) {
1634    match event {
1635        SendEvent::SenderArrived { stream, manifest } => {
1636            let old = std::mem::replace(state, TransferState::Idle);
1637            match old {
1638                TransferState::WaitingForSender { receiver_stream }
1639                    if !stream_is_dead(&receiver_stream) =>
1640                {
1641                    info!(
1642                        files = manifest.files.len(),
1643                        bytes = manifest.total_bytes(),
1644                        "transfer: sender+receiver paired"
1645                    );
1646                    let handle =
1647                        spawn_transfer_relay(stream, receiver_stream, manifest, notify_tx.clone());
1648                    *state = TransferState::Active { relay_handle: handle };
1649                }
1650                _ => {
1651                    if let TransferState::Active { relay_handle } = old {
1652                        let _ = notify_tx.send(Frame::SendCancel {
1653                            reason: "superseded by new sender".to_string(),
1654                        });
1655                        relay_handle.abort();
1656                    }
1657                    info!(files = manifest.files.len(), "transfer: sender waiting for receiver");
1658                    *state = TransferState::WaitingForReceiver { sender_stream: stream, manifest };
1659                }
1660            }
1661        }
1662        SendEvent::ReceiverArrived { stream } => {
1663            let old = std::mem::replace(state, TransferState::Idle);
1664            match old {
1665                TransferState::WaitingForReceiver { sender_stream, manifest }
1666                    if !stream_is_dead(&sender_stream) =>
1667                {
1668                    info!(
1669                        files = manifest.files.len(),
1670                        bytes = manifest.total_bytes(),
1671                        "transfer: receiver+sender paired"
1672                    );
1673                    let handle =
1674                        spawn_transfer_relay(sender_stream, stream, manifest, notify_tx.clone());
1675                    *state = TransferState::Active { relay_handle: handle };
1676                }
1677                _ => {
1678                    if let TransferState::Active { relay_handle } = old {
1679                        let _ = notify_tx.send(Frame::SendCancel {
1680                            reason: "superseded by new receiver".to_string(),
1681                        });
1682                        relay_handle.abort();
1683                    }
1684                    info!("transfer: receiver waiting for sender");
1685                    *state = TransferState::WaitingForSender { receiver_stream: stream };
1686                }
1687            }
1688        }
1689    }
1690}
1691
1692fn bind_agent_listener(path: &Path) -> Option<UnixListener> {
1693    match crate::security::bind_unix_listener(path) {
1694        Ok(listener) => {
1695            info!(path = %path.display(), "agent socket listening");
1696            Some(listener)
1697        }
1698        Err(e) => {
1699            warn!("failed to bind agent socket at {}: {e}", path.display());
1700            None
1701        }
1702    }
1703}
1704
1705fn cleanup_socket(path: &Path) {
1706    let _ = std::fs::remove_file(path);
1707}
1708
1709#[cfg(test)]
1710mod tests {
1711    use super::*;
1712
1713    #[test]
1714    fn extract_redirect_port_basic() {
1715        let url = "https://accounts.google.com/o/oauth2/auth?redirect_uri=http://localhost:8080/callback&client_id=xyz";
1716        assert_eq!(extract_redirect_port(url), Some(8080));
1717    }
1718
1719    #[test]
1720    fn extract_redirect_port_127() {
1721        let url = "https://auth.example.com/authorize?redirect_uri=http://127.0.0.1:9090/cb";
1722        assert_eq!(extract_redirect_port(url), Some(9090));
1723    }
1724
1725    #[test]
1726    fn extract_redirect_port_url_encoded() {
1727        let url = "https://auth.example.com/authorize?redirect_uri=http%3A%2F%2Flocalhost%3A3000%2Fcallback";
1728        assert_eq!(extract_redirect_port(url), Some(3000));
1729    }
1730
1731    #[test]
1732    fn extract_redirect_port_no_port() {
1733        let url = "https://auth.example.com/authorize?redirect_uri=http://localhost/callback";
1734        assert_eq!(extract_redirect_port(url), None);
1735    }
1736
1737    #[test]
1738    fn extract_redirect_port_no_redirect_uri() {
1739        let url = "https://example.com/page?foo=bar";
1740        assert_eq!(extract_redirect_port(url), None);
1741    }
1742
1743    #[test]
1744    fn extract_redirect_port_non_localhost() {
1745        let url =
1746            "https://auth.example.com/authorize?redirect_uri=https://example.com:8080/callback";
1747        assert_eq!(extract_redirect_port(url), None);
1748    }
1749
1750    #[test]
1751    fn extract_redirect_port_https_redirect() {
1752        let url = "https://auth.example.com/authorize?redirect_uri=https://localhost:4443/callback";
1753        assert_eq!(extract_redirect_port(url), Some(4443));
1754    }
1755
1756    #[test]
1757    fn extract_redirect_url_variant() {
1758        let url = "https://auth.example.com/authorize?redirect_url=http://localhost:5000/cb";
1759        assert_eq!(extract_redirect_port(url), Some(5000));
1760    }
1761}