Skip to main content

ts_runtime/
fallback_tcp.rs

1//! Fallback TCP handler registry (`tsnet.Server.RegisterFallbackTCPHandler` parity).
2//!
3//! Go `tsnet` lets an embedder register a callback consulted for every inbound TCP flow that
4//! matches **no** explicit `Listener`. The callback inspects the `(src, dst)` tuple and either
5//! declines (`intercept = false`, try the next handler) or claims the flow (`intercept = true`),
6//! optionally returning a per-connection handler. This module is the faithful equivalent on the
7//! **application** netstack.
8//!
9//! ## How an unmatched flow is observed
10//!
11//! smoltcp RSTs an inbound SYN to a port with no matching listener *inside* its ingress loop,
12//! before any of our code runs. The single lever it gives us is the same one
13//! `ts_forwarder::all_port` uses: a `raw` `(Ipv4, Tcp)` socket whose `accepts()` sets
14//! `handled_by_raw_socket = true`, which **suppresses that RST** and hands us a copy of every
15//! inbound TCP packet. We read each SYN's destination port and lazily materialize a per-port
16//! any-IP listener; the peer's SYN retransmit is then accepted by that listener and dispatched to
17//! the registered handlers.
18//!
19//! ## The observer runs **only** while a handler is registered
20//!
21//! Because the raw observer suppresses the unmatched-SYN RST for the whole netstack, it must not
22//! be running when there are no fallback handlers — otherwise a node with zero handlers would stop
23//! RSTing unrouted SYNs (silently swallowing them) instead of cleanly refusing. So the observer is
24//! started on the *first* registration and torn down on the *last* deregistration, leaving the
25//! default fail-closed RST behavior pristine whenever no handler is installed.
26//!
27//! ## Anti-leak
28//!
29//! The raw observer never creates a host socket and never dials anything; it only learns ports.
30//! Every accepted flow is handed to the embedder's own handler over the overlay netstack — never a
31//! host socket. Ports already owned by an explicit `tcp_listen`er are skipped (queried read-only
32//! via `CreateSocket::bound_tcp_ports`) so a fallback listener never competes with a real one. A
33//! flow no handler claims is closed (fail-closed), never direct-dialed. IPv4-only.
34
35use std::{
36    collections::{BTreeMap, HashMap},
37    future::Future,
38    net::{Ipv4Addr, SocketAddr},
39    pin::Pin,
40    sync::{Arc, Mutex, Weak},
41    time::{Duration, Instant},
42};
43
44use netstack::{
45    CreateSocket,
46    netcore::{
47        Channel,
48        smoltcp::wire::{IpProtocol, Ipv4Packet, TcpPacket},
49    },
50    netsock::TcpStream,
51};
52use tokio::sync::Semaphore;
53
54/// Maximum number of distinct ports that may have a live on-demand fallback listener at once.
55///
56/// Mirrors `ts_forwarder::all_port`'s cap: without it a remote could scan all 65,535 ports and
57/// permanently materialize that many tasks + netstack sockets (remote FD/memory-exhaustion DoS).
58/// Over the cap, SYNs to *new* ports are dropped (no listener spawned) until a port is evicted.
59const MAX_PORTS: usize = 1024;
60
61/// How long an on-demand per-port listener may go without any observed inbound packet before it is
62/// reaped (its task aborted and the port freed so a later packet can re-trigger it).
63const PORT_IDLE: Duration = Duration::from_secs(120);
64
65/// How often the idle-port reaper runs (half [`PORT_IDLE`] to keep worst-case dormant lifetime
66/// near `PORT_IDLE` rather than double it).
67const PORT_REAP_INTERVAL: Duration = Duration::from_secs(60);
68
69/// Max concurrent in-flight handled flows per fallback port listener. Bounds the per-flow spawn
70/// fan-out so a flood of accepts cannot grow tasks without limit; saturated => drop (fail-closed).
71const MAX_INFLIGHT: usize = 512;
72
73/// The future returned by a [`FallbackConnHandler`]; spawned to service one accepted flow.
74pub type FallbackConnFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
75
76/// Per-connection handler returned by a fallback callback that claims a flow. Consumes the
77/// accepted overlay [`TcpStream`] and returns a future the manager spawns. Mirrors the
78/// `func(net.Conn)` Go `tsnet` returns from its fallback callback.
79pub type FallbackConnHandler = Box<dyn FnOnce(TcpStream) -> FallbackConnFuture + Send>;
80
81/// A fallback callback's decision for one `(src, dst)` flow: an optional per-connection handler
82/// and whether this callback intercepts the flow. Matches Go's `(handler func(net.Conn), intercept
83/// bool)`:
84/// - `(_, false)` — decline; the manager tries the next registered callback.
85/// - `(Some(h), true)` — claim the flow; `h` services the connection.
86/// - `(None, true)` — claim the flow and reject it (the connection is closed).
87pub type FallbackDecision = (Option<FallbackConnHandler>, bool);
88
89/// A registered fallback callback. Invoked per unmatched inbound TCP flow with `(src, dst)`.
90type Handler = Arc<dyn Fn(SocketAddr, SocketAddr) -> FallbackDecision + Send + Sync>;
91
92/// Bookkeeping for one on-demand per-port fallback listener.
93struct PortEntry {
94    /// Aborts the listener task on eviction / observer drop.
95    handle: tokio::task::AbortHandle,
96    /// Last time an inbound packet for this port was observed (for idle eviction).
97    last: Instant,
98}
99
100impl Drop for PortEntry {
101    fn drop(&mut self) {
102        self.handle.abort();
103    }
104}
105
106/// Shared manager state behind a single lock.
107struct Inner {
108    /// Registered callbacks keyed by monotonic id. Iteration order (ascending id ≈ registration
109    /// order) is the dispatch order; the first callback to intercept wins.
110    handlers: BTreeMap<u64, Handler>,
111    /// Next callback id to hand out.
112    next_id: u64,
113    /// The running raw-SYN observer task, present iff `handlers` is non-empty.
114    observer: Option<tokio::task::AbortHandle>,
115    /// Application-netstack channel the observer and per-port listeners run on.
116    channel: Channel,
117}
118
119/// Manages the fallback-TCP handler registry and the lifecycle of the raw-SYN observer.
120///
121/// Built once from the application netstack channel and held by the runtime. Registering the first
122/// handler starts the observer; dropping the last [`FallbackTcpHandle`] stops it.
123pub struct FallbackTcpManager {
124    inner: Arc<Mutex<Inner>>,
125}
126
127impl FallbackTcpManager {
128    /// Build a manager bound to the application netstack `channel`. The observer is not started
129    /// until the first handler is registered.
130    pub fn new(channel: Channel) -> Self {
131        Self {
132            inner: Arc::new(Mutex::new(Inner {
133                handlers: BTreeMap::new(),
134                next_id: 0,
135                observer: None,
136                channel,
137            })),
138        }
139    }
140
141    /// Register a fallback callback, returning a RAII handle that deregisters it on drop.
142    ///
143    /// The first registration starts the raw-SYN observer; the last deregistration stops it.
144    pub fn register(&self, cb: Handler) -> FallbackTcpHandle {
145        // Recover from a poisoned lock rather than cascading a panic across flows (matches the
146        // reliability posture of the rest of the dataplane).
147        let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
148        let id = inner.next_id;
149        inner.next_id += 1;
150        inner.handlers.insert(id, cb);
151
152        if inner.observer.is_none() {
153            let channel = inner.channel.clone();
154            let weak = Arc::downgrade(&self.inner);
155            let task = tokio::spawn(async move {
156                if let Err(e) = run_observer(channel, weak).await {
157                    tracing::warn!(error = %e, "fallback-tcp observer exited");
158                }
159            });
160            inner.observer = Some(task.abort_handle());
161            tracing::debug!("fallback-tcp: started raw SYN observer (first handler registered)");
162        }
163
164        FallbackTcpHandle {
165            id,
166            inner: Arc::downgrade(&self.inner),
167        }
168    }
169}
170
171/// RAII deregistration handle for a fallback callback (mirrors the `unregister func()` Go returns).
172///
173/// Dropping it removes the callback; dropping the last handle also tears down the raw observer, so
174/// the netstack's default fail-closed RST behavior returns when no handler is installed.
175#[must_use = "dropping the handle immediately deregisters the fallback handler"]
176pub struct FallbackTcpHandle {
177    id: u64,
178    inner: Weak<Mutex<Inner>>,
179}
180
181impl FallbackTcpHandle {
182    /// Explicitly deregister the handler now. Equivalent to dropping the handle.
183    pub fn unregister(self) {
184        // Drop runs the deregistration.
185    }
186}
187
188impl Drop for FallbackTcpHandle {
189    fn drop(&mut self) {
190        let Some(inner) = self.inner.upgrade() else {
191            return;
192        };
193        let mut g = inner.lock().unwrap_or_else(|e| e.into_inner());
194        g.handlers.remove(&self.id);
195        if g.handlers.is_empty()
196            && let Some(observer) = g.observer.take()
197        {
198            // Last handler gone: stop suppressing the unmatched-SYN RST. Aborting the observer
199            // drops its per-port `PortEntry`s, which abort the per-port listener tasks.
200            observer.abort();
201            tracing::debug!("fallback-tcp: stopped raw SYN observer (last handler deregistered)");
202        }
203    }
204}
205
206/// Observe inbound SYNs via a raw socket and lazily start a per-port any-IP listener for each new
207/// destination port that is not already served by an explicit listener.
208async fn run_observer(
209    channel: Channel,
210    inner: Weak<Mutex<Inner>>,
211) -> Result<(), netstack::netcore::Error> {
212    // The raw observer both suppresses the unmatched-SYN RST and reveals each SYN's dst port.
213    let raw = channel.raw_open(true, IpProtocol::Tcp).await?;
214
215    // A per-port listener task sends its port back when it exits so the observer removes it from
216    // the active set (so a retransmit re-triggers it).
217    let (exit_tx, mut exit_rx) = tokio::sync::mpsc::unbounded_channel::<u16>();
218    let mut ports: HashMap<u16, PortEntry> = HashMap::new();
219    let mut reap = tokio::time::interval(PORT_REAP_INTERVAL);
220    reap.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
221
222    loop {
223        tokio::select! {
224            packet = raw.recv_bytes() => {
225                let packet = packet?;
226                let Some(port) = syn_dst_port(&packet) else {
227                    continue;
228                };
229                if let Some(entry) = ports.get_mut(&port) {
230                    entry.last = Instant::now();
231                    continue;
232                }
233                if ports.len() >= MAX_PORTS {
234                    tracing::warn!(%port, "fallback-tcp: at max active ports ({MAX_PORTS}); dropping new port");
235                    continue;
236                }
237                // Cold path only: skip ports an explicit listener already owns so a fallback
238                // listener never competes with a real one. Read-only registry query.
239                match channel.bound_tcp_ports().await {
240                    Ok(bound) if bound.contains(&port) => continue,
241                    Ok(_) => {}
242                    Err(e) => {
243                        tracing::warn!(%port, error = %e, "fallback-tcp: bound-ports query failed; skipping port");
244                        continue;
245                    }
246                }
247                let Some(inner) = inner.upgrade() else {
248                    // Manager dropped; nothing left to serve.
249                    return Ok(());
250                };
251                tracing::debug!(%port, "fallback-tcp: starting listener on demand");
252                let channel = channel.clone();
253                let exit_tx = exit_tx.clone();
254                let handle = tokio::spawn(async move {
255                    if let Err(e) = run_port(channel, port, inner).await {
256                        tracing::warn!(%port, error = %e, "fallback-tcp listener exited");
257                    }
258                    let _ = exit_tx.send(port);
259                })
260                .abort_handle();
261                ports.insert(port, PortEntry { handle, last: Instant::now() });
262            }
263            Some(port) = exit_rx.recv() => {
264                ports.remove(&port);
265            }
266            _ = reap.tick() => {
267                let before = ports.len();
268                ports.retain(|_, e| e.last.elapsed() < PORT_IDLE);
269                let reaped = before - ports.len();
270                if reaped > 0 {
271                    tracing::debug!(reaped, "fallback-tcp: reaped idle listeners");
272                }
273            }
274        }
275    }
276}
277
278/// Accept flows on `0.0.0.0:port` of the application netstack and dispatch each to the registered
279/// fallback callbacks in order; the first to intercept wins.
280async fn run_port(
281    channel: Channel,
282    port: u16,
283    inner: Arc<Mutex<Inner>>,
284) -> Result<(), netstack::netcore::Error> {
285    let listen_addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), port);
286    let listener = channel.tcp_listen(listen_addr).await?;
287    tracing::debug!(%port, "fallback-tcp listener accepting");
288
289    let inflight = Arc::new(Semaphore::new(MAX_INFLIGHT));
290
291    loop {
292        let overlay = listener.accept().await?;
293        let Ok(permit) = inflight.clone().try_acquire_owned() else {
294            tracing::warn!(
295                %port,
296                peer = %overlay.remote_addr(),
297                "fallback-tcp drop: at max in-flight flows ({MAX_INFLIGHT})"
298            );
299            // Dropping `overlay` closes the flow; fail-closed, never direct-dialed.
300            continue;
301        };
302
303        // Snapshot the callbacks under the lock, then release it before invoking them.
304        let handlers: Vec<Handler> = {
305            let g = inner.lock().unwrap_or_else(|e| e.into_inner());
306            g.handlers.values().cloned().collect()
307        };
308
309        let src = overlay.remote_addr();
310        let dst = overlay.local_addr();
311
312        match dispatch(&handlers, src, dst) {
313            Some(conn_handler) => {
314                tokio::spawn(async move {
315                    let _permit = permit; // released when the handler future completes
316                    conn_handler(overlay).await;
317                });
318            }
319            // No handler claimed with a connection handler: either every handler declined, or one
320            // intercepted to reject (intercept=true, handler=None). Either way the flow is closed
321            // by dropping `overlay`. Fail-closed.
322            None => {
323                drop(overlay);
324            }
325        }
326    }
327}
328
329/// Consult `handlers` in order for the flow `(src, dst)` and return the per-connection handler of
330/// the first callback that intercepts, if any.
331///
332/// Mirrors Go `tsnet`: the first callback returning `intercept = true` wins; a `true` with no
333/// connection handler (reject) and an exhausted handler list (decline) both yield `None`, which the
334/// caller treats as "close the flow".
335fn dispatch(handlers: &[Handler], src: SocketAddr, dst: SocketAddr) -> Option<FallbackConnHandler> {
336    for handler in handlers {
337        let (conn_handler, intercept) = handler(src, dst);
338        if intercept {
339            return conn_handler;
340        }
341    }
342    None
343}
344
345/// Parse a raw IPv4 packet and return its TCP destination port iff it is a connection-initiating
346/// SYN (SYN set, ACK clear). Non-TCP, malformed, or non-SYN packets yield `None`.
347fn syn_dst_port(packet: &[u8]) -> Option<u16> {
348    let ip = Ipv4Packet::new_checked(packet).ok()?;
349    if ip.next_header() != IpProtocol::Tcp {
350        return None;
351    }
352    let tcp = TcpPacket::new_checked(ip.payload()).ok()?;
353    if tcp.syn() && !tcp.ack() {
354        Some(tcp.dst_port())
355    } else {
356        None
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use netstack::netcore::smoltcp::wire::Ipv4Address;
363
364    use super::*;
365
366    fn ipv4(proto: IpProtocol, payload: &[u8]) -> Vec<u8> {
367        const IHL: usize = 20;
368        let total = IHL + payload.len();
369        let mut buf = vec![0u8; total];
370        let mut ip = Ipv4Packet::new_unchecked(&mut buf);
371        ip.set_version(4);
372        ip.set_header_len(IHL as u8);
373        ip.set_total_len(total as u16);
374        ip.set_hop_limit(64);
375        ip.set_next_header(proto);
376        ip.set_src_addr(Ipv4Address::new(10, 0, 0, 1));
377        ip.set_dst_addr(Ipv4Address::new(10, 0, 0, 2));
378        ip.payload_mut().copy_from_slice(payload);
379        buf
380    }
381
382    fn tcp_segment(dst_port: u16, syn: bool, ack: bool) -> Vec<u8> {
383        let mut seg = vec![0u8; 20];
384        let mut tcp = TcpPacket::new_unchecked(&mut seg);
385        tcp.set_src_port(12345);
386        tcp.set_dst_port(dst_port);
387        tcp.set_header_len(20);
388        tcp.set_syn(syn);
389        tcp.set_ack(ack);
390        seg
391    }
392
393    #[test]
394    fn syn_dst_port_reads_connection_initiating_syn() {
395        let pkt = ipv4(IpProtocol::Tcp, &tcp_segment(8443, true, false));
396        assert_eq!(syn_dst_port(&pkt), Some(8443));
397    }
398
399    #[test]
400    fn syn_dst_port_ignores_syn_ack_and_non_syn() {
401        let synack = ipv4(IpProtocol::Tcp, &tcp_segment(8443, true, true));
402        assert_eq!(syn_dst_port(&synack), None);
403        let ack = ipv4(IpProtocol::Tcp, &tcp_segment(8443, false, true));
404        assert_eq!(syn_dst_port(&ack), None);
405    }
406
407    #[test]
408    fn syn_dst_port_ignores_malformed() {
409        assert_eq!(syn_dst_port(&[0u8; 4]), None);
410    }
411
412    #[test]
413    fn caps_are_bounded() {
414        assert_eq!(MAX_PORTS, 1024);
415        assert!(PORT_REAP_INTERVAL <= PORT_IDLE / 2);
416        assert_eq!(MAX_INFLIGHT, 512);
417    }
418
419    fn addr(port: u16) -> SocketAddr {
420        SocketAddr::new(Ipv4Addr::new(100, 64, 0, 1).into(), port)
421    }
422
423    /// A handler that returns the given decision and records that it was consulted.
424    fn handler(decision: impl Fn() -> FallbackDecision + Send + Sync + 'static) -> Handler {
425        Arc::new(move |_src, _dst| decision())
426    }
427
428    #[test]
429    fn dispatch_declines_when_no_handler_intercepts() {
430        let handlers = vec![handler(|| (None, false)), handler(|| (None, false))];
431        assert!(dispatch(&handlers, addr(1), addr(8443)).is_none());
432    }
433
434    #[test]
435    fn dispatch_empty_handler_list_yields_none() {
436        assert!(dispatch(&[], addr(1), addr(8443)).is_none());
437    }
438
439    #[test]
440    fn dispatch_intercept_with_handler_is_returned() {
441        let handlers = vec![handler(|| {
442            let h: FallbackConnHandler = Box::new(|_stream| Box::pin(async {}));
443            (Some(h), true)
444        })];
445        assert!(dispatch(&handlers, addr(1), addr(8443)).is_some());
446    }
447
448    #[test]
449    fn dispatch_intercept_reject_yields_none_and_stops() {
450        // First handler intercepts to reject (handler=None, intercept=true). The second handler
451        // would intercept-with-handler, but must NOT be consulted — first intercept wins.
452        let second_consulted = Arc::new(std::sync::atomic::AtomicBool::new(false));
453        let flag = second_consulted.clone();
454        let handlers = vec![
455            handler(|| (None, true)),
456            Arc::new(move |_s: SocketAddr, _d: SocketAddr| {
457                flag.store(true, std::sync::atomic::Ordering::SeqCst);
458                let h: FallbackConnHandler = Box::new(|_stream| Box::pin(async {}));
459                (Some(h), true)
460            }) as Handler,
461        ];
462        assert!(
463            dispatch(&handlers, addr(1), addr(8443)).is_none(),
464            "intercept=true with no handler must reject (None)"
465        );
466        assert!(
467            !second_consulted.load(std::sync::atomic::Ordering::SeqCst),
468            "first intercept must win; later handlers must not be consulted"
469        );
470    }
471
472    #[test]
473    fn dispatch_first_interceptor_wins_over_later() {
474        // A declining handler is skipped; the first that intercepts (here the second) wins.
475        let handlers = vec![
476            handler(|| (None, false)),
477            handler(|| {
478                let h: FallbackConnHandler = Box::new(|_stream| Box::pin(async {}));
479                (Some(h), true)
480            }),
481        ];
482        assert!(dispatch(&handlers, addr(1), addr(8443)).is_some());
483    }
484}