use std::{
net::{Ipv4Addr, SocketAddr},
time::Duration,
};
use rustls_pki_types::CertificateDer;
use tokio::time::timeout;
use x509_certificate::EcdsaCurve;
use qcp::{
Configuration,
protocol::{
compat::Feature,
control::{Compatibility, ConnectionType},
},
transport::ThroughputMode,
util::Credentials,
};
use qcp::{
control::create_endpoint,
protocol::{DataTag, control::CredentialsType},
};
async fn run_endpoint_connection<F, G>(
modify_certs_fn: F,
check_fn: G,
compat: Compatibility,
) -> anyhow::Result<()>
where
F: FnOnce(&Credentials, &Credentials) -> (CertificateDer<'static>, CertificateDer<'static>),
G: FnOnce(
anyhow::Result<quinn::Connection>,
anyhow::Result<quinn::Connection>,
) -> anyhow::Result<()>,
{
let client_credentials = Credentials::generate()?;
let server_credentials = Credentials::generate()?;
let (cli_cert_messed, srv_cert_messed) =
modify_certs_fn(&client_credentials, &server_credentials);
let cli_cert_messed = if compat.supports(Feature::CMSG_SMSG_2) {
CredentialsType::RawPublicKey.with_bytes(cli_cert_messed)
} else {
CredentialsType::X509.with_bytes(cli_cert_messed)
};
let srv_cert_messed = if compat.supports(Feature::CMSG_SMSG_2) {
CredentialsType::RawPublicKey.with_bytes(srv_cert_messed)
} else {
CredentialsType::X509.with_bytes(srv_cert_messed)
};
let (server_endpoint, _) = create_endpoint(
&server_credentials,
&cli_cert_messed,
ConnectionType::Ipv4,
Configuration::system_default(),
ThroughputMode::Both,
true,
compat,
)?;
let conn_addr = server_endpoint.local_addr()?;
eprintln!("Server bound to {conn_addr:?}");
let conn_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), conn_addr.port());
let srv_name = server_credentials.hostname.clone();
let (client_endpoint, _) = create_endpoint(
&client_credentials,
&srv_cert_messed,
ConnectionType::Ipv4,
Configuration::system_default(),
ThroughputMode::Both,
false,
compat,
)?;
eprintln!("Client bound to {:?}", client_endpoint.local_addr()?);
let srv_hdl = tokio::spawn(async move {
eprintln!("SERVER: accepting");
let connecting = timeout(Duration::from_secs(5), server_endpoint.accept())
.await?
.ok_or(anyhow::anyhow!("server ended"))
.and_then(|i| Ok(i.accept()?));
if let Ok(c) = connecting {
Ok(c.await?)
} else {
anyhow::bail!("server accept failed");
}
});
let cli_hdl = tokio::spawn(async move {
eprintln!("CLIENT: connecting to {conn_addr:?}");
timeout(
Duration::from_secs(5),
client_endpoint.connect(conn_addr, &srv_name)?,
)
.await
.map_err(|_| anyhow::anyhow!("client connect timed out"))?
.map_err(|e| anyhow::anyhow!("client connect failed: {e}"))
});
tokio::pin!(srv_hdl, cli_hdl);
let res = tokio::join!(srv_hdl, cli_hdl);
let (srv_res, cli_res) = res;
let srv_res = srv_res.unwrap();
let cli_res = cli_res.unwrap();
check_fn(cli_res, srv_res)
}
#[cfg_attr(
all(target_os = "windows", target_env = "gnu"),
ignore = "Doesn't work with the mingw cross-compile test runner"
)]
#[tokio::test]
async fn test_x509_ok() {
run_endpoint_connection(
|cli, srv| (cli.certificate().to_owned(), srv.certificate().to_owned()),
|cli_res, srv_res| {
assert!(cli_res.is_ok());
assert!(srv_res.is_ok());
Ok(())
},
Compatibility::Level(1),
)
.await
.unwrap();
}
fn replace_certificate(der: &CertificateDer<'static>) -> Vec<u8> {
use x509_certificate::{KeyAlgorithm, X509Certificate, X509CertificateBuilder};
let parsed = X509Certificate::from_der(der).unwrap();
let mut builder = X509CertificateBuilder::default();
let _ = builder
.subject()
.append_common_name_utf8_string(&parsed.subject_common_name().unwrap());
let (newcert, _keypair) = builder
.create_with_random_keypair(KeyAlgorithm::Ecdsa(EcdsaCurve::Secp256r1))
.unwrap();
newcert.encode_der().unwrap()
}
#[cfg_attr(
all(target_os = "windows", target_env = "gnu"),
ignore = "Doesn't work with the mingw cross-compile test runner"
)]
#[tokio::test]
async fn test_client_x509_mismatch() {
run_endpoint_connection(
|cli, srv| {
(
replace_certificate(cli.certificate()).into(),
srv.certificate().to_owned(),
)
},
|cli_res, srv_res| {
assert!(cli_res.is_ok());
assert!(srv_res.is_err());
let err = srv_res.unwrap_err();
assert!(err.to_string().contains("invalid peer certificate"));
eprintln!("Server result: {err}");
Ok(())
},
Compatibility::Level(1),
)
.await
.unwrap();
}
#[cfg_attr(
all(target_os = "windows", target_env = "gnu"),
ignore = "Doesn't work with the mingw cross-compile test runner"
)]
#[tokio::test]
async fn test_server_x509_mismatch() {
run_endpoint_connection(
|cli, srv| {
(
cli.certificate().to_owned(),
replace_certificate(srv.certificate()).into(),
)
},
|cli_res, srv_res| {
assert!(cli_res.is_err());
let err = cli_res.unwrap_err();
assert!(err.to_string().contains("invalid peer certificate"));
eprintln!("Client result: {err}");
assert!(srv_res.is_err());
let err = srv_res.unwrap_err();
assert!(err.to_string().contains("invalid peer certificate"));
eprintln!("Server result: {err}");
Ok(())
},
Compatibility::Level(1),
)
.await
.unwrap();
}
#[cfg_attr(
all(target_os = "windows", target_env = "gnu"),
ignore = "Doesn't work with the mingw cross-compile test runner"
)]
#[tokio::test]
async fn test_rpk_ok() {
run_endpoint_connection(
|cli, srv| {
(
cli.as_raw_public_key().unwrap().cert[0].clone(),
srv.as_raw_public_key().unwrap().cert[0].clone(),
)
},
|cli_res, srv_res| {
assert!(cli_res.inspect_err(|e| eprintln!("{e}")).is_ok());
assert!(srv_res.inspect_err(|e| eprintln!("{e}")).is_ok());
Ok(())
},
Compatibility::Level(3),
)
.await
.unwrap();
}
fn replace_rpk() -> CertificateDer<'static> {
let creds = Credentials::generate().unwrap();
let rpk = creds.as_raw_public_key().unwrap();
rpk.cert[0].clone()
}
#[cfg_attr(
all(target_os = "windows", target_env = "gnu"),
ignore = "Doesn't work with the mingw cross-compile test runner"
)]
#[tokio::test]
async fn test_client_rpk_mismatch() {
run_endpoint_connection(
|_cli, srv| {
(
replace_rpk(),
srv.as_raw_public_key().unwrap().cert[0].clone(),
)
},
|cli_res, srv_res| {
assert!(cli_res.inspect_err(|e| eprintln!("{e}")).is_ok());
assert!(srv_res.is_err());
let err = srv_res.unwrap_err();
assert!(err.to_string().contains("invalid peer certificate"));
eprintln!("Server result: {err}");
Ok(())
},
Compatibility::Level(3),
)
.await
.unwrap();
}
#[cfg_attr(
all(target_os = "windows", target_env = "gnu"),
ignore = "Doesn't work with the mingw cross-compile test runner"
)]
#[tokio::test]
async fn test_server_rpk_mismatch() {
run_endpoint_connection(
|cli, _srv| {
(
cli.as_raw_public_key().unwrap().cert[0].clone(),
replace_rpk(),
)
},
|cli_res, srv_res| {
assert!(cli_res.is_err());
let err = cli_res.unwrap_err();
assert!(err.to_string().contains("invalid peer certificate"));
eprintln!("Client result: {err}");
assert!(srv_res.is_err());
let err = srv_res.unwrap_err();
assert!(err.to_string().contains("invalid peer certificate"));
eprintln!("Server result: {err}");
Ok(())
},
Compatibility::Level(3),
)
.await
.unwrap();
}