use std::{
collections::{BTreeMap, HashMap},
future::Future,
net::{Ipv4Addr, SocketAddr},
pin::Pin,
sync::{Arc, Mutex, Weak},
time::{Duration, Instant},
};
use netstack::{
CreateSocket,
netcore::{
Channel,
smoltcp::wire::{IpProtocol, Ipv4Packet, TcpPacket},
},
netsock::TcpStream,
};
use tokio::sync::Semaphore;
const MAX_PORTS: usize = 1024;
const PORT_IDLE: Duration = Duration::from_secs(120);
const PORT_REAP_INTERVAL: Duration = Duration::from_secs(60);
const MAX_INFLIGHT: usize = 512;
pub type FallbackConnFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
pub type FallbackConnHandler = Box<dyn FnOnce(TcpStream) -> FallbackConnFuture + Send>;
pub type FallbackDecision = (Option<FallbackConnHandler>, bool);
type Handler = Arc<dyn Fn(SocketAddr, SocketAddr) -> FallbackDecision + Send + Sync>;
struct PortEntry {
handle: tokio::task::AbortHandle,
last: Instant,
}
impl Drop for PortEntry {
fn drop(&mut self) {
self.handle.abort();
}
}
struct Inner {
handlers: BTreeMap<u64, Handler>,
next_id: u64,
observer: Option<tokio::task::AbortHandle>,
channel: Channel,
}
pub struct FallbackTcpManager {
inner: Arc<Mutex<Inner>>,
}
impl FallbackTcpManager {
pub fn new(channel: Channel) -> Self {
Self {
inner: Arc::new(Mutex::new(Inner {
handlers: BTreeMap::new(),
next_id: 0,
observer: None,
channel,
})),
}
}
pub fn register(&self, cb: Handler) -> FallbackTcpHandle {
let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
let id = inner.next_id;
inner.next_id += 1;
inner.handlers.insert(id, cb);
if inner.observer.is_none() {
let channel = inner.channel.clone();
let weak = Arc::downgrade(&self.inner);
let task = tokio::spawn(async move {
if let Err(e) = run_observer(channel, weak).await {
tracing::warn!(error = %e, "fallback-tcp observer exited");
}
});
inner.observer = Some(task.abort_handle());
tracing::debug!("fallback-tcp: started raw SYN observer (first handler registered)");
}
FallbackTcpHandle {
id,
inner: Arc::downgrade(&self.inner),
}
}
}
#[must_use = "dropping the handle immediately deregisters the fallback handler"]
pub struct FallbackTcpHandle {
id: u64,
inner: Weak<Mutex<Inner>>,
}
impl FallbackTcpHandle {
pub fn unregister(self) {
}
}
impl Drop for FallbackTcpHandle {
fn drop(&mut self) {
let Some(inner) = self.inner.upgrade() else {
return;
};
let mut g = inner.lock().unwrap_or_else(|e| e.into_inner());
g.handlers.remove(&self.id);
if g.handlers.is_empty()
&& let Some(observer) = g.observer.take()
{
observer.abort();
tracing::debug!("fallback-tcp: stopped raw SYN observer (last handler deregistered)");
}
}
}
async fn run_observer(
channel: Channel,
inner: Weak<Mutex<Inner>>,
) -> Result<(), netstack::netcore::Error> {
let raw = channel.raw_open(true, IpProtocol::Tcp).await?;
let (exit_tx, mut exit_rx) = tokio::sync::mpsc::unbounded_channel::<u16>();
let mut ports: HashMap<u16, PortEntry> = HashMap::new();
let mut reap = tokio::time::interval(PORT_REAP_INTERVAL);
reap.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
packet = raw.recv_bytes() => {
let packet = packet?;
let Some(port) = syn_dst_port(&packet) else {
continue;
};
if let Some(entry) = ports.get_mut(&port) {
entry.last = Instant::now();
continue;
}
if ports.len() >= MAX_PORTS {
tracing::warn!(%port, "fallback-tcp: at max active ports ({MAX_PORTS}); dropping new port");
continue;
}
match channel.bound_tcp_ports().await {
Ok(bound) if bound.contains(&port) => continue,
Ok(_) => {}
Err(e) => {
tracing::warn!(%port, error = %e, "fallback-tcp: bound-ports query failed; skipping port");
continue;
}
}
let Some(inner) = inner.upgrade() else {
return Ok(());
};
tracing::debug!(%port, "fallback-tcp: starting listener on demand");
let channel = channel.clone();
let exit_tx = exit_tx.clone();
let handle = tokio::spawn(async move {
if let Err(e) = run_port(channel, port, inner).await {
tracing::warn!(%port, error = %e, "fallback-tcp listener exited");
}
let _ = exit_tx.send(port);
})
.abort_handle();
ports.insert(port, PortEntry { handle, last: Instant::now() });
}
Some(port) = exit_rx.recv() => {
ports.remove(&port);
}
_ = reap.tick() => {
let before = ports.len();
ports.retain(|_, e| e.last.elapsed() < PORT_IDLE);
let reaped = before - ports.len();
if reaped > 0 {
tracing::debug!(reaped, "fallback-tcp: reaped idle listeners");
}
}
}
}
}
async fn run_port(
channel: Channel,
port: u16,
inner: Arc<Mutex<Inner>>,
) -> Result<(), netstack::netcore::Error> {
let listen_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), port);
let listener = channel.tcp_listen(listen_addr).await?;
tracing::debug!(%port, "fallback-tcp listener accepting");
let inflight = Arc::new(Semaphore::new(MAX_INFLIGHT));
loop {
let overlay = listener.accept().await?;
let Ok(permit) = inflight.clone().try_acquire_owned() else {
tracing::warn!(
%port,
peer = %overlay.remote_addr(),
"fallback-tcp drop: at max in-flight flows ({MAX_INFLIGHT})"
);
continue;
};
let handlers: Vec<Handler> = {
let g = inner.lock().unwrap_or_else(|e| e.into_inner());
g.handlers.values().cloned().collect()
};
let src = overlay.remote_addr();
let dst = overlay.local_addr();
match dispatch(&handlers, src, dst) {
Some(conn_handler) => {
tokio::spawn(async move {
let _permit = permit; conn_handler(overlay).await;
});
}
None => {
drop(overlay);
}
}
}
}
fn dispatch(handlers: &[Handler], src: SocketAddr, dst: SocketAddr) -> Option<FallbackConnHandler> {
for handler in handlers {
let (conn_handler, intercept) = handler(src, dst);
if intercept {
return conn_handler;
}
}
None
}
fn syn_dst_port(packet: &[u8]) -> Option<u16> {
let ip = Ipv4Packet::new_checked(packet).ok()?;
if ip.next_header() != IpProtocol::Tcp {
return None;
}
let tcp = TcpPacket::new_checked(ip.payload()).ok()?;
if tcp.syn() && !tcp.ack() {
Some(tcp.dst_port())
} else {
None
}
}
#[cfg(test)]
mod tests {
use netstack::netcore::smoltcp::wire::Ipv4Address;
use super::*;
fn ipv4(proto: IpProtocol, payload: &[u8]) -> Vec<u8> {
const IHL: usize = 20;
let total = IHL + payload.len();
let mut buf = vec![0u8; total];
let mut ip = Ipv4Packet::new_unchecked(&mut buf);
ip.set_version(4);
ip.set_header_len(IHL as u8);
ip.set_total_len(total as u16);
ip.set_hop_limit(64);
ip.set_next_header(proto);
ip.set_src_addr(Ipv4Address::new(10, 0, 0, 1));
ip.set_dst_addr(Ipv4Address::new(10, 0, 0, 2));
ip.payload_mut().copy_from_slice(payload);
buf
}
fn tcp_segment(dst_port: u16, syn: bool, ack: bool) -> Vec<u8> {
let mut seg = vec![0u8; 20];
let mut tcp = TcpPacket::new_unchecked(&mut seg);
tcp.set_src_port(12345);
tcp.set_dst_port(dst_port);
tcp.set_header_len(20);
tcp.set_syn(syn);
tcp.set_ack(ack);
seg
}
#[test]
fn syn_dst_port_reads_connection_initiating_syn() {
let pkt = ipv4(IpProtocol::Tcp, &tcp_segment(8443, true, false));
assert_eq!(syn_dst_port(&pkt), Some(8443));
}
#[test]
fn syn_dst_port_ignores_syn_ack_and_non_syn() {
let synack = ipv4(IpProtocol::Tcp, &tcp_segment(8443, true, true));
assert_eq!(syn_dst_port(&synack), None);
let ack = ipv4(IpProtocol::Tcp, &tcp_segment(8443, false, true));
assert_eq!(syn_dst_port(&ack), None);
}
#[test]
fn syn_dst_port_ignores_malformed() {
assert_eq!(syn_dst_port(&[0u8; 4]), None);
}
#[test]
fn caps_are_bounded() {
assert_eq!(MAX_PORTS, 1024);
assert!(PORT_REAP_INTERVAL <= PORT_IDLE / 2);
assert_eq!(MAX_INFLIGHT, 512);
}
fn addr(port: u16) -> SocketAddr {
SocketAddr::new(Ipv4Addr::new(100, 64, 0, 1).into(), port)
}
fn handler(decision: impl Fn() -> FallbackDecision + Send + Sync + 'static) -> Handler {
Arc::new(move |_src, _dst| decision())
}
#[test]
fn dispatch_declines_when_no_handler_intercepts() {
let handlers = vec![handler(|| (None, false)), handler(|| (None, false))];
assert!(dispatch(&handlers, addr(1), addr(8443)).is_none());
}
#[test]
fn dispatch_empty_handler_list_yields_none() {
assert!(dispatch(&[], addr(1), addr(8443)).is_none());
}
#[test]
fn dispatch_intercept_with_handler_is_returned() {
let handlers = vec![handler(|| {
let h: FallbackConnHandler = Box::new(|_stream| Box::pin(async {}));
(Some(h), true)
})];
assert!(dispatch(&handlers, addr(1), addr(8443)).is_some());
}
#[test]
fn dispatch_intercept_reject_yields_none_and_stops() {
let second_consulted = Arc::new(std::sync::atomic::AtomicBool::new(false));
let flag = second_consulted.clone();
let handlers = vec![
handler(|| (None, true)),
Arc::new(move |_s: SocketAddr, _d: SocketAddr| {
flag.store(true, std::sync::atomic::Ordering::SeqCst);
let h: FallbackConnHandler = Box::new(|_stream| Box::pin(async {}));
(Some(h), true)
}) as Handler,
];
assert!(
dispatch(&handlers, addr(1), addr(8443)).is_none(),
"intercept=true with no handler must reject (None)"
);
assert!(
!second_consulted.load(std::sync::atomic::Ordering::SeqCst),
"first intercept must win; later handlers must not be consulted"
);
}
#[test]
fn dispatch_first_interceptor_wins_over_later() {
let handlers = vec![
handler(|| (None, false)),
handler(|| {
let h: FallbackConnHandler = Box::new(|_stream| Box::pin(async {}));
(Some(h), true)
}),
];
assert!(dispatch(&handlers, addr(1), addr(8443)).is_some());
}
}