use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use quinn::{
ClientConfig, Connection, Endpoint, RecvStream, SendStream, ServerConfig, TransportConfig,
VarInt,
crypto::rustls::{QuicClientConfig, QuicServerConfig},
};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio::sync::watch;
use tracing::{debug, error, info};
use crate::stats::StreamStats;
const DEFAULT_BUFFER_SIZE: usize = 128 * 1024;
pub fn generate_self_signed_cert()
-> anyhow::Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
let cert = rcgen::generate_simple_self_signed(vec!["xfr".to_string()])?;
let key = PrivateKeyDer::Pkcs8(cert.signing_key.serialize_der().into());
let cert_der = CertificateDer::from(cert.cert);
Ok((vec![cert_der], key))
}
pub fn create_server_endpoint(
addr: SocketAddr,
cert: Vec<CertificateDer<'static>>,
key: PrivateKeyDer<'static>,
) -> anyhow::Result<Endpoint> {
let mut crypto = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert, key)?;
crypto.alpn_protocols = vec![b"xfr".to_vec()];
let mut transport = TransportConfig::default();
transport.max_concurrent_bidi_streams(VarInt::from_u32(129)); transport.keep_alive_interval(Some(Duration::from_secs(5)));
let quic_crypto = QuicServerConfig::try_from(crypto)?;
let mut server_config = ServerConfig::with_crypto(Arc::new(quic_crypto));
server_config.transport_config(Arc::new(transport));
let endpoint = Endpoint::server(server_config, addr)?;
info!("QUIC server listening on {}", addr);
Ok(endpoint)
}
pub fn create_client_endpoint(
remote_addr: SocketAddr,
local_bind: Option<SocketAddr>,
) -> anyhow::Result<Endpoint> {
let mut crypto = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
.with_no_client_auth();
crypto.alpn_protocols = vec![b"xfr".to_vec()];
let mut transport = TransportConfig::default();
transport.max_concurrent_bidi_streams(VarInt::from_u32(129));
let quic_crypto = QuicClientConfig::try_from(crypto)?;
let mut client_config = ClientConfig::new(Arc::new(quic_crypto));
client_config.transport_config(Arc::new(transport));
let bind_addr: SocketAddr = local_bind.unwrap_or_else(|| {
if remote_addr.is_ipv6() {
"[::]:0".parse().unwrap()
} else {
"0.0.0.0:0".parse().unwrap()
}
});
let mut endpoint = Endpoint::client(bind_addr)?;
endpoint.set_default_client_config(client_config);
Ok(endpoint)
}
#[derive(Debug)]
struct SkipServerVerification;
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_: &[u8],
_: &CertificateDer<'_>,
_: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_: &[u8],
_: &CertificateDer<'_>,
_: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ED25519,
]
}
}
pub async fn send_quic_data(
mut send: SendStream,
stats: Arc<StreamStats>,
duration: Duration,
mut cancel: watch::Receiver<bool>,
mut pause: watch::Receiver<bool>,
) -> anyhow::Result<()> {
let buffer = vec![0u8; DEFAULT_BUFFER_SIZE];
let deadline = tokio::time::Instant::now() + duration;
let is_infinite = duration == Duration::ZERO;
loop {
if *cancel.borrow() {
debug!("QUIC send cancelled");
break;
}
if *pause.borrow() {
if crate::pause::wait_while_paused(&mut pause, &mut cancel).await {
break;
}
continue;
}
if !is_infinite && tokio::time::Instant::now() >= deadline {
break;
}
match send.write(&buffer).await {
Ok(n) => stats.add_bytes_sent(n as u64),
Err(e) => {
error!("QUIC send error: {}", e);
return Err(e.into());
}
}
}
send.finish()?;
Ok(())
}
pub async fn receive_quic_data(
mut recv: RecvStream,
stats: Arc<StreamStats>,
mut cancel: watch::Receiver<bool>,
) -> anyhow::Result<()> {
let mut buffer = vec![0u8; DEFAULT_BUFFER_SIZE];
loop {
tokio::select! {
result = recv.read(&mut buffer) => {
match result? {
Some(n) => stats.add_bytes_received(n as u64),
None => break, }
}
_ = cancel.changed() => {
if *cancel.borrow() {
debug!("QUIC receive cancelled");
break;
}
}
}
}
Ok(())
}
pub async fn connect(endpoint: &Endpoint, addr: SocketAddr) -> anyhow::Result<Connection> {
let connection = endpoint.connect(addr, "xfr")?.await?;
debug!("QUIC connected to {}", addr);
Ok(connection)
}