comnoq 0.2.3

QUIC for compio with noq backend
Documentation
use std::{
    net::{Ipv4Addr, SocketAddr},
    time::Duration,
};

use comnoq::{
    Endpoint, PathError, PathEvent, PathId, PathStatus, TransportConfig, n0_nat_traversal,
};
use futures_util::{FutureExt, StreamExt, join};

mod common;
use common::{config_pair, subscribe};

async fn endpoint_pair_with_transport(transport: TransportConfig) -> (Endpoint, Endpoint) {
    let (server_config, client_config) = config_pair(Some(transport));
    let server = Endpoint::server("127.0.0.1:0", server_config)
        .await
        .unwrap();
    let mut client = Endpoint::client("127.0.0.1:0").await.unwrap();
    client.default_client_config = Some(client_config);
    (client, server)
}

#[cfg(linux_all)]
async fn dual_stack_endpoint_pair_with_transport(
    transport: TransportConfig,
) -> (Endpoint, Endpoint) {
    let (server_config, client_config) = config_pair(Some(transport));
    let server = Endpoint::server("[::]:0", server_config).await.unwrap();
    let mut client = Endpoint::client("[::]:0").await.unwrap();
    client.default_client_config = Some(client_config);
    (client, server)
}

#[compio::test]
async fn path_api() {
    let _guard = subscribe();

    let (client_endpoint, server_endpoint) =
        endpoint_pair_with_transport(TransportConfig::default()).await;
    let server_addr = server_endpoint.local_addr().unwrap();

    let (client, server) = join!(
        async {
            client_endpoint
                .connect(server_addr, "localhost", None)
                .unwrap()
                .await
                .unwrap()
        },
        async {
            server_endpoint
                .wait_incoming()
                .await
                .unwrap()
                .await
                .unwrap()
        },
    );

    let path = client.path(PathId::ZERO).expect("path zero exists");
    assert_eq!(path, client.path(PathId::ZERO).unwrap());
    assert_eq!(path.remote_address().unwrap(), client.remote_address());

    let idle_timeout = Duration::from_millis(250);
    let next_idle_timeout = Duration::from_millis(500);
    let previous_idle_timeout = path.set_max_idle_timeout(Some(idle_timeout)).unwrap();
    assert_eq!(
        path.set_max_idle_timeout(Some(next_idle_timeout)).unwrap(),
        Some(idle_timeout)
    );
    path.set_max_idle_timeout(previous_idle_timeout).unwrap();

    let keep_alive = Duration::from_millis(100);
    let next_keep_alive = Duration::from_millis(200);
    let previous_keep_alive = path.set_keep_alive_interval(Some(keep_alive)).unwrap();
    assert_eq!(
        path.set_keep_alive_interval(Some(next_keep_alive)).unwrap(),
        Some(keep_alive)
    );
    path.set_keep_alive_interval(previous_keep_alive).unwrap();

    path.ping().unwrap();

    let mut discovery = path.observed_external_addr().unwrap();
    let first_poll = discovery.next().now_or_never();
    assert!(
        !matches!(first_poll, Some(None)),
        "address discovery stream should not terminate immediately"
    );

    drop(discovery);
    drop(path);
    drop(server);
    drop(client);
    client_endpoint.shutdown().await.unwrap();
    server_endpoint.shutdown().await.unwrap();
}

#[compio::test]
async fn handshake_confirmed_and_open_path_event() {
    let _guard = subscribe();

    let mut transport = TransportConfig::default();
    transport.max_concurrent_multipath_paths(2);
    let (client_endpoint, server_endpoint) = endpoint_pair_with_transport(transport).await;
    let server_addr = server_endpoint.local_addr().unwrap();

    let (client, server) = join!(
        async {
            client_endpoint
                .connect(server_addr, "localhost", None)
                .unwrap()
                .await
                .unwrap()
        },
        async {
            server_endpoint
                .wait_incoming()
                .await
                .unwrap()
                .await
                .unwrap()
        },
    );

    client.handshake_confirmed().await.unwrap();
    assert!(client.is_multipath_enabled());

    let path_events = client.path_events();
    let path = loop {
        match client.open_path(server_addr, PathStatus::Available).await {
            Ok(path) => break path,
            Err(PathError::RemoteCidsExhausted) => {
                compio::runtime::time::sleep(Duration::from_millis(10)).await;
            }
            Err(err) => panic!("unexpected open_path error: {err:?}"),
        }
    };

    assert_ne!(path.id(), PathId::ZERO);
    assert!(path.stats().is_some());

    loop {
        let event = path_events.recv_async().await.unwrap();
        if matches!(event, PathEvent::Opened { id } if id == path.id()) {
            break;
        }
    }

    drop(path);
    drop(server);
    drop(client);
    client_endpoint.shutdown().await.unwrap();
    server_endpoint.shutdown().await.unwrap();
}

#[compio::test]
async fn discarded_path_stats_are_retained() {
    let _guard = subscribe();

    let mut transport = TransportConfig::default();
    transport.max_concurrent_multipath_paths(2);
    let (client_endpoint, server_endpoint) = endpoint_pair_with_transport(transport).await;
    let server_addr = server_endpoint.local_addr().unwrap();

    let (client, server) = join!(
        async {
            client_endpoint
                .connect(server_addr, "localhost", None)
                .unwrap()
                .await
                .unwrap()
        },
        async {
            server_endpoint
                .wait_incoming()
                .await
                .unwrap()
                .await
                .unwrap()
        },
    );

    client.handshake_confirmed().await.unwrap();
    assert!(client.is_multipath_enabled());

    let path_events = client.path_events();
    let path = loop {
        match client.open_path(server_addr, PathStatus::Available).await {
            Ok(path) => break path,
            Err(PathError::RemoteCidsExhausted) => {
                compio::runtime::time::sleep(Duration::from_millis(10)).await;
            }
            Err(err) => panic!("unexpected open_path error: {err:?}"),
        }
    };

    loop {
        let event = path_events.recv_async().await.unwrap();
        if matches!(event, PathEvent::Opened { id } if id == path.id()) {
            break;
        }
    }

    path.close().unwrap();

    loop {
        let event = path_events.recv_async().await.unwrap();
        if matches!(event, PathEvent::Discarded { id, .. } if id == path.id()) {
            break;
        }
    }

    assert!(
        path.stats().is_some(),
        "path handle should retain final stats after discard"
    );
    assert!(
        client.all_path_stats().contains_key(&path.id()),
        "aggregated path stats should include discarded paths with retained final stats"
    );
    assert!(
        !client.live_path_stats().contains_key(&path.id()),
        "live path stats should exclude discarded paths"
    );

    drop(path);
    drop(server);
    drop(client);
    client_endpoint.shutdown().await.unwrap();
    server_endpoint.shutdown().await.unwrap();
}

#[cfg(linux_all)]
#[compio::test]
async fn handshake_confirmed_and_open_path_event_dual_stack() {
    let _guard = subscribe();

    let mut transport = TransportConfig::default();
    transport.max_concurrent_multipath_paths(2);
    let (client_endpoint, server_endpoint) =
        dual_stack_endpoint_pair_with_transport(transport).await;
    let mut ipv4_server_addr = server_endpoint.local_addr().unwrap();
    ipv4_server_addr.set_ip("127.0.0.1".parse().unwrap());

    let (client, server) = join!(
        async {
            client_endpoint
                .connect(ipv4_server_addr, "localhost", None)
                .unwrap()
                .await
                .unwrap()
        },
        async {
            server_endpoint
                .wait_incoming()
                .await
                .unwrap()
                .await
                .unwrap()
        },
    );

    client.handshake_confirmed().await.unwrap();
    assert!(client.is_multipath_enabled());

    let path_events = client.path_events();
    let mut ipv6_server_addr = ipv4_server_addr;
    ipv6_server_addr.set_ip("::1".parse().unwrap());
    let path = loop {
        match client
            .open_path(ipv6_server_addr, PathStatus::Available)
            .await
        {
            Ok(path) => break path,
            Err(PathError::RemoteCidsExhausted) => {
                compio::runtime::time::sleep(Duration::from_millis(10)).await;
            }
            Err(err) => panic!("unexpected open_path error: {err:?}"),
        }
    };

    assert_ne!(path.id(), PathId::ZERO);
    assert!(path.stats().is_some());

    loop {
        let event = path_events.recv_async().await.unwrap();
        if matches!(event, PathEvent::Opened { id } if id == path.id()) {
            break;
        }
    }

    drop(path);
    drop(server);
    drop(client);
    client_endpoint.shutdown().await.unwrap();
    server_endpoint.shutdown().await.unwrap();
}

#[compio::test]
async fn nat_traversal_updates_are_forwarded() {
    let _guard = subscribe();

    let mut transport = TransportConfig::default();
    transport.set_max_remote_nat_traversal_addresses(2);
    let (client_endpoint, server_endpoint) = endpoint_pair_with_transport(transport).await;
    let server_addr = server_endpoint.local_addr().unwrap();

    let (client, server) = join!(
        async {
            client_endpoint
                .connect(server_addr, "localhost", None)
                .unwrap()
                .await
                .unwrap()
        },
        async {
            server_endpoint
                .wait_incoming()
                .await
                .unwrap()
                .await
                .unwrap()
        },
    );

    client.handshake_confirmed().await.unwrap();

    let updates = client.nat_traversal_updates();
    let added_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 9));
    server.add_nat_traversal_address(added_addr).unwrap();

    let event = updates.recv_async().await.unwrap();
    assert!(matches!(event, n0_nat_traversal::Event::AddressAdded(addr) if addr == added_addr));
    assert_eq!(
        client.get_remote_nat_traversal_addresses().unwrap(),
        vec![added_addr]
    );

    drop(server);
    drop(client);
    client_endpoint.shutdown().await.unwrap();
    server_endpoint.shutdown().await.unwrap();
}