#![allow(clippy::unwrap_used, clippy::expect_used)]
use ant_quic::{
config::{ClientConfig, ServerConfig},
crypto::pqc::types::{MlDsaPublicKey, MlDsaSecretKey},
crypto::raw_public_keys::pqc::{create_subject_public_key_info, generate_ml_dsa_keypair},
high_level::Endpoint,
nat_traversal_api::PeerId,
trust::{self, EventCollector, FsPinStore, TransportPolicy},
};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use std::{net::SocketAddr, sync::Arc};
use tempfile::TempDir;
use tokio::time::{timeout, Duration};
fn gen_self_signed_cert() -> (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>) {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
.expect("generate self-signed");
let cert_der = CertificateDer::from(cert.cert);
let key_der = PrivateKeyDer::Pkcs8(cert.signing_key.serialize_der().into());
(vec![cert_der], key_der)
}
fn ml_dsa_keypair() -> (MlDsaPublicKey, MlDsaSecretKey) {
generate_ml_dsa_keypair().expect("ML-DSA-65 keypair generation")
}
fn spki_from_pk(pk: &MlDsaPublicKey) -> Vec<u8> {
create_subject_public_key_info(pk).expect("SPKI creation")
}
fn peer_id_from_spki(spki: &[u8]) -> PeerId {
use sha2::{Digest, Sha256};
let mut h = Sha256::new();
h.update(spki);
let r = h.finalize();
let mut id = [0u8; 32];
id.copy_from_slice(&r);
PeerId(id)
}
async fn loopback_pair() -> (ant_quic::Connection, ant_quic::Connection) {
let (chain, key) = gen_self_signed_cert();
let server_cfg = ServerConfig::with_single_cert(chain.clone(), key).expect("server cfg");
let server = Endpoint::server(server_cfg, ([127, 0, 0, 1], 0).into()).expect("server ep");
let addr: SocketAddr = server.local_addr().unwrap();
let accept = tokio::spawn(async move {
let inc = timeout(Duration::from_secs(10), server.accept())
.await
.unwrap()
.unwrap();
timeout(Duration::from_secs(10), inc).await.unwrap().unwrap()
});
let mut roots = rustls::RootCertStore::empty();
for c in chain {
roots.add(c).unwrap();
}
let client_cfg = ClientConfig::with_root_certificates(Arc::new(roots)).unwrap();
let mut client = Endpoint::client(([127, 0, 0, 1], 0).into()).expect("client ep");
client.set_default_client_config(client_cfg);
let c_conn = timeout(
Duration::from_secs(10),
client.connect(addr, "localhost").expect("start"),
)
.await
.unwrap()
.unwrap();
let s_conn = accept.await.unwrap();
(c_conn, s_conn)
}
#[tokio::test]
async fn binding_success_with_pinned_key() {
let (client_conn, server_conn) = loopback_pair().await;
let (c_pk, c_sk) = ml_dsa_keypair();
let (s_pk, _s_sk) = ml_dsa_keypair();
let c_spki = spki_from_pk(&c_pk);
let s_spki = spki_from_pk(&s_pk);
let _c_peer = peer_id_from_spki(&c_spki);
let _s_peer = peer_id_from_spki(&s_spki);
let client_store_dir = TempDir::new().unwrap();
let server_store_dir = TempDir::new().unwrap();
let c_store = FsPinStore::new(client_store_dir.path());
let s_store = FsPinStore::new(server_store_dir.path());
trust::register_first_seen(&c_store, &TransportPolicy::default(), &s_spki).unwrap();
trust::register_first_seen(&s_store, &TransportPolicy::default(), &c_spki).unwrap();
let c_events = Arc::new(EventCollector::default());
let s_events = Arc::new(EventCollector::default());
let _c_policy = TransportPolicy::default().with_event_sink(c_events.clone());
let s_policy = TransportPolicy::default().with_event_sink(s_events.clone());
let exp_client = trust::derive_exporter(&client_conn).unwrap();
let exp_server = trust::derive_exporter(&server_conn).unwrap();
assert_eq!(exp_client, exp_server);
let s_store_owned = s_store.clone();
let s_policy_owned = s_policy.clone();
let s_conn_clone = server_conn.clone();
let recv_task = tokio::spawn(async move {
trust::recv_verify_binding(&s_conn_clone, &s_store_owned, &s_policy_owned).await
});
trust::send_binding(&client_conn, &exp_client, &c_sk, &c_spki)
.await
.expect("send ok");
let pid = recv_task.await.unwrap().expect("verify ok");
assert!(s_events.binding_verified_called());
let _ = pid;
let _ = exp_server; }
#[tokio::test]
async fn binding_reject_on_key_mismatch() {
let (client_conn, server_conn) = loopback_pair().await;
let (c_pk, c_sk) = ml_dsa_keypair();
let (s_pk, _s_sk) = ml_dsa_keypair();
let c_spki = spki_from_pk(&c_pk);
let wrong_spki = spki_from_pk(&s_pk);
let c_store_dir = TempDir::new().unwrap();
let s_store_dir = TempDir::new().unwrap();
let c_store = FsPinStore::new(c_store_dir.path());
let s_store = FsPinStore::new(s_store_dir.path());
trust::register_first_seen(&c_store, &TransportPolicy::default(), &wrong_spki).unwrap();
trust::register_first_seen(&s_store, &TransportPolicy::default(), &c_spki).unwrap();
let exp = trust::derive_exporter(&client_conn).unwrap();
let s_conn_clone = server_conn.clone();
let c_store_owned = c_store.clone();
let policy_owned = TransportPolicy::default();
let recv_task = tokio::spawn(async move {
trust::recv_verify_binding(&s_conn_clone, &c_store_owned, &policy_owned).await
});
trust::send_binding(&client_conn, &exp, &c_sk, &c_spki)
.await
.expect("send ok");
let err = recv_task.await.unwrap().expect_err("should reject");
match err {
ant_quic::trust::TrustError::ChannelBinding(_) | ant_quic::trust::TrustError::NotPinned => {
}
_ => panic!("unexpected err"),
}
}