1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::pty::openpty;
5use std::collections::{HashMap, VecDeque};
6use std::io;
7use std::os::fd::{AsRawFd, OwnedFd};
8use std::path::{Path, PathBuf};
9use std::process::Stdio;
10use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
11use std::sync::{Arc, OnceLock};
12use tokio::io::AsyncReadExt;
13use tokio::io::unix::AsyncFd;
14use tokio::net::{UnixListener, UnixStream};
15use tokio::process::Command;
16use tokio::sync::mpsc;
17use tokio_util::codec::Framed;
18use tracing::{debug, info, warn};
19
20pub struct SessionMetadata {
21 pub pty_path: String,
22 pub shell_pid: u32,
23 pub created_at: u64,
24 pub attached: AtomicBool,
25 pub last_heartbeat: AtomicU64,
26}
27
28struct ManagedChild {
31 child: tokio::process::Child,
32 pgid: nix::unistd::Pid,
33}
34
35impl ManagedChild {
36 fn new(child: tokio::process::Child) -> Self {
37 let pid = child.id().expect("child should have pid") as i32;
38 Self { child, pgid: nix::unistd::Pid::from_raw(pid) }
39 }
40}
41
42impl Drop for ManagedChild {
43 fn drop(&mut self) {
44 let _ = nix::sys::signal::killpg(self.pgid, nix::sys::signal::Signal::SIGHUP);
45 let _ = self.child.try_wait();
46 }
47}
48
49enum RelayExit {
51 ClientGone,
53 ShellExited(i32),
55}
56
57enum AgentEvent {
59 Accepted { channel_id: u32, writer_tx: mpsc::UnboundedSender<Bytes> },
60 Data { channel_id: u32, data: Bytes },
61 Closed { channel_id: u32 },
62}
63
64enum OpenEvent {
66 Url(String),
67}
68
69fn spawn_agent_acceptor(
72 listener: UnixListener,
73 event_tx: mpsc::UnboundedSender<AgentEvent>,
74 next_channel_id: Arc<AtomicU32>,
75) -> tokio::task::JoinHandle<()> {
76 tokio::spawn(async move {
77 loop {
78 let (stream, _) = match listener.accept().await {
79 Ok(conn) => conn,
80 Err(e) => {
81 debug!("agent listener accept error: {e}");
82 break;
83 }
84 };
85
86 let channel_id = next_channel_id.fetch_add(1, Ordering::Relaxed);
87
88 let (read_half, write_half) = stream.into_split();
89 let data_tx = event_tx.clone();
90 let close_tx = event_tx.clone();
91 let writer_tx = crate::spawn_channel_relay(
92 channel_id,
93 read_half,
94 write_half,
95 move |id, data| data_tx.send(AgentEvent::Data { channel_id: id, data }).is_ok(),
96 move |id| {
97 let _ = close_tx.send(AgentEvent::Closed { channel_id: id });
98 },
99 );
100
101 if event_tx.send(AgentEvent::Accepted { channel_id, writer_tx }).is_err() {
103 break; }
105 }
106 })
107}
108
109fn spawn_open_acceptor(
112 listener: UnixListener,
113 event_tx: mpsc::UnboundedSender<OpenEvent>,
114) -> tokio::task::JoinHandle<()> {
115 tokio::spawn(async move {
116 loop {
117 let (mut stream, _) = match listener.accept().await {
118 Ok(conn) => conn,
119 Err(e) => {
120 debug!("open listener accept error: {e}");
121 break;
122 }
123 };
124
125 let etx = event_tx.clone();
126 tokio::spawn(async move {
127 let mut buf = vec![0u8; 8192];
128 let mut total = 0;
129 loop {
130 match stream.read(&mut buf[total..]).await {
131 Ok(0) => break,
132 Ok(n) => {
133 total += n;
134 if buf[..total].contains(&b'\n') || total >= buf.len() {
136 break;
137 }
138 }
139 Err(_) => return,
140 }
141 }
142 let s = String::from_utf8_lossy(&buf[..total]);
143 let url = s.trim();
144 if !url.is_empty() {
145 let _ = etx.send(OpenEvent::Url(url.to_string()));
146 }
147 });
148 }
149 })
150}
151
152pub async fn run(
153 mut client_rx: mpsc::UnboundedReceiver<Framed<UnixStream, FrameCodec>>,
154 metadata_slot: Arc<OnceLock<SessionMetadata>>,
155 agent_socket_path: PathBuf,
156 open_socket_path: PathBuf,
157) -> anyhow::Result<()> {
158 let pty = openpty(None, None)?;
160 let master: OwnedFd = pty.master;
161 let slave: OwnedFd = pty.slave;
162
163 let pty_path =
165 nix::unistd::ttyname(&slave).map(|p| p.display().to_string()).unwrap_or_default();
166
167 let slave_fd = slave.as_raw_fd();
169 let stdin_fd = crate::security::checked_dup(slave_fd)?;
170 let stdout_fd = crate::security::checked_dup(slave_fd)?;
171 let stderr_fd = crate::security::checked_dup(slave_fd)?;
172 let raw_stdin = stdin_fd.as_raw_fd();
173 drop(slave);
174
175 let flags = nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_GETFL)?;
177 let mut oflags = nix::fcntl::OFlag::from_bits_truncate(flags);
178 oflags |= nix::fcntl::OFlag::O_NONBLOCK;
179 nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_SETFL(oflags))?;
180
181 let async_master = AsyncFd::new(master)?;
182 let mut buf = vec![0u8; 4096];
183 let mut ring_buf: VecDeque<Bytes> = VecDeque::new();
184 let mut ring_buf_size: usize = 0;
185 const RING_BUF_CAP: usize = 1 << 20; let (agent_event_tx, mut agent_event_rx) = mpsc::unbounded_channel::<AgentEvent>();
189
190 let mut framed = match client_rx.recv().await {
192 Some(framed) => {
193 info!("first client connected via channel");
194 framed
195 }
196 None => {
197 info!("client channel closed before first client");
198 cleanup_socket(&agent_socket_path);
199 return Ok(());
200 }
201 };
202
203 let env_vars =
205 match tokio::time::timeout(std::time::Duration::from_millis(100), framed.next()).await {
206 Ok(Some(Ok(Frame::Env { vars }))) => {
207 debug!(count = vars.len(), "received env vars from client");
208 vars
209 }
210 _ => Vec::new(),
211 };
212
213 let shell = std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string());
215 let home = std::env::var("HOME").ok();
216 let mut cmd = Command::new(&shell);
217 cmd.arg("-l");
218 if let Some(ref dir) = home {
219 cmd.current_dir(dir);
220 }
221 for (k, v) in &env_vars {
222 cmd.env(k, v);
223 }
224 cmd.env("SSH_AUTH_SOCK", &agent_socket_path);
226 cmd.env("GRITTY_OPEN_SOCK", &open_socket_path);
228 let mut managed = ManagedChild::new(unsafe {
229 cmd.pre_exec(move || {
230 nix::unistd::setsid().map_err(io::Error::other)?;
231 libc::ioctl(raw_stdin, libc::TIOCSCTTY as libc::c_ulong, 0);
232 Ok(())
233 })
234 .stdin(Stdio::from(stdin_fd))
235 .stdout(Stdio::from(stdout_fd))
236 .stderr(Stdio::from(stderr_fd))
237 .spawn()?
238 });
239
240 let shell_pid = managed.child.id().unwrap_or(0);
241 let created_at = std::time::SystemTime::now()
242 .duration_since(std::time::UNIX_EPOCH)
243 .unwrap_or_default()
244 .as_secs();
245
246 let _ = metadata_slot.set(SessionMetadata {
247 pty_path,
248 shell_pid,
249 created_at,
250 attached: AtomicBool::new(false),
251 last_heartbeat: AtomicU64::new(0),
252 });
253
254 metadata_slot.get().unwrap().attached.store(true, Ordering::Relaxed);
256
257 let mut agent_forward_enabled = false;
259 let mut agent_channels: HashMap<u32, mpsc::UnboundedSender<Bytes>> = HashMap::new();
260 let mut agent_acceptor: Option<tokio::task::JoinHandle<()>> = None;
261 let next_agent_channel_id = Arc::new(AtomicU32::new(0));
262
263 let mut open_forward_enabled = false;
265 let mut open_acceptor: Option<tokio::task::JoinHandle<()>> = None;
266 let (open_event_tx, mut open_event_rx) = mpsc::unbounded_channel::<OpenEvent>();
267
268 let teardown_forwarding =
269 |agent_channels: &mut HashMap<u32, mpsc::UnboundedSender<Bytes>>,
270 agent_forward_enabled: &mut bool,
271 agent_acceptor: &mut Option<tokio::task::JoinHandle<()>>,
272 open_forward_enabled: &mut bool,
273 open_acceptor: &mut Option<tokio::task::JoinHandle<()>>| {
274 agent_channels.clear();
275 *agent_forward_enabled = false;
276 if let Some(handle) = agent_acceptor.take() {
277 handle.abort();
278 }
279 cleanup_socket(&agent_socket_path);
280 *open_forward_enabled = false;
281 if let Some(handle) = open_acceptor.take() {
282 handle.abort();
283 }
284 cleanup_socket(&open_socket_path);
285 };
286
287 let mut first_client = true;
290 loop {
291 if !first_client {
292 let got_client = 'drain: loop {
293 tokio::select! {
294 client = client_rx.recv() => {
295 match client {
296 Some(f) => {
297 info!("client connected via channel");
298 framed = f;
299 break 'drain true;
300 }
301 None => {
302 info!("client channel closed");
303 break 'drain false;
304 }
305 }
306 }
307 status = managed.child.wait() => {
308 let code = status?.code().unwrap_or(1);
309 info!(code, "shell exited while awaiting client");
310 break 'drain false;
311 }
312 ready = async_master.readable() => {
313 let mut guard = ready?;
314 match guard.try_io(|inner| {
315 nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
316 }) {
317 Ok(Ok(0)) => {
318 debug!("pty EOF while disconnected");
319 break 'drain false;
320 }
321 Ok(Ok(n)) => {
322 let chunk = Bytes::copy_from_slice(&buf[..n]);
323 ring_buf_size += chunk.len();
324 ring_buf.push_back(chunk);
325 while ring_buf_size > RING_BUF_CAP {
326 if let Some(old) = ring_buf.pop_front() {
327 ring_buf_size -= old.len();
328 }
329 }
330 }
331 Ok(Err(e)) => {
332 if e.raw_os_error() == Some(libc::EIO) {
333 debug!("pty EIO while disconnected");
334 break 'drain false;
335 }
336 return Err(e.into());
337 }
338 Err(_would_block) => continue,
339 }
340 }
341 }
342 };
343 if !got_client {
344 break;
345 }
346
347 if let Some(meta) = metadata_slot.get() {
348 meta.attached.store(true, Ordering::Relaxed);
349 }
350 }
351 first_client = false;
352
353 if !ring_buf.is_empty() {
355 debug!(chunks = ring_buf.len(), bytes = ring_buf_size, "flushing ring buffer");
356 while let Some(chunk) = ring_buf.pop_front() {
357 framed.send(Frame::Data(chunk)).await?;
358 }
359 ring_buf_size = 0;
360 }
361
362 let exit = loop {
364 tokio::select! {
365 frame = framed.next() => {
366 match frame {
367 Some(Ok(Frame::Data(data))) => {
368 debug!(len = data.len(), "socket -> pty");
369 let mut guard = async_master.writable().await?;
370 match guard.try_io(|inner| {
371 nix::unistd::write(inner, &data).map_err(io::Error::from)
372 }) {
373 Ok(Ok(_)) => {}
374 Ok(Err(e)) => return Err(e.into()),
375 Err(_would_block) => continue,
376 }
377 }
378 Some(Ok(Frame::Resize { cols, rows })) => {
379 let (cols, rows) = crate::security::clamp_winsize(cols, rows);
380 debug!(cols, rows, "resize pty");
381 let ws = libc::winsize {
382 ws_row: rows,
383 ws_col: cols,
384 ws_xpixel: 0,
385 ws_ypixel: 0,
386 };
387 unsafe {
388 libc::ioctl(
389 async_master.as_raw_fd(),
390 libc::TIOCSWINSZ,
391 &ws as *const _,
392 );
393 }
394 if let Ok(pgid) = nix::unistd::tcgetpgrp(&async_master) {
395 let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGWINCH);
396 }
397 }
398 Some(Ok(Frame::Ping)) => {
399 if let Some(meta) = metadata_slot.get() {
400 let now = std::time::SystemTime::now()
401 .duration_since(std::time::UNIX_EPOCH)
402 .unwrap_or_default()
403 .as_secs();
404 meta.last_heartbeat.store(now, Ordering::Relaxed);
405 }
406 let _ = framed.send(Frame::Pong).await;
407 }
408 Some(Ok(Frame::AgentForward)) => {
409 debug!("agent forwarding enabled by client");
410 agent_forward_enabled = true;
411 if agent_acceptor.is_none() {
413 if let Some(listener) = bind_agent_listener(&agent_socket_path) {
414 agent_acceptor = Some(spawn_agent_acceptor(listener, agent_event_tx.clone(), next_agent_channel_id.clone()));
415 }
416 }
417 }
418 Some(Ok(Frame::AgentData { channel_id, data })) => {
419 if let Some(tx) = agent_channels.get(&channel_id) {
420 let _ = tx.send(data);
421 }
422 }
423 Some(Ok(Frame::AgentClose { channel_id })) => {
424 agent_channels.remove(&channel_id);
426 }
427 Some(Ok(Frame::OpenForward)) => {
428 debug!("open forwarding enabled by client");
429 open_forward_enabled = true;
430 if open_acceptor.is_none() {
431 if let Some(listener) = bind_agent_listener(&open_socket_path) {
432 open_acceptor = Some(spawn_open_acceptor(listener, open_event_tx.clone()));
433 }
434 }
435 }
436 Some(Ok(Frame::Exit { .. })) | None => {
438 break RelayExit::ClientGone;
439 }
440 Some(Ok(_)) => {}
442 Some(Err(e)) => return Err(e.into()),
443 }
444 }
445
446 ready = async_master.readable() => {
447 let mut guard = ready?;
448 match guard.try_io(|inner| {
449 nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
450 }) {
451 Ok(Ok(0)) => {
452 debug!("pty EOF");
453 break RelayExit::ShellExited(0);
454 }
455 Ok(Ok(n)) => {
456 debug!(len = n, "pty -> socket");
457 framed.send(Frame::Data(Bytes::copy_from_slice(&buf[..n]))).await?;
458 }
459 Ok(Err(e)) => {
460 if e.raw_os_error() == Some(libc::EIO) {
461 debug!("pty EIO (shell exited)");
462 break RelayExit::ShellExited(0);
463 }
464 return Err(e.into());
465 }
466 Err(_would_block) => continue,
467 }
468 }
469
470 new_client = client_rx.recv() => {
472 if let Some(new_framed) = new_client {
473 info!("new client via channel, detaching old client");
474 let _ = framed.send(Frame::Detached).await;
475 teardown_forwarding(
476 &mut agent_channels,
477 &mut agent_forward_enabled,
478 &mut agent_acceptor,
479 &mut open_forward_enabled,
480 &mut open_acceptor,
481 );
482 framed = new_framed;
483 }
484 }
485
486 event = agent_event_rx.recv() => {
488 match event {
489 Some(AgentEvent::Accepted { channel_id, writer_tx }) => {
490 if agent_forward_enabled {
491 agent_channels.insert(channel_id, writer_tx);
492 let _ = framed.send(Frame::AgentOpen { channel_id }).await;
493 }
494 }
496 Some(AgentEvent::Data { channel_id, data }) => {
497 if agent_forward_enabled && agent_channels.contains_key(&channel_id) {
498 let _ = framed.send(Frame::AgentData { channel_id, data }).await;
499 }
500 }
501 Some(AgentEvent::Closed { channel_id }) => {
502 if agent_channels.remove(&channel_id).is_some() {
503 let _ = framed.send(Frame::AgentClose { channel_id }).await;
504 }
505 }
506 None => {
507 debug!("agent event channel closed");
509 }
510 }
511 }
512
513 event = open_event_rx.recv() => {
515 match event {
516 Some(OpenEvent::Url(url)) => {
517 if open_forward_enabled {
518 let _ = framed.send(Frame::OpenUrl { url }).await;
519 }
520 }
521 None => {
522 debug!("open event channel closed");
523 }
524 }
525 }
526
527 status = managed.child.wait() => {
528 let code = status?.code().unwrap_or(1);
529 info!(code, "shell exited");
530 break RelayExit::ShellExited(code);
531 }
532 }
533 };
534
535 match exit {
536 RelayExit::ClientGone => {
537 if let Some(meta) = metadata_slot.get() {
538 meta.attached.store(false, Ordering::Relaxed);
539 }
540 teardown_forwarding(
541 &mut agent_channels,
542 &mut agent_forward_enabled,
543 &mut agent_acceptor,
544 &mut open_forward_enabled,
545 &mut open_acceptor,
546 );
547 info!("client disconnected, waiting for reconnect");
548 continue;
549 }
550 RelayExit::ShellExited(mut code) => {
551 if let Ok(Some(status)) = managed.child.try_wait() {
554 code = status.code().unwrap_or(code);
555 }
556 let _ = framed.send(Frame::Exit { code }).await;
557 info!(code, "session ended");
558 break;
559 }
560 }
561 }
562
563 cleanup_socket(&agent_socket_path);
564 cleanup_socket(&open_socket_path);
565 Ok(())
566}
567
568fn bind_agent_listener(path: &Path) -> Option<UnixListener> {
569 match crate::security::bind_unix_listener(path) {
570 Ok(listener) => {
571 info!(path = %path.display(), "agent socket listening");
572 Some(listener)
573 }
574 Err(e) => {
575 warn!("failed to bind agent socket at {}: {e}", path.display());
576 None
577 }
578 }
579}
580
581fn cleanup_socket(path: &Path) {
582 let _ = std::fs::remove_file(path);
583}