use clap::Parser;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
use std::path::PathBuf;
use tracing::{error, info};
use tracing_subscriber::EnvFilter;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long, default_value = "tests/interop/interop-matrix.yaml")]
matrix: PathBuf,
#[arg(short, long, default_value = "interop-results")]
output: PathBuf,
#[arg(short, long)]
implementation: Option<String>,
#[arg(short, long)]
category: Option<String>,
#[arg(long)]
html: bool,
#[arg(long)]
json: bool,
#[arg(short, long, default_value = "30")]
timeout: u64,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::from_default_env()
.add_directive("ant_quic=debug".parse()?)
.add_directive("interop_test=info".parse()?),
)
.with_target(false)
.with_thread_ids(true)
.with_file(true)
.with_line_number(true)
.init();
let args = Args::parse();
info!("QUIC Interoperability Test Runner");
info!("=================================");
info!("Matrix file: {:?}", args.matrix);
info!("Output directory: {:?}", args.output);
std::fs::create_dir_all(&args.output)?;
if !args.matrix.exists() {
error!("Matrix file not found: {:?}", args.matrix);
error!("Please ensure the interop-matrix.yaml file exists at the specified path");
return Err("Matrix file not found".into());
}
let matrix_content = std::fs::read_to_string(&args.matrix)?;
info!("Loaded test matrix: {} bytes", matrix_content.len());
let matrix: serde_yaml::Value = serde_yaml::from_str(&matrix_content)?;
let implementations = matrix["implementations"]
.as_mapping()
.ok_or("Invalid matrix format: missing implementations")?;
let category = validate_category_filter(&matrix, args.category.as_deref())?;
info!("Found {} implementations to test", implementations.len());
if let Some(category) = category {
info!("Filtering tests by category: {}", category);
}
for (impl_name, impl_data) in implementations {
let name = impl_name.as_str().unwrap_or("unknown");
if !should_test_implementation(&matrix, name, args.implementation.as_deref(), category) {
continue;
}
info!("Testing implementation: {}", name);
if let Some(endpoints) = impl_data["endpoints"].as_sequence() {
for endpoint in endpoints {
if let Some(endpoint_str) = endpoint.as_str() {
info!(" Endpoint: {}", endpoint_str);
match test_endpoint(endpoint_str, args.timeout).await {
Ok(duration) => {
info!(" ✓ Connected successfully in {:?}", duration);
}
Err(e) => {
error!(" ✗ Failed to connect: {}", e);
}
}
}
}
}
}
if args.html || args.json {
info!("Generating reports...");
if args.html {
let html_path = args.output.join("report.html");
std::fs::write(&html_path, generate_html_report())?;
info!("HTML report written to: {:?}", html_path);
}
if args.json {
let json_path = args.output.join("report.json");
let json_report = serde_json::json!({
"version": "1.0",
"test_date": chrono::Utc::now().to_rfc3339(),
"summary": "Interoperability test results"
});
std::fs::write(&json_path, serde_json::to_string_pretty(&json_report)?)?;
info!("JSON report written to: {:?}", json_path);
}
}
info!("Interoperability tests completed");
Ok(())
}
async fn test_endpoint(
endpoint_str: &str,
timeout_secs: u64,
) -> Result<std::time::Duration, Box<dyn std::error::Error>> {
use ant_quic::high_level::Endpoint;
use std::sync::Arc;
use std::time::Instant;
let resolved = resolve_endpoint(endpoint_str)?;
let start = Instant::now();
let socket = std::net::UdpSocket::bind(bind_addr_for_remote(resolved.addr))?;
let runtime = ant_quic::high_level::default_runtime()
.ok_or_else(|| std::io::Error::other("No compatible async runtime found"))?;
let endpoint = Endpoint::new(ant_quic::EndpointConfig::default(), None, socket, runtime)?;
#[cfg(feature = "platform-verifier")]
let client_config = ant_quic::ClientConfig::try_with_platform_verifier().unwrap_or_else(|_| {
let roots = rustls::RootCertStore::empty();
let crypto = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
#[allow(clippy::unwrap_used)]
ant_quic::ClientConfig::new(Arc::new(
ant_quic::crypto::rustls::QuicClientConfig::try_from(crypto).unwrap(),
))
});
#[cfg(not(feature = "platform-verifier"))]
let client_config = {
let roots = rustls::RootCertStore::empty();
let crypto = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
ant_quic::ClientConfig::new(Arc::new(
ant_quic::crypto::rustls::QuicClientConfig::try_from(crypto).unwrap(),
))
};
let connect_future = endpoint.connect_with(client_config, resolved.addr, &resolved.server_name);
let connection = tokio::time::timeout(std::time::Duration::from_secs(timeout_secs), async {
match connect_future {
Ok(connecting) => connecting.await.map_err(|e| e.into()),
Err(e) => Err(Box::new(e) as Box<dyn std::error::Error>),
}
})
.await??;
let duration = start.elapsed();
connection.close(0u32.into(), b"test complete");
Ok(duration)
}
fn validate_category_filter<'a>(
matrix: &serde_yaml::Value,
category: Option<&'a str>,
) -> Result<Option<&'a str>, Box<dyn std::error::Error>> {
let Some(category) = category else {
return Ok(None);
};
let categories = matrix["test_categories"]
.as_mapping()
.ok_or("Invalid matrix format: missing test_categories")?;
if categories
.keys()
.any(|category_name| category_name.as_str() == Some(category))
{
Ok(Some(category))
} else {
Err(format!("Unknown test category: {category}").into())
}
}
fn should_test_implementation(
matrix: &serde_yaml::Value,
implementation_name: &str,
implementation_filter: Option<&str>,
category_filter: Option<&str>,
) -> bool {
if let Some(target) = implementation_filter
&& implementation_name != target
{
return false;
}
let Some(category) = category_filter else {
return true;
};
implementation_has_expected_outcome(matrix, implementation_name, category)
}
fn implementation_has_expected_outcome(
matrix: &serde_yaml::Value,
implementation_name: &str,
category: &str,
) -> bool {
matrix["expected_outcomes"]["ant_quic_client"][implementation_name]
.as_mapping()
.is_some_and(|outcomes| {
outcomes
.keys()
.any(|outcome_category| outcome_category.as_str() == Some(category))
})
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ResolvedEndpoint {
addr: SocketAddr,
server_name: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ParsedEndpoint {
host: String,
port: u16,
}
fn resolve_endpoint(endpoint_str: &str) -> Result<ResolvedEndpoint, Box<dyn std::error::Error>> {
let parsed = parse_endpoint(endpoint_str)?;
let addr = (parsed.host.as_str(), parsed.port)
.to_socket_addrs()?
.next()
.ok_or_else(|| std::io::Error::other(format!("No address resolved for {endpoint_str}")))?;
Ok(ResolvedEndpoint {
addr,
server_name: parsed.host,
})
}
fn parse_endpoint(endpoint_str: &str) -> Result<ParsedEndpoint, Box<dyn std::error::Error>> {
let endpoint = endpoint_str.trim();
let (host, port) = split_endpoint_host_port(endpoint)?;
let port = port.parse()?;
Ok(ParsedEndpoint {
host: host.to_string(),
port,
})
}
fn split_endpoint_host_port(endpoint: &str) -> Result<(&str, &str), std::io::Error> {
if endpoint.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"endpoint is empty",
));
}
if let Some(rest) = endpoint.strip_prefix('[') {
let Some((host, port)) = rest.split_once("]:") else {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"bracketed IPv6 endpoint must use [addr]:port syntax",
));
};
return validate_endpoint_parts(host, port);
}
let Some((host, port)) = endpoint.rsplit_once(':') else {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"endpoint must include a port",
));
};
if host.contains(':') {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"IPv6 endpoint must use [addr]:port syntax",
));
}
validate_endpoint_parts(host, port)
}
fn validate_endpoint_parts<'a>(
host: &'a str,
port: &'a str,
) -> Result<(&'a str, &'a str), std::io::Error> {
if host.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"endpoint host is empty",
));
}
if port.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"endpoint port is empty",
));
}
Ok((host, port))
}
fn bind_addr_for_remote(remote_addr: SocketAddr) -> SocketAddr {
if remote_addr.is_ipv4() {
SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))
} else {
SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))
}
}
fn generate_html_report() -> String {
format!(
r#"<!DOCTYPE html>
<html>
<head>
<title>QUIC Interoperability Test Report</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
h1 {{ color: #333; }}
.summary {{ background: #f0f0f0; padding: 10px; margin: 20px 0; }}
</style>
</head>
<body>
<h1>QUIC Interoperability Test Report</h1>
<div class="summary">
<p>Generated: {}</p>
<p>This is a placeholder report. Full implementation coming soon.</p>
</div>
</body>
</html>"#,
chrono::Utc::now()
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv6Addr};
#[test]
fn parses_hostname_endpoint_without_resolution() -> Result<(), Box<dyn std::error::Error>> {
let parsed = parse_endpoint("www.google.com:443")?;
assert_eq!(
parsed,
ParsedEndpoint {
host: "www.google.com".to_string(),
port: 443,
}
);
Ok(())
}
#[test]
fn parses_bracketed_ipv6_endpoint_without_truncating_server_name()
-> Result<(), Box<dyn std::error::Error>> {
let resolved = resolve_endpoint("[2001:db8::1]:4433")?;
assert_eq!(resolved.server_name, "2001:db8::1");
assert_eq!(
resolved.addr,
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
4433
)
);
Ok(())
}
#[test]
fn uses_ipv6_bind_address_for_ipv6_targets() {
let remote = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 4433);
assert_eq!(
bind_addr_for_remote(remote),
SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))
);
}
#[test]
fn category_filter_selects_matching_implementations() -> Result<(), Box<dyn std::error::Error>>
{
let matrix: serde_yaml::Value =
serde_yaml::from_str(include_str!("../../tests/interop/interop-matrix.yaml"))?;
let category = validate_category_filter(&matrix, Some("nat_traversal"))?;
assert!(should_test_implementation(
&matrix, "picoquic", None, category
));
assert!(!should_test_implementation(
&matrix, "google", None, category
));
let Some(error) = validate_category_filter(&matrix, Some("missing")).err() else {
return Err("expected unknown category error".into());
};
assert_eq!(error.to_string(), "Unknown test category: missing");
Ok(())
}
}