innisfree 0.4.3

Exposes local services on public IPv4 address, via cloud server.
Documentation
//! Core network proxy logic for forwarding TCP and UDP traffic
//! between two sockets.
//!
//! Two parallel handlers live here: [`proxy_handler`] for TCP and
//! [`udp_handler`] for UDP. [`crate::manager::run_proxy`] dispatches
//! to the right one based on each [`ServicePort::protocol`]. The
//! methods exposed here are low-level — see
//! [`crate::manager::TunnelManager`] for the user-facing wrapper.

use anyhow::{Context, Result};
use futures::FutureExt;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpStream, UdpSocket};

/// How long to keep a UDP session's outbound socket open after the
/// last reply from the destination. Once this elapses, the per-
/// session reply task exits and removes the session from the map;
/// the next datagram from the same client allocates a fresh one.
const UDP_SESSION_IDLE: Duration = Duration::from_secs(60);

/// Buffer size for a single UDP recv. 65 535 fits any IPv4 datagram
/// payload; anything larger gets truncated by the kernel before we
/// see it (which is the sender's problem, not ours).
const UDP_BUF_BYTES: usize = 65_536;

// Taken from Tokio proxy example (MIT license):
// https://github.com/tokio-rs/tokio/blob/a08ce0d3e06d650361283dc87c8fe14b146df15d/examples/proxy.rs
/// Handle proxying traffic along a given `TcpStream` to a given
/// destination socket.
pub async fn transfer(mut inbound: TcpStream, proxy_addr: SocketAddr) -> Result<()> {
    let mut outbound = TcpStream::connect(proxy_addr).await?;

    let (mut ri, mut wi) = inbound.split();
    let (mut ro, mut wo) = outbound.split();

    let client_to_server = async {
        tokio::io::copy(&mut ri, &mut wo).await?;
        wo.shutdown().await
    };

    let server_to_client = async {
        tokio::io::copy(&mut ro, &mut wi).await?;
        wi.shutdown().await
    };

    tokio::try_join!(client_to_server, server_to_client)?;

    Ok(())
}

/// Create a blocking service proxy that passes TCP traffic
/// between two sockets.
pub async fn proxy_handler(listen_addr: SocketAddr, dest_addr: SocketAddr) -> Result<()> {
    tracing::debug!("Proxying TCP traffic: {} -> {}", listen_addr, dest_addr);
    let listener = tokio::net::TcpListener::bind(&listen_addr).await?;
    while let Ok((inbound, _)) = listener.accept().await {
        let transfer = transfer(inbound, dest_addr).map(|r| {
            if let Err(e) = r {
                tracing::warn!("Proxy connection dropped, creating new handler: {}", e);
            }
        });
        tokio::spawn(transfer);
    }
    Ok(())
}

/// Map keyed by client source address to the outbound socket
/// allocated for that client. UDP is connectionless, so we manufacture
/// per-client state ourselves to know where to send replies. Wrapped
/// in a [`std::sync::Mutex`] because every access is brief and never
/// crosses an `await` point.
type UdpSessions = Arc<Mutex<HashMap<SocketAddr, Arc<UdpSocket>>>>;

/// Bind a UDP socket on `listen_addr` and forward each client's
/// datagrams to `dest_addr` over a per-client outbound socket, with
/// replies routed back to the originating client. Mirrors
/// [`proxy_handler`] for the UDP case.
///
/// Per-client outbound sockets are cached in a session map; idle
/// sessions are reaped after [`UDP_SESSION_IDLE`] of silence from
/// the destination. Datagrams that race a session being torn down
/// may be dropped — UDP semantics tolerate that, and the next packet
/// from the same client recreates the session.
pub async fn udp_handler(listen_addr: SocketAddr, dest_addr: SocketAddr) -> Result<()> {
    tracing::debug!("Proxying UDP traffic: {} -> {}", listen_addr, dest_addr);
    let listener = Arc::new(
        UdpSocket::bind(listen_addr)
            .await
            .with_context(|| format!("binding UDP listener {listen_addr}"))?,
    );
    let sessions: UdpSessions = Arc::new(Mutex::new(HashMap::new()));

    let mut buf = vec![0u8; UDP_BUF_BYTES];
    loop {
        let (len, src) = match listener.recv_from(&mut buf).await {
            Ok(t) => t,
            Err(e) => {
                tracing::warn!("UDP listener recv on {listen_addr}: {e}");
                continue;
            }
        };

        // Cheap fast-path: existing session.
        let cached = sessions.lock().unwrap().get(&src).cloned();
        let outbound = match cached {
            Some(s) => s,
            None => {
                // First packet from this client. Allocate an
                // ephemeral outbound socket connect()-ed to dest so
                // we can use plain `send` / `recv` on it, and spawn
                // a reply pump for it.
                let s = UdpSocket::bind("0.0.0.0:0")
                    .await
                    .context("binding ephemeral UDP outbound")?;
                s.connect(dest_addr)
                    .await
                    .with_context(|| format!("connecting UDP outbound to {dest_addr}"))?;
                let s = Arc::new(s);
                sessions.lock().unwrap().insert(src, s.clone());
                spawn_udp_reply_pump(s.clone(), listener.clone(), src, sessions.clone());
                s
            }
        };

        if let Err(e) = outbound.send(&buf[..len]).await {
            tracing::warn!("UDP forward {src} -> {dest_addr}: {e}");
            // Drop the session so the next packet retries fresh
            // (compare-and-remove so we don't yank a session that
            // was just rotated by the reply pump's reaper path).
            let mut guard = sessions.lock().unwrap();
            if let Some(current) = guard.get(&src) {
                if Arc::ptr_eq(current, &outbound) {
                    guard.remove(&src);
                }
            }
        }
    }
}

/// Spawn the per-session reply pump: read replies from `outbound`
/// (connected to dest) and forward them back to `client_src` via
/// the shared `listener`. Exits when the session goes idle for
/// [`UDP_SESSION_IDLE`] or hits a socket error, then removes itself
/// from `sessions` (compare-and-remove, so a concurrent
/// rotation doesn't get clobbered).
fn spawn_udp_reply_pump(
    outbound: Arc<UdpSocket>,
    listener: Arc<UdpSocket>,
    client_src: SocketAddr,
    sessions: UdpSessions,
) {
    tokio::spawn(async move {
        let mut buf = vec![0u8; UDP_BUF_BYTES];
        loop {
            match tokio::time::timeout(UDP_SESSION_IDLE, outbound.recv(&mut buf)).await {
                Ok(Ok(n)) => {
                    if let Err(e) = listener.send_to(&buf[..n], client_src).await {
                        tracing::warn!("UDP reply to {client_src}: {e}");
                        break;
                    }
                }
                Ok(Err(e)) => {
                    tracing::warn!("UDP outbound recv from session {client_src}: {e}");
                    break;
                }
                Err(_) => {
                    tracing::trace!("UDP session {client_src} idle, reaping");
                    break;
                }
            }
        }
        // Compare-and-remove: only clear the entry if it's still
        // the same Arc we pumped, otherwise we'd kick out a fresh
        // session that replaced ours after a transient hiccup.
        let mut guard = sessions.lock().unwrap();
        if let Some(current) = guard.get(&client_src) {
            if Arc::ptr_eq(current, &outbound) {
                guard.remove(&client_src);
            }
        }
    });
}

#[cfg(test)]
mod tests {
    use super::*;
    use tokio::time::timeout;

    #[tokio::test]
    async fn udp_handler_round_trips_a_datagram() {
        // Echo server: bind ephemeral, reply to the first sender
        // with the same payload.
        let server = UdpSocket::bind("127.0.0.1:0").await.unwrap();
        let server_addr = server.local_addr().unwrap();
        let server_task = tokio::spawn(async move {
            let mut buf = vec![0u8; 1500];
            let (n, src) = server.recv_from(&mut buf).await.unwrap();
            server.send_to(&buf[..n], src).await.unwrap();
        });

        // Probe an ephemeral port for the proxy listener, then drop
        // it and pass the address to udp_handler. Brief race window
        // between drop and rebind is acceptable for a unit test.
        let probe = UdpSocket::bind("127.0.0.1:0").await.unwrap();
        let listen_addr = probe.local_addr().unwrap();
        drop(probe);
        let proxy_task = tokio::spawn(udp_handler(listen_addr, server_addr));
        tokio::time::sleep(Duration::from_millis(50)).await;

        // Client sends to the proxy, expects the echoed payload.
        let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
        client.send_to(b"hello-udp", listen_addr).await.unwrap();
        let mut buf = vec![0u8; 1500];
        let (n, _) = timeout(Duration::from_secs(2), client.recv_from(&mut buf))
            .await
            .expect("timed out waiting for UDP echo")
            .unwrap();
        assert_eq!(&buf[..n], b"hello-udp");

        proxy_task.abort();
        server_task.await.unwrap();
    }

    #[tokio::test]
    async fn udp_handler_isolates_sessions_per_client() {
        // Echo server prefixes each reply with the source port, so
        // we can confirm that each client gets its own reply path.
        let server = UdpSocket::bind("127.0.0.1:0").await.unwrap();
        let server_addr = server.local_addr().unwrap();
        let server_task = tokio::spawn(async move {
            let mut buf = vec![0u8; 1500];
            for _ in 0..2 {
                let (n, src) = server.recv_from(&mut buf).await.unwrap();
                let mut reply = format!("from-{}: ", src.port()).into_bytes();
                reply.extend_from_slice(&buf[..n]);
                server.send_to(&reply, src).await.unwrap();
            }
        });

        let probe = UdpSocket::bind("127.0.0.1:0").await.unwrap();
        let listen_addr = probe.local_addr().unwrap();
        drop(probe);
        let proxy_task = tokio::spawn(udp_handler(listen_addr, server_addr));
        tokio::time::sleep(Duration::from_millis(50)).await;

        let client_a = UdpSocket::bind("127.0.0.1:0").await.unwrap();
        let client_b = UdpSocket::bind("127.0.0.1:0").await.unwrap();
        client_a.send_to(b"A", listen_addr).await.unwrap();
        client_b.send_to(b"B", listen_addr).await.unwrap();

        let mut a_buf = vec![0u8; 1500];
        let mut b_buf = vec![0u8; 1500];
        let (an, _) = timeout(Duration::from_secs(2), client_a.recv_from(&mut a_buf))
            .await
            .expect("client A timed out")
            .unwrap();
        let (bn, _) = timeout(Duration::from_secs(2), client_b.recv_from(&mut b_buf))
            .await
            .expect("client B timed out")
            .unwrap();

        // Each client's reply ends with its own payload — proves we
        // didn't cross the streams.
        assert!(a_buf[..an].ends_with(b"A"), "A got: {:?}", &a_buf[..an]);
        assert!(b_buf[..bn].ends_with(b"B"), "B got: {:?}", &b_buf[..bn]);

        proxy_task.abort();
        server_task.await.unwrap();
    }
}