use std::{
collections::{BTreeMap, BTreeSet},
net::{Ipv4Addr, SocketAddr},
sync::{Arc, Mutex},
};
use netstack::{CreateSocket, netcore::Channel, netsock::TcpStream as OverlayStream};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
sync::{Semaphore, mpsc},
};
use ts_control::{ServeState, ServeTarget, tls::TlsAcceptor};
const MAX_SERVE_CONNS_PER_PORT: usize = 256;
pub struct ServeAccepted {
pub port: u16,
pub stream: Box<dyn AsyncReadWrite>,
}
pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Unpin {}
impl<T: AsyncRead + AsyncWrite + Send + Unpin> AsyncReadWrite for T {}
pub type ServeAcceptedReceiver = mpsc::Receiver<ServeAccepted>;
pub struct ResolvedPort {
pub target: ServeTarget,
pub acceptor: Option<TlsAcceptor>,
}
struct Inner {
state: ServeState,
ports: BTreeMap<u16, tokio::task::AbortHandle>,
}
impl Drop for Inner {
fn drop(&mut self) {
for h in self.ports.values() {
h.abort();
}
}
}
pub struct ServeManager {
inner: Arc<Mutex<Inner>>,
channel: Channel,
self_ipv4: Ipv4Addr,
}
impl ServeManager {
pub fn new(channel: Channel, self_ipv4: Ipv4Addr) -> Self {
Self {
inner: Arc::new(Mutex::new(Inner {
state: ServeState::default(),
ports: BTreeMap::new(),
})),
channel,
self_ipv4,
}
}
pub fn get(&self) -> ServeState {
self.inner
.lock()
.unwrap_or_else(|e| e.into_inner())
.state
.clone()
}
pub fn set(
&self,
state: ServeState,
resolved: BTreeMap<u16, ResolvedPort>,
) -> ServeAcceptedReceiver {
let (accept_tx, accept_rx) = mpsc::channel::<ServeAccepted>(MAX_SERVE_CONNS_PER_PORT);
let mut new_ports: BTreeMap<u16, tokio::task::AbortHandle> = BTreeMap::new();
for (port, rp) in resolved {
let channel = self.channel.clone();
let self_ipv4 = self.self_ipv4;
let accept_tx = accept_tx.clone();
let handle = tokio::spawn(async move {
if let Err(e) = run_port(channel, self_ipv4, port, rp, accept_tx).await {
tracing::warn!(%port, error = %e, "serve listener exited");
}
})
.abort_handle();
new_ports.insert(port, handle);
}
let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
inner.state = state;
let old = std::mem::replace(&mut inner.ports, new_ports);
drop(inner);
for h in old.values() {
h.abort();
}
accept_rx
}
}
#[cfg_attr(not(test), allow(dead_code))]
fn pure_reconcile(
current: &BTreeMap<u16, ServeTarget>,
next: &BTreeMap<u16, ServeTarget>,
) -> (BTreeSet<u16>, BTreeSet<u16>) {
let mut to_add = BTreeSet::new();
let mut to_remove = BTreeSet::new();
for (port, target) in next {
match current.get(port) {
Some(cur) if cur == target => {}
_ => {
to_add.insert(*port);
}
}
}
for port in current.keys() {
match next.get(port) {
Some(target) if current.get(port) == Some(target) => {}
_ => {
to_remove.insert(*port);
}
}
}
(to_add, to_remove)
}
async fn run_port(
channel: Channel,
self_ipv4: Ipv4Addr,
port: u16,
rp: ResolvedPort,
accept_tx: mpsc::Sender<ServeAccepted>,
) -> Result<(), netstack::netcore::Error> {
let listen_addr = SocketAddr::new(self_ipv4.into(), port);
let listener = channel.tcp_listen(listen_addr).await?;
tracing::debug!(%port, "serve listener accepting");
let rp = Arc::new(rp);
let inflight = Arc::new(Semaphore::new(MAX_SERVE_CONNS_PER_PORT));
loop {
let Ok(permit) = inflight.clone().acquire_owned().await else {
return Ok(());
};
let overlay = listener.accept().await?;
let rp = rp.clone();
let accept_tx = accept_tx.clone();
tokio::spawn(async move {
let _permit = permit; dispatch_conn(port, overlay, rp, accept_tx).await;
});
}
}
async fn dispatch_conn(
port: u16,
overlay: OverlayStream,
rp: Arc<ResolvedPort>,
accept_tx: mpsc::Sender<ServeAccepted>,
) {
match &rp.target {
ServeTarget::TcpForward { to } => {
forward_to_backend(port, overlay, to).await;
}
_ => {
let Some(acceptor) = rp.acceptor.as_ref() else {
tracing::warn!(%port, "serve: missing TLS acceptor for TLS port; dropping conn");
return;
};
let tls = match acceptor.accept(overlay).await {
Ok(s) => s,
Err(e) => {
tracing::debug!(%port, error = %e, "serve: TLS handshake failed; dropping conn");
return;
}
};
match &rp.target {
ServeTarget::Accept => {
let accepted = ServeAccepted {
port,
stream: Box::new(tls),
};
if accept_tx.send(accepted).await.is_err() {
tracing::debug!(%port, "serve: accept receiver dropped; closing conn");
}
}
ServeTarget::Proxy { to } => {
proxy_to_backend(port, tls, to).await;
}
ServeTarget::Text { body } => {
write_text(port, tls, body).await;
}
ServeTarget::Redirect { to, status } => {
serve_redirect(port, tls, to, *status).await;
}
ServeTarget::Path { handlers } => {
serve_path(port, tls, handlers).await;
}
other => {
debug_assert!(
!other.terminates_tls(),
"TLS-terminating ServeTarget reached fall-through arm"
);
tracing::warn!(%port, "serve: unhandled ServeTarget on TLS port; dropping conn");
}
}
}
}
}
async fn proxy_to_backend<S>(port: u16, tls: S, to: &str)
where
S: AsyncRead + AsyncWrite + Unpin,
{
proxy_to_backend_with_prefix(port, tls, to, &[]).await;
}
async fn proxy_to_backend_with_prefix<S>(port: u16, mut tls: S, to: &str, prefix: &[u8])
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mut backend = match tokio::net::TcpStream::connect(to).await {
Ok(b) => b,
Err(e) => {
tracing::debug!(%port, %to, error = %e, "serve proxy: backend dial failed; dropping conn");
return;
}
};
if !prefix.is_empty()
&& let Err(e) = backend.write_all(prefix).await
{
tracing::debug!(%port, %to, error = %e, "serve proxy: prefix replay failed; dropping conn");
return;
}
if let Err(e) = tokio::io::copy_bidirectional(&mut tls, &mut backend).await {
tracing::debug!(%port, %to, error = %e, "serve proxy: splice ended");
}
}
async fn forward_to_backend(port: u16, mut overlay: OverlayStream, to: &str) {
let mut backend = match tokio::net::TcpStream::connect(to).await {
Ok(b) => b,
Err(e) => {
tracing::debug!(%port, %to, error = %e, "serve forward: backend dial failed; dropping conn");
return;
}
};
if let Err(e) = tokio::io::copy_bidirectional(&mut overlay, &mut backend).await {
tracing::debug!(%port, %to, error = %e, "serve forward: splice ended");
}
}
async fn write_text<S>(port: u16, mut tls: S, body: &str)
where
S: AsyncRead + AsyncWrite + Unpin,
{
if let Err(e) = tls.write_all(body.as_bytes()).await {
tracing::debug!(%port, error = %e, "serve text: write failed");
return;
}
if let Err(e) = tls.flush().await {
tracing::debug!(%port, error = %e, "serve text: flush failed");
}
drop(tls.shutdown().await);
}
const MAX_HTTP_HEAD: usize = 8 * 1024;
async fn read_http_head<S>(stream: &mut S) -> Option<(Vec<u8>, usize)>
where
S: AsyncRead + AsyncWrite + Unpin,
{
use tokio::io::AsyncReadExt;
let mut buf = Vec::with_capacity(1024);
let mut tmp = [0u8; 1024];
loop {
if let Some(end) = crate::peerapi_doh::find_header_end(&buf) {
return Some((buf, end));
}
match stream.read(&mut tmp).await {
Ok(0) => return None,
Ok(n) => {
buf.extend_from_slice(&tmp[..n]);
if crate::peerapi_doh::find_header_end(&buf).is_none() && buf.len() >= MAX_HTTP_HEAD
{
return None;
}
}
Err(_) => return None,
}
}
}
fn request_path(buf: &[u8]) -> Option<String> {
let mut headers = [httparse::EMPTY_HEADER; 32];
let mut req = httparse::Request::new(&mut headers);
match req.parse(buf) {
Ok(_) => {}
Err(_) => return None,
}
let path = req.path?;
let raw = path.split_once('?').map(|(p, _)| p).unwrap_or(path);
Some(raw.to_string())
}
fn redirect_reason(status: u16) -> &'static str {
match status {
301 => "Moved Permanently",
302 => "Found",
303 => "See Other",
307 => "Temporary Redirect",
308 => "Permanent Redirect",
_ => "Redirect",
}
}
async fn serve_redirect<S>(port: u16, mut tls: S, to: &str, status: u16)
where
S: AsyncRead + AsyncWrite + Unpin,
{
let head = format!(
"HTTP/1.1 {status} {reason}\r\nLocation: {to}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n",
reason = redirect_reason(status),
);
if let Err(e) = tls.write_all(head.as_bytes()).await {
tracing::debug!(%port, error = %e, "serve redirect: write failed");
return;
}
if let Err(e) = tls.flush().await {
tracing::debug!(%port, error = %e, "serve redirect: flush failed");
}
drop(tls.shutdown().await);
}
async fn write_http_status<S>(port: u16, mut tls: S, status: &str)
where
S: AsyncRead + AsyncWrite + Unpin,
{
let head = format!("HTTP/1.1 {status}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n");
if let Err(e) = tls.write_all(head.as_bytes()).await {
tracing::debug!(%port, error = %e, "serve path: status write failed");
return;
}
drop(tls.flush().await);
drop(tls.shutdown().await);
}
async fn serve_path<S>(port: u16, mut tls: S, handlers: &BTreeMap<String, ServeTarget>)
where
S: AsyncRead + AsyncWrite + Unpin,
{
let Some((buf, _end)) = read_http_head(&mut tls).await else {
tracing::debug!(%port, "serve path: incomplete/oversized request head; dropping conn");
return;
};
let Some(path) = request_path(&buf) else {
write_http_status(port, tls, "400 Bad Request").await;
return;
};
let matched = handlers
.iter()
.filter(|(prefix, _)| path.starts_with(prefix.as_str()))
.max_by_key(|(prefix, _)| prefix.len())
.map(|(_, target)| target);
let Some(target) = matched else {
write_http_status(port, tls, "404 Not Found").await;
return;
};
match target {
ServeTarget::Proxy { to } => proxy_to_backend_with_prefix(port, tls, to, &buf).await,
ServeTarget::Text { body } => write_text(port, tls, body).await,
ServeTarget::Redirect { to, status } => serve_redirect(port, tls, to, *status).await,
_ => {
tracing::warn!(%port, "serve path: unsupported nested target; dropping conn");
write_http_status(port, tls, "404 Not Found").await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn proxy(to: &str) -> ServeTarget {
ServeTarget::Proxy { to: to.into() }
}
#[test]
fn cap_is_bounded() {
assert_eq!(MAX_SERVE_CONNS_PER_PORT, 256);
}
#[test]
fn reconcile_adds_new_ports() {
let current = BTreeMap::new();
let mut next = BTreeMap::new();
next.insert(443u16, ServeTarget::Accept);
next.insert(8443u16, proxy("127.0.0.1:8080"));
let (add, remove) = pure_reconcile(¤t, &next);
assert_eq!(add, BTreeSet::from([443, 8443]));
assert!(remove.is_empty());
}
#[test]
fn reconcile_removes_dropped_ports() {
let mut current = BTreeMap::new();
current.insert(443u16, ServeTarget::Accept);
current.insert(8443u16, proxy("127.0.0.1:8080"));
let mut next = BTreeMap::new();
next.insert(443u16, ServeTarget::Accept);
let (add, remove) = pure_reconcile(¤t, &next);
assert!(add.is_empty());
assert_eq!(remove, BTreeSet::from([8443]));
}
#[test]
fn reconcile_changed_port_is_remove_and_add() {
let mut current = BTreeMap::new();
current.insert(443u16, proxy("127.0.0.1:8080"));
let mut next = BTreeMap::new();
next.insert(443u16, proxy("127.0.0.1:9090"));
let (add, remove) = pure_reconcile(¤t, &next);
assert_eq!(add, BTreeSet::from([443]));
assert_eq!(remove, BTreeSet::from([443]));
}
#[test]
fn reconcile_unchanged_port_is_noop() {
let mut current = BTreeMap::new();
current.insert(443u16, ServeTarget::Accept);
let next = current.clone();
let (add, remove) = pure_reconcile(¤t, &next);
assert!(add.is_empty());
assert!(remove.is_empty());
}
#[test]
fn terminates_tls_matches_dispatch_arm() {
assert!(ServeTarget::Accept.terminates_tls());
assert!(proxy("127.0.0.1:8080").terminates_tls());
assert!(ServeTarget::Text { body: "ok".into() }.terminates_tls());
assert!(
ServeTarget::Redirect {
to: "/elsewhere".into(),
status: 302,
}
.terminates_tls()
);
let mut handlers = BTreeMap::new();
handlers.insert("/".to_string(), proxy("127.0.0.1:8080"));
assert!(ServeTarget::Path { handlers }.terminates_tls());
assert!(
!ServeTarget::TcpForward {
to: "127.0.0.1:5000".into()
}
.terminates_tls()
);
}
#[test]
fn find_header_end_shared_with_peerapi_doh() {
assert_eq!(
crate::peerapi_doh::find_header_end(b"GET / HTTP/1.1\r\n\r\n"),
Some(18)
);
assert_eq!(
crate::peerapi_doh::find_header_end(b"GET / HTTP/1.1\r\n"),
None
);
}
#[test]
fn request_path_strips_query() {
assert_eq!(
request_path(b"GET /api/v1?x=1 HTTP/1.1\r\nHost: h\r\n\r\n").as_deref(),
Some("/api/v1")
);
assert_eq!(
request_path(b"GET / HTTP/1.1\r\n\r\n").as_deref(),
Some("/")
);
assert_eq!(request_path(b"not a request").as_deref(), None);
}
#[test]
fn request_path_none_on_malformed_request_line() {
assert_eq!(request_path(b"GARBAGE\r\n\r\n").as_deref(), None);
assert_eq!(request_path(b"").as_deref(), None);
}
#[test]
fn longest_prefix_wins() {
let mut handlers: BTreeMap<String, ServeTarget> = BTreeMap::new();
handlers.insert("/".to_string(), proxy("127.0.0.1:1"));
handlers.insert("/api".to_string(), proxy("127.0.0.1:2"));
handlers.insert("/api/v2".to_string(), proxy("127.0.0.1:3"));
let pick = |path: &str| -> Option<&ServeTarget> {
handlers
.iter()
.filter(|(prefix, _)| path.starts_with(prefix.as_str()))
.max_by_key(|(prefix, _)| prefix.len())
.map(|(_, target)| target)
};
assert_eq!(pick("/api/v2/x"), Some(&proxy("127.0.0.1:3")));
assert_eq!(pick("/api/v1"), Some(&proxy("127.0.0.1:2")));
assert_eq!(pick("/other"), Some(&proxy("127.0.0.1:1")));
}
#[test]
fn redirect_reason_known_statuses() {
assert_eq!(redirect_reason(301), "Moved Permanently");
assert_eq!(redirect_reason(308), "Permanent Redirect");
assert_eq!(redirect_reason(399), "Redirect");
}
use tokio::io::{AsyncReadExt, AsyncWriteExt};
async fn drain_to_string(mut client: tokio::io::DuplexStream) -> String {
let mut out = Vec::new();
drop(client.read_to_end(&mut out).await);
String::from_utf8(out).expect("server emitted valid utf8")
}
#[tokio::test]
async fn serve_redirect_emits_exact_response() {
let (client, server) = tokio::io::duplex(4096);
let t = tokio::spawn(async move {
serve_redirect(443, server, "/elsewhere", 302).await;
});
let got = drain_to_string(client).await;
t.await.unwrap();
assert_eq!(
got,
"HTTP/1.1 302 Found\r\nLocation: /elsewhere\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"
);
}
#[tokio::test]
async fn write_http_status_emits_status_line() {
let (client, server) = tokio::io::duplex(4096);
let t = tokio::spawn(async move {
write_http_status(443, server, "404 Not Found").await;
});
let got = drain_to_string(client).await;
t.await.unwrap();
assert_eq!(
got,
"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"
);
let (client, server) = tokio::io::duplex(4096);
let t = tokio::spawn(async move {
write_http_status(443, server, "400 Bad Request").await;
});
let got = drain_to_string(client).await;
t.await.unwrap();
assert_eq!(
got,
"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"
);
}
#[tokio::test]
async fn read_http_head_reads_terminated_head() {
let (mut client, mut server) = tokio::io::duplex(4096);
client
.write_all(b"GET /api HTTP/1.1\r\nHost: h\r\n\r\nBODY")
.await
.unwrap();
drop(client);
let (buf, end) = read_http_head(&mut server).await.expect("complete head");
assert_eq!(&buf[..end], b"GET /api HTTP/1.1\r\nHost: h\r\n\r\n");
assert_eq!(&buf[end..], b"BODY");
}
#[tokio::test]
async fn read_http_head_none_on_early_eof() {
let (mut client, mut server) = tokio::io::duplex(4096);
client.write_all(b"GET / HTTP/1.1\r\n").await.unwrap();
drop(client); assert!(read_http_head(&mut server).await.is_none());
}
#[tokio::test]
async fn read_http_head_none_on_oversized_head() {
let (mut client, mut server) = tokio::io::duplex(64 * 1024);
let oversized = vec![b'a'; MAX_HTTP_HEAD + 1024];
client.write_all(&oversized).await.unwrap();
drop(client);
assert!(read_http_head(&mut server).await.is_none());
}
#[tokio::test]
async fn read_http_head_never_exceeds_max_head() {
let (mut client, mut server) = tokio::io::duplex(MAX_HTTP_HEAD + 16);
let mut head = vec![b'a'; MAX_HTTP_HEAD - 4];
head.extend_from_slice(b"\r\n\r\n");
assert_eq!(head.len(), MAX_HTTP_HEAD);
client.write_all(&head).await.unwrap();
drop(client);
let (buf, end) = read_http_head(&mut server).await.expect("head at bound");
assert_eq!(end, MAX_HTTP_HEAD);
assert!(buf.len() <= MAX_HTTP_HEAD);
}
#[tokio::test]
async fn proxy_with_prefix_writes_prefix_before_bidi_copy() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let prefix = b"GET /api HTTP/1.1\r\nHost: h\r\n\r\n";
let body = b"trailing-body-bytes";
let backend = tokio::spawn(async move {
let (mut sock, _) = listener.accept().await.unwrap();
let mut head = vec![0u8; prefix.len()];
sock.read_exact(&mut head).await.unwrap();
let mut rest = vec![0u8; body.len()];
sock.read_exact(&mut rest).await.unwrap();
(head, rest)
});
let (mut client, server) = tokio::io::duplex(4096);
let to = backend_addr.to_string();
let proxy_task = tokio::spawn(async move {
proxy_to_backend_with_prefix(443, server, &to, prefix).await;
});
client.write_all(body).await.unwrap();
drop(client);
let (head, rest) = backend.await.unwrap();
proxy_task.await.unwrap();
assert_eq!(
head, prefix,
"prefix (consumed head) replayed to backend first"
);
assert_eq!(rest, body, "remaining stream spliced after the prefix");
}
#[tokio::test]
async fn serve_path_proxy_replays_consumed_head_to_backend() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let request = b"GET /api/v2/x HTTP/1.1\r\nHost: h\r\n\r\n";
let backend = tokio::spawn(async move {
let (mut sock, _) = listener.accept().await.unwrap();
let mut head = vec![0u8; request.len()];
sock.read_exact(&mut head).await.unwrap();
head
});
let mut handlers: BTreeMap<String, ServeTarget> = BTreeMap::new();
handlers.insert("/".to_string(), proxy("127.0.0.1:1")); handlers.insert("/api/v2".to_string(), proxy(&backend_addr.to_string()));
let (mut client, server) = tokio::io::duplex(4096);
let path_task = tokio::spawn(async move {
serve_path(443, server, &handlers).await;
});
client.write_all(request).await.unwrap();
drop(client);
let head = backend.await.unwrap();
path_task.await.unwrap();
assert_eq!(
head, request,
"serve_path routed to the longest-prefix Proxy and replayed the consumed head"
);
}
#[tokio::test]
async fn serve_path_text_target_emits_body() {
let mut handlers: BTreeMap<String, ServeTarget> = BTreeMap::new();
handlers.insert(
"/".to_string(),
ServeTarget::Text {
body: "root".into(),
},
);
handlers.insert(
"/hello".to_string(),
ServeTarget::Text {
body: "hello-body".into(),
},
);
let (mut client, server) = tokio::io::duplex(4096);
let t = tokio::spawn(async move {
serve_path(443, server, &handlers).await;
});
client
.write_all(b"GET /hello/world HTTP/1.1\r\nHost: h\r\n\r\n")
.await
.unwrap();
let got = drain_to_string(client).await;
t.await.unwrap();
assert_eq!(got, "hello-body");
}
}