runnel-rs 0.2.0

A Rust proxy and tunnel toolbox with WireGuard-style, TUN, SOCKS, and TLS-based transports.
Documentation
use anyhow::{Context, Result, bail};
use hex::encode as hex_encode;
use std::{
    io::{Read, Write},
    os::unix::net::UnixStream,
    path::{Path, PathBuf},
    thread::sleep,
    time::Duration,
};

use super::WgRuntimeConfig;

#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub(crate) struct WgDeviceStats {
    pub rx_bytes: u64,
    pub tx_bytes: u64,
}

pub(crate) fn control_socket_path(device_name: &str) -> PathBuf {
    Path::new("/var/run/wireguard").join(format!("{device_name}.sock"))
}

pub(crate) fn build_set_request(runtime: &WgRuntimeConfig) -> String {
    let mut lines = vec![format!("private_key={}", hex_encode(runtime.private_key))];
    if let Some(listen_port) = runtime.listen_port() {
        lines.push(format!("listen_port={listen_port}"));
    }
    lines.extend(peer_config_lines(runtime));
    lines.join("\n")
}

pub(crate) fn build_peer_refresh_request(runtime: &WgRuntimeConfig) -> String {
    peer_config_lines(runtime).join("\n")
}

fn peer_config_lines(runtime: &WgRuntimeConfig) -> Vec<String> {
    let mut lines = Vec::new();
    lines.push("replace_peers=true".to_owned());
    lines.push(format!(
        "public_key={}",
        hex_encode(runtime.peer_public_key)
    ));
    if let Some(endpoint) = runtime.endpoint {
        lines.push(format!("endpoint={endpoint}"));
    }
    if let Some(keepalive) = runtime.persistent_keepalive_secs {
        lines.push(format!("persistent_keepalive_interval={keepalive}"));
    }
    lines.extend(
        runtime
            .peer_allowed_ips
            .iter()
            .map(|allowed_ip| format!("allowed_ip={allowed_ip}")),
    );
    lines
}

pub(crate) fn apply_device_config(socket_path: &Path, runtime: &WgRuntimeConfig) -> Result<()> {
    let request = build_set_request(runtime);
    send_set_request(socket_path, &request)
}

pub(crate) fn refresh_peer_config(socket_path: &Path, runtime: &WgRuntimeConfig) -> Result<()> {
    let request = build_peer_refresh_request(runtime);
    send_set_request(socket_path, &request)
}

pub(crate) fn read_device_stats(socket_path: &Path) -> Result<WgDeviceStats> {
    let response = send_get_request(socket_path)?;
    parse_device_stats(&response)
}

pub(crate) fn read_last_handshake_age(socket_path: &Path) -> Result<Option<Duration>> {
    let response = send_get_request(socket_path)?;
    parse_last_handshake_age(&response)
}

fn send_set_request(socket_path: &Path, body: &str) -> Result<()> {
    let mut last_error = None;
    for _ in 0..20 {
        match try_send_set_request(socket_path, body) {
            Ok(()) => return Ok(()),
            Err(err) => {
                last_error = Some(err);
                sleep(Duration::from_millis(50));
            }
        }
    }

    Err(last_error.unwrap_or_else(|| anyhow::anyhow!("failed to configure boringtun UAPI socket")))
}

fn try_send_set_request(socket_path: &Path, body: &str) -> Result<()> {
    let mut socket = UnixStream::connect(socket_path).with_context(|| {
        format!(
            "failed to connect boringtun UAPI socket {}",
            socket_path.display()
        )
    })?;
    write!(socket, "set=1\n{body}\n\n").with_context(|| {
        format!(
            "failed to write boringtun UAPI socket {}",
            socket_path.display()
        )
    })?;
    let mut response = String::new();
    socket.read_to_string(&mut response).with_context(|| {
        format!(
            "failed to read boringtun UAPI socket {}",
            socket_path.display()
        )
    })?;
    parse_errno(&response)
}

fn send_get_request(socket_path: &Path) -> Result<String> {
    let mut socket = UnixStream::connect(socket_path).with_context(|| {
        format!(
            "failed to connect boringtun UAPI socket {}",
            socket_path.display()
        )
    })?;
    write!(socket, "get=1\n\n").with_context(|| {
        format!(
            "failed to write boringtun UAPI socket {}",
            socket_path.display()
        )
    })?;
    let mut response = String::new();
    socket.read_to_string(&mut response).with_context(|| {
        format!(
            "failed to read boringtun UAPI socket {}",
            socket_path.display()
        )
    })?;
    Ok(response)
}

fn parse_device_stats(response: &str) -> Result<WgDeviceStats> {
    parse_errno(response)?;

    let mut stats = WgDeviceStats::default();
    for line in response.lines() {
        if let Some(value) = line.strip_prefix("rx_bytes=") {
            stats.rx_bytes = stats
                .rx_bytes
                .saturating_add(parse_stat(value, "rx_bytes")?);
        } else if let Some(value) = line.strip_prefix("tx_bytes=") {
            stats.tx_bytes = stats
                .tx_bytes
                .saturating_add(parse_stat(value, "tx_bytes")?);
        }
    }
    Ok(stats)
}

fn parse_last_handshake_age(response: &str) -> Result<Option<Duration>> {
    parse_errno(response)?;

    let mut age = None;
    for line in response.lines() {
        if let Some(value) = line.strip_prefix("last_handshake_time_sec=") {
            let current = Duration::from_secs(parse_stat(value, "last_handshake_time_sec")?);
            age = Some(age.map_or(current, |previous: Duration| previous.min(current)));
        }
    }
    Ok(age)
}

fn parse_stat(value: &str, field: &str) -> Result<u64> {
    value
        .parse::<u64>()
        .with_context(|| format!("invalid boringtun {field} field: {value}"))
}

fn parse_errno(response: &str) -> Result<()> {
    let errno = response
        .lines()
        .find_map(|line| line.strip_prefix("errno="))
        .context("boringtun UAPI response did not include errno")?;
    let errno: i32 = errno
        .parse()
        .with_context(|| format!("invalid boringtun errno field: {errno}"))?;
    if errno == 0 {
        return Ok(());
    }
    bail!("boringtun UAPI returned errno={errno}: {response}");
}

#[cfg(test)]
mod tests {
    use super::{
        WgDeviceStats, build_peer_refresh_request, build_set_request, control_socket_path,
        parse_device_stats, parse_last_handshake_age,
    };
    use crate::wg::{WgRuntimeConfig, default_client_allowed_ips, default_server_allowed_ips};
    use std::{
        net::{IpAddr, Ipv4Addr, SocketAddr},
        time::Duration,
    };

    #[test]
    fn control_socket_path_uses_wireguard_run_dir() {
        assert_eq!(
            control_socket_path("utun123"),
            std::path::Path::new("/var/run/wireguard/utun123.sock")
        );
    }

    #[test]
    fn client_set_request_contains_endpoint_keepalive_and_ipv4_full_tunnel_allowed_ips() {
        let runtime = WgRuntimeConfig {
            bind: SocketAddr::from(([0, 0, 0, 0], 51820)),
            endpoint: Some(SocketAddr::from(([198, 51, 100, 10], 51820))),
            tunnel_ip: IpAddr::V4(Ipv4Addr::new(10, 8, 0, 2)),
            peer_tunnel_ip: IpAddr::V4(Ipv4Addr::new(10, 8, 0, 1)),
            mtu: 1420,
            persistent_keepalive_secs: Some(25),
            private_key: [0x11; 32],
            peer_public_key: [0x22; 32],
            peer_allowed_ips: default_client_allowed_ips(),
            excluded_ips: Vec::new(),
        };

        let request = build_set_request(&runtime);
        assert!(request.contains(
            "private_key=1111111111111111111111111111111111111111111111111111111111111111"
        ));
        assert!(request.contains("listen_port=51820"));
        assert!(request.contains("endpoint=198.51.100.10:51820"));
        assert!(request.contains("persistent_keepalive_interval=25"));
        assert!(request.contains("allowed_ip=0.0.0.0/0"));
        assert!(!request.contains("allowed_ip=::/0"));
    }

    #[test]
    fn server_set_request_defaults_to_host_route_for_peer_tunnel_ip() {
        let runtime = WgRuntimeConfig {
            bind: SocketAddr::from(([0, 0, 0, 0], 51820)),
            endpoint: None,
            tunnel_ip: IpAddr::V4(Ipv4Addr::new(10, 8, 0, 1)),
            peer_tunnel_ip: IpAddr::V4(Ipv4Addr::new(10, 8, 0, 2)),
            mtu: 1420,
            persistent_keepalive_secs: None,
            private_key: [0x33; 32],
            peer_public_key: [0x44; 32],
            peer_allowed_ips: default_server_allowed_ips(IpAddr::V4(Ipv4Addr::new(10, 8, 0, 2))),
            excluded_ips: Vec::new(),
        };

        let request = build_set_request(&runtime);
        assert!(request.contains("replace_peers=true"));
        assert!(request.contains(
            "public_key=4444444444444444444444444444444444444444444444444444444444444444"
        ));
        assert!(request.contains("allowed_ip=10.8.0.2/32"));
        assert!(!request.contains("endpoint="));
        assert!(!request.contains("persistent_keepalive_interval="));
    }

    #[test]
    fn peer_refresh_request_updates_peer_without_rebinding_device() {
        let runtime = WgRuntimeConfig {
            bind: SocketAddr::from(([0, 0, 0, 0], 51820)),
            endpoint: Some(SocketAddr::from(([198, 51, 100, 10], 51820))),
            tunnel_ip: IpAddr::V4(Ipv4Addr::new(10, 8, 0, 2)),
            peer_tunnel_ip: IpAddr::V4(Ipv4Addr::new(10, 8, 0, 1)),
            mtu: 1420,
            persistent_keepalive_secs: Some(25),
            private_key: [0x11; 32],
            peer_public_key: [0x22; 32],
            peer_allowed_ips: default_client_allowed_ips(),
            excluded_ips: Vec::new(),
        };

        let request = build_peer_refresh_request(&runtime);
        assert!(request.contains("replace_peers=true"));
        assert!(request.contains(
            "public_key=2222222222222222222222222222222222222222222222222222222222222222"
        ));
        assert!(request.contains("endpoint=198.51.100.10:51820"));
        assert!(request.contains("persistent_keepalive_interval=25"));
        assert!(request.contains("allowed_ip=0.0.0.0/0"));
        assert!(!request.contains("private_key="));
        assert!(!request.contains("listen_port="));
    }

    #[test]
    fn parses_device_stats_from_uapi_get_response() {
        let response = "\
own_public_key=aaaa
listen_port=51820
public_key=bbbb
last_handshake_time_sec=7
last_handshake_time_nsec=123
allowed_ip=10.8.0.2/32
rx_bytes=123
tx_bytes=456
errno=0

";

        assert_eq!(
            parse_device_stats(response).unwrap(),
            WgDeviceStats {
                rx_bytes: 123,
                tx_bytes: 456,
            }
        );
    }

    #[test]
    fn parses_last_handshake_age_from_uapi_get_response() {
        let response = "\
public_key=bbbb
last_handshake_time_sec=12
last_handshake_time_nsec=123
rx_bytes=10
tx_bytes=20
public_key=cccc
last_handshake_time_sec=5
last_handshake_time_nsec=456
rx_bytes=30
tx_bytes=40
errno=0

";

        assert_eq!(
            parse_last_handshake_age(response).unwrap(),
            Some(Duration::from_secs(5))
        );
    }

    #[test]
    fn missing_last_handshake_age_means_no_successful_handshake_yet() {
        let response = "\
public_key=bbbb
rx_bytes=0
tx_bytes=0
errno=0

";

        assert_eq!(parse_last_handshake_age(response).unwrap(), None);
    }

    #[test]
    fn sums_device_stats_across_peers() {
        let response = "\
public_key=bbbb
rx_bytes=10
tx_bytes=20
public_key=cccc
rx_bytes=30
tx_bytes=40
errno=0

";

        assert_eq!(
            parse_device_stats(response).unwrap(),
            WgDeviceStats {
                rx_bytes: 40,
                tx_bytes: 60,
            }
        );
    }

    #[test]
    fn rejects_nonzero_uapi_errno_when_parsing_stats() {
        let err = parse_device_stats("errno=22\n\n").unwrap_err().to_string();
        assert!(err.contains("errno=22"));
    }
}