wiretun 0.5.0

WireGuard Library
Documentation
mod support;

use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::time::Duration;

use tokio::time;

use support::*;
use wiretun::noise::protocol::{HandshakeInitiation, TransportData};
use wiretun::*;

#[tokio::test]
async fn test_noop_when_no_endpoint() {
    let secret = TestKit::gen_local_secret();
    let tun = StubTun::new();
    let transport = StubTransport::bind(Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0)
        .await
        .unwrap();
    let cfg = DeviceConfig::default()
        .private_key(secret.private_key().to_bytes())
        .peer(
            PeerConfig::default().public_key(TestKit::gen_local_secret().public_key().to_bytes()),
        );
    let device = Device::with_transport(tun.clone(), transport.clone(), cfg)
        .await
        .unwrap();

    let _ctrl = device.control();

    time::sleep(Duration::from_secs(30)).await;

    assert_eq!(transport.inbound_sent(), 0);
    assert_eq!(transport.outbound_sent(), 0);

    assert_eq!(tun.inbound_sent(), 0);
    assert_eq!(tun.outbound_sent(), 0);
}

#[tokio::test]
async fn test_keep_initiation_when_no_response() {
    let secret = TestKit::gen_local_secret();
    let tun = StubTun::new();
    let transport = StubTransport::bind(Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0)
        .await
        .unwrap();
    let peer_pub = TestKit::gen_local_secret().public_key().to_bytes();
    let peer_endpoint = "10.0.0.1:80".parse().unwrap();
    let cfg = DeviceConfig::default()
        .private_key(secret.private_key().to_bytes())
        .peer(
            PeerConfig::default()
                .public_key(peer_pub)
                .endpoint(peer_endpoint),
        );
    let device = Device::with_transport(tun.clone(), transport.clone(), cfg)
        .await
        .unwrap();

    let _ctrl = device.control();

    time::sleep(Duration::from_secs(30)).await;

    assert_eq!(transport.inbound_sent(), 0);
    assert!(transport.outbound_sent() > 0);

    assert_eq!(tun.inbound_sent(), 0);
    assert_eq!(tun.outbound_sent(), 0);

    for _ in 0..transport.outbound_sent() {
        let (endpoint, data) = transport.fetch_outbound().await;
        assert_eq!(endpoint.dst(), peer_endpoint);
        let ret = HandshakeInitiation::try_from(data.as_slice());
        assert!(ret.is_ok());
    }
}

#[tokio::test]
async fn test_complete_handshake() {
    use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
    tracing_subscriber::registry()
        .with(tracing_subscriber::EnvFilter::new(
            std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into()),
        ))
        .with(tracing_subscriber::fmt::layer())
        .init();

    let secret1 = TestKit::gen_local_secret();
    let endpoint1 = "10.10.0.1:6789".parse::<SocketAddr>().unwrap();
    let endpoint2 = "10.10.0.2:1245".parse::<SocketAddr>().unwrap();
    let secret2 = TestKit::gen_local_secret();
    let (_device1, tun1, transport1) = {
        let tun = StubTun::new();
        let transport = StubTransport::bind(Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0)
            .await
            .unwrap();
        let cfg = DeviceConfig::default()
            .private_key(secret1.private_key().to_bytes())
            .peer(
                PeerConfig::default()
                    .public_key(secret2.public_key().to_bytes())
                    .allowed_ip(endpoint2.ip())
                    .endpoint(endpoint2),
            );
        let device = Device::with_transport(tun.clone(), transport.clone(), cfg)
            .await
            .unwrap();
        (device, tun, transport)
    };
    let (_device2, tun2, transport2) = {
        let tun = StubTun::new();
        let transport = StubTransport::bind(Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0)
            .await
            .unwrap();
        let cfg = DeviceConfig::default()
            .private_key(secret2.private_key().to_bytes())
            .peer(
                PeerConfig::default()
                    .public_key(secret1.public_key().to_bytes())
                    .allowed_ip(endpoint1.ip()),
            );
        let device = Device::with_transport(tun.clone(), transport.clone(), cfg)
            .await
            .unwrap();
        (device, tun, transport)
    };

    {
        let (t1, t2) = (transport1.clone(), transport2.clone());
        tokio::spawn(async move {
            loop {
                let (endpoint, data) = t1.fetch_outbound().await;
                assert_eq!(endpoint.dst(), endpoint2);
                let endpoint = Endpoint::new(t2.clone(), endpoint1);
                t2.send_inbound(&data, &endpoint).await;
            }
        });
        let (t1, t2) = (transport1.clone(), transport2.clone());
        tokio::spawn(async move {
            loop {
                let (endpoint, data) = t2.fetch_outbound().await;
                assert_eq!(endpoint.dst(), endpoint1);
                let endpoint = Endpoint::new(t1.clone(), endpoint2);
                t1.send_inbound(&data, &endpoint).await;
            }
        });
    }

    time::sleep(Duration::from_secs(30)).await;
    assert_eq!(tun1.inbound_sent(), 0);
    assert_eq!(tun1.outbound_sent(), 0);
    assert_eq!(tun2.inbound_sent(), 0);
    assert_eq!(tun2.outbound_sent(), 0);

    assert!(transport1.inbound_sent() > 0);
    assert!(transport1.outbound_sent() > 0);
    assert!(transport2.inbound_sent() > 0);
    assert!(transport2.outbound_sent() > 0);

    let (mut d1_completed, mut d2_completed) = (false, false);

    for (_, data) in transport1.outbound_recording() {
        if TransportData::try_from(data.as_slice()).is_ok() {
            d1_completed = true;
            break;
        }
    }

    for (_, data) in transport2.outbound_recording() {
        if TransportData::try_from(data.as_slice()).is_ok() {
            d2_completed = true;
            break;
        }
    }

    assert!(d1_completed);
    assert!(!d2_completed);
}