#![cfg(target_os = "linux")]
use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpStream, UnixListener};
use tokio::task::JoinHandle;
use tracing::{debug, trace, warn};
pub const STAGE2_UDS_ENV_KEY: &str = "KODA_SANDBOX_STAGE2_UDS";
pub const STAGE2_REWRITE_KEYS_ENV_KEY: &str = "KODA_SANDBOX_STAGE2_REWRITE_KEYS";
pub fn proxy_uds_path(parent_pid: u32, proxy_port: u16) -> PathBuf {
std::env::temp_dir().join(format!("koda-sandbox-proxy-{parent_pid}-{proxy_port}.sock"))
}
pub async fn spawn_uds_bridge(uds_path: PathBuf, tcp_port: u16) -> Result<JoinHandle<()>> {
let _ = tokio::fs::remove_file(&uds_path).await;
let listener = UnixListener::bind(&uds_path)
.with_context(|| format!("bind UDS bridge at {}", uds_path.display()))?;
debug!(
"uds bridge listening on {} → 127.0.0.1:{tcp_port}",
uds_path.display()
);
let _bound_path = uds_path;
let task = tokio::spawn(async move {
loop {
let (uds_stream, _) = match listener.accept().await {
Ok(pair) => pair,
Err(e) => {
trace!("uds bridge accept error (likely shutdown): {e}");
return;
}
};
tokio::spawn(async move {
let tcp_stream = match TcpStream::connect(("127.0.0.1", tcp_port)).await {
Ok(s) => s,
Err(e) => {
trace!("uds bridge: TCP connect to 127.0.0.1:{tcp_port} failed: {e}");
return;
}
};
if let Err(e) = bridge_streams(uds_stream, tcp_stream).await {
trace!("uds bridge: pipe error: {e}");
}
});
}
});
Ok(task)
}
pub fn cleanup_uds_path(path: &Path) {
if let Err(e) = std::fs::remove_file(path)
&& e.kind() != std::io::ErrorKind::NotFound
{
warn!(
"uds bridge cleanup: failed to remove {}: {e}",
path.display()
);
}
}
async fn bridge_streams(
uds_stream: tokio::net::UnixStream,
tcp_stream: TcpStream,
) -> std::io::Result<()> {
let (mut uds_r, mut uds_w) = uds_stream.into_split();
let (mut tcp_r, mut tcp_w) = tcp_stream.into_split();
let uds_to_tcp = async move {
let _ = tokio::io::copy(&mut uds_r, &mut tcp_w).await;
let _ = tcp_w.shutdown().await;
};
let tcp_to_uds = async move {
let _ = tokio::io::copy(&mut tcp_r, &mut uds_w).await;
let _ = uds_w.shutdown().await;
};
tokio::join!(uds_to_tcp, tcp_to_uds);
Ok(())
}
pub fn rewrite_proxy_url_port(url: &str, new_port: u16) -> Option<String> {
let scheme_end = url.find("://")?;
let after_scheme = &url[scheme_end + 3..];
let authority_end = after_scheme.find('/').unwrap_or(after_scheme.len());
let authority = &after_scheme[..authority_end];
let rest = &after_scheme[authority_end..];
let (userinfo, host_port) = match authority.rsplit_once('@') {
Some((u, hp)) => (Some(u), hp),
None => (None, authority),
};
let host = match host_port.rsplit_once(':') {
Some((h, _)) => h,
None => host_port,
};
if host.is_empty() {
return None;
}
let mut out = String::with_capacity(url.len() + 6);
out.push_str(&url[..scheme_end + 3]);
if let Some(u) = userinfo {
out.push_str(u);
out.push('@');
}
out.push_str(host);
out.push(':');
out.push_str(&new_port.to_string());
out.push_str(rest);
Some(out)
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, UnixStream};
#[test]
fn proxy_uds_path_includes_pid_and_port() {
let p = proxy_uds_path(12345, 54321);
let s = p.to_string_lossy();
assert!(s.contains("12345"), "pid missing: {s}");
assert!(s.contains("54321"), "port missing: {s}");
assert!(s.ends_with(".sock"), "extension wrong: {s}");
assert!(p.starts_with(std::env::temp_dir()), "not in tmp: {s}");
}
#[test]
fn proxy_uds_path_distinct_for_distinct_inputs() {
let a = proxy_uds_path(1, 100);
let b = proxy_uds_path(1, 101);
assert_ne!(a, b);
let c = proxy_uds_path(2, 100);
assert_ne!(a, c);
}
#[test]
fn rewrite_proxy_url_port_handles_basic_http() {
let out = rewrite_proxy_url_port("http://127.0.0.1:1234", 9999).unwrap();
assert_eq!(out, "http://127.0.0.1:9999");
}
#[test]
fn rewrite_proxy_url_port_handles_no_existing_port() {
let out = rewrite_proxy_url_port("http://localhost", 9999).unwrap();
assert_eq!(out, "http://localhost:9999");
}
#[test]
fn rewrite_proxy_url_port_preserves_path() {
let out = rewrite_proxy_url_port("http://127.0.0.1:1234/path?q=1", 9999).unwrap();
assert_eq!(out, "http://127.0.0.1:9999/path?q=1");
}
#[test]
fn rewrite_proxy_url_port_preserves_userinfo() {
let out = rewrite_proxy_url_port("http://user:pw@127.0.0.1:1234", 9999).unwrap();
assert_eq!(out, "http://user:pw@127.0.0.1:9999");
}
#[test]
fn rewrite_proxy_url_port_handles_socks_scheme() {
let out = rewrite_proxy_url_port("socks5h://127.0.0.1:1080", 9999).unwrap();
assert_eq!(out, "socks5h://127.0.0.1:9999");
}
#[test]
fn rewrite_proxy_url_port_returns_none_for_garbage() {
assert_eq!(rewrite_proxy_url_port("not a url", 9999), None);
assert_eq!(rewrite_proxy_url_port("http://", 9999), None);
}
#[tokio::test]
async fn uds_bridge_forwards_bytes_to_tcp_listener() {
let tcp = TcpListener::bind("127.0.0.1:0").await.unwrap();
let tcp_port = tcp.local_addr().unwrap().port();
tokio::spawn(async move {
let (mut sock, _) = tcp.accept().await.unwrap();
let mut buf = [0u8; 5];
sock.read_exact(&mut buf).await.unwrap();
sock.write_all(&buf).await.unwrap();
});
let uds_dir = tempfile::tempdir().unwrap();
let uds_path = uds_dir.path().join("bridge.sock");
let bridge = spawn_uds_bridge(uds_path.clone(), tcp_port).await.unwrap();
let mut client = UnixStream::connect(&uds_path).await.unwrap();
client.write_all(b"hello").await.unwrap();
let mut got = [0u8; 5];
tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut got))
.await
.expect("must echo within 2s")
.unwrap();
assert_eq!(&got, b"hello");
bridge.abort();
cleanup_uds_path(&uds_path);
}
#[tokio::test]
async fn uds_bridge_removes_stale_socket_file() {
let uds_dir = tempfile::tempdir().unwrap();
let uds_path = uds_dir.path().join("stale.sock");
std::fs::write(&uds_path, b"leftover").unwrap();
assert!(uds_path.exists());
let bridge = spawn_uds_bridge(uds_path.clone(), 1).await.unwrap();
assert!(uds_path.exists());
bridge.abort();
cleanup_uds_path(&uds_path);
}
}