ant-quic 0.26.7

QUIC transport protocol with advanced NAT traversal for P2P networks
Documentation
//! Channel binding tests using ML-DSA-65 Pure PQC signatures.
//!
//! v0.2.0+: Updated for Pure PQC - uses ML-DSA-65 only, no Ed25519.

#![allow(clippy::unwrap_used, clippy::expect_used)]

use ant_quic::{
    config::{ClientConfig, ServerConfig},
    crypto::pqc::types::{MlDsaPublicKey, MlDsaSecretKey},
    crypto::raw_public_keys::pqc::{create_subject_public_key_info, generate_ml_dsa_keypair},
    high_level::Endpoint,
    nat_traversal_api::PeerId,
    trust::{self, EventCollector, FsPinStore, TransportPolicy},
};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use std::{net::SocketAddr, sync::Arc};
use tempfile::TempDir;
use tokio::time::{timeout, Duration};

fn gen_self_signed_cert() -> (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>) {
    let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
        .expect("generate self-signed");
    let cert_der = CertificateDer::from(cert.cert);
    let key_der = PrivateKeyDer::Pkcs8(cert.signing_key.serialize_der().into());
    (vec![cert_der], key_der)
}

/// Generate an ML-DSA-65 keypair for testing
fn ml_dsa_keypair() -> (MlDsaPublicKey, MlDsaSecretKey) {
    generate_ml_dsa_keypair().expect("ML-DSA-65 keypair generation")
}

/// Create SPKI from ML-DSA-65 public key
fn spki_from_pk(pk: &MlDsaPublicKey) -> Vec<u8> {
    create_subject_public_key_info(pk).expect("SPKI creation")
}

fn peer_id_from_spki(spki: &[u8]) -> PeerId {
    use sha2::{Digest, Sha256};
    let mut h = Sha256::new();
    h.update(spki);
    let r = h.finalize();
    let mut id = [0u8; 32];
    id.copy_from_slice(&r);
    PeerId(id)
}

async fn loopback_pair() -> (ant_quic::Connection, ant_quic::Connection) {
    let (chain, key) = gen_self_signed_cert();
    let server_cfg = ServerConfig::with_single_cert(chain.clone(), key).expect("server cfg");
    let server = Endpoint::server(server_cfg, ([127, 0, 0, 1], 0).into()).expect("server ep");
    let addr: SocketAddr = server.local_addr().unwrap();

    let accept = tokio::spawn(async move {
        let inc = timeout(Duration::from_secs(10), server.accept())
            .await
            .unwrap()
            .unwrap();
        timeout(Duration::from_secs(10), inc).await.unwrap().unwrap()
    });

    let mut roots = rustls::RootCertStore::empty();
    for c in chain {
        roots.add(c).unwrap();
    }
    let client_cfg = ClientConfig::with_root_certificates(Arc::new(roots)).unwrap();
    let mut client = Endpoint::client(([127, 0, 0, 1], 0).into()).expect("client ep");
    client.set_default_client_config(client_cfg);
    let c_conn = timeout(
        Duration::from_secs(10),
        client.connect(addr, "localhost").expect("start"),
    )
    .await
    .unwrap()
    .unwrap();
    let s_conn = accept.await.unwrap();
    (c_conn, s_conn)
}

#[tokio::test]
async fn binding_success_with_pinned_key() {
    let (client_conn, server_conn) = loopback_pair().await;

    // Generate ML-DSA-65 keys for client and server identity
    let (c_pk, c_sk) = ml_dsa_keypair();
    let (s_pk, _s_sk) = ml_dsa_keypair();
    let c_spki = spki_from_pk(&c_pk);
    let s_spki = spki_from_pk(&s_pk);
    let _c_peer = peer_id_from_spki(&c_spki);
    let _s_peer = peer_id_from_spki(&s_spki);

    // Pin each other's keys
    let client_store_dir = TempDir::new().unwrap();
    let server_store_dir = TempDir::new().unwrap();
    let c_store = FsPinStore::new(client_store_dir.path());
    let s_store = FsPinStore::new(server_store_dir.path());

    // Client pins server
    trust::register_first_seen(&c_store, &TransportPolicy::default(), &s_spki).unwrap();
    // Server pins client
    trust::register_first_seen(&s_store, &TransportPolicy::default(), &c_spki).unwrap();

    // Event sinks
    let c_events = Arc::new(EventCollector::default());
    let s_events = Arc::new(EventCollector::default());
    let _c_policy = TransportPolicy::default().with_event_sink(c_events.clone());
    let s_policy = TransportPolicy::default().with_event_sink(s_events.clone());

    // Derive exporter (same for both sides)
    let exp_client = trust::derive_exporter(&client_conn).unwrap();
    let exp_server = trust::derive_exporter(&server_conn).unwrap();
    assert_eq!(exp_client, exp_server);

    // Server waits to receive; client sends
    let s_store_owned = s_store.clone();
    let s_policy_owned = s_policy.clone();
    let s_conn_clone = server_conn.clone();
    let recv_task = tokio::spawn(async move {
        trust::recv_verify_binding(&s_conn_clone, &s_store_owned, &s_policy_owned).await
    });
    trust::send_binding(&client_conn, &exp_client, &c_sk, &c_spki)
        .await
        .expect("send ok");
    let pid = recv_task.await.unwrap().expect("verify ok");
    assert!(s_events.binding_verified_called());
    let _ = pid;
    let _ = exp_server; // silence
}

#[tokio::test]
async fn binding_reject_on_key_mismatch() {
    let (client_conn, server_conn) = loopback_pair().await;
    let (c_pk, c_sk) = ml_dsa_keypair();
    let (s_pk, _s_sk) = ml_dsa_keypair();
    let c_spki = spki_from_pk(&c_pk);
    let wrong_spki = spki_from_pk(&s_pk); // wrong key pinned

    let c_store_dir = TempDir::new().unwrap();
    let s_store_dir = TempDir::new().unwrap();
    let c_store = FsPinStore::new(c_store_dir.path());
    let s_store = FsPinStore::new(s_store_dir.path());
    trust::register_first_seen(&c_store, &TransportPolicy::default(), &wrong_spki).unwrap();
    trust::register_first_seen(&s_store, &TransportPolicy::default(), &c_spki).unwrap();

    let exp = trust::derive_exporter(&client_conn).unwrap();
    // Server waits and should reject because pin mismatches
    let s_conn_clone = server_conn.clone();
    let c_store_owned = c_store.clone();
    let policy_owned = TransportPolicy::default();
    let recv_task = tokio::spawn(async move {
        trust::recv_verify_binding(&s_conn_clone, &c_store_owned, &policy_owned).await
    });
    trust::send_binding(&client_conn, &exp, &c_sk, &c_spki)
        .await
        .expect("send ok");
    let err = recv_task.await.unwrap().expect_err("should reject");
    match err {
        ant_quic::trust::TrustError::ChannelBinding(_) | ant_quic::trust::TrustError::NotPinned => {
        }
        _ => panic!("unexpected err"),
    }
}