Skip to main content

gritty/
server.rs

1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::pty::openpty;
5use std::collections::{HashMap, VecDeque};
6use std::io;
7use std::os::fd::{AsRawFd, OwnedFd};
8use std::path::{Path, PathBuf};
9use std::process::Stdio;
10use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
11use std::sync::{Arc, OnceLock};
12use tokio::io::AsyncReadExt;
13use tokio::io::unix::AsyncFd;
14use tokio::net::{UnixListener, UnixStream};
15use tokio::process::Command;
16use tokio::sync::mpsc;
17use tokio_util::codec::Framed;
18use tracing::{debug, info, warn};
19
20pub struct SessionMetadata {
21    pub pty_path: String,
22    pub shell_pid: u32,
23    pub created_at: u64,
24    pub attached: AtomicBool,
25    pub last_heartbeat: AtomicU64,
26}
27
28/// Wraps a child process and its process group ID.
29/// On drop, sends SIGHUP to the entire process group.
30struct ManagedChild {
31    child: tokio::process::Child,
32    pgid: nix::unistd::Pid,
33}
34
35impl ManagedChild {
36    fn new(child: tokio::process::Child) -> Self {
37        let pid = child.id().expect("child should have pid") as i32;
38        Self { child, pgid: nix::unistd::Pid::from_raw(pid) }
39    }
40}
41
42impl Drop for ManagedChild {
43    fn drop(&mut self) {
44        let _ = nix::sys::signal::killpg(self.pgid, nix::sys::signal::Signal::SIGHUP);
45        let _ = self.child.try_wait();
46    }
47}
48
49/// Why the relay loop exited.
50enum RelayExit {
51    /// Client disconnected — re-accept.
52    ClientGone,
53    /// Shell exited with a code — we're done.
54    ShellExited(i32),
55}
56
57/// Events from agent connection tasks to the main relay loop.
58enum AgentEvent {
59    Accepted { channel_id: u32, writer_tx: mpsc::UnboundedSender<Bytes> },
60    Data { channel_id: u32, data: Bytes },
61    Closed { channel_id: u32 },
62}
63
64/// Events from open socket acceptor to the main relay loop.
65enum OpenEvent {
66    Url(String),
67}
68
69/// Spawn the agent acceptor task that accepts connections on the agent socket
70/// and creates per-connection relay tasks.
71fn spawn_agent_acceptor(
72    listener: UnixListener,
73    event_tx: mpsc::UnboundedSender<AgentEvent>,
74    next_channel_id: Arc<AtomicU32>,
75) -> tokio::task::JoinHandle<()> {
76    tokio::spawn(async move {
77        loop {
78            let (stream, _) = match listener.accept().await {
79                Ok(conn) => conn,
80                Err(e) => {
81                    debug!("agent listener accept error: {e}");
82                    break;
83                }
84            };
85
86            let channel_id = next_channel_id.fetch_add(1, Ordering::Relaxed);
87
88            let (read_half, write_half) = stream.into_split();
89            let data_tx = event_tx.clone();
90            let close_tx = event_tx.clone();
91            let writer_tx = crate::spawn_channel_relay(
92                channel_id,
93                read_half,
94                write_half,
95                move |id, data| data_tx.send(AgentEvent::Data { channel_id: id, data }).is_ok(),
96                move |id| {
97                    let _ = close_tx.send(AgentEvent::Closed { channel_id: id });
98                },
99            );
100
101            // Notify the relay loop about the new connection
102            if event_tx.send(AgentEvent::Accepted { channel_id, writer_tx }).is_err() {
103                break; // relay loop is gone
104            }
105        }
106    })
107}
108
109/// Spawn the open acceptor task that accepts connections on the open socket,
110/// reads a URL (up to 8KB, newline-terminated or EOF), and sends it as an event.
111fn spawn_open_acceptor(
112    listener: UnixListener,
113    event_tx: mpsc::UnboundedSender<OpenEvent>,
114) -> tokio::task::JoinHandle<()> {
115    tokio::spawn(async move {
116        loop {
117            let (mut stream, _) = match listener.accept().await {
118                Ok(conn) => conn,
119                Err(e) => {
120                    debug!("open listener accept error: {e}");
121                    break;
122                }
123            };
124
125            let etx = event_tx.clone();
126            tokio::spawn(async move {
127                let mut buf = vec![0u8; 8192];
128                let mut total = 0;
129                loop {
130                    match stream.read(&mut buf[total..]).await {
131                        Ok(0) => break,
132                        Ok(n) => {
133                            total += n;
134                            // Stop at newline or buffer full
135                            if buf[..total].contains(&b'\n') || total >= buf.len() {
136                                break;
137                            }
138                        }
139                        Err(_) => return,
140                    }
141                }
142                let s = String::from_utf8_lossy(&buf[..total]);
143                let url = s.trim();
144                if !url.is_empty() {
145                    let _ = etx.send(OpenEvent::Url(url.to_string()));
146                }
147            });
148        }
149    })
150}
151
152pub async fn run(
153    mut client_rx: mpsc::UnboundedReceiver<Framed<UnixStream, FrameCodec>>,
154    metadata_slot: Arc<OnceLock<SessionMetadata>>,
155    agent_socket_path: PathBuf,
156    open_socket_path: PathBuf,
157) -> anyhow::Result<()> {
158    // Allocate PTY (once, before accept loop)
159    let pty = openpty(None, None)?;
160    let master: OwnedFd = pty.master;
161    let slave: OwnedFd = pty.slave;
162
163    // Get PTY slave name before we drop the slave fd
164    let pty_path =
165        nix::unistd::ttyname(&slave).map(|p| p.display().to_string()).unwrap_or_default();
166
167    // Dup slave fds for shell stdio (before dropping slave)
168    let slave_fd = slave.as_raw_fd();
169    let stdin_fd = crate::security::checked_dup(slave_fd)?;
170    let stdout_fd = crate::security::checked_dup(slave_fd)?;
171    let stderr_fd = crate::security::checked_dup(slave_fd)?;
172    let raw_stdin = stdin_fd.as_raw_fd();
173    drop(slave);
174
175    // Set master to non-blocking for AsyncFd
176    let flags = nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_GETFL)?;
177    let mut oflags = nix::fcntl::OFlag::from_bits_truncate(flags);
178    oflags |= nix::fcntl::OFlag::O_NONBLOCK;
179    nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_SETFL(oflags))?;
180
181    let async_master = AsyncFd::new(master)?;
182    let mut buf = vec![0u8; 4096];
183    let mut ring_buf: VecDeque<Bytes> = VecDeque::new();
184    let mut ring_buf_size: usize = 0;
185    const RING_BUF_CAP: usize = 1 << 20; // 1 MB
186
187    // Agent event channel persists across acceptor lifetimes
188    let (agent_event_tx, mut agent_event_rx) = mpsc::unbounded_channel::<AgentEvent>();
189
190    // Wait for first client before spawning shell (so we can read Env frame)
191    let mut framed = match client_rx.recv().await {
192        Some(framed) => {
193            info!("first client connected via channel");
194            framed
195        }
196        None => {
197            info!("client channel closed before first client");
198            cleanup_socket(&agent_socket_path);
199            return Ok(());
200        }
201    };
202
203    // Read optional Env frame from first client (100ms timeout)
204    let env_vars =
205        match tokio::time::timeout(std::time::Duration::from_millis(100), framed.next()).await {
206            Ok(Some(Ok(Frame::Env { vars }))) => {
207                debug!(count = vars.len(), "received env vars from client");
208                vars
209            }
210            _ => Vec::new(),
211        };
212
213    // Spawn login shell on slave PTY
214    let shell = std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string());
215    let home = std::env::var("HOME").ok();
216    let mut cmd = Command::new(&shell);
217    cmd.arg("-l");
218    if let Some(ref dir) = home {
219        cmd.current_dir(dir);
220    }
221    for (k, v) in &env_vars {
222        cmd.env(k, v);
223    }
224    // Set SSH_AUTH_SOCK to the agent socket path
225    cmd.env("SSH_AUTH_SOCK", &agent_socket_path);
226    // Set GRITTY_OPEN_SOCK so `gritty open` can find the open socket
227    cmd.env("GRITTY_OPEN_SOCK", &open_socket_path);
228    let mut managed = ManagedChild::new(unsafe {
229        cmd.pre_exec(move || {
230            nix::unistd::setsid().map_err(io::Error::other)?;
231            libc::ioctl(raw_stdin, libc::TIOCSCTTY as libc::c_ulong, 0);
232            Ok(())
233        })
234        .stdin(Stdio::from(stdin_fd))
235        .stdout(Stdio::from(stdout_fd))
236        .stderr(Stdio::from(stderr_fd))
237        .spawn()?
238    });
239
240    let shell_pid = managed.child.id().unwrap_or(0);
241    let created_at = std::time::SystemTime::now()
242        .duration_since(std::time::UNIX_EPOCH)
243        .unwrap_or_default()
244        .as_secs();
245
246    let _ = metadata_slot.set(SessionMetadata {
247        pty_path,
248        shell_pid,
249        created_at,
250        attached: AtomicBool::new(false),
251        last_heartbeat: AtomicU64::new(0),
252    });
253
254    // First client is already connected — enter relay directly
255    metadata_slot.get().unwrap().attached.store(true, Ordering::Relaxed);
256
257    // Agent forwarding state
258    let mut agent_forward_enabled = false;
259    let mut agent_channels: HashMap<u32, mpsc::UnboundedSender<Bytes>> = HashMap::new();
260    let mut agent_acceptor: Option<tokio::task::JoinHandle<()>> = None;
261    let next_agent_channel_id = Arc::new(AtomicU32::new(0));
262
263    // Open forwarding state
264    let mut open_forward_enabled = false;
265    let mut open_acceptor: Option<tokio::task::JoinHandle<()>> = None;
266    let (open_event_tx, mut open_event_rx) = mpsc::unbounded_channel::<OpenEvent>();
267
268    let teardown_forwarding =
269        |agent_channels: &mut HashMap<u32, mpsc::UnboundedSender<Bytes>>,
270         agent_forward_enabled: &mut bool,
271         agent_acceptor: &mut Option<tokio::task::JoinHandle<()>>,
272         open_forward_enabled: &mut bool,
273         open_acceptor: &mut Option<tokio::task::JoinHandle<()>>| {
274            agent_channels.clear();
275            *agent_forward_enabled = false;
276            if let Some(handle) = agent_acceptor.take() {
277                handle.abort();
278            }
279            cleanup_socket(&agent_socket_path);
280            *open_forward_enabled = false;
281            if let Some(handle) = open_acceptor.take() {
282                handle.abort();
283            }
284            cleanup_socket(&open_socket_path);
285        };
286
287    // Outer loop: accept clients via channel. PTY persists across reconnects.
288    // First iteration skips client-wait (first client already connected above).
289    let mut first_client = true;
290    loop {
291        if !first_client {
292            let got_client = 'drain: loop {
293                tokio::select! {
294                    client = client_rx.recv() => {
295                        match client {
296                            Some(f) => {
297                                info!("client connected via channel");
298                                framed = f;
299                                break 'drain true;
300                            }
301                            None => {
302                                info!("client channel closed");
303                                break 'drain false;
304                            }
305                        }
306                    }
307                    status = managed.child.wait() => {
308                        let code = status?.code().unwrap_or(1);
309                        info!(code, "shell exited while awaiting client");
310                        break 'drain false;
311                    }
312                    ready = async_master.readable() => {
313                        let mut guard = ready?;
314                        match guard.try_io(|inner| {
315                            nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
316                        }) {
317                            Ok(Ok(0)) => {
318                                debug!("pty EOF while disconnected");
319                                break 'drain false;
320                            }
321                            Ok(Ok(n)) => {
322                                let chunk = Bytes::copy_from_slice(&buf[..n]);
323                                ring_buf_size += chunk.len();
324                                ring_buf.push_back(chunk);
325                                while ring_buf_size > RING_BUF_CAP {
326                                    if let Some(old) = ring_buf.pop_front() {
327                                        ring_buf_size -= old.len();
328                                    }
329                                }
330                            }
331                            Ok(Err(e)) => {
332                                if e.raw_os_error() == Some(libc::EIO) {
333                                    debug!("pty EIO while disconnected");
334                                    break 'drain false;
335                                }
336                                return Err(e.into());
337                            }
338                            Err(_would_block) => continue,
339                        }
340                    }
341                }
342            };
343            if !got_client {
344                break;
345            }
346
347            if let Some(meta) = metadata_slot.get() {
348                meta.attached.store(true, Ordering::Relaxed);
349            }
350        }
351        first_client = false;
352
353        // Flush any buffered PTY output to the new client
354        if !ring_buf.is_empty() {
355            debug!(chunks = ring_buf.len(), bytes = ring_buf_size, "flushing ring buffer");
356            while let Some(chunk) = ring_buf.pop_front() {
357                framed.send(Frame::Data(chunk)).await?;
358            }
359            ring_buf_size = 0;
360        }
361
362        // Inner loop: relay between socket and PTY
363        let exit = loop {
364            tokio::select! {
365                frame = framed.next() => {
366                    match frame {
367                        Some(Ok(Frame::Data(data))) => {
368                            debug!(len = data.len(), "socket -> pty");
369                            let mut guard = async_master.writable().await?;
370                            match guard.try_io(|inner| {
371                                nix::unistd::write(inner, &data).map_err(io::Error::from)
372                            }) {
373                                Ok(Ok(_)) => {}
374                                Ok(Err(e)) => return Err(e.into()),
375                                Err(_would_block) => continue,
376                            }
377                        }
378                        Some(Ok(Frame::Resize { cols, rows })) => {
379                            let (cols, rows) = crate::security::clamp_winsize(cols, rows);
380                            debug!(cols, rows, "resize pty");
381                            let ws = libc::winsize {
382                                ws_row: rows,
383                                ws_col: cols,
384                                ws_xpixel: 0,
385                                ws_ypixel: 0,
386                            };
387                            unsafe {
388                                libc::ioctl(
389                                    async_master.as_raw_fd(),
390                                    libc::TIOCSWINSZ,
391                                    &ws as *const _,
392                                );
393                            }
394                            if let Ok(pgid) = nix::unistd::tcgetpgrp(&async_master) {
395                                let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGWINCH);
396                            }
397                        }
398                        Some(Ok(Frame::Ping)) => {
399                            if let Some(meta) = metadata_slot.get() {
400                                let now = std::time::SystemTime::now()
401                                    .duration_since(std::time::UNIX_EPOCH)
402                                    .unwrap_or_default()
403                                    .as_secs();
404                                meta.last_heartbeat.store(now, Ordering::Relaxed);
405                            }
406                            let _ = framed.send(Frame::Pong).await;
407                        }
408                        Some(Ok(Frame::AgentForward)) => {
409                            debug!("agent forwarding enabled by client");
410                            agent_forward_enabled = true;
411                            // Bind agent socket so SSH_AUTH_SOCK points to a live file
412                            if agent_acceptor.is_none() {
413                                if let Some(listener) = bind_agent_listener(&agent_socket_path) {
414                                    agent_acceptor = Some(spawn_agent_acceptor(listener, agent_event_tx.clone(), next_agent_channel_id.clone()));
415                                }
416                            }
417                        }
418                        Some(Ok(Frame::AgentData { channel_id, data })) => {
419                            if let Some(tx) = agent_channels.get(&channel_id) {
420                                let _ = tx.send(data);
421                            }
422                        }
423                        Some(Ok(Frame::AgentClose { channel_id })) => {
424                            // Drop the sender, writer task sees closed channel and exits
425                            agent_channels.remove(&channel_id);
426                        }
427                        Some(Ok(Frame::OpenForward)) => {
428                            debug!("open forwarding enabled by client");
429                            open_forward_enabled = true;
430                            if open_acceptor.is_none() {
431                                if let Some(listener) = bind_agent_listener(&open_socket_path) {
432                                    open_acceptor = Some(spawn_open_acceptor(listener, open_event_tx.clone()));
433                                }
434                            }
435                        }
436                        // Client disconnected or sent Exit
437                        Some(Ok(Frame::Exit { .. })) | None => {
438                            break RelayExit::ClientGone;
439                        }
440                        // Control frames ignored on session connections
441                        Some(Ok(_)) => {}
442                        Some(Err(e)) => return Err(e.into()),
443                    }
444                }
445
446                ready = async_master.readable() => {
447                    let mut guard = ready?;
448                    match guard.try_io(|inner| {
449                        nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
450                    }) {
451                        Ok(Ok(0)) => {
452                            debug!("pty EOF");
453                            break RelayExit::ShellExited(0);
454                        }
455                        Ok(Ok(n)) => {
456                            debug!(len = n, "pty -> socket");
457                            framed.send(Frame::Data(Bytes::copy_from_slice(&buf[..n]))).await?;
458                        }
459                        Ok(Err(e)) => {
460                            if e.raw_os_error() == Some(libc::EIO) {
461                                debug!("pty EIO (shell exited)");
462                                break RelayExit::ShellExited(0);
463                            }
464                            return Err(e.into());
465                        }
466                        Err(_would_block) => continue,
467                    }
468                }
469
470                // Client takeover via channel
471                new_client = client_rx.recv() => {
472                    if let Some(new_framed) = new_client {
473                        info!("new client via channel, detaching old client");
474                        let _ = framed.send(Frame::Detached).await;
475                        teardown_forwarding(
476                            &mut agent_channels,
477                            &mut agent_forward_enabled,
478                            &mut agent_acceptor,
479                            &mut open_forward_enabled,
480                            &mut open_acceptor,
481                        );
482                        framed = new_framed;
483                    }
484                }
485
486                // Agent events from acceptor/connection tasks
487                event = agent_event_rx.recv() => {
488                    match event {
489                        Some(AgentEvent::Accepted { channel_id, writer_tx }) => {
490                            if agent_forward_enabled {
491                                agent_channels.insert(channel_id, writer_tx);
492                                let _ = framed.send(Frame::AgentOpen { channel_id }).await;
493                            }
494                            // If forwarding not enabled, drop writer_tx (closes the connection)
495                        }
496                        Some(AgentEvent::Data { channel_id, data }) => {
497                            if agent_forward_enabled && agent_channels.contains_key(&channel_id) {
498                                let _ = framed.send(Frame::AgentData { channel_id, data }).await;
499                            }
500                        }
501                        Some(AgentEvent::Closed { channel_id }) => {
502                            if agent_channels.remove(&channel_id).is_some() {
503                                let _ = framed.send(Frame::AgentClose { channel_id }).await;
504                            }
505                        }
506                        None => {
507                            // Agent acceptor exited — not fatal
508                            debug!("agent event channel closed");
509                        }
510                    }
511                }
512
513                // Open URL events from open acceptor
514                event = open_event_rx.recv() => {
515                    match event {
516                        Some(OpenEvent::Url(url)) => {
517                            if open_forward_enabled {
518                                let _ = framed.send(Frame::OpenUrl { url }).await;
519                            }
520                        }
521                        None => {
522                            debug!("open event channel closed");
523                        }
524                    }
525                }
526
527                status = managed.child.wait() => {
528                    let code = status?.code().unwrap_or(1);
529                    info!(code, "shell exited");
530                    break RelayExit::ShellExited(code);
531                }
532            }
533        };
534
535        match exit {
536            RelayExit::ClientGone => {
537                if let Some(meta) = metadata_slot.get() {
538                    meta.attached.store(false, Ordering::Relaxed);
539                }
540                teardown_forwarding(
541                    &mut agent_channels,
542                    &mut agent_forward_enabled,
543                    &mut agent_acceptor,
544                    &mut open_forward_enabled,
545                    &mut open_acceptor,
546                );
547                info!("client disconnected, waiting for reconnect");
548                continue;
549            }
550            RelayExit::ShellExited(mut code) => {
551                // PTY EOF/EIO may fire before child.wait(), giving code=0.
552                // Try to get the real exit code from the child.
553                if let Ok(Some(status)) = managed.child.try_wait() {
554                    code = status.code().unwrap_or(code);
555                }
556                let _ = framed.send(Frame::Exit { code }).await;
557                info!(code, "session ended");
558                break;
559            }
560        }
561    }
562
563    cleanup_socket(&agent_socket_path);
564    cleanup_socket(&open_socket_path);
565    Ok(())
566}
567
568fn bind_agent_listener(path: &Path) -> Option<UnixListener> {
569    match crate::security::bind_unix_listener(path) {
570        Ok(listener) => {
571            info!(path = %path.display(), "agent socket listening");
572            Some(listener)
573        }
574        Err(e) => {
575            warn!("failed to bind agent socket at {}: {e}", path.display());
576            None
577        }
578    }
579}
580
581fn cleanup_socket(path: &Path) {
582    let _ = std::fs::remove_file(path);
583}