1use crate::protocol::{Frame, FrameCodec};
2use bytes::Bytes;
3use futures_util::{SinkExt, StreamExt};
4use nix::pty::openpty;
5use std::collections::{HashMap, VecDeque};
6use std::io;
7use std::ops::ControlFlow;
8use std::os::fd::{AsRawFd, OwnedFd};
9use std::path::{Path, PathBuf};
10use std::process::Stdio;
11use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
12use std::sync::{Arc, OnceLock};
13use std::time::Duration;
14use tokio::io::AsyncReadExt;
15use tokio::io::unix::AsyncFd;
16use tokio::net::{UnixListener, UnixStream};
17use tokio::process::Command;
18use tokio::sync::{broadcast, mpsc};
19use tokio_util::codec::Framed;
20use tracing::{debug, info, warn};
21
22pub enum ClientConn {
24 Active(Framed<UnixStream, FrameCodec>),
25 Tail(Framed<UnixStream, FrameCodec>),
26 Send(UnixStream),
28}
29
30#[derive(Clone)]
32enum TailEvent {
33 Data(Bytes),
34 Exit { code: i32 },
35}
36
37pub struct SessionMetadata {
38 pub pty_path: String,
39 pub shell_pid: u32,
40 pub created_at: u64,
41 pub attached: AtomicBool,
42 pub last_heartbeat: AtomicU64,
43}
44
45struct ManagedChild {
48 child: tokio::process::Child,
49 pgid: nix::unistd::Pid,
50}
51
52impl ManagedChild {
53 fn new(child: tokio::process::Child) -> Self {
54 let pid = child.id().expect("child should have pid") as i32;
55 Self { child, pgid: nix::unistd::Pid::from_raw(pid) }
56 }
57}
58
59impl Drop for ManagedChild {
60 fn drop(&mut self) {
61 let _ = nix::sys::signal::killpg(self.pgid, nix::sys::signal::Signal::SIGHUP);
62 let _ = self.child.try_wait();
63 }
64}
65
66enum RelayExit {
68 ClientGone,
70 ShellExited(i32),
72}
73
74enum AgentEvent {
76 Accepted { channel_id: u32, writer_tx: mpsc::Sender<Bytes> },
77 Data { channel_id: u32, data: Bytes },
78 Closed { channel_id: u32 },
79}
80
81enum OpenEvent {
83 Url(String),
84}
85
86enum TunnelEvent {
88 Connected { channel_id: u32, stream: tokio::net::TcpStream },
89 ConnectFailed { channel_id: u32 },
90 Data { channel_id: u32, data: Bytes },
91 Closed { channel_id: u32 },
92}
93
94enum PortForwardEvent {
96 Requested { stream: UnixStream, direction: u8, listen_port: u16, target_port: u16 },
98 Accepted { forward_id: u32, channel_id: u32, writer_tx: mpsc::Sender<Bytes> },
100 Connected { forward_id: u32, channel_id: u32, stream: tokio::net::TcpStream },
102 ConnectFailed { channel_id: u32 },
104 Data { channel_id: u32, data: Bytes },
106 Closed { channel_id: u32 },
108 Stopped { forward_id: u32 },
110}
111
112struct PortForwardState {
114 listener_handle: Option<tokio::task::JoinHandle<()>>,
116 channels: HashMap<u32, mpsc::Sender<Bytes>>,
118 stop_handle: Option<tokio::task::JoinHandle<()>>,
120 target_port: u16,
121}
122
123struct AgentForwardState {
125 enabled: bool,
126 channels: HashMap<u32, mpsc::Sender<Bytes>>,
127 acceptor: Option<tokio::task::JoinHandle<()>>,
128 next_channel_id: Arc<AtomicU32>,
129 socket_path: PathBuf,
130}
131
132impl AgentForwardState {
133 fn new(socket_path: PathBuf) -> Self {
134 Self {
135 enabled: false,
136 channels: HashMap::new(),
137 acceptor: None,
138 next_channel_id: Arc::new(AtomicU32::new(0)),
139 socket_path,
140 }
141 }
142
143 fn teardown(&mut self) {
144 self.channels.clear();
145 self.enabled = false;
146 if let Some(handle) = self.acceptor.take() {
147 handle.abort();
148 }
149 cleanup_socket(&self.socket_path);
150 }
151}
152
153struct TunnelRelayState {
155 port: Option<u16>,
156 channels: HashMap<u32, mpsc::Sender<Bytes>>,
157 idle_deadline: Option<tokio::time::Instant>,
158 idle_timeout: Duration,
159}
160
161impl TunnelRelayState {
162 fn new(idle_timeout: Duration) -> Self {
163 Self { port: None, channels: HashMap::new(), idle_deadline: None, idle_timeout }
164 }
165
166 fn teardown(&mut self) {
167 self.channels.clear();
168 self.port = None;
169 self.idle_deadline = None;
170 }
171}
172
173struct PortForwardTable {
175 forwards: HashMap<u32, PortForwardState>,
176 channels: HashMap<u32, (u32, mpsc::Sender<Bytes>)>,
177 pending_remote: HashMap<u32, UnixStream>,
178 next_forward_id: u32,
179 next_channel_id: Arc<AtomicU32>,
180}
181
182impl PortForwardTable {
183 fn new() -> Self {
184 Self {
185 forwards: HashMap::new(),
186 channels: HashMap::new(),
187 pending_remote: HashMap::new(),
188 next_forward_id: 0,
189 next_channel_id: Arc::new(AtomicU32::new(0)),
190 }
191 }
192
193 fn teardown(&mut self) {
194 for (_, pf) in self.forwards.drain() {
195 if let Some(h) = pf.listener_handle {
196 h.abort();
197 }
198 if let Some(h) = pf.stop_handle {
199 h.abort();
200 }
201 }
202 self.channels.clear();
203 self.pending_remote.clear();
204 }
205}
206
207struct FileManifest {
209 files: Vec<(String, u64)>, }
211
212impl FileManifest {
213 fn total_bytes(&self) -> u64 {
214 self.files.iter().map(|(_, s)| s).sum()
215 }
216}
217
218enum SendEvent {
220 SenderArrived { stream: UnixStream, manifest: FileManifest },
221 ReceiverArrived { stream: UnixStream },
222}
223
224enum TransferState {
226 Idle,
227 WaitingForReceiver { sender_stream: UnixStream, manifest: FileManifest },
228 WaitingForSender { receiver_stream: UnixStream },
229 Active { relay_handle: tokio::task::JoinHandle<()> },
230}
231
232fn sanitize_filename(name: &str) -> Option<String> {
234 let basename = std::path::Path::new(name).file_name().and_then(|n| n.to_str()).unwrap_or(name);
235 if basename.is_empty() || basename == ".." || basename == "." {
236 return None;
237 }
238 if basename.contains('\0') || basename.contains('\\') {
239 return None;
240 }
241 Some(basename.to_string())
242}
243
244fn extract_redirect_port(url_str: &str) -> Option<u16> {
247 let parsed = url::Url::parse(url_str).ok()?;
248 for (key, value) in parsed.query_pairs() {
249 if key != "redirect_uri" && key != "redirect_url" {
250 continue;
251 }
252 let redirect = url::Url::parse(&value).ok()?;
253 match redirect.host_str()? {
254 "localhost" | "127.0.0.1" => return redirect.port(),
255 _ => {}
256 }
257 }
258 None
259}
260
261fn port_in_use(port: u16) -> bool {
263 std::net::TcpListener::bind(("127.0.0.1", port)).is_err()
264}
265
266fn spawn_agent_acceptor(
269 listener: UnixListener,
270 event_tx: mpsc::UnboundedSender<AgentEvent>,
271 next_channel_id: Arc<AtomicU32>,
272) -> tokio::task::JoinHandle<()> {
273 tokio::spawn(async move {
274 loop {
275 let (stream, _) = match listener.accept().await {
276 Ok(conn) => conn,
277 Err(e) => {
278 debug!("agent listener accept error: {e}");
279 break;
280 }
281 };
282
283 if let Err(e) = crate::security::verify_peer_uid(&stream) {
284 warn!("agent socket: {e}");
285 continue;
286 }
287
288 let channel_id = next_channel_id.fetch_add(1, Ordering::Relaxed);
289
290 let (read_half, write_half) = stream.into_split();
291 let data_tx = event_tx.clone();
292 let close_tx = event_tx.clone();
293 let writer_tx = crate::spawn_channel_relay(
294 channel_id,
295 read_half,
296 write_half,
297 move |id, data| data_tx.send(AgentEvent::Data { channel_id: id, data }).is_ok(),
298 move |id| {
299 let _ = close_tx.send(AgentEvent::Closed { channel_id: id });
300 },
301 );
302
303 if event_tx.send(AgentEvent::Accepted { channel_id, writer_tx }).is_err() {
305 break; }
307 }
308 })
309}
310
311fn spawn_pf_tcp_acceptor(
314 listener: tokio::net::TcpListener,
315 forward_id: u32,
316 next_channel_id: Arc<AtomicU32>,
317 event_tx: mpsc::UnboundedSender<PortForwardEvent>,
318) -> tokio::task::JoinHandle<()> {
319 tokio::spawn(async move {
320 loop {
321 let (stream, _) = match listener.accept().await {
322 Ok(conn) => conn,
323 Err(e) => {
324 debug!(forward_id, "pf tcp listener accept error: {e}");
325 break;
326 }
327 };
328
329 let channel_id = next_channel_id.fetch_add(1, Ordering::Relaxed);
330 let (read_half, write_half) = stream.into_split();
331 let data_tx = event_tx.clone();
332 let close_tx = event_tx.clone();
333 let writer_tx = crate::spawn_channel_relay(
334 channel_id,
335 read_half,
336 write_half,
337 move |id, data| {
338 data_tx.send(PortForwardEvent::Data { channel_id: id, data }).is_ok()
339 },
340 move |id| {
341 let _ = close_tx.send(PortForwardEvent::Closed { channel_id: id });
342 },
343 );
344
345 if event_tx
346 .send(PortForwardEvent::Accepted { forward_id, channel_id, writer_tx })
347 .is_err()
348 {
349 break;
350 }
351 }
352 })
353}
354
355fn spawn_pf_svc_watcher(
357 stream: UnixStream,
358 forward_id: u32,
359 event_tx: mpsc::UnboundedSender<PortForwardEvent>,
360) -> tokio::task::JoinHandle<()> {
361 tokio::spawn(async move {
362 let mut stream = stream;
363 let mut buf = [0u8; 1];
364 let _ = stream.read(&mut buf).await;
366 let _ = event_tx.send(PortForwardEvent::Stopped { forward_id });
367 })
368}
369
370const URL_MAX_LEN: usize = 4096;
372
373async fn parse_sender_manifest(stream: &mut UnixStream) -> io::Result<FileManifest> {
377 let mut buf4 = [0u8; 4];
378 stream.read_exact(&mut buf4).await?;
379 let file_count = u32::from_be_bytes(buf4);
380 if file_count == 0 || file_count > 10_000 {
381 return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid file count"));
382 }
383 let mut files = Vec::with_capacity(file_count as usize);
384 for _ in 0..file_count {
385 let mut buf2 = [0u8; 2];
386 stream.read_exact(&mut buf2).await?;
387 let name_len = u16::from_be_bytes(buf2) as usize;
388 if name_len == 0 || name_len > 4096 {
389 return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid filename length"));
390 }
391 let mut name_buf = vec![0u8; name_len];
392 stream.read_exact(&mut name_buf).await?;
393 let name = String::from_utf8(name_buf)
394 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
395 let name = match sanitize_filename(&name) {
396 Some(n) => n,
397 None => {
398 return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid filename"));
399 }
400 };
401 let mut buf8 = [0u8; 8];
402 stream.read_exact(&mut buf8).await?;
403 let file_size = u64::from_be_bytes(buf8);
404 files.push((name, file_size));
405 }
406 Ok(FileManifest { files })
407}
408
409async fn parse_receiver_dest(stream: &mut UnixStream) -> io::Result<String> {
412 let mut buf = Vec::with_capacity(256);
413 let mut byte = [0u8; 1];
414 loop {
415 match stream.read_exact(&mut byte).await {
416 Ok(_) => {
417 if byte[0] == b'\n' {
418 break;
419 }
420 buf.push(byte[0]);
421 if buf.len() > 4096 {
422 return Err(io::Error::new(io::ErrorKind::InvalidData, "dest dir too long"));
423 }
424 }
425 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
426 Err(e) => return Err(e),
427 }
428 }
429 String::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
430}
431
432fn spawn_svc_acceptor(
435 listener: UnixListener,
436 open_event_tx: mpsc::UnboundedSender<OpenEvent>,
437 send_event_tx: mpsc::UnboundedSender<SendEvent>,
438 pf_event_tx: mpsc::UnboundedSender<PortForwardEvent>,
439) -> tokio::task::JoinHandle<()> {
440 tokio::spawn(async move {
441 loop {
442 let (mut stream, _) = match listener.accept().await {
443 Ok(conn) => conn,
444 Err(e) => {
445 debug!("svc listener accept error: {e}");
446 break;
447 }
448 };
449
450 match crate::security::verify_peer_uid(&stream) {
454 Ok(()) => {}
455 Err(e) if e.kind() == std::io::ErrorKind::PermissionDenied => {
456 warn!("svc socket: {e}");
457 continue;
458 }
459 Err(e) => {
460 debug!("svc socket peer_cred unavailable: {e}");
461 }
462 }
463
464 let otx = open_event_tx.clone();
465 let stx = send_event_tx.clone();
466 let ptx = pf_event_tx.clone();
467 tokio::spawn(async move {
468 let mut disc = [0u8; 1];
470 if stream.read_exact(&mut disc).await.is_err() {
471 return;
472 }
473 match crate::protocol::SvcRequest::from_byte(disc[0]) {
474 Some(crate::protocol::SvcRequest::OpenUrl) => {
475 let mut buf = vec![0u8; URL_MAX_LEN];
476 let mut total = 0;
477 loop {
478 match stream.read(&mut buf[total..]).await {
479 Ok(0) => break,
480 Ok(n) => {
481 total += n;
482 if buf[..total].contains(&b'\n') || total >= buf.len() {
483 break;
484 }
485 }
486 Err(_) => return,
487 }
488 }
489 let s = String::from_utf8_lossy(&buf[..total]);
490 let url = s.trim();
491 if !url.is_empty() {
492 let _ = otx.send(OpenEvent::Url(url.to_string()));
493 }
494 }
495 Some(crate::protocol::SvcRequest::Send) => {
496 match parse_sender_manifest(&mut stream).await {
497 Ok(manifest) => {
498 let _ = stx.send(SendEvent::SenderArrived { stream, manifest });
499 }
500 Err(e) => debug!("svc socket: bad sender manifest: {e}"),
501 }
502 }
503 Some(crate::protocol::SvcRequest::Receive) => {
504 match parse_receiver_dest(&mut stream).await {
505 Ok(_) => {
506 let _ = stx.send(SendEvent::ReceiverArrived { stream });
507 }
508 Err(e) => debug!("svc socket: bad receiver dest: {e}"),
509 }
510 }
511 Some(crate::protocol::SvcRequest::PortForward) => {
512 let mut hdr = [0u8; 5];
514 if stream.read_exact(&mut hdr).await.is_err() {
515 return;
516 }
517 let direction = hdr[0];
518 let listen_port = u16::from_be_bytes([hdr[1], hdr[2]]);
519 let target_port = u16::from_be_bytes([hdr[3], hdr[4]]);
520 let _ = ptx.send(PortForwardEvent::Requested {
521 stream,
522 direction,
523 listen_port,
524 target_port,
525 });
526 }
527 None => {
528 debug!("svc socket: unknown request byte: 0x{:02x}", disc[0]);
529 }
530 }
531 });
532 }
533 })
534}
535
536fn spawn_transfer_relay(
539 mut sender: UnixStream,
540 mut receiver: UnixStream,
541 manifest: FileManifest,
542 notify_tx: mpsc::UnboundedSender<Frame>,
543) -> tokio::task::JoinHandle<()> {
544 tokio::spawn(async move {
545 use tokio::io::AsyncWriteExt;
546
547 let file_count = manifest.files.len() as u32;
548 let total_bytes = manifest.total_bytes();
549
550 let _ = notify_tx.send(Frame::SendOffer { file_count, total_bytes });
552
553 if sender.write_all(&[0x01]).await.is_err() {
555 let _ = notify_tx.send(Frame::SendCancel { reason: "sender disconnected".into() });
556 return;
557 }
558
559 if receiver.write_all(&file_count.to_be_bytes()).await.is_err() {
561 let _ = notify_tx.send(Frame::SendCancel { reason: "receiver disconnected".into() });
562 return;
563 }
564
565 let mut buf = vec![0u8; 64 * 1024];
567 for (name, size) in &manifest.files {
568 let name_bytes = name.as_bytes();
570 let mut hdr = Vec::with_capacity(2 + name_bytes.len() + 8);
571 hdr.extend_from_slice(&(name_bytes.len() as u16).to_be_bytes());
572 hdr.extend_from_slice(name_bytes);
573 hdr.extend_from_slice(&size.to_be_bytes());
574 if receiver.write_all(&hdr).await.is_err() {
575 let _ =
576 notify_tx.send(Frame::SendCancel { reason: "receiver disconnected".into() });
577 return;
578 }
579
580 let mut remaining = *size;
582 while remaining > 0 {
583 let to_read = (remaining as usize).min(buf.len());
584 match sender.read_exact(&mut buf[..to_read]).await {
585 Ok(_) => {
586 if receiver.write_all(&buf[..to_read]).await.is_err() {
587 let _ = notify_tx
588 .send(Frame::SendCancel { reason: "receiver disconnected".into() });
589 return;
590 }
591 remaining -= to_read as u64;
592 }
593 Err(_) => {
594 let _ = notify_tx
595 .send(Frame::SendCancel { reason: "sender disconnected".into() });
596 return;
597 }
598 }
599 }
600 }
601
602 if receiver.write_all(&[0u8; 2]).await.is_err() {
604 let _ = notify_tx.send(Frame::SendCancel { reason: "receiver disconnected".into() });
605 return;
606 }
607
608 let _ = notify_tx.send(Frame::SendDone);
609 })
610}
611
612async fn tail_relay(
614 mut framed: Framed<UnixStream, FrameCodec>,
615 mut rx: broadcast::Receiver<TailEvent>,
616) {
617 loop {
618 tokio::select! {
619 event = rx.recv() => match event {
620 Ok(TailEvent::Data(chunk)) => {
621 if framed.send(Frame::Data(chunk)).await.is_err() { break; }
622 }
623 Ok(TailEvent::Exit { code }) => {
624 let _ = framed.send(Frame::Exit { code }).await;
625 break;
626 }
627 Err(broadcast::error::RecvError::Lagged(_)) => continue,
628 Err(broadcast::error::RecvError::Closed) => break,
629 },
630 frame = framed.next() => match frame {
631 Some(Ok(Frame::Ping)) => { let _ = framed.send(Frame::Pong).await; }
632 _ => break,
633 },
634 }
635 }
636}
637
638fn spawn_tail(
640 mut framed: Framed<UnixStream, FrameCodec>,
641 ring_buf: &VecDeque<Bytes>,
642 tail_tx: &broadcast::Sender<TailEvent>,
643) {
644 let rx = tail_tx.subscribe();
645 let chunks: Vec<Bytes> = ring_buf.iter().cloned().collect();
646 tokio::spawn(async move {
647 for chunk in chunks {
648 if framed.send(Frame::Data(chunk)).await.is_err() {
649 return;
650 }
651 }
652 tail_relay(framed, rx).await;
653 });
654}
655
656struct ServerRelay<'a> {
660 async_master: &'a AsyncFd<OwnedFd>,
661 agent: &'a mut AgentForwardState,
662 tunnel: &'a mut TunnelRelayState,
663 pf: &'a mut PortForwardTable,
664 transfer_state: &'a mut TransferState,
665 open_forward_enabled: &'a mut bool,
666 tail_tx: &'a broadcast::Sender<TailEvent>,
667 metadata_slot: &'a Arc<OnceLock<SessionMetadata>>,
668 agent_event_tx: &'a mpsc::UnboundedSender<AgentEvent>,
669 tunnel_event_tx: &'a mpsc::UnboundedSender<TunnelEvent>,
670 pf_event_tx: &'a mpsc::UnboundedSender<PortForwardEvent>,
671 send_notify_tx: &'a mpsc::UnboundedSender<Frame>,
672}
673
674impl ServerRelay<'_> {
675 async fn handle_client_frame(
676 &mut self,
677 framed: &mut Framed<UnixStream, FrameCodec>,
678 frame: Option<Result<Frame, io::Error>>,
679 ) -> Result<ControlFlow<RelayExit>, anyhow::Error> {
680 match frame {
681 Some(Ok(Frame::Data(data))) => {
682 debug!(len = data.len(), "socket -> pty");
683 let mut written = 0;
684 while written < data.len() {
685 let mut guard = self.async_master.writable().await?;
686 match guard.try_io(|inner| {
687 nix::unistd::write(inner, &data[written..]).map_err(io::Error::from)
688 }) {
689 Ok(Ok(n)) => written += n,
690 Ok(Err(e)) => return Err(e.into()),
691 Err(_would_block) => continue,
692 }
693 }
694 }
695 Some(Ok(Frame::Resize { cols, rows })) => {
696 let (cols, rows) = crate::security::clamp_winsize(cols, rows);
697 debug!(cols, rows, "resize pty");
698 let ws = libc::winsize { ws_row: rows, ws_col: cols, ws_xpixel: 0, ws_ypixel: 0 };
699 unsafe {
700 libc::ioctl(self.async_master.as_raw_fd(), libc::TIOCSWINSZ, &ws as *const _);
701 }
702 if let Ok(pgid) = nix::unistd::tcgetpgrp(self.async_master) {
703 let _ = nix::sys::signal::killpg(pgid, nix::sys::signal::Signal::SIGWINCH);
704 }
705 }
706 Some(Ok(Frame::Ping)) => {
707 if let Some(meta) = self.metadata_slot.get() {
708 let now = std::time::SystemTime::now()
709 .duration_since(std::time::UNIX_EPOCH)
710 .unwrap_or_default()
711 .as_secs();
712 meta.last_heartbeat.store(now, Ordering::Relaxed);
713 }
714 let _ = framed.send(Frame::Pong).await;
715 }
716 Some(Ok(Frame::AgentForward)) => {
717 debug!("agent forwarding enabled by client");
718 self.agent.enabled = true;
719 if self.agent.acceptor.is_none() {
720 if let Some(listener) = bind_agent_listener(&self.agent.socket_path) {
721 self.agent.acceptor = Some(spawn_agent_acceptor(
722 listener,
723 self.agent_event_tx.clone(),
724 self.agent.next_channel_id.clone(),
725 ));
726 }
727 }
728 }
729 Some(Ok(Frame::AgentData { channel_id, data })) => {
730 if let Some(tx) = self.agent.channels.get(&channel_id) {
731 let _ = tx.send(data).await;
732 }
733 }
734 Some(Ok(Frame::AgentClose { channel_id })) => {
735 self.agent.channels.remove(&channel_id);
736 }
737 Some(Ok(Frame::OpenForward)) => {
738 debug!("open forwarding enabled by client");
739 *self.open_forward_enabled = true;
740 }
741 Some(Ok(Frame::TunnelOpen { channel_id })) => {
742 if let Some(port) = self.tunnel.port {
743 self.tunnel.idle_deadline = None;
744 let tx = self.tunnel_event_tx.clone();
745 tokio::spawn(async move {
746 match tokio::net::TcpStream::connect(("127.0.0.1", port)).await {
747 Ok(stream) => {
748 let _ = tx.send(TunnelEvent::Connected { channel_id, stream });
749 }
750 Err(_) => {
751 let _ = tx.send(TunnelEvent::ConnectFailed { channel_id });
752 }
753 }
754 });
755 }
756 }
757 Some(Ok(Frame::TunnelData { channel_id, data })) => {
758 if let Some(tx) = self.tunnel.channels.get(&channel_id) {
759 let _ = tx.send(data).await;
760 }
761 }
762 Some(Ok(Frame::TunnelClose { channel_id })) => {
763 self.tunnel.channels.remove(&channel_id);
764 if self.tunnel.channels.is_empty() && self.tunnel.port.is_some() {
765 self.tunnel.idle_deadline =
766 Some(tokio::time::Instant::now() + self.tunnel.idle_timeout);
767 }
768 }
769 Some(Ok(Frame::PortForwardReady { forward_id })) => {
770 if let Some(mut svc_stream) = self.pf.pending_remote.remove(&forward_id) {
771 use tokio::io::AsyncWriteExt;
772 let _ = svc_stream.write_all(&[0x01]).await;
773 let stop_handle =
774 spawn_pf_svc_watcher(svc_stream, forward_id, self.pf_event_tx.clone());
775 if let Some(fwd) = self.pf.forwards.get_mut(&forward_id) {
776 fwd.stop_handle = Some(stop_handle);
777 }
778 }
779 }
780 Some(Ok(Frame::PortForwardOpen { forward_id, channel_id, target_port })) => {
781 if self.pf.forwards.contains_key(&forward_id) {
782 let tx = self.pf_event_tx.clone();
783 tokio::spawn(async move {
784 match tokio::net::TcpStream::connect(("127.0.0.1", target_port)).await {
785 Ok(stream) => {
786 let _ = tx.send(PortForwardEvent::Connected {
787 forward_id,
788 channel_id,
789 stream,
790 });
791 }
792 Err(_) => {
793 let _ = tx.send(PortForwardEvent::ConnectFailed { channel_id });
794 }
795 }
796 });
797 }
798 }
799 Some(Ok(Frame::PortForwardData { channel_id, data })) => {
800 if let Some((_, tx)) = self.pf.channels.get(&channel_id) {
801 let _ = tx.send(data).await;
802 }
803 }
804 Some(Ok(Frame::PortForwardClose { channel_id })) => {
805 if let Some((forward_id, _)) = self.pf.channels.remove(&channel_id) {
806 if let Some(fwd) = self.pf.forwards.get_mut(&forward_id) {
807 fwd.channels.remove(&channel_id);
808 }
809 }
810 }
811 Some(Ok(Frame::PortForwardStop { forward_id })) => {
812 if let Some(mut svc_stream) = self.pf.pending_remote.remove(&forward_id) {
813 use tokio::io::AsyncWriteExt;
814 let _ = svc_stream.write_all(&[0x02]).await;
815 let _ = svc_stream.write_all(b"client declined forward").await;
816 }
817 if let Some(fwd) = self.pf.forwards.remove(&forward_id) {
818 if let Some(h) = fwd.listener_handle {
819 h.abort();
820 }
821 if let Some(h) = fwd.stop_handle {
822 h.abort();
823 }
824 for ch_id in fwd.channels.keys() {
825 self.pf.channels.remove(ch_id);
826 }
827 }
828 }
829 Some(Ok(Frame::Exit { .. })) | None => {
830 return Ok(ControlFlow::Break(RelayExit::ClientGone));
831 }
832 Some(Ok(_)) => {}
833 Some(Err(e)) => return Err(e.into()),
834 }
835 Ok(ControlFlow::Continue(()))
836 }
837
838 async fn handle_agent_event(
839 &mut self,
840 framed: &mut Framed<UnixStream, FrameCodec>,
841 event: Option<AgentEvent>,
842 ) {
843 match event {
844 Some(AgentEvent::Accepted { channel_id, writer_tx }) => {
845 if self.agent.enabled {
846 self.agent.channels.insert(channel_id, writer_tx);
847 let _ = framed.send(Frame::AgentOpen { channel_id }).await;
848 }
849 }
850 Some(AgentEvent::Data { channel_id, data }) => {
851 if self.agent.enabled && self.agent.channels.contains_key(&channel_id) {
852 let _ = framed.send(Frame::AgentData { channel_id, data }).await;
853 }
854 }
855 Some(AgentEvent::Closed { channel_id }) => {
856 if self.agent.channels.remove(&channel_id).is_some() {
857 let _ = framed.send(Frame::AgentClose { channel_id }).await;
858 }
859 }
860 None => {
861 debug!("agent event channel closed");
862 }
863 }
864 }
865
866 async fn handle_open_event(
867 &mut self,
868 framed: &mut Framed<UnixStream, FrameCodec>,
869 event: Option<OpenEvent>,
870 ) {
871 match event {
872 Some(OpenEvent::Url(url)) => {
873 if *self.open_forward_enabled {
874 if let Some(port) = extract_redirect_port(&url) {
875 if port_in_use(port) {
876 debug!(port, "setting up reverse tunnel for OAuth callback");
877 self.tunnel.port = Some(port);
878 let _ = framed.send(Frame::TunnelListen { port }).await;
879 }
880 }
881 let _ = framed.send(Frame::OpenUrl { url }).await;
882 }
883 }
884 None => {
885 debug!("open event channel closed");
886 }
887 }
888 }
889
890 async fn handle_tunnel_event(
891 &mut self,
892 framed: &mut Framed<UnixStream, FrameCodec>,
893 event: Option<TunnelEvent>,
894 ) {
895 match event {
896 Some(TunnelEvent::Connected { channel_id, stream }) => {
897 if self.tunnel.port.is_some() {
898 debug!(channel_id, "tunnel channel connected");
899 self.tunnel.idle_deadline = None;
900 let (mut read_half, write_half) = stream.into_split();
901 let (writer_tx, mut writer_rx) =
902 mpsc::channel::<Bytes>(crate::CHANNEL_RELAY_BUFFER);
903 self.tunnel.channels.insert(channel_id, writer_tx);
904
905 tokio::spawn(async move {
907 use tokio::io::AsyncWriteExt;
908 let mut writer = write_half;
909 while let Some(data) = writer_rx.recv().await {
910 if writer.write_all(&data).await.is_err() {
911 break;
912 }
913 }
914 });
915
916 let tx = self.tunnel_event_tx.clone();
918 tokio::spawn(async move {
919 let mut buf = vec![0u8; 4096];
920 loop {
921 match read_half.read(&mut buf).await {
922 Ok(0) | Err(_) => {
923 let _ = tx.send(TunnelEvent::Closed { channel_id });
924 break;
925 }
926 Ok(n) => {
927 let data = Bytes::copy_from_slice(&buf[..n]);
928 if tx.send(TunnelEvent::Data { channel_id, data }).is_err() {
929 break;
930 }
931 }
932 }
933 }
934 });
935 }
936 }
937 Some(TunnelEvent::ConnectFailed { channel_id }) => {
938 let _ = framed.send(Frame::TunnelClose { channel_id }).await;
939 }
940 Some(TunnelEvent::Data { channel_id, data }) => {
941 let _ = framed.send(Frame::TunnelData { channel_id, data }).await;
942 }
943 Some(TunnelEvent::Closed { channel_id }) => {
944 self.tunnel.channels.remove(&channel_id);
945 let _ = framed.send(Frame::TunnelClose { channel_id }).await;
946 if self.tunnel.channels.is_empty() && self.tunnel.port.is_some() {
947 self.tunnel.idle_deadline =
948 Some(tokio::time::Instant::now() + self.tunnel.idle_timeout);
949 }
950 }
951 None => {}
952 }
953 }
954
955 async fn handle_send_notification(
956 &mut self,
957 framed: &mut Framed<UnixStream, FrameCodec>,
958 notification: Option<Frame>,
959 ) {
960 if let Some(frame) = notification {
961 if matches!(frame, Frame::SendDone | Frame::SendCancel { .. }) {
962 *self.transfer_state = TransferState::Idle;
963 }
964 let _ = framed.send(frame).await;
965 }
966 }
967
968 async fn handle_pf_event(
969 &mut self,
970 framed: &mut Framed<UnixStream, FrameCodec>,
971 event: Option<PortForwardEvent>,
972 ) {
973 match event {
974 Some(PortForwardEvent::Requested { stream, direction, listen_port, target_port }) => {
975 use tokio::io::AsyncWriteExt;
976 let fwd_id = self.pf.next_forward_id;
977 self.pf.next_forward_id += 1;
978 if direction == 0 {
979 match tokio::net::TcpListener::bind(("127.0.0.1", listen_port)).await {
981 Ok(listener) => {
982 debug!(fwd_id, listen_port, target_port, "local-forward: bound");
983 let handle = spawn_pf_tcp_acceptor(
984 listener,
985 fwd_id,
986 self.pf.next_channel_id.clone(),
987 self.pf_event_tx.clone(),
988 );
989 let mut s = stream;
990 let _ = s.write_all(&[0x01]).await;
991 let stream = s;
992 let stop_handle =
993 spawn_pf_svc_watcher(stream, fwd_id, self.pf_event_tx.clone());
994 self.pf.forwards.insert(
995 fwd_id,
996 PortForwardState {
997 listener_handle: Some(handle),
998 channels: HashMap::new(),
999 stop_handle: Some(stop_handle),
1000 target_port,
1001 },
1002 );
1003 }
1004 Err(e) => {
1005 debug!(listen_port, "local-forward: bind failed: {e}");
1006 let mut s = stream;
1007 let msg = format!("bind failed: {e}");
1008 let _ = s.write_all(&[0x02]).await;
1009 let _ = s.write_all(msg.as_bytes()).await;
1010 }
1011 }
1012 } else {
1013 let _ = framed
1015 .send(Frame::PortForwardListen {
1016 forward_id: fwd_id,
1017 listen_port,
1018 target_port,
1019 })
1020 .await;
1021 self.pf.pending_remote.insert(fwd_id, stream);
1022 self.pf.forwards.insert(
1023 fwd_id,
1024 PortForwardState {
1025 listener_handle: None,
1026 channels: HashMap::new(),
1027 stop_handle: None,
1028 target_port,
1029 },
1030 );
1031 }
1032 }
1033 Some(PortForwardEvent::Accepted { forward_id, channel_id, writer_tx }) => {
1034 if let Some(fwd) = self.pf.forwards.get_mut(&forward_id) {
1035 fwd.channels.insert(channel_id, writer_tx.clone());
1036 self.pf.channels.insert(channel_id, (forward_id, writer_tx));
1037 let _ = framed
1038 .send(Frame::PortForwardOpen {
1039 forward_id,
1040 channel_id,
1041 target_port: fwd.target_port,
1042 })
1043 .await;
1044 }
1045 }
1046 Some(PortForwardEvent::Connected { forward_id, channel_id, stream }) => {
1047 if self.pf.forwards.contains_key(&forward_id) {
1048 let (read_half, write_half) = stream.into_split();
1049 let data_tx = self.pf_event_tx.clone();
1050 let close_tx = self.pf_event_tx.clone();
1051 let writer_tx = crate::spawn_channel_relay(
1052 channel_id,
1053 read_half,
1054 write_half,
1055 move |id, data| {
1056 data_tx.send(PortForwardEvent::Data { channel_id: id, data }).is_ok()
1057 },
1058 move |id| {
1059 let _ = close_tx.send(PortForwardEvent::Closed { channel_id: id });
1060 },
1061 );
1062 self.pf.channels.insert(channel_id, (forward_id, writer_tx.clone()));
1063 if let Some(fwd) = self.pf.forwards.get_mut(&forward_id) {
1064 fwd.channels.insert(channel_id, writer_tx);
1065 }
1066 }
1067 }
1068 Some(PortForwardEvent::ConnectFailed { channel_id }) => {
1069 let _ = framed.send(Frame::PortForwardClose { channel_id }).await;
1070 }
1071 Some(PortForwardEvent::Data { channel_id, data }) => {
1072 if self.pf.channels.contains_key(&channel_id) {
1073 let _ = framed.send(Frame::PortForwardData { channel_id, data }).await;
1074 }
1075 }
1076 Some(PortForwardEvent::Closed { channel_id }) => {
1077 if let Some((forward_id, _)) = self.pf.channels.remove(&channel_id) {
1078 let _ = framed.send(Frame::PortForwardClose { channel_id }).await;
1079 if let Some(fwd) = self.pf.forwards.get_mut(&forward_id) {
1080 fwd.channels.remove(&channel_id);
1081 }
1082 }
1083 }
1084 Some(PortForwardEvent::Stopped { forward_id }) => {
1085 debug!(forward_id, "port forward stopped (svc socket dropped)");
1086 if let Some(fwd) = self.pf.forwards.remove(&forward_id) {
1087 if let Some(h) = fwd.listener_handle {
1088 h.abort();
1089 }
1090 if let Some(h) = fwd.stop_handle {
1091 h.abort();
1092 }
1093 for ch_id in fwd.channels.keys() {
1094 self.pf.channels.remove(ch_id);
1095 }
1096 let _ = framed.send(Frame::PortForwardStop { forward_id }).await;
1097 }
1098 }
1099 None => {}
1100 }
1101 }
1102}
1103
1104#[allow(clippy::too_many_arguments)]
1105pub async fn run(
1106 mut client_rx: mpsc::UnboundedReceiver<ClientConn>,
1107 metadata_slot: Arc<OnceLock<SessionMetadata>>,
1108 agent_socket_path: PathBuf,
1109 svc_socket_path: PathBuf,
1110 session_id: u32,
1111 session_name: Option<String>,
1112 command: Option<String>,
1113 ring_buffer_cap: usize,
1114 oauth_tunnel_idle_timeout: u64,
1115) -> anyhow::Result<()> {
1116 let pty = openpty(None, None)?;
1118 let master: OwnedFd = pty.master;
1119 let slave: OwnedFd = pty.slave;
1120
1121 let pty_path =
1123 nix::unistd::ttyname(&slave).map(|p| p.display().to_string()).unwrap_or_default();
1124
1125 let slave_fd = slave.as_raw_fd();
1127 let stdin_fd = crate::security::checked_dup(slave_fd)?;
1128 let stdout_fd = crate::security::checked_dup(slave_fd)?;
1129 let stderr_fd = crate::security::checked_dup(slave_fd)?;
1130 let raw_stdin = stdin_fd.as_raw_fd();
1131 drop(slave);
1132
1133 let flags = nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_GETFL)?;
1135 let mut oflags = nix::fcntl::OFlag::from_bits_truncate(flags);
1136 oflags |= nix::fcntl::OFlag::O_NONBLOCK;
1137 nix::fcntl::fcntl(&master, nix::fcntl::FcntlArg::F_SETFL(oflags))?;
1138
1139 let async_master = AsyncFd::new(master)?;
1140 let mut buf = vec![0u8; 4096];
1141 let mut ring_buf: VecDeque<Bytes> = VecDeque::new();
1142 let mut ring_buf_size: usize = 0;
1143 let mut ring_buf_dropped: usize = 0;
1144
1145 let mut agent = AgentForwardState::new(agent_socket_path);
1147 let (agent_event_tx, mut agent_event_rx) = mpsc::unbounded_channel::<AgentEvent>();
1148
1149 let mut open_forward_enabled = false;
1151
1152 let (tunnel_event_tx, mut tunnel_event_rx) = mpsc::unbounded_channel::<TunnelEvent>();
1154 let mut tunnel = TunnelRelayState::new(Duration::from_secs(oauth_tunnel_idle_timeout));
1155
1156 let (tail_tx, _) = broadcast::channel::<TailEvent>(256);
1158
1159 let (open_event_tx, mut open_event_rx) = mpsc::unbounded_channel::<OpenEvent>();
1161 let (send_event_tx, mut send_event_rx) = mpsc::unbounded_channel::<SendEvent>();
1162 let (send_notify_tx, mut send_notify_rx) = mpsc::unbounded_channel::<Frame>();
1163 let mut transfer_state = TransferState::Idle;
1164 let mut svc_acceptor: Option<tokio::task::JoinHandle<()>> = None;
1165
1166 let (pf_event_tx, mut pf_event_rx) = mpsc::unbounded_channel::<PortForwardEvent>();
1168 let mut pf = PortForwardTable::new();
1169
1170 if let Some(listener) = bind_agent_listener(&svc_socket_path) {
1172 svc_acceptor = Some(spawn_svc_acceptor(
1173 listener,
1174 open_event_tx.clone(),
1175 send_event_tx.clone(),
1176 pf_event_tx.clone(),
1177 ));
1178 }
1179
1180 let mut framed = loop {
1184 tokio::select! {
1185 client = client_rx.recv() => match client {
1186 Some(ClientConn::Active(f)) => {
1187 info!("first client connected via channel");
1188 break f;
1189 }
1190 Some(ClientConn::Tail(f)) => {
1191 info!("tail client connected before shell spawn");
1192 spawn_tail(f, &ring_buf, &tail_tx);
1193 continue;
1194 }
1195 Some(ClientConn::Send(stream)) => {
1196 handle_send_stream(stream, &send_event_tx);
1197 continue;
1198 }
1199 None => {
1200 info!("client channel closed before first client");
1201 cleanup_socket(&agent.socket_path);
1202 cleanup_socket(&svc_socket_path);
1203 return Ok(());
1204 }
1205 },
1206 event = send_event_rx.recv() => {
1207 if let Some(event) = event {
1208 handle_send_event(event, &mut transfer_state, &send_notify_tx);
1209 }
1210 continue;
1211 }
1212 }
1213 };
1214
1215 let env_vars =
1217 match tokio::time::timeout(std::time::Duration::from_secs(2), framed.next()).await {
1218 Ok(Some(Ok(Frame::Env { vars }))) => {
1219 debug!(count = vars.len(), "received env vars from client");
1220 vars
1221 }
1222 _ => Vec::new(),
1223 };
1224
1225 let shell = std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string());
1227 let home = std::env::var("HOME").ok();
1228 let mut cmd = Command::new(&shell);
1229 if let Some(ref user_cmd) = command {
1230 cmd.arg("-c").arg(user_cmd);
1231 } else {
1232 cmd.arg("-l");
1233 }
1234 if let Some(ref dir) = home {
1235 cmd.current_dir(dir);
1236 }
1237 const ALLOWED_ENV_KEYS: &[&str] = &["TERM", "LANG", "COLORTERM"];
1238 for (k, v) in &env_vars {
1239 if ALLOWED_ENV_KEYS.contains(&k.as_str()) {
1240 cmd.env(k, v);
1241 } else if k == "BROWSER" {
1242 let exe = std::env::current_exe()
1244 .ok()
1245 .and_then(|p| p.to_str().map(String::from))
1246 .unwrap_or_else(|| "gritty".into());
1247 cmd.env("BROWSER", format!("{exe} open"));
1248 } else {
1249 warn!(key = k, "ignoring disallowed env var from client");
1250 }
1251 }
1252 cmd.env("SSH_AUTH_SOCK", &agent.socket_path);
1254 cmd.env("GRITTY_SOCK", &svc_socket_path);
1256 cmd.env("GRITTY_SESSION", session_id.to_string());
1258 if let Some(ref name) = session_name {
1259 cmd.env("GRITTY_SESSION_NAME", name);
1260 }
1261 let mut managed = ManagedChild::new(unsafe {
1262 cmd.pre_exec(move || {
1263 nix::unistd::setsid().map_err(io::Error::other)?;
1264 libc::ioctl(raw_stdin, libc::TIOCSCTTY as libc::c_ulong, 0);
1265 Ok(())
1266 })
1267 .stdin(Stdio::from(stdin_fd))
1268 .stdout(Stdio::from(stdout_fd))
1269 .stderr(Stdio::from(stderr_fd))
1270 .spawn()?
1271 });
1272
1273 let shell_pid = managed.child.id().unwrap_or(0);
1274 let created_at = std::time::SystemTime::now()
1275 .duration_since(std::time::UNIX_EPOCH)
1276 .unwrap_or_default()
1277 .as_secs();
1278
1279 let _ = metadata_slot.set(SessionMetadata {
1280 pty_path,
1281 shell_pid,
1282 created_at,
1283 attached: AtomicBool::new(false),
1284 last_heartbeat: AtomicU64::new(0),
1285 });
1286
1287 metadata_slot.get().unwrap().attached.store(true, Ordering::Relaxed);
1289
1290 let mut first_client = true;
1293 loop {
1294 if !first_client {
1295 let got_client = 'drain: loop {
1296 tokio::select! {
1297 client = client_rx.recv() => {
1298 match client {
1299 Some(ClientConn::Active(f)) => {
1300 info!("client connected via channel");
1301 framed = f;
1302 break 'drain true;
1303 }
1304 Some(ClientConn::Tail(f)) => {
1305 info!("tail client connected while disconnected");
1306 spawn_tail(f, &ring_buf, &tail_tx);
1307 continue;
1308 }
1309 Some(ClientConn::Send(stream)) => {
1310 handle_send_stream(stream, &send_event_tx);
1311 continue;
1312 }
1313 None => {
1314 info!("client channel closed");
1315 break 'drain false;
1316 }
1317 }
1318 }
1319 status = managed.child.wait() => {
1320 let code = status?.code().unwrap_or(1);
1321 info!(code, "shell exited while awaiting client");
1322 break 'drain false;
1323 }
1324 ready = async_master.readable() => {
1325 let mut guard = ready?;
1326 match guard.try_io(|inner| {
1327 nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
1328 }) {
1329 Ok(Ok(0)) => {
1330 debug!("pty EOF while disconnected");
1331 break 'drain false;
1332 }
1333 Ok(Ok(n)) => {
1334 let chunk = Bytes::copy_from_slice(&buf[..n]);
1335 if tail_tx.receiver_count() > 0 {
1336 let _ = tail_tx.send(TailEvent::Data(chunk.clone()));
1337 }
1338 ring_buf_size += chunk.len();
1339 ring_buf.push_back(chunk);
1340 while ring_buf_size > ring_buffer_cap {
1341 if let Some(old) = ring_buf.pop_front() {
1342 ring_buf_size -= old.len();
1343 ring_buf_dropped += old.len();
1344 }
1345 }
1346 }
1347 Ok(Err(e)) => {
1348 if e.raw_os_error() == Some(libc::EIO) {
1349 debug!("pty EIO while disconnected");
1350 break 'drain false;
1351 }
1352 return Err(e.into());
1353 }
1354 Err(_would_block) => continue,
1355 }
1356 }
1357 event = send_event_rx.recv() => {
1358 if let Some(event) = event {
1359 handle_send_event(event, &mut transfer_state, &send_notify_tx);
1360 }
1361 continue;
1362 }
1363 }
1364 };
1365 if !got_client {
1366 break;
1367 }
1368
1369 if let Some(meta) = metadata_slot.get() {
1370 meta.attached.store(true, Ordering::Relaxed);
1371 }
1372 }
1373 first_client = false;
1374
1375 if !ring_buf.is_empty() {
1377 debug!(
1378 chunks = ring_buf.len(),
1379 bytes = ring_buf_size,
1380 dropped = ring_buf_dropped,
1381 "flushing ring buffer"
1382 );
1383 if ring_buf_dropped > 0 {
1384 let msg = format!(
1385 "\r\n\x1b[2;33m[gritty: {} bytes of output dropped]\x1b[0m\r\n",
1386 ring_buf_dropped
1387 );
1388 framed.send(Frame::Data(Bytes::from(msg))).await?;
1389 ring_buf_dropped = 0;
1390 }
1391 while let Some(chunk) = ring_buf.pop_front() {
1392 framed.send(Frame::Data(chunk)).await?;
1393 }
1394 ring_buf_size = 0;
1395 }
1396
1397 let exit = {
1401 let mut relay = ServerRelay {
1402 async_master: &async_master,
1403 agent: &mut agent,
1404 tunnel: &mut tunnel,
1405 pf: &mut pf,
1406 transfer_state: &mut transfer_state,
1407 open_forward_enabled: &mut open_forward_enabled,
1408 tail_tx: &tail_tx,
1409 metadata_slot: &metadata_slot,
1410 agent_event_tx: &agent_event_tx,
1411 tunnel_event_tx: &tunnel_event_tx,
1412 pf_event_tx: &pf_event_tx,
1413 send_notify_tx: &send_notify_tx,
1414 };
1415 loop {
1416 tokio::select! {
1417 frame = framed.next() => {
1418 if let ControlFlow::Break(exit) = relay.handle_client_frame(&mut framed, frame).await? {
1419 break exit;
1420 }
1421 }
1422
1423 ready = relay.async_master.readable() => {
1424 let mut guard = ready?;
1425 match guard.try_io(|inner| {
1426 nix::unistd::read(inner, &mut buf).map_err(io::Error::from)
1427 }) {
1428 Ok(Ok(0)) => {
1429 debug!("pty EOF");
1430 break RelayExit::ShellExited(0);
1431 }
1432 Ok(Ok(n)) => {
1433 debug!(len = n, "pty -> socket");
1434 let chunk = Bytes::copy_from_slice(&buf[..n]);
1435 if relay.tail_tx.receiver_count() > 0 {
1436 let _ = relay.tail_tx.send(TailEvent::Data(chunk.clone()));
1437 }
1438 framed.send(Frame::Data(chunk)).await?;
1439 }
1440 Ok(Err(e)) => {
1441 if e.raw_os_error() == Some(libc::EIO) {
1442 debug!("pty EIO (shell exited)");
1443 break RelayExit::ShellExited(0);
1444 }
1445 return Err(e.into());
1446 }
1447 Err(_would_block) => continue,
1448 }
1449 }
1450
1451 new_client = client_rx.recv() => {
1452 match new_client {
1453 Some(ClientConn::Active(new_framed)) => {
1454 info!("new client via channel, detaching old client");
1455 let _ = framed.send(Frame::Detached).await;
1456 relay.agent.teardown();
1457 relay.tunnel.teardown();
1458 relay.pf.teardown();
1459 *relay.open_forward_enabled = false;
1460 let was_attached = relay.metadata_slot.get()
1462 .map(|m| m.attached.load(Ordering::Relaxed))
1463 .unwrap_or(false);
1464 framed = new_framed;
1465 if was_attached {
1466 let hb_age = relay.metadata_slot.get()
1467 .and_then(|m| {
1468 let hb = m.last_heartbeat.load(Ordering::Relaxed);
1469 if hb == 0 { return None; }
1470 let now = std::time::SystemTime::now()
1471 .duration_since(std::time::UNIX_EPOCH)
1472 .unwrap_or_default()
1473 .as_secs();
1474 Some(now.saturating_sub(hb))
1475 });
1476 let hb_str = match hb_age {
1477 Some(s) => format!("{s}s ago"),
1478 None => "n/a".to_string(),
1479 };
1480 let msg = format!(
1481 "\r\n\x1b[2;33m[gritty: took over session (was active, heartbeat {hb_str})]\x1b[0m\r\n"
1482 );
1483 let _ = framed.send(Frame::Data(Bytes::from(msg))).await;
1484 }
1485 }
1486 Some(ClientConn::Tail(f)) => {
1487 info!("tail client connected while active");
1488 spawn_tail(f, &ring_buf, relay.tail_tx);
1489 }
1490 Some(ClientConn::Send(stream)) => {
1491 handle_send_stream(stream, &send_event_tx);
1492 }
1493 None => {}
1494 }
1495 }
1496
1497 event = agent_event_rx.recv() => {
1498 relay.handle_agent_event(&mut framed, event).await;
1499 }
1500
1501 event = open_event_rx.recv() => {
1502 relay.handle_open_event(&mut framed, event).await;
1503 }
1504
1505 event = tunnel_event_rx.recv() => {
1506 relay.handle_tunnel_event(&mut framed, event).await;
1507 }
1508
1509 _ = async {
1510 match relay.tunnel.idle_deadline {
1511 Some(deadline) => tokio::time::sleep_until(deadline).await,
1512 None => std::future::pending().await,
1513 }
1514 } => {
1515 if relay.tunnel.channels.is_empty() {
1516 debug!("tunnel idle timeout, tearing down");
1517 relay.tunnel.teardown();
1518 }
1519 }
1520
1521 event = send_event_rx.recv() => {
1522 if let Some(event) = event {
1523 handle_send_event(event, relay.transfer_state, relay.send_notify_tx);
1524 }
1525 }
1526
1527 notification = send_notify_rx.recv() => {
1528 relay.handle_send_notification(&mut framed, notification).await;
1529 }
1530
1531 event = pf_event_rx.recv() => {
1532 relay.handle_pf_event(&mut framed, event).await;
1533 }
1534
1535 status = managed.child.wait() => {
1536 let code = status?.code().unwrap_or(1);
1537 info!(code, "shell exited");
1538 let _ = relay.tail_tx.send(TailEvent::Exit { code });
1539 break RelayExit::ShellExited(code);
1540 }
1541 }
1542 }
1543 };
1544
1545 match exit {
1546 RelayExit::ClientGone => {
1547 if let Some(meta) = metadata_slot.get() {
1548 meta.attached.store(false, Ordering::Relaxed);
1549 }
1550 agent.teardown();
1551 tunnel.teardown();
1552 pf.teardown();
1553 open_forward_enabled = false;
1554 info!("client disconnected, waiting for reconnect");
1555 continue;
1556 }
1557 RelayExit::ShellExited(mut code) => {
1558 if let Ok(Ok(status)) = tokio::time::timeout(
1562 std::time::Duration::from_millis(500),
1563 managed.child.wait(),
1564 )
1565 .await
1566 {
1567 code = status.code().unwrap_or(code);
1568 }
1569 let _ = tail_tx.send(TailEvent::Exit { code });
1570 let _ = framed.send(Frame::Exit { code }).await;
1571 info!(code, "session ended");
1572 break;
1573 }
1574 }
1575 }
1576
1577 cleanup_socket(&agent.socket_path);
1578 cleanup_socket(&svc_socket_path);
1579 if let Some(handle) = svc_acceptor.take() {
1580 handle.abort();
1581 }
1582 Ok(())
1583}
1584
1585fn handle_send_stream(mut stream: UnixStream, send_event_tx: &mpsc::UnboundedSender<SendEvent>) {
1588 let etx = send_event_tx.clone();
1589 tokio::spawn(async move {
1590 let mut disc = [0u8; 1];
1591 if stream.read_exact(&mut disc).await.is_err() {
1592 return;
1593 }
1594 match crate::protocol::SvcRequest::from_byte(disc[0]) {
1595 Some(crate::protocol::SvcRequest::Send) => {
1596 match parse_sender_manifest(&mut stream).await {
1597 Ok(manifest) => {
1598 let _ = etx.send(SendEvent::SenderArrived { stream, manifest });
1599 }
1600 Err(e) => debug!("send stream: bad sender manifest: {e}"),
1601 }
1602 }
1603 Some(crate::protocol::SvcRequest::Receive) => {
1604 match parse_receiver_dest(&mut stream).await {
1605 Ok(_) => {
1606 let _ = etx.send(SendEvent::ReceiverArrived { stream });
1607 }
1608 Err(e) => debug!("send stream: bad receiver dest: {e}"),
1609 }
1610 }
1611 _ => {}
1612 }
1613 });
1614}
1615
1616fn stream_is_dead(stream: &UnixStream) -> bool {
1619 let mut probe = [0u8; 1];
1620 match stream.try_read(&mut probe) {
1621 Ok(0) => true, Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => false, Err(_) => true, Ok(_) => false, }
1626}
1627
1628fn handle_send_event(
1630 event: SendEvent,
1631 state: &mut TransferState,
1632 notify_tx: &mpsc::UnboundedSender<Frame>,
1633) {
1634 match event {
1635 SendEvent::SenderArrived { stream, manifest } => {
1636 let old = std::mem::replace(state, TransferState::Idle);
1637 match old {
1638 TransferState::WaitingForSender { receiver_stream }
1639 if !stream_is_dead(&receiver_stream) =>
1640 {
1641 info!(
1642 files = manifest.files.len(),
1643 bytes = manifest.total_bytes(),
1644 "transfer: sender+receiver paired"
1645 );
1646 let handle =
1647 spawn_transfer_relay(stream, receiver_stream, manifest, notify_tx.clone());
1648 *state = TransferState::Active { relay_handle: handle };
1649 }
1650 _ => {
1651 if let TransferState::Active { relay_handle } = old {
1652 let _ = notify_tx.send(Frame::SendCancel {
1653 reason: "superseded by new sender".to_string(),
1654 });
1655 relay_handle.abort();
1656 }
1657 info!(files = manifest.files.len(), "transfer: sender waiting for receiver");
1658 *state = TransferState::WaitingForReceiver { sender_stream: stream, manifest };
1659 }
1660 }
1661 }
1662 SendEvent::ReceiverArrived { stream } => {
1663 let old = std::mem::replace(state, TransferState::Idle);
1664 match old {
1665 TransferState::WaitingForReceiver { sender_stream, manifest }
1666 if !stream_is_dead(&sender_stream) =>
1667 {
1668 info!(
1669 files = manifest.files.len(),
1670 bytes = manifest.total_bytes(),
1671 "transfer: receiver+sender paired"
1672 );
1673 let handle =
1674 spawn_transfer_relay(sender_stream, stream, manifest, notify_tx.clone());
1675 *state = TransferState::Active { relay_handle: handle };
1676 }
1677 _ => {
1678 if let TransferState::Active { relay_handle } = old {
1679 let _ = notify_tx.send(Frame::SendCancel {
1680 reason: "superseded by new receiver".to_string(),
1681 });
1682 relay_handle.abort();
1683 }
1684 info!("transfer: receiver waiting for sender");
1685 *state = TransferState::WaitingForSender { receiver_stream: stream };
1686 }
1687 }
1688 }
1689 }
1690}
1691
1692fn bind_agent_listener(path: &Path) -> Option<UnixListener> {
1693 match crate::security::bind_unix_listener(path) {
1694 Ok(listener) => {
1695 info!(path = %path.display(), "agent socket listening");
1696 Some(listener)
1697 }
1698 Err(e) => {
1699 warn!("failed to bind agent socket at {}: {e}", path.display());
1700 None
1701 }
1702 }
1703}
1704
1705fn cleanup_socket(path: &Path) {
1706 let _ = std::fs::remove_file(path);
1707}
1708
1709#[cfg(test)]
1710mod tests {
1711 use super::*;
1712
1713 #[test]
1714 fn extract_redirect_port_basic() {
1715 let url = "https://accounts.google.com/o/oauth2/auth?redirect_uri=http://localhost:8080/callback&client_id=xyz";
1716 assert_eq!(extract_redirect_port(url), Some(8080));
1717 }
1718
1719 #[test]
1720 fn extract_redirect_port_127() {
1721 let url = "https://auth.example.com/authorize?redirect_uri=http://127.0.0.1:9090/cb";
1722 assert_eq!(extract_redirect_port(url), Some(9090));
1723 }
1724
1725 #[test]
1726 fn extract_redirect_port_url_encoded() {
1727 let url = "https://auth.example.com/authorize?redirect_uri=http%3A%2F%2Flocalhost%3A3000%2Fcallback";
1728 assert_eq!(extract_redirect_port(url), Some(3000));
1729 }
1730
1731 #[test]
1732 fn extract_redirect_port_no_port() {
1733 let url = "https://auth.example.com/authorize?redirect_uri=http://localhost/callback";
1734 assert_eq!(extract_redirect_port(url), None);
1735 }
1736
1737 #[test]
1738 fn extract_redirect_port_no_redirect_uri() {
1739 let url = "https://example.com/page?foo=bar";
1740 assert_eq!(extract_redirect_port(url), None);
1741 }
1742
1743 #[test]
1744 fn extract_redirect_port_non_localhost() {
1745 let url =
1746 "https://auth.example.com/authorize?redirect_uri=https://example.com:8080/callback";
1747 assert_eq!(extract_redirect_port(url), None);
1748 }
1749
1750 #[test]
1751 fn extract_redirect_port_https_redirect() {
1752 let url = "https://auth.example.com/authorize?redirect_uri=https://localhost:4443/callback";
1753 assert_eq!(extract_redirect_port(url), Some(4443));
1754 }
1755
1756 #[test]
1757 fn extract_redirect_url_variant() {
1758 let url = "https://auth.example.com/authorize?redirect_url=http://localhost:5000/cb";
1759 assert_eq!(extract_redirect_port(url), Some(5000));
1760 }
1761}