1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::pty::openpty;
5use std::collections::{HashMap, VecDeque};
6use std::io;
7use std::os::fd::{AsRawFd, OwnedFd};
8use std::path::{Path, PathBuf};
9use std::process::Stdio;
10use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
11use std::sync::{Arc, OnceLock};
12use tokio::io::AsyncReadExt;
13use tokio::io::unix::AsyncFd;
14use tokio::net::{UnixListener, UnixStream};
15use tokio::process::Command;
16use tokio::sync::{broadcast, mpsc};
17use tokio_util::codec::Framed;
18use tracing::{debug, info, warn};
19
20pub enum ClientConn {
22 Active(Framed<UnixStream, FrameCodec>),
23 Tail(Framed<UnixStream, FrameCodec>),
24}
25
26#[derive(Clone)]
28enum TailEvent {
29 Data(Bytes),
30 Exit { code: i32 },
31}
32
33pub struct SessionMetadata {
34 pub pty_path: String,
35 pub shell_pid: u32,
36 pub created_at: u64,
37 pub attached: AtomicBool,
38 pub last_heartbeat: AtomicU64,
39}
40
41struct ManagedChild {
44 child: tokio::process::Child,
45 pgid: nix::unistd::Pid,
46}
47
48impl ManagedChild {
49 fn new(child: tokio::process::Child) -> Self {
50 let pid = child.id().expect("child should have pid") as i32;
51 Self { child, pgid: nix::unistd::Pid::from_raw(pid) }
52 }
53}
54
55impl Drop for ManagedChild {
56 fn drop(&mut self) {
57 let _ = nix::sys::signal::killpg(self.pgid, nix::sys::signal::Signal::SIGHUP);
58 let _ = self.child.try_wait();
59 }
60}
61
62enum RelayExit {
64 ClientGone,
66 ShellExited(i32),
68}
69
70enum AgentEvent {
72 Accepted { channel_id: u32, writer_tx: mpsc::UnboundedSender<Bytes> },
73 Data { channel_id: u32, data: Bytes },
74 Closed { channel_id: u32 },
75}
76
77enum OpenEvent {
79 Url(String),
80}
81
82fn spawn_agent_acceptor(
85 listener: UnixListener,
86 event_tx: mpsc::UnboundedSender<AgentEvent>,
87 next_channel_id: Arc<AtomicU32>,
88) -> tokio::task::JoinHandle<()> {
89 tokio::spawn(async move {
90 loop {
91 let (stream, _) = match listener.accept().await {
92 Ok(conn) => conn,
93 Err(e) => {
94 debug!("agent listener accept error: {e}");
95 break;
96 }
97 };
98
99 if let Err(e) = crate::security::verify_peer_uid(&stream) {
100 warn!("agent socket: {e}");
101 continue;
102 }
103
104 let channel_id = next_channel_id.fetch_add(1, Ordering::Relaxed);
105
106 let (read_half, write_half) = stream.into_split();
107 let data_tx = event_tx.clone();
108 let close_tx = event_tx.clone();
109 let writer_tx = crate::spawn_channel_relay(
110 channel_id,
111 read_half,
112 write_half,
113 move |id, data| data_tx.send(AgentEvent::Data { channel_id: id, data }).is_ok(),
114 move |id| {
115 let _ = close_tx.send(AgentEvent::Closed { channel_id: id });
116 },
117 );
118
119 if event_tx.send(AgentEvent::Accepted { channel_id, writer_tx }).is_err() {
121 break; }
123 }
124 })
125}
126
127const URL_MAX_LEN: usize = 4096;
129
130fn spawn_open_acceptor(
133 listener: UnixListener,
134 event_tx: mpsc::UnboundedSender<OpenEvent>,
135) -> tokio::task::JoinHandle<()> {
136 tokio::spawn(async move {
137 loop {
138 let (mut stream, _) = match listener.accept().await {
139 Ok(conn) => conn,
140 Err(e) => {
141 debug!("open listener accept error: {e}");
142 break;
143 }
144 };
145
146 match crate::security::verify_peer_uid(&stream) {
151 Ok(()) => {}
152 Err(e) if e.kind() == std::io::ErrorKind::PermissionDenied => {
153 warn!("open socket: {e}");
154 continue;
155 }
156 Err(e) => {
157 debug!("open socket peer_cred unavailable: {e}");
163 }
164 }
165
166 let etx = event_tx.clone();
167 tokio::spawn(async move {
168 let mut buf = vec![0u8; URL_MAX_LEN];
169 let mut total = 0;
170 loop {
171 match stream.read(&mut buf[total..]).await {
172 Ok(0) => break,
173 Ok(n) => {
174 total += n;
175 if buf[..total].contains(&b'\n') || total >= buf.len() {
177 break;
178 }
179 }
180 Err(_) => return,
181 }
182 }
183 let s = String::from_utf8_lossy(&buf[..total]);
184 let url = s.trim();
185 if !url.is_empty() {
186 let _ = etx.send(OpenEvent::Url(url.to_string()));
187 }
188 });
189 }
190 })
191}
192
193async fn tail_relay(
195 mut framed: Framed<UnixStream, FrameCodec>,
196 mut rx: broadcast::Receiver<TailEvent>,
197) {
198 loop {
199 tokio::select! {
200 event = rx.recv() => match event {
201 Ok(TailEvent::Data(chunk)) => {
202 if framed.send(Frame::Data(chunk)).await.is_err() { break; }
203 }
204 Ok(TailEvent::Exit { code }) => {
205 let _ = framed.send(Frame::Exit { code }).await;
206 break;
207 }
208 Err(broadcast::error::RecvError::Lagged(_)) => continue,
209 Err(broadcast::error::RecvError::Closed) => break,
210 },
211 frame = framed.next() => match frame {
212 Some(Ok(Frame::Ping)) => { let _ = framed.send(Frame::Pong).await; }
213 _ => break,
214 },
215 }
216 }
217}
218
219fn spawn_tail(
221 mut framed: Framed<UnixStream, FrameCodec>,
222 ring_buf: &VecDeque<Bytes>,
223 tail_tx: &broadcast::Sender<TailEvent>,
224) {
225 let rx = tail_tx.subscribe();
226 let chunks: Vec<Bytes> = ring_buf.iter().cloned().collect();
227 tokio::spawn(async move {
228 for chunk in chunks {
229 if framed.send(Frame::Data(chunk)).await.is_err() {
230 return;
231 }
232 }
233 tail_relay(framed, rx).await;
234 });
235}
236
237pub async fn run(
238 mut client_rx: mpsc::UnboundedReceiver<ClientConn>,
239 metadata_slot: Arc<OnceLock<SessionMetadata>>,
240 agent_socket_path: PathBuf,
241 open_socket_path: PathBuf,
242) -> anyhow::Result<()> {
243 let pty = openpty(None, None)?;
245 let master: OwnedFd = pty.master;
246 let slave: OwnedFd = pty.slave;
247
248 let pty_path =
250 nix::unistd::ttyname(&slave).map(|p| p.display().to_string()).unwrap_or_default();
251
252 let slave_fd = slave.as_raw_fd();
254 let stdin_fd = crate::security::checked_dup(slave_fd)?;
255 let stdout_fd = crate::security::checked_dup(slave_fd)?;
256 let stderr_fd = crate::security::checked_dup(slave_fd)?;
257 let raw_stdin = stdin_fd.as_raw_fd();
258 drop(slave);
259
260 let flags = nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_GETFL)?;
262 let mut oflags = nix::fcntl::OFlag::from_bits_truncate(flags);
263 oflags |= nix::fcntl::OFlag::O_NONBLOCK;
264 nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_SETFL(oflags))?;
265
266 let async_master = AsyncFd::new(master)?;
267 let mut buf = vec![0u8; 4096];
268 let mut ring_buf: VecDeque<Bytes> = VecDeque::new();
269 let mut ring_buf_size: usize = 0;
270 let mut ring_buf_dropped: usize = 0;
271 const RING_BUF_CAP: usize = 1 << 20; let (agent_event_tx, mut agent_event_rx) = mpsc::unbounded_channel::<AgentEvent>();
275
276 let (tail_tx, _) = broadcast::channel::<TailEvent>(256);
278
279 let mut framed = loop {
283 match client_rx.recv().await {
284 Some(ClientConn::Active(f)) => {
285 info!("first client connected via channel");
286 break f;
287 }
288 Some(ClientConn::Tail(f)) => {
289 info!("tail client connected before shell spawn");
290 spawn_tail(f, &ring_buf, &tail_tx);
291 continue;
292 }
293 None => {
294 info!("client channel closed before first client");
295 cleanup_socket(&agent_socket_path);
296 return Ok(());
297 }
298 }
299 };
300
301 let env_vars =
303 match tokio::time::timeout(std::time::Duration::from_secs(2), framed.next()).await {
304 Ok(Some(Ok(Frame::Env { vars }))) => {
305 debug!(count = vars.len(), "received env vars from client");
306 vars
307 }
308 _ => Vec::new(),
309 };
310
311 let shell = std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string());
313 let home = std::env::var("HOME").ok();
314 let mut cmd = Command::new(&shell);
315 cmd.arg("-l");
316 if let Some(ref dir) = home {
317 cmd.current_dir(dir);
318 }
319 const ALLOWED_ENV_KEYS: &[&str] = &["TERM", "LANG", "COLORTERM", "BROWSER"];
320 for (k, v) in &env_vars {
321 if ALLOWED_ENV_KEYS.contains(&k.as_str()) {
322 cmd.env(k, v);
323 } else {
324 warn!(key = k, "ignoring disallowed env var from client");
325 }
326 }
327 cmd.env("SSH_AUTH_SOCK", &agent_socket_path);
329 cmd.env("GRITTY_OPEN_SOCK", &open_socket_path);
331 let mut managed = ManagedChild::new(unsafe {
332 cmd.pre_exec(move || {
333 nix::unistd::setsid().map_err(io::Error::other)?;
334 libc::ioctl(raw_stdin, libc::TIOCSCTTY as libc::c_ulong, 0);
335 Ok(())
336 })
337 .stdin(Stdio::from(stdin_fd))
338 .stdout(Stdio::from(stdout_fd))
339 .stderr(Stdio::from(stderr_fd))
340 .spawn()?
341 });
342
343 let shell_pid = managed.child.id().unwrap_or(0);
344 let created_at = std::time::SystemTime::now()
345 .duration_since(std::time::UNIX_EPOCH)
346 .unwrap_or_default()
347 .as_secs();
348
349 let _ = metadata_slot.set(SessionMetadata {
350 pty_path,
351 shell_pid,
352 created_at,
353 attached: AtomicBool::new(false),
354 last_heartbeat: AtomicU64::new(0),
355 });
356
357 metadata_slot.get().unwrap().attached.store(true, Ordering::Relaxed);
359
360 let mut agent_forward_enabled = false;
362 let mut agent_channels: HashMap<u32, mpsc::UnboundedSender<Bytes>> = HashMap::new();
363 let mut agent_acceptor: Option<tokio::task::JoinHandle<()>> = None;
364 let next_agent_channel_id = Arc::new(AtomicU32::new(0));
365
366 let mut open_forward_enabled = false;
368 let mut open_acceptor: Option<tokio::task::JoinHandle<()>> = None;
369 let (open_event_tx, mut open_event_rx) = mpsc::unbounded_channel::<OpenEvent>();
370
371 let teardown_forwarding =
372 |agent_channels: &mut HashMap<u32, mpsc::UnboundedSender<Bytes>>,
373 agent_forward_enabled: &mut bool,
374 agent_acceptor: &mut Option<tokio::task::JoinHandle<()>>,
375 open_forward_enabled: &mut bool,
376 open_acceptor: &mut Option<tokio::task::JoinHandle<()>>| {
377 agent_channels.clear();
378 *agent_forward_enabled = false;
379 if let Some(handle) = agent_acceptor.take() {
380 handle.abort();
381 }
382 cleanup_socket(&agent_socket_path);
383 *open_forward_enabled = false;
384 if let Some(handle) = open_acceptor.take() {
385 handle.abort();
386 }
387 cleanup_socket(&open_socket_path);
388 };
389
390 let mut first_client = true;
393 loop {
394 if !first_client {
395 let got_client = 'drain: loop {
396 tokio::select! {
397 client = client_rx.recv() => {
398 match client {
399 Some(ClientConn::Active(f)) => {
400 info!("client connected via channel");
401 framed = f;
402 break 'drain true;
403 }
404 Some(ClientConn::Tail(f)) => {
405 info!("tail client connected while disconnected");
406 spawn_tail(f, &ring_buf, &tail_tx);
407 continue;
408 }
409 None => {
410 info!("client channel closed");
411 break 'drain false;
412 }
413 }
414 }
415 status = managed.child.wait() => {
416 let code = status?.code().unwrap_or(1);
417 info!(code, "shell exited while awaiting client");
418 break 'drain false;
419 }
420 ready = async_master.readable() => {
421 let mut guard = ready?;
422 match guard.try_io(|inner| {
423 nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
424 }) {
425 Ok(Ok(0)) => {
426 debug!("pty EOF while disconnected");
427 break 'drain false;
428 }
429 Ok(Ok(n)) => {
430 let chunk = Bytes::copy_from_slice(&buf[..n]);
431 let _ = tail_tx.send(TailEvent::Data(chunk.clone()));
432 ring_buf_size += chunk.len();
433 ring_buf.push_back(chunk);
434 while ring_buf_size > RING_BUF_CAP {
435 if let Some(old) = ring_buf.pop_front() {
436 ring_buf_size -= old.len();
437 ring_buf_dropped += old.len();
438 }
439 }
440 }
441 Ok(Err(e)) => {
442 if e.raw_os_error() == Some(libc::EIO) {
443 debug!("pty EIO while disconnected");
444 break 'drain false;
445 }
446 return Err(e.into());
447 }
448 Err(_would_block) => continue,
449 }
450 }
451 }
452 };
453 if !got_client {
454 break;
455 }
456
457 if let Some(meta) = metadata_slot.get() {
458 meta.attached.store(true, Ordering::Relaxed);
459 }
460 }
461 first_client = false;
462
463 if !ring_buf.is_empty() {
465 debug!(
466 chunks = ring_buf.len(),
467 bytes = ring_buf_size,
468 dropped = ring_buf_dropped,
469 "flushing ring buffer"
470 );
471 if ring_buf_dropped > 0 {
472 let msg = format!("\r\n[gritty: {} bytes of output dropped]\r\n", ring_buf_dropped);
473 framed.send(Frame::Data(Bytes::from(msg))).await?;
474 ring_buf_dropped = 0;
475 }
476 while let Some(chunk) = ring_buf.pop_front() {
477 framed.send(Frame::Data(chunk)).await?;
478 }
479 ring_buf_size = 0;
480 }
481
482 let exit = loop {
484 tokio::select! {
485 frame = framed.next() => {
486 match frame {
487 Some(Ok(Frame::Data(data))) => {
488 debug!(len = data.len(), "socket -> pty");
489 let mut written = 0;
490 while written < data.len() {
491 let mut guard = async_master.writable().await?;
492 match guard.try_io(|inner| {
493 nix::unistd::write(inner, &data[written..]).map_err(io::Error::from)
494 }) {
495 Ok(Ok(n)) => written += n,
496 Ok(Err(e)) => return Err(e.into()),
497 Err(_would_block) => continue,
498 }
499 }
500 }
501 Some(Ok(Frame::Resize { cols, rows })) => {
502 let (cols, rows) = crate::security::clamp_winsize(cols, rows);
503 debug!(cols, rows, "resize pty");
504 let ws = libc::winsize {
505 ws_row: rows,
506 ws_col: cols,
507 ws_xpixel: 0,
508 ws_ypixel: 0,
509 };
510 unsafe {
511 libc::ioctl(
512 async_master.as_raw_fd(),
513 libc::TIOCSWINSZ,
514 &ws as *const _,
515 );
516 }
517 if let Ok(pgid) = nix::unistd::tcgetpgrp(&async_master) {
518 let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGWINCH);
519 }
520 }
521 Some(Ok(Frame::Ping)) => {
522 if let Some(meta) = metadata_slot.get() {
523 let now = std::time::SystemTime::now()
524 .duration_since(std::time::UNIX_EPOCH)
525 .unwrap_or_default()
526 .as_secs();
527 meta.last_heartbeat.store(now, Ordering::Relaxed);
528 }
529 let _ = framed.send(Frame::Pong).await;
530 }
531 Some(Ok(Frame::AgentForward)) => {
532 debug!("agent forwarding enabled by client");
533 agent_forward_enabled = true;
534 if agent_acceptor.is_none() {
536 if let Some(listener) = bind_agent_listener(&agent_socket_path) {
537 agent_acceptor = Some(spawn_agent_acceptor(listener, agent_event_tx.clone(), next_agent_channel_id.clone()));
538 }
539 }
540 }
541 Some(Ok(Frame::AgentData { channel_id, data })) => {
542 if let Some(tx) = agent_channels.get(&channel_id) {
543 let _ = tx.send(data);
544 }
545 }
546 Some(Ok(Frame::AgentClose { channel_id })) => {
547 agent_channels.remove(&channel_id);
549 }
550 Some(Ok(Frame::OpenForward)) => {
551 debug!("open forwarding enabled by client");
552 open_forward_enabled = true;
553 if open_acceptor.is_none() {
554 if let Some(listener) = bind_agent_listener(&open_socket_path) {
555 open_acceptor = Some(spawn_open_acceptor(listener, open_event_tx.clone()));
556 }
557 }
558 }
559 Some(Ok(Frame::Exit { .. })) | None => {
561 break RelayExit::ClientGone;
562 }
563 Some(Ok(_)) => {}
565 Some(Err(e)) => return Err(e.into()),
566 }
567 }
568
569 ready = async_master.readable() => {
570 let mut guard = ready?;
571 match guard.try_io(|inner| {
572 nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
573 }) {
574 Ok(Ok(0)) => {
575 debug!("pty EOF");
576 break RelayExit::ShellExited(0);
577 }
578 Ok(Ok(n)) => {
579 debug!(len = n, "pty -> socket");
580 let chunk = Bytes::copy_from_slice(&buf[..n]);
581 let _ = tail_tx.send(TailEvent::Data(chunk.clone()));
582 framed.send(Frame::Data(chunk)).await?;
583 }
584 Ok(Err(e)) => {
585 if e.raw_os_error() == Some(libc::EIO) {
586 debug!("pty EIO (shell exited)");
587 break RelayExit::ShellExited(0);
588 }
589 return Err(e.into());
590 }
591 Err(_would_block) => continue,
592 }
593 }
594
595 new_client = client_rx.recv() => {
597 match new_client {
598 Some(ClientConn::Active(new_framed)) => {
599 info!("new client via channel, detaching old client");
600 let _ = framed.send(Frame::Detached).await;
601 teardown_forwarding(
602 &mut agent_channels,
603 &mut agent_forward_enabled,
604 &mut agent_acceptor,
605 &mut open_forward_enabled,
606 &mut open_acceptor,
607 );
608 framed = new_framed;
609 }
610 Some(ClientConn::Tail(f)) => {
611 info!("tail client connected while active");
612 spawn_tail(f, &ring_buf, &tail_tx);
613 }
614 None => {}
615 }
616 }
617
618 event = agent_event_rx.recv() => {
620 match event {
621 Some(AgentEvent::Accepted { channel_id, writer_tx }) => {
622 if agent_forward_enabled {
623 agent_channels.insert(channel_id, writer_tx);
624 let _ = framed.send(Frame::AgentOpen { channel_id }).await;
625 }
626 }
628 Some(AgentEvent::Data { channel_id, data }) => {
629 if agent_forward_enabled && agent_channels.contains_key(&channel_id) {
630 let _ = framed.send(Frame::AgentData { channel_id, data }).await;
631 }
632 }
633 Some(AgentEvent::Closed { channel_id }) => {
634 if agent_channels.remove(&channel_id).is_some() {
635 let _ = framed.send(Frame::AgentClose { channel_id }).await;
636 }
637 }
638 None => {
639 debug!("agent event channel closed");
641 }
642 }
643 }
644
645 event = open_event_rx.recv() => {
647 match event {
648 Some(OpenEvent::Url(url)) => {
649 if open_forward_enabled {
650 let _ = framed.send(Frame::OpenUrl { url }).await;
651 }
652 }
653 None => {
654 debug!("open event channel closed");
655 }
656 }
657 }
658
659 status = managed.child.wait() => {
660 let code = status?.code().unwrap_or(1);
661 info!(code, "shell exited");
662 break RelayExit::ShellExited(code);
663 }
664 }
665 };
666
667 match exit {
668 RelayExit::ClientGone => {
669 if let Some(meta) = metadata_slot.get() {
670 meta.attached.store(false, Ordering::Relaxed);
671 }
672 teardown_forwarding(
673 &mut agent_channels,
674 &mut agent_forward_enabled,
675 &mut agent_acceptor,
676 &mut open_forward_enabled,
677 &mut open_acceptor,
678 );
679 info!("client disconnected, waiting for reconnect");
680 continue;
681 }
682 RelayExit::ShellExited(mut code) => {
683 if let Ok(Ok(status)) = tokio::time::timeout(
687 std::time::Duration::from_millis(500),
688 managed.child.wait(),
689 )
690 .await
691 {
692 code = status.code().unwrap_or(code);
693 }
694 let _ = tail_tx.send(TailEvent::Exit { code });
695 let _ = framed.send(Frame::Exit { code }).await;
696 info!(code, "session ended");
697 break;
698 }
699 }
700 }
701
702 cleanup_socket(&agent_socket_path);
703 cleanup_socket(&open_socket_path);
704 Ok(())
705}
706
707fn bind_agent_listener(path: &Path) -> Option<UnixListener> {
708 match crate::security::bind_unix_listener(path) {
709 Ok(listener) => {
710 info!(path = %path.display(), "agent socket listening");
711 Some(listener)
712 }
713 Err(e) => {
714 warn!("failed to bind agent socket at {}: {e}", path.display());
715 None
716 }
717 }
718}
719
720fn cleanup_socket(path: &Path) {
721 let _ = std::fs::remove_file(path);
722}