#![allow(clippy::unwrap_used, clippy::expect_used)]
use std::sync::{Arc, Mutex};
use tokio::time::{Duration, timeout};
use ant_quic as quic;
use ant_quic::crypto::raw_public_keys::pqc::{
create_subject_public_key_info, generate_ml_dsa_keypair,
};
use ant_quic::{
TokenStore,
config::{ClientConfig, ServerConfig},
high_level::Endpoint,
};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
fn gen_self_signed_cert() -> (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>) {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
.expect("generate self-signed");
let cert_der = CertificateDer::from(cert.cert);
let key_der = PrivateKeyDer::Pkcs8(cert.signing_key.serialize_der().into());
(vec![cert_der], key_der)
}
fn mk_client_config(chain: &[CertificateDer<'static>]) -> ClientConfig {
let mut roots = rustls::RootCertStore::empty();
for c in chain.iter().cloned() {
roots.add(c).expect("add root");
}
ClientConfig::with_root_certificates(Arc::new(roots)).expect("client cfg")
}
async fn mk_server() -> (
Endpoint,
std::net::SocketAddr,
Vec<CertificateDer<'static>>,
quic::token_v2::TokenKey,
) {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let (chain, key) = gen_self_signed_cert();
let mut rng = rand::thread_rng();
let token_key = quic::token_v2::test_key_from_rng(&mut rng);
let mut server_cfg = ServerConfig::with_single_cert(chain.clone(), key).expect("server cfg");
server_cfg.token_key(token_key);
let ep = Endpoint::server(server_cfg, ([127, 0, 0, 1], 0).into()).expect("server ep");
let addr = ep.local_addr().expect("server addr");
(ep, addr, chain, token_key)
}
#[derive(Clone, Default)]
struct CollectingTokenStore(Arc<Mutex<Vec<bytes::Bytes>>>);
impl TokenStore for CollectingTokenStore {
fn insert(&self, _server_name: &str, token: bytes::Bytes) {
self.0.lock().unwrap().push(token);
}
fn take(&self, _server_name: &str) -> Option<bytes::Bytes> {
None
}
}
#[tokio::test]
#[ignore = "Uses X.509 certs which deadlock with PQC-only crypto provider. Needs rewrite to use PQC auth path."]
async fn auto_binding_emits_new_token_v2() {
let (server, server_addr, chain, token_key) = mk_server().await;
let (public_key, secret_key) = generate_ml_dsa_keypair().expect("keygen");
let spki = create_subject_public_key_info(&public_key).expect("spki");
let tmp = tempfile::tempdir().expect("tempdir");
let store = quic::trust::FsPinStore::new(tmp.path());
let events = Arc::new(quic::trust::EventCollector::default());
let policy = quic::trust::TransportPolicy::default().with_event_sink(events.clone());
let _peer_id = quic::trust::register_first_seen(&store, &policy, &spki).expect("pin ok");
quic::trust::set_global_runtime(Arc::new(quic::trust::GlobalTrustRuntime {
store: Arc::new(store.clone()),
policy: policy.clone(),
local_public_key: Arc::new(public_key),
local_secret_key: Arc::new(secret_key),
local_spki: Arc::new(spki.clone()),
}));
let server_task = tokio::spawn(async move {
let inc = timeout(Duration::from_secs(10), server.accept())
.await
.expect("accept wait")
.expect("incoming");
let _conn = timeout(Duration::from_secs(10), inc)
.await
.expect("hs wait")
.expect("server hs ok");
tokio::time::sleep(Duration::from_millis(500)).await;
});
let mut client = Endpoint::client(([127, 0, 0, 1], 0).into()).expect("client ep");
let mut client_cfg = mk_client_config(&chain);
let collector = CollectingTokenStore::default();
client_cfg.token_store(Arc::new(collector.clone()));
client.set_default_client_config(client_cfg);
let connecting = client
.connect(server_addr, "localhost")
.expect("connect start");
let conn = timeout(Duration::from_secs(10), connecting)
.await
.expect("client wait")
.expect("client ok");
tokio::time::sleep(Duration::from_millis(400)).await;
assert!(
events.binding_verified_called(),
"binding should be verified"
);
let tokens = collector.0.lock().unwrap().clone();
assert!(!tokens.is_empty(), "expected at least one NEW_TOKEN");
let tok = &tokens[0];
let dec = quic::token_v2::decode_validation_token(&token_key, tok).expect("decode v2");
assert_eq!(dec.ip, server_addr.ip());
conn.close(0u32.into(), b"done");
server_task.await.expect("server");
}
#[tokio::test]
#[ignore = "Uses X.509 certs which deadlock with PQC-only crypto provider. Needs rewrite to use PQC auth path."]
async fn auto_binding_rejects_on_mismatch() {
let (server, server_addr, chain, _token_key) = mk_server().await;
let (public_key, secret_key) = generate_ml_dsa_keypair().expect("keygen");
let spki = create_subject_public_key_info(&public_key).expect("spki");
let (wrong_pk, _wrong_sk) = generate_ml_dsa_keypair().expect("wrong keygen");
let wrong_spki = create_subject_public_key_info(&wrong_pk).expect("wrong spki");
let tmp = tempfile::tempdir().expect("tempdir");
let store = quic::trust::FsPinStore::new(tmp.path());
let policy = quic::trust::TransportPolicy::default();
quic::trust::register_first_seen(&store, &policy, &wrong_spki).expect("pin wrong ok");
quic::trust::set_global_runtime(Arc::new(quic::trust::GlobalTrustRuntime {
store: Arc::new(store.clone()),
policy: policy.clone(),
local_public_key: Arc::new(public_key),
local_secret_key: Arc::new(secret_key),
local_spki: Arc::new(spki.clone()),
}));
let server_task = tokio::spawn(async move {
let inc = timeout(Duration::from_secs(10), server.accept())
.await
.expect("accept wait")
.expect("incoming");
let conn = timeout(Duration::from_secs(10), inc)
.await
.expect("hs wait")
.expect("server hs ok");
let _ = timeout(Duration::from_secs(2), conn.closed()).await;
});
let mut client = Endpoint::client(([127, 0, 0, 1], 0).into()).expect("client ep");
let client_cfg = mk_client_config(&chain);
client.set_default_client_config(client_cfg);
let connecting = client
.connect(server_addr, "localhost")
.expect("connect start");
let conn = timeout(Duration::from_secs(10), connecting)
.await
.expect("client wait")
.expect("client ok");
let closed = timeout(Duration::from_secs(3), conn.closed()).await;
assert!(closed.is_ok(), "connection should close on binding failure");
server_task.await.expect("server");
}