soth-mitm 0.2.0

Rust intercepting proxy crate with deterministic handler/event contracts for SOTH.
Documentation
use super::SidecarConfig;
use std::io;
use tokio::net::{TcpListener, TcpStream};

pub(crate) async fn bind_listener_with_socket_hardening(
    config: &SidecarConfig,
) -> io::Result<TcpListener> {
    const LISTENER_BIND_RETRY_ATTEMPTS: u32 = 8;

    let resolved = tokio::net::lookup_host((config.listen_addr.as_str(), config.listen_port))
        .await
        .map_err(|error| {
            io::Error::new(
                io::ErrorKind::AddrNotAvailable,
                format!(
                    "failed to resolve listen address {}:{}: {error}",
                    config.listen_addr, config.listen_port
                ),
            )
        })?;
    let mut listen_addrs: Vec<std::net::SocketAddr> = resolved.collect();
    if listen_addrs.is_empty() {
        return Err(io::Error::new(
            io::ErrorKind::AddrNotAvailable,
            format!(
                "no resolved socket address for {}:{}",
                config.listen_addr, config.listen_port
            ),
        ));
    }
    order_listen_addrs_for_dual_stack(&mut listen_addrs);

    let retry_delay = std::time::Duration::from_millis(config.accept_retry_backoff_ms.max(1));
    let mut last_error: Option<io::Error> = None;
    for attempt in 1..=LISTENER_BIND_RETRY_ATTEMPTS {
        let mut saw_retryable_error = false;
        for listen_addr in listen_addrs.iter().copied() {
            match bind_single_listener_socket(listen_addr) {
                Ok(listener) => return Ok(listener),
                Err(error) => {
                    if should_retry_listener_bind(&error) {
                        saw_retryable_error = true;
                    }
                    last_error = Some(error);
                }
            }
        }

        if !saw_retryable_error || attempt >= LISTENER_BIND_RETRY_ATTEMPTS {
            break;
        }
        tracing::warn!(
            listen_addr = %config.listen_addr,
            listen_port = config.listen_port,
            attempt,
            max_attempts = LISTENER_BIND_RETRY_ATTEMPTS,
            backoff_ms = retry_delay.as_millis() as u64,
            "transient listener bind failure; retrying"
        );
        tokio::time::sleep(retry_delay).await;
    }
    Err(last_error.unwrap_or_else(|| {
        io::Error::new(
            io::ErrorKind::AddrNotAvailable,
            format!(
                "failed to bind resolved socket addresses for {}:{}",
                config.listen_addr, config.listen_port
            ),
        )
    }))
}

fn should_retry_listener_bind(error: &io::Error) -> bool {
    matches!(
        error.kind(),
        io::ErrorKind::PermissionDenied | io::ErrorKind::AddrInUse
    )
}

fn bind_single_listener_socket(listen_addr: std::net::SocketAddr) -> io::Result<TcpListener> {
    if is_dual_stack_candidate(&listen_addr) {
        match bind_dual_stack_listener_socket(listen_addr) {
            Ok(listener) => return Ok(listener),
            Err(error) => {
                tracing::debug!(
                    addr = %listen_addr,
                    error = %error,
                    "dual-stack bind path failed; falling back to default bind"
                );
            }
        }
    }
    bind_listener_with_tokio_socket(listen_addr)
}

fn bind_listener_with_tokio_socket(listen_addr: std::net::SocketAddr) -> io::Result<TcpListener> {
    let socket = if listen_addr.is_ipv4() {
        tokio::net::TcpSocket::new_v4()?
    } else {
        tokio::net::TcpSocket::new_v6()?
    };
    let _ = socket.set_reuseaddr(true);
    socket.bind(listen_addr)?;
    socket.listen(1024)
}

fn bind_dual_stack_listener_socket(listen_addr: std::net::SocketAddr) -> io::Result<TcpListener> {
    let socket = socket2::Socket::new(
        socket2::Domain::IPV6,
        socket2::Type::STREAM,
        Some(socket2::Protocol::TCP),
    )?;
    socket.set_reuse_address(true)?;
    let _ = socket.set_only_v6(false);
    socket.bind(&socket2::SockAddr::from(listen_addr))?;
    socket.listen(1024)?;
    socket.set_nonblocking(true)?;
    let std_listener: std::net::TcpListener = socket.into();
    TcpListener::from_std(std_listener)
}

fn is_dual_stack_candidate(listen_addr: &std::net::SocketAddr) -> bool {
    matches!(listen_addr, std::net::SocketAddr::V6(v6) if v6.ip().is_unspecified())
}

fn order_listen_addrs_for_dual_stack(listen_addrs: &mut [std::net::SocketAddr]) {
    fn priority(addr: &std::net::SocketAddr) -> u8 {
        match addr {
            std::net::SocketAddr::V6(v6) if v6.ip().is_unspecified() => 0,
            std::net::SocketAddr::V4(v4) if v4.ip().is_unspecified() => 1,
            std::net::SocketAddr::V6(_) => 2,
            std::net::SocketAddr::V4(_) => 3,
        }
    }
    listen_addrs.sort_by_key(priority);
}

#[cfg(unix)]
pub(crate) async fn bind_unix_listener_with_socket_hardening(
    socket_path: &str,
) -> io::Result<tokio::net::UnixListener> {
    let path = std::path::Path::new(socket_path);
    if let Some(parent) = path.parent() {
        if !parent.as_os_str().is_empty() {
            std::fs::create_dir_all(parent)?;
        }
    }
    if path.exists() {
        std::fs::remove_file(path)?;
    }
    tokio::net::UnixListener::bind(path)
}

pub(crate) fn apply_per_connection_socket_hardening(stream: &TcpStream) {
    let _ = stream.set_nodelay(true);
}

pub(crate) fn is_benign_socket_close_error(error: &io::Error) -> bool {
    matches!(
        error.kind(),
        io::ErrorKind::UnexpectedEof
            | io::ErrorKind::BrokenPipe
            | io::ErrorKind::ConnectionReset
            | io::ErrorKind::ConnectionAborted
            | io::ErrorKind::NotConnected
    )
}

#[cfg(test)]
mod socket_hardening_tests {
    use super::{
        is_benign_socket_close_error, is_dual_stack_candidate, order_listen_addrs_for_dual_stack,
        should_retry_listener_bind,
    };
    use std::io;
    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};

    #[test]
    fn benign_socket_close_error_kinds_are_classified() {
        assert!(is_benign_socket_close_error(&io::Error::new(
            io::ErrorKind::NotConnected,
            "not connected",
        )));
        assert!(is_benign_socket_close_error(&io::Error::new(
            io::ErrorKind::ConnectionReset,
            "reset",
        )));
        assert!(!is_benign_socket_close_error(&io::Error::new(
            io::ErrorKind::InvalidData,
            "invalid",
        )));
    }

    #[test]
    fn dual_stack_listener_prefers_ipv6_unspecified_first() {
        let mut addrs = vec![
            SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080)),
            SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 8080, 0, 0)),
            SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 8080)),
        ];
        order_listen_addrs_for_dual_stack(&mut addrs);
        assert!(matches!(addrs[0], SocketAddr::V6(v6) if v6.ip().is_unspecified()));
        assert!(matches!(addrs[1], SocketAddr::V4(v4) if v4.ip().is_unspecified()));
    }

    #[test]
    fn dual_stack_candidate_only_matches_ipv6_unspecified() {
        let ipv6_unspecified = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 8080, 0, 0));
        let ipv6_loopback = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8080, 0, 0));
        let ipv4_unspecified = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 8080));
        assert!(is_dual_stack_candidate(&ipv6_unspecified));
        assert!(!is_dual_stack_candidate(&ipv6_loopback));
        assert!(!is_dual_stack_candidate(&ipv4_unspecified));
    }

    #[test]
    fn bind_retry_classifier_only_marks_transient_bind_errors() {
        assert!(should_retry_listener_bind(&io::Error::new(
            io::ErrorKind::PermissionDenied,
            "operation not permitted",
        )));
        assert!(should_retry_listener_bind(&io::Error::new(
            io::ErrorKind::AddrInUse,
            "address in use",
        )));
        assert!(!should_retry_listener_bind(&io::Error::new(
            io::ErrorKind::InvalidInput,
            "invalid address",
        )));
    }
}