1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::pty::openpty;
5use std::collections::{HashMap, VecDeque};
6use std::io;
7use std::os::fd::{AsRawFd, OwnedFd};
8use std::path::{Path, PathBuf};
9use std::process::Stdio;
10use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
11use std::sync::{Arc, OnceLock};
12use tokio::io::AsyncReadExt;
13use tokio::io::unix::AsyncFd;
14use tokio::net::{UnixListener, UnixStream};
15use tokio::process::Command;
16use tokio::sync::{broadcast, mpsc};
17use tokio_util::codec::Framed;
18use tracing::{debug, info, warn};
19
20pub enum ClientConn {
22 Active(Framed<UnixStream, FrameCodec>),
23 Tail(Framed<UnixStream, FrameCodec>),
24}
25
26#[derive(Clone)]
28enum TailEvent {
29 Data(Bytes),
30 Exit { code: i32 },
31}
32
33pub struct SessionMetadata {
34 pub pty_path: String,
35 pub shell_pid: u32,
36 pub created_at: u64,
37 pub attached: AtomicBool,
38 pub last_heartbeat: AtomicU64,
39}
40
41struct ManagedChild {
44 child: tokio::process::Child,
45 pgid: nix::unistd::Pid,
46}
47
48impl ManagedChild {
49 fn new(child: tokio::process::Child) -> Self {
50 let pid = child.id().expect("child should have pid") as i32;
51 Self { child, pgid: nix::unistd::Pid::from_raw(pid) }
52 }
53}
54
55impl Drop for ManagedChild {
56 fn drop(&mut self) {
57 let _ = nix::sys::signal::killpg(self.pgid, nix::sys::signal::Signal::SIGHUP);
58 let _ = self.child.try_wait();
59 }
60}
61
62enum RelayExit {
64 ClientGone,
66 ShellExited(i32),
68}
69
70enum AgentEvent {
72 Accepted { channel_id: u32, writer_tx: mpsc::UnboundedSender<Bytes> },
73 Data { channel_id: u32, data: Bytes },
74 Closed { channel_id: u32 },
75}
76
77enum OpenEvent {
79 Url(String),
80}
81
82fn spawn_agent_acceptor(
85 listener: UnixListener,
86 event_tx: mpsc::UnboundedSender<AgentEvent>,
87 next_channel_id: Arc<AtomicU32>,
88) -> tokio::task::JoinHandle<()> {
89 tokio::spawn(async move {
90 loop {
91 let (stream, _) = match listener.accept().await {
92 Ok(conn) => conn,
93 Err(e) => {
94 debug!("agent listener accept error: {e}");
95 break;
96 }
97 };
98
99 if let Err(e) = crate::security::verify_peer_uid(&stream) {
100 warn!("agent socket: {e}");
101 continue;
102 }
103
104 let channel_id = next_channel_id.fetch_add(1, Ordering::Relaxed);
105
106 let (read_half, write_half) = stream.into_split();
107 let data_tx = event_tx.clone();
108 let close_tx = event_tx.clone();
109 let writer_tx = crate::spawn_channel_relay(
110 channel_id,
111 read_half,
112 write_half,
113 move |id, data| data_tx.send(AgentEvent::Data { channel_id: id, data }).is_ok(),
114 move |id| {
115 let _ = close_tx.send(AgentEvent::Closed { channel_id: id });
116 },
117 );
118
119 if event_tx.send(AgentEvent::Accepted { channel_id, writer_tx }).is_err() {
121 break; }
123 }
124 })
125}
126
127fn spawn_open_acceptor(
130 listener: UnixListener,
131 event_tx: mpsc::UnboundedSender<OpenEvent>,
132) -> tokio::task::JoinHandle<()> {
133 tokio::spawn(async move {
134 loop {
135 let (mut stream, _) = match listener.accept().await {
136 Ok(conn) => conn,
137 Err(e) => {
138 debug!("open listener accept error: {e}");
139 break;
140 }
141 };
142
143 match crate::security::verify_peer_uid(&stream) {
148 Ok(()) => {}
149 Err(e) if e.kind() == std::io::ErrorKind::PermissionDenied => {
150 warn!("open socket: {e}");
151 continue;
152 }
153 Err(e) => {
154 debug!("open socket peer_cred unavailable: {e}");
155 }
156 }
157
158 let etx = event_tx.clone();
159 tokio::spawn(async move {
160 let mut buf = vec![0u8; 8192];
161 let mut total = 0;
162 loop {
163 match stream.read(&mut buf[total..]).await {
164 Ok(0) => break,
165 Ok(n) => {
166 total += n;
167 if buf[..total].contains(&b'\n') || total >= buf.len() {
169 break;
170 }
171 }
172 Err(_) => return,
173 }
174 }
175 let s = String::from_utf8_lossy(&buf[..total]);
176 let url = s.trim();
177 if !url.is_empty() {
178 let _ = etx.send(OpenEvent::Url(url.to_string()));
179 }
180 });
181 }
182 })
183}
184
185async fn tail_relay(
187 mut framed: Framed<UnixStream, FrameCodec>,
188 mut rx: broadcast::Receiver<TailEvent>,
189) {
190 loop {
191 tokio::select! {
192 event = rx.recv() => match event {
193 Ok(TailEvent::Data(chunk)) => {
194 if framed.send(Frame::Data(chunk)).await.is_err() { break; }
195 }
196 Ok(TailEvent::Exit { code }) => {
197 let _ = framed.send(Frame::Exit { code }).await;
198 break;
199 }
200 Err(broadcast::error::RecvError::Lagged(_)) => continue,
201 Err(broadcast::error::RecvError::Closed) => break,
202 },
203 frame = framed.next() => match frame {
204 Some(Ok(Frame::Ping)) => { let _ = framed.send(Frame::Pong).await; }
205 _ => break,
206 },
207 }
208 }
209}
210
211fn spawn_tail(
213 mut framed: Framed<UnixStream, FrameCodec>,
214 ring_buf: &VecDeque<Bytes>,
215 tail_tx: &broadcast::Sender<TailEvent>,
216) {
217 let rx = tail_tx.subscribe();
218 let chunks: Vec<Bytes> = ring_buf.iter().cloned().collect();
219 tokio::spawn(async move {
220 for chunk in chunks {
221 if framed.send(Frame::Data(chunk)).await.is_err() {
222 return;
223 }
224 }
225 tail_relay(framed, rx).await;
226 });
227}
228
229pub async fn run(
230 mut client_rx: mpsc::UnboundedReceiver<ClientConn>,
231 metadata_slot: Arc<OnceLock<SessionMetadata>>,
232 agent_socket_path: PathBuf,
233 open_socket_path: PathBuf,
234) -> anyhow::Result<()> {
235 let pty = openpty(None, None)?;
237 let master: OwnedFd = pty.master;
238 let slave: OwnedFd = pty.slave;
239
240 let pty_path =
242 nix::unistd::ttyname(&slave).map(|p| p.display().to_string()).unwrap_or_default();
243
244 let slave_fd = slave.as_raw_fd();
246 let stdin_fd = crate::security::checked_dup(slave_fd)?;
247 let stdout_fd = crate::security::checked_dup(slave_fd)?;
248 let stderr_fd = crate::security::checked_dup(slave_fd)?;
249 let raw_stdin = stdin_fd.as_raw_fd();
250 drop(slave);
251
252 let flags = nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_GETFL)?;
254 let mut oflags = nix::fcntl::OFlag::from_bits_truncate(flags);
255 oflags |= nix::fcntl::OFlag::O_NONBLOCK;
256 nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_SETFL(oflags))?;
257
258 let async_master = AsyncFd::new(master)?;
259 let mut buf = vec![0u8; 4096];
260 let mut ring_buf: VecDeque<Bytes> = VecDeque::new();
261 let mut ring_buf_size: usize = 0;
262 const RING_BUF_CAP: usize = 1 << 20; let (agent_event_tx, mut agent_event_rx) = mpsc::unbounded_channel::<AgentEvent>();
266
267 let (tail_tx, _) = broadcast::channel::<TailEvent>(256);
269
270 let mut framed = loop {
274 match client_rx.recv().await {
275 Some(ClientConn::Active(f)) => {
276 info!("first client connected via channel");
277 break f;
278 }
279 Some(ClientConn::Tail(f)) => {
280 info!("tail client connected before shell spawn");
281 spawn_tail(f, &ring_buf, &tail_tx);
282 continue;
283 }
284 None => {
285 info!("client channel closed before first client");
286 cleanup_socket(&agent_socket_path);
287 return Ok(());
288 }
289 }
290 };
291
292 let env_vars =
294 match tokio::time::timeout(std::time::Duration::from_millis(100), framed.next()).await {
295 Ok(Some(Ok(Frame::Env { vars }))) => {
296 debug!(count = vars.len(), "received env vars from client");
297 vars
298 }
299 _ => Vec::new(),
300 };
301
302 let shell = std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string());
304 let home = std::env::var("HOME").ok();
305 let mut cmd = Command::new(&shell);
306 cmd.arg("-l");
307 if let Some(ref dir) = home {
308 cmd.current_dir(dir);
309 }
310 const ALLOWED_ENV_KEYS: &[&str] = &["TERM", "LANG", "COLORTERM", "BROWSER"];
311 for (k, v) in &env_vars {
312 if ALLOWED_ENV_KEYS.contains(&k.as_str()) {
313 cmd.env(k, v);
314 } else {
315 warn!(key = k, "ignoring disallowed env var from client");
316 }
317 }
318 cmd.env("SSH_AUTH_SOCK", &agent_socket_path);
320 cmd.env("GRITTY_OPEN_SOCK", &open_socket_path);
322 let mut managed = ManagedChild::new(unsafe {
323 cmd.pre_exec(move || {
324 nix::unistd::setsid().map_err(io::Error::other)?;
325 libc::ioctl(raw_stdin, libc::TIOCSCTTY as libc::c_ulong, 0);
326 Ok(())
327 })
328 .stdin(Stdio::from(stdin_fd))
329 .stdout(Stdio::from(stdout_fd))
330 .stderr(Stdio::from(stderr_fd))
331 .spawn()?
332 });
333
334 let shell_pid = managed.child.id().unwrap_or(0);
335 let created_at = std::time::SystemTime::now()
336 .duration_since(std::time::UNIX_EPOCH)
337 .unwrap_or_default()
338 .as_secs();
339
340 let _ = metadata_slot.set(SessionMetadata {
341 pty_path,
342 shell_pid,
343 created_at,
344 attached: AtomicBool::new(false),
345 last_heartbeat: AtomicU64::new(0),
346 });
347
348 metadata_slot.get().unwrap().attached.store(true, Ordering::Relaxed);
350
351 let mut agent_forward_enabled = false;
353 let mut agent_channels: HashMap<u32, mpsc::UnboundedSender<Bytes>> = HashMap::new();
354 let mut agent_acceptor: Option<tokio::task::JoinHandle<()>> = None;
355 let next_agent_channel_id = Arc::new(AtomicU32::new(0));
356
357 let mut open_forward_enabled = false;
359 let mut open_acceptor: Option<tokio::task::JoinHandle<()>> = None;
360 let (open_event_tx, mut open_event_rx) = mpsc::unbounded_channel::<OpenEvent>();
361
362 let teardown_forwarding =
363 |agent_channels: &mut HashMap<u32, mpsc::UnboundedSender<Bytes>>,
364 agent_forward_enabled: &mut bool,
365 agent_acceptor: &mut Option<tokio::task::JoinHandle<()>>,
366 open_forward_enabled: &mut bool,
367 open_acceptor: &mut Option<tokio::task::JoinHandle<()>>| {
368 agent_channels.clear();
369 *agent_forward_enabled = false;
370 if let Some(handle) = agent_acceptor.take() {
371 handle.abort();
372 }
373 cleanup_socket(&agent_socket_path);
374 *open_forward_enabled = false;
375 if let Some(handle) = open_acceptor.take() {
376 handle.abort();
377 }
378 cleanup_socket(&open_socket_path);
379 };
380
381 let mut first_client = true;
384 loop {
385 if !first_client {
386 let got_client = 'drain: loop {
387 tokio::select! {
388 client = client_rx.recv() => {
389 match client {
390 Some(ClientConn::Active(f)) => {
391 info!("client connected via channel");
392 framed = f;
393 break 'drain true;
394 }
395 Some(ClientConn::Tail(f)) => {
396 info!("tail client connected while disconnected");
397 spawn_tail(f, &ring_buf, &tail_tx);
398 continue;
399 }
400 None => {
401 info!("client channel closed");
402 break 'drain false;
403 }
404 }
405 }
406 status = managed.child.wait() => {
407 let code = status?.code().unwrap_or(1);
408 info!(code, "shell exited while awaiting client");
409 break 'drain false;
410 }
411 ready = async_master.readable() => {
412 let mut guard = ready?;
413 match guard.try_io(|inner| {
414 nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
415 }) {
416 Ok(Ok(0)) => {
417 debug!("pty EOF while disconnected");
418 break 'drain false;
419 }
420 Ok(Ok(n)) => {
421 let chunk = Bytes::copy_from_slice(&buf[..n]);
422 let _ = tail_tx.send(TailEvent::Data(chunk.clone()));
423 ring_buf_size += chunk.len();
424 ring_buf.push_back(chunk);
425 while ring_buf_size > RING_BUF_CAP {
426 if let Some(old) = ring_buf.pop_front() {
427 ring_buf_size -= old.len();
428 }
429 }
430 }
431 Ok(Err(e)) => {
432 if e.raw_os_error() == Some(libc::EIO) {
433 debug!("pty EIO while disconnected");
434 break 'drain false;
435 }
436 return Err(e.into());
437 }
438 Err(_would_block) => continue,
439 }
440 }
441 }
442 };
443 if !got_client {
444 break;
445 }
446
447 if let Some(meta) = metadata_slot.get() {
448 meta.attached.store(true, Ordering::Relaxed);
449 }
450 }
451 first_client = false;
452
453 if !ring_buf.is_empty() {
455 debug!(chunks = ring_buf.len(), bytes = ring_buf_size, "flushing ring buffer");
456 while let Some(chunk) = ring_buf.pop_front() {
457 framed.send(Frame::Data(chunk)).await?;
458 }
459 ring_buf_size = 0;
460 }
461
462 let exit = loop {
464 tokio::select! {
465 frame = framed.next() => {
466 match frame {
467 Some(Ok(Frame::Data(data))) => {
468 debug!(len = data.len(), "socket -> pty");
469 let mut guard = async_master.writable().await?;
470 match guard.try_io(|inner| {
471 nix::unistd::write(inner, &data).map_err(io::Error::from)
472 }) {
473 Ok(Ok(_)) => {}
474 Ok(Err(e)) => return Err(e.into()),
475 Err(_would_block) => continue,
476 }
477 }
478 Some(Ok(Frame::Resize { cols, rows })) => {
479 let (cols, rows) = crate::security::clamp_winsize(cols, rows);
480 debug!(cols, rows, "resize pty");
481 let ws = libc::winsize {
482 ws_row: rows,
483 ws_col: cols,
484 ws_xpixel: 0,
485 ws_ypixel: 0,
486 };
487 unsafe {
488 libc::ioctl(
489 async_master.as_raw_fd(),
490 libc::TIOCSWINSZ,
491 &ws as *const _,
492 );
493 }
494 if let Ok(pgid) = nix::unistd::tcgetpgrp(&async_master) {
495 let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGWINCH);
496 }
497 }
498 Some(Ok(Frame::Ping)) => {
499 if let Some(meta) = metadata_slot.get() {
500 let now = std::time::SystemTime::now()
501 .duration_since(std::time::UNIX_EPOCH)
502 .unwrap_or_default()
503 .as_secs();
504 meta.last_heartbeat.store(now, Ordering::Relaxed);
505 }
506 let _ = framed.send(Frame::Pong).await;
507 }
508 Some(Ok(Frame::AgentForward)) => {
509 debug!("agent forwarding enabled by client");
510 agent_forward_enabled = true;
511 if agent_acceptor.is_none() {
513 if let Some(listener) = bind_agent_listener(&agent_socket_path) {
514 agent_acceptor = Some(spawn_agent_acceptor(listener, agent_event_tx.clone(), next_agent_channel_id.clone()));
515 }
516 }
517 }
518 Some(Ok(Frame::AgentData { channel_id, data })) => {
519 if let Some(tx) = agent_channels.get(&channel_id) {
520 let _ = tx.send(data);
521 }
522 }
523 Some(Ok(Frame::AgentClose { channel_id })) => {
524 agent_channels.remove(&channel_id);
526 }
527 Some(Ok(Frame::OpenForward)) => {
528 debug!("open forwarding enabled by client");
529 open_forward_enabled = true;
530 if open_acceptor.is_none() {
531 if let Some(listener) = bind_agent_listener(&open_socket_path) {
532 open_acceptor = Some(spawn_open_acceptor(listener, open_event_tx.clone()));
533 }
534 }
535 }
536 Some(Ok(Frame::Exit { .. })) | None => {
538 break RelayExit::ClientGone;
539 }
540 Some(Ok(_)) => {}
542 Some(Err(e)) => return Err(e.into()),
543 }
544 }
545
546 ready = async_master.readable() => {
547 let mut guard = ready?;
548 match guard.try_io(|inner| {
549 nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
550 }) {
551 Ok(Ok(0)) => {
552 debug!("pty EOF");
553 break RelayExit::ShellExited(0);
554 }
555 Ok(Ok(n)) => {
556 debug!(len = n, "pty -> socket");
557 let chunk = Bytes::copy_from_slice(&buf[..n]);
558 let _ = tail_tx.send(TailEvent::Data(chunk.clone()));
559 framed.send(Frame::Data(chunk)).await?;
560 }
561 Ok(Err(e)) => {
562 if e.raw_os_error() == Some(libc::EIO) {
563 debug!("pty EIO (shell exited)");
564 break RelayExit::ShellExited(0);
565 }
566 return Err(e.into());
567 }
568 Err(_would_block) => continue,
569 }
570 }
571
572 new_client = client_rx.recv() => {
574 match new_client {
575 Some(ClientConn::Active(new_framed)) => {
576 info!("new client via channel, detaching old client");
577 let _ = framed.send(Frame::Detached).await;
578 teardown_forwarding(
579 &mut agent_channels,
580 &mut agent_forward_enabled,
581 &mut agent_acceptor,
582 &mut open_forward_enabled,
583 &mut open_acceptor,
584 );
585 framed = new_framed;
586 }
587 Some(ClientConn::Tail(f)) => {
588 info!("tail client connected while active");
589 spawn_tail(f, &ring_buf, &tail_tx);
590 }
591 None => {}
592 }
593 }
594
595 event = agent_event_rx.recv() => {
597 match event {
598 Some(AgentEvent::Accepted { channel_id, writer_tx }) => {
599 if agent_forward_enabled {
600 agent_channels.insert(channel_id, writer_tx);
601 let _ = framed.send(Frame::AgentOpen { channel_id }).await;
602 }
603 }
605 Some(AgentEvent::Data { channel_id, data }) => {
606 if agent_forward_enabled && agent_channels.contains_key(&channel_id) {
607 let _ = framed.send(Frame::AgentData { channel_id, data }).await;
608 }
609 }
610 Some(AgentEvent::Closed { channel_id }) => {
611 if agent_channels.remove(&channel_id).is_some() {
612 let _ = framed.send(Frame::AgentClose { channel_id }).await;
613 }
614 }
615 None => {
616 debug!("agent event channel closed");
618 }
619 }
620 }
621
622 event = open_event_rx.recv() => {
624 match event {
625 Some(OpenEvent::Url(url)) => {
626 if open_forward_enabled {
627 let _ = framed.send(Frame::OpenUrl { url }).await;
628 }
629 }
630 None => {
631 debug!("open event channel closed");
632 }
633 }
634 }
635
636 status = managed.child.wait() => {
637 let code = status?.code().unwrap_or(1);
638 info!(code, "shell exited");
639 break RelayExit::ShellExited(code);
640 }
641 }
642 };
643
644 match exit {
645 RelayExit::ClientGone => {
646 if let Some(meta) = metadata_slot.get() {
647 meta.attached.store(false, Ordering::Relaxed);
648 }
649 teardown_forwarding(
650 &mut agent_channels,
651 &mut agent_forward_enabled,
652 &mut agent_acceptor,
653 &mut open_forward_enabled,
654 &mut open_acceptor,
655 );
656 info!("client disconnected, waiting for reconnect");
657 continue;
658 }
659 RelayExit::ShellExited(mut code) => {
660 if let Ok(Some(status)) = managed.child.try_wait() {
663 code = status.code().unwrap_or(code);
664 }
665 let _ = tail_tx.send(TailEvent::Exit { code });
666 let _ = framed.send(Frame::Exit { code }).await;
667 info!(code, "session ended");
668 break;
669 }
670 }
671 }
672
673 cleanup_socket(&agent_socket_path);
674 cleanup_socket(&open_socket_path);
675 Ok(())
676}
677
678fn bind_agent_listener(path: &Path) -> Option<UnixListener> {
679 match crate::security::bind_unix_listener(path) {
680 Ok(listener) => {
681 info!(path = %path.display(), "agent socket listening");
682 Some(listener)
683 }
684 Err(e) => {
685 warn!("failed to bind agent socket at {}: {e}", path.display());
686 None
687 }
688 }
689}
690
691fn cleanup_socket(path: &Path) {
692 let _ = std::fs::remove_file(path);
693}