Skip to main content

gritty/
server.rs

1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::pty::openpty;
5use std::io;
6use std::os::fd::{AsRawFd, OwnedFd};
7use std::process::Stdio;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::{Arc, OnceLock};
10use tokio::io::unix::AsyncFd;
11use tokio::net::UnixStream;
12use tokio::process::Command;
13use tokio::sync::mpsc;
14use tokio_util::codec::Framed;
15use tracing::{debug, info};
16
17pub struct SessionMetadata {
18    pub pty_path: String,
19    pub shell_pid: u32,
20    pub created_at: u64,
21    pub attached: AtomicBool,
22    pub last_heartbeat: AtomicU64,
23}
24
25/// Wraps a child process and its process group ID.
26/// On drop, sends SIGHUP to the entire process group.
27struct ManagedChild {
28    child: tokio::process::Child,
29    pgid: nix::unistd::Pid,
30}
31
32impl ManagedChild {
33    fn new(child: tokio::process::Child) -> Self {
34        let pid = child.id().expect("child should have pid") as i32;
35        Self { child, pgid: nix::unistd::Pid::from_raw(pid) }
36    }
37}
38
39impl Drop for ManagedChild {
40    fn drop(&mut self) {
41        let _ = nix::sys::signal::killpg(self.pgid, nix::sys::signal::Signal::SIGHUP);
42        let _ = self.child.try_wait();
43    }
44}
45
46/// Why the relay loop exited.
47enum RelayExit {
48    /// Client disconnected — re-accept.
49    ClientGone,
50    /// Shell exited with a code — we're done.
51    ShellExited(i32),
52}
53
54pub async fn run(
55    mut client_rx: mpsc::UnboundedReceiver<Framed<UnixStream, FrameCodec>>,
56    metadata_slot: Arc<OnceLock<SessionMetadata>>,
57) -> anyhow::Result<()> {
58    // Allocate PTY (once, before accept loop)
59    let pty = openpty(None, None)?;
60    let master: OwnedFd = pty.master;
61    let slave: OwnedFd = pty.slave;
62
63    // Get PTY slave name before we drop the slave fd
64    let pty_path =
65        nix::unistd::ttyname(&slave).map(|p| p.display().to_string()).unwrap_or_default();
66
67    // Dup slave fds for shell stdio (before dropping slave)
68    let slave_fd = slave.as_raw_fd();
69    let stdin_fd = crate::security::checked_dup(slave_fd)?;
70    let stdout_fd = crate::security::checked_dup(slave_fd)?;
71    let stderr_fd = crate::security::checked_dup(slave_fd)?;
72    let raw_stdin = stdin_fd.as_raw_fd();
73    drop(slave);
74
75    // Set master to non-blocking for AsyncFd
76    let flags = nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_GETFL)?;
77    let mut oflags = nix::fcntl::OFlag::from_bits_truncate(flags);
78    oflags |= nix::fcntl::OFlag::O_NONBLOCK;
79    nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_SETFL(oflags))?;
80
81    let async_master = AsyncFd::new(master)?;
82    let mut buf = vec![0u8; 4096];
83
84    // Wait for first client before spawning shell (so we can read Env frame)
85    let mut framed = match client_rx.recv().await {
86        Some(framed) => {
87            info!("first client connected via channel");
88            framed
89        }
90        None => {
91            info!("client channel closed before first client");
92            return Ok(());
93        }
94    };
95
96    // Read optional Env frame from first client (100ms timeout)
97    let env_vars =
98        match tokio::time::timeout(std::time::Duration::from_millis(100), framed.next()).await {
99            Ok(Some(Ok(Frame::Env { vars }))) => {
100                debug!(count = vars.len(), "received env vars from client");
101                vars
102            }
103            _ => Vec::new(),
104        };
105
106    // Spawn login shell on slave PTY
107    let shell = std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string());
108    let home = std::env::var("HOME").ok();
109    let mut cmd = Command::new(&shell);
110    cmd.arg("-l");
111    if let Some(ref dir) = home {
112        cmd.current_dir(dir);
113    }
114    for (k, v) in &env_vars {
115        cmd.env(k, v);
116    }
117    let mut managed = ManagedChild::new(unsafe {
118        cmd.pre_exec(move || {
119            nix::unistd::setsid().map_err(io::Error::other)?;
120            libc::ioctl(raw_stdin, libc::TIOCSCTTY as libc::c_ulong, 0);
121            Ok(())
122        })
123        .stdin(Stdio::from(stdin_fd))
124        .stdout(Stdio::from(stdout_fd))
125        .stderr(Stdio::from(stderr_fd))
126        .spawn()?
127    });
128
129    let shell_pid = managed.child.id().unwrap_or(0);
130    let created_at = std::time::SystemTime::now()
131        .duration_since(std::time::UNIX_EPOCH)
132        .unwrap_or_default()
133        .as_secs();
134
135    let _ = metadata_slot.set(SessionMetadata {
136        pty_path,
137        shell_pid,
138        created_at,
139        attached: AtomicBool::new(false),
140        last_heartbeat: AtomicU64::new(0),
141    });
142
143    // First client is already connected — enter relay directly
144    metadata_slot.get().unwrap().attached.store(true, Ordering::Relaxed);
145
146    // Outer loop: accept clients via channel. PTY persists across reconnects.
147    // First iteration skips client-wait (first client already connected above).
148    let mut first_client = true;
149    loop {
150        if !first_client {
151            framed = tokio::select! {
152                client = client_rx.recv() => {
153                    match client {
154                        Some(f) => {
155                            info!("client connected via channel");
156                            f
157                        }
158                        None => {
159                            info!("client channel closed");
160                            break;
161                        }
162                    }
163                }
164                status = managed.child.wait() => {
165                    let code = status?.code().unwrap_or(1);
166                    info!(code, "shell exited while awaiting client");
167                    break;
168                }
169            };
170
171            if let Some(meta) = metadata_slot.get() {
172                meta.attached.store(true, Ordering::Relaxed);
173            }
174        }
175        first_client = false;
176
177        // Inner loop: relay between socket and PTY
178        let exit = loop {
179            tokio::select! {
180                frame = framed.next() => {
181                    match frame {
182                        Some(Ok(Frame::Data(data))) => {
183                            debug!(len = data.len(), "socket -> pty");
184                            let mut guard = async_master.writable().await?;
185                            match guard.try_io(|inner| {
186                                nix::unistd::write(inner, &data).map_err(io::Error::from)
187                            }) {
188                                Ok(Ok(_)) => {}
189                                Ok(Err(e)) => return Err(e.into()),
190                                Err(_would_block) => continue,
191                            }
192                        }
193                        Some(Ok(Frame::Resize { cols, rows })) => {
194                            let (cols, rows) = crate::security::clamp_winsize(cols, rows);
195                            debug!(cols, rows, "resize pty");
196                            let ws = libc::winsize {
197                                ws_row: rows,
198                                ws_col: cols,
199                                ws_xpixel: 0,
200                                ws_ypixel: 0,
201                            };
202                            unsafe {
203                                libc::ioctl(
204                                    async_master.as_raw_fd(),
205                                    libc::TIOCSWINSZ,
206                                    &ws as *const _,
207                                );
208                            }
209                            if let Ok(pgid) = nix::unistd::tcgetpgrp(&async_master) {
210                                let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGWINCH);
211                            }
212                        }
213                        Some(Ok(Frame::Ping)) => {
214                            if let Some(meta) = metadata_slot.get() {
215                                let now = std::time::SystemTime::now()
216                                    .duration_since(std::time::UNIX_EPOCH)
217                                    .unwrap_or_default()
218                                    .as_secs();
219                                meta.last_heartbeat.store(now, Ordering::Relaxed);
220                            }
221                            let _ = framed.send(Frame::Pong).await;
222                        }
223                        // Client disconnected or sent Exit
224                        Some(Ok(Frame::Exit { .. })) | None => {
225                            break RelayExit::ClientGone;
226                        }
227                        // Control frames ignored on session connections
228                        Some(Ok(_)) => {}
229                        Some(Err(e)) => return Err(e.into()),
230                    }
231                }
232
233                ready = async_master.readable() => {
234                    let mut guard = ready?;
235                    match guard.try_io(|inner| {
236                        nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
237                    }) {
238                        Ok(Ok(0)) => {
239                            debug!("pty EOF");
240                            break RelayExit::ShellExited(0);
241                        }
242                        Ok(Ok(n)) => {
243                            debug!(len = n, "pty -> socket");
244                            framed.send(Frame::Data(Bytes::copy_from_slice(&buf[..n]))).await?;
245                        }
246                        Ok(Err(e)) => {
247                            if e.raw_os_error() == Some(libc::EIO) {
248                                debug!("pty EIO (shell exited)");
249                                break RelayExit::ShellExited(0);
250                            }
251                            return Err(e.into());
252                        }
253                        Err(_would_block) => continue,
254                    }
255                }
256
257                // Client takeover via channel
258                new_client = client_rx.recv() => {
259                    if let Some(new_framed) = new_client {
260                        info!("new client via channel, detaching old client");
261                        let _ = framed.send(Frame::Detached).await;
262                        framed = new_framed;
263                    }
264                }
265
266                status = managed.child.wait() => {
267                    let code = status?.code().unwrap_or(1);
268                    info!(code, "shell exited");
269                    break RelayExit::ShellExited(code);
270                }
271            }
272        };
273
274        match exit {
275            RelayExit::ClientGone => {
276                if let Some(meta) = metadata_slot.get() {
277                    meta.attached.store(false, Ordering::Relaxed);
278                }
279                info!("client disconnected, waiting for reconnect");
280                continue;
281            }
282            RelayExit::ShellExited(mut code) => {
283                // PTY EOF/EIO may fire before child.wait(), giving code=0.
284                // Try to get the real exit code from the child.
285                if let Ok(Some(status)) = managed.child.try_wait() {
286                    code = status.code().unwrap_or(code);
287                }
288                let _ = framed.send(Frame::Exit { code }).await;
289                info!(code, "session ended");
290                break;
291            }
292        }
293    }
294
295    Ok(())
296}