use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use quinn::{Connection, Endpoint, ReadExactError, RecvStream, SendStream};
use rcgen::generate_simple_self_signed;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio::runtime::Runtime;
use crate::error::{RepError, Result};
use crate::net::channel::Channel;
#[derive(Debug)]
struct SkipCertVerification(Arc<rustls::crypto::CryptoProvider>);
impl rustls::client::danger::ServerCertVerifier for SkipCertVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp: &[u8],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<
rustls::client::danger::ServerCertVerified,
rustls::Error,
> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dh: &rustls::DigitallySignedStruct,
) -> std::result::Result<
rustls::client::danger::HandshakeSignatureValid,
rustls::Error,
> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dh,
&self.0.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dh: &rustls::DigitallySignedStruct,
) -> std::result::Result<
rustls::client::danger::HandshakeSignatureValid,
rustls::Error,
> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dh,
&self.0.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}
fn self_signed_cert()
-> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
let ck = generate_simple_self_signed(vec!["localhost".to_string()])
.map_err(|e| RepError::NetworkError(format!("rcgen: {e}")))?;
let cert = CertificateDer::from(ck.cert.der().to_vec());
let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(
ck.key_pair.serialize_der(),
));
Ok((vec![cert], key))
}
pub fn default_server_config() -> Result<quinn::ServerConfig> {
let (certs, key) = self_signed_cert()?;
let tls = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| RepError::NetworkError(format!("TLS: {e}")))?;
let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(tls)
.map_err(|e| {
RepError::NetworkError(format!("QUIC server config: {e}"))
})?;
let mut cfg = quinn::ServerConfig::with_crypto(Arc::new(quic_tls));
let mut transport = quinn::TransportConfig::default();
transport.mtu_discovery_config(None);
cfg.transport_config(Arc::new(transport));
Ok(cfg)
}
pub fn insecure_client_config() -> quinn::ClientConfig {
let provider = Arc::new(rustls::crypto::ring::default_provider());
let verifier = SkipCertVerification(Arc::clone(&provider));
let tls = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(verifier))
.with_no_client_auth();
let mut cfg = quinn::ClientConfig::new(Arc::new(
quinn::crypto::rustls::QuicClientConfig::try_from(tls)
.expect("valid insecure client config"),
));
let mut transport = quinn::TransportConfig::default();
transport.mtu_discovery_config(None);
cfg.transport_config(Arc::new(transport));
cfg
}
const MAGIC: &[u8; 4] = b"NXUR";
pub struct QuicChannel {
_endpoint: Option<Endpoint>,
_connection: Connection,
send: Arc<tokio::sync::Mutex<SendStream>>,
recv: Arc<tokio::sync::Mutex<RecvStream>>,
runtime: Arc<Runtime>,
open: AtomicBool,
}
impl QuicChannel {
pub fn from_streams(
connection: Connection,
send: SendStream,
recv: RecvStream,
runtime: Arc<Runtime>,
) -> Self {
Self {
_endpoint: None,
_connection: connection,
send: Arc::new(tokio::sync::Mutex::new(send)),
recv: Arc::new(tokio::sync::Mutex::new(recv)),
runtime,
open: AtomicBool::new(true),
}
}
pub fn connect(addr: SocketAddr, server_name: &str) -> Result<Self> {
Self::connect_with_config(addr, server_name, insecure_client_config())
}
pub fn connect_host(
host: &str,
port: u16,
server_name: &str,
) -> Result<Self> {
let addrs: Vec<SocketAddr> = (host, port)
.to_socket_addrs()
.map_err(|e| {
RepError::NetworkError(format!(
"DNS resolution failed for {host}:{port}: {e}"
))
})?
.collect();
if addrs.is_empty() {
return Err(RepError::NetworkError(format!(
"no addresses resolved for {host}:{port}"
)));
}
let mut sorted = addrs;
sorted.sort_by_key(|a| if a.is_ipv6() { 0u8 } else { 1u8 });
let mut last_err = None;
for addr in &sorted {
match Self::connect(*addr, server_name) {
Ok(ch) => return Ok(ch),
Err(e) => last_err = Some(e),
}
}
Err(last_err.unwrap_or_else(|| {
RepError::NetworkError(format!(
"could not connect to {host}:{port}"
))
}))
}
pub fn connect_with_config(
addr: SocketAddr,
server_name: &str,
client_cfg: quinn::ClientConfig,
) -> Result<Self> {
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
.map_err(|e| RepError::NetworkError(format!("tokio: {e}")))?,
);
let server_name = server_name.to_string();
let (endpoint, conn, send, recv) = runtime.block_on(async move {
let mut endpoint =
Endpoint::client("0.0.0.0:0".parse().expect("valid bind addr"))
.map_err(|e| RepError::NetworkError(e.to_string()))?;
endpoint.set_default_client_config(client_cfg);
let conn = endpoint
.connect(addr, &server_name)
.map_err(|e| RepError::NetworkError(e.to_string()))?
.await
.map_err(|e| RepError::NetworkError(e.to_string()))?;
let (mut send, recv) = conn
.open_bi()
.await
.map_err(|e| RepError::NetworkError(e.to_string()))?;
send.write_all(MAGIC)
.await
.map_err(|e| RepError::NetworkError(e.to_string()))?;
Ok::<_, RepError>((endpoint, conn, send, recv))
})?;
Ok(Self {
_endpoint: Some(endpoint),
_connection: conn,
send: Arc::new(tokio::sync::Mutex::new(send)),
recv: Arc::new(tokio::sync::Mutex::new(recv)),
runtime,
open: AtomicBool::new(true),
})
}
}
impl Channel for QuicChannel {
fn send(&self, data: &[u8]) -> Result<()> {
if !self.is_open() {
return Err(RepError::ChannelClosed(
"QuicChannel is closed".into(),
));
}
let len_prefix = (data.len() as u32).to_le_bytes();
let payload = data.to_vec();
self.runtime.block_on(async {
let mut stream = self.send.lock().await;
stream
.write_all(&len_prefix)
.await
.map_err(|e| RepError::NetworkError(e.to_string()))?;
stream
.write_all(&payload)
.await
.map_err(|e| RepError::NetworkError(e.to_string()))
})
}
fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>> {
if !self.is_open() {
return Err(RepError::ChannelClosed(
"QuicChannel is closed".into(),
));
}
self.runtime.block_on(async {
let mut stream = self.recv.lock().await;
let mut len_buf = [0u8; 4];
match tokio::time::timeout(timeout, stream.read_exact(&mut len_buf))
.await
{
Err(_elapsed) => return Ok(None),
Ok(Ok(_n)) => {}
Ok(Err(ReadExactError::FinishedEarly(_))) => {
return Err(RepError::ChannelClosed(
"QUIC stream closed by peer".into(),
));
}
Ok(Err(ReadExactError::ReadError(e))) => {
return Err(RepError::NetworkError(e.to_string()));
}
}
let payload_len = u32::from_le_bytes(len_buf) as usize;
if payload_len > crate::net::channel::MAX_FRAME_PAYLOAD {
return Err(RepError::ProtocolError(format!(
"frame payload too large: {} > {}",
payload_len,
crate::net::channel::MAX_FRAME_PAYLOAD
)));
}
let mut payload = vec![0u8; payload_len];
stream.read_exact(&mut payload).await.map_err(|e| match e {
ReadExactError::FinishedEarly(_) => RepError::ChannelClosed(
"QUIC stream closed mid-payload".into(),
),
ReadExactError::ReadError(re) => {
RepError::NetworkError(re.to_string())
}
})?;
Ok(Some(payload))
})
}
fn close(&self) -> Result<()> {
if !self.open.swap(false, Ordering::SeqCst) {
return Ok(());
}
self.runtime.block_on(async {
let mut stream = self.send.lock().await;
let _ = stream.finish();
tokio::time::sleep(Duration::from_millis(50)).await;
});
Ok(())
}
fn is_open(&self) -> bool {
self.open.load(Ordering::SeqCst)
}
}
impl Drop for QuicChannel {
fn drop(&mut self) {
if self.is_open() {
let _ = self.close();
}
}
}
pub struct QuicChannelListener {
endpoint: Endpoint,
runtime: Arc<Runtime>,
}
impl QuicChannelListener {
pub fn bind(addr: SocketAddr) -> Result<Self> {
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
.map_err(|e| RepError::NetworkError(format!("tokio: {e}")))?,
);
let server_cfg = default_server_config()?;
let endpoint = runtime.block_on(async move {
Endpoint::server(server_cfg, addr)
.map_err(|e| RepError::NetworkError(e.to_string()))
})?;
Ok(Self { endpoint, runtime })
}
pub fn with_server_config(
addr: SocketAddr,
server_cfg: quinn::ServerConfig,
) -> Result<Self> {
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
.map_err(|e| RepError::NetworkError(format!("tokio: {e}")))?,
);
let endpoint = runtime.block_on(async move {
Endpoint::server(server_cfg, addr)
.map_err(|e| RepError::NetworkError(e.to_string()))
})?;
Ok(Self { endpoint, runtime })
}
pub fn bind_with_tls_and_allowlist(
addr: SocketAddr,
tls: &crate::tls::TlsConfig,
allowlist: crate::auth::PeerAllowlist,
) -> Result<Self> {
let server_cfg =
tls.to_quinn_server_config_with_allowlist(allowlist)?;
Self::with_server_config(addr, server_cfg)
}
pub fn local_addr(&self) -> Result<SocketAddr> {
self.endpoint
.local_addr()
.map_err(|e| RepError::NetworkError(e.to_string()))
}
pub fn accept(&self) -> Result<QuicChannel> {
self.runtime.block_on(async {
let incoming = self.endpoint.accept().await.ok_or_else(|| {
RepError::NetworkError("QUIC endpoint closed".into())
})?;
let conn = incoming
.await
.map_err(|e| RepError::NetworkError(e.to_string()))?;
let (send, mut recv) = conn
.accept_bi()
.await
.map_err(|e| RepError::NetworkError(e.to_string()))?;
let mut magic = [0u8; 4];
recv.read_exact(&mut magic).await.map_err(|e| {
RepError::NetworkError(format!("handshake read: {e}"))
})?;
if &magic != MAGIC {
return Err(RepError::NetworkError(format!(
"invalid QUIC handshake magic: {magic:02x?}"
)));
}
Ok(QuicChannel::from_streams(
conn,
send,
recv,
Arc::clone(&self.runtime),
))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn loopback_listener() -> QuicChannelListener {
QuicChannelListener::bind("127.0.0.1:0".parse().unwrap())
.expect("bind QUIC listener")
}
#[test]
fn test_quic_basic_send_receive() {
let listener = loopback_listener();
let addr = listener.local_addr().unwrap();
let server_thread = std::thread::spawn(move || {
let ch = listener.accept().unwrap();
let msg = ch.receive(Duration::from_secs(5)).unwrap();
assert_eq!(msg, Some(b"hello quic".to_vec()));
ch.send(b"world").unwrap();
});
let client = QuicChannel::connect(addr, "localhost").unwrap();
client.send(b"hello quic").unwrap();
let reply = client.receive(Duration::from_secs(5)).unwrap();
assert_eq!(reply, Some(b"world".to_vec()));
server_thread.join().unwrap();
}
#[test]
fn test_quic_empty_message() {
let listener = loopback_listener();
let addr = listener.local_addr().unwrap();
let server_thread = std::thread::spawn(move || {
let ch = listener.accept().unwrap();
let msg = ch.receive(Duration::from_secs(5)).unwrap();
assert_eq!(msg, Some(vec![]));
});
let client = QuicChannel::connect(addr, "localhost").unwrap();
client.send(b"").unwrap();
server_thread.join().unwrap();
}
#[test]
fn test_quic_large_message() {
let listener = loopback_listener();
let addr = listener.local_addr().unwrap();
let payload: Vec<u8> = (0u32..65536).map(|i| (i % 256) as u8).collect();
let expected = payload.clone();
let server_thread = std::thread::spawn(move || {
let ch = listener.accept().unwrap();
let msg = ch.receive(Duration::from_secs(5)).unwrap().unwrap();
assert_eq!(msg, expected);
});
let client = QuicChannel::connect(addr, "localhost").unwrap();
client.send(&payload).unwrap();
server_thread.join().unwrap();
}
#[test]
fn test_quic_multiple_messages_fifo() {
let listener = loopback_listener();
let addr = listener.local_addr().unwrap();
let server_thread = std::thread::spawn(move || {
let ch = listener.accept().unwrap();
for i in 0u8..5 {
let msg = ch.receive(Duration::from_secs(5)).unwrap().unwrap();
assert_eq!(msg, vec![i]);
}
});
let client = QuicChannel::connect(addr, "localhost").unwrap();
for i in 0u8..5 {
client.send(&[i]).unwrap();
}
server_thread.join().unwrap();
}
#[test]
fn test_quic_receive_timeout() {
let listener = loopback_listener();
let addr = listener.local_addr().unwrap();
let server_thread =
std::thread::spawn(move || listener.accept().unwrap());
let client = QuicChannel::connect(addr, "localhost").unwrap();
let result = client.receive(Duration::from_millis(300)).unwrap();
assert_eq!(result, None, "expected timeout → None");
drop(server_thread.join().unwrap());
}
#[test]
fn test_quic_is_open_and_close() {
let listener = loopback_listener();
let addr = listener.local_addr().unwrap();
let server_thread = std::thread::spawn(move || {
listener.accept().unwrap();
});
let client = QuicChannel::connect(addr, "localhost").unwrap();
assert!(client.is_open());
client.close().unwrap();
assert!(!client.is_open());
server_thread.join().unwrap();
}
#[test]
fn test_quic_send_after_close_fails() {
let listener = loopback_listener();
let addr = listener.local_addr().unwrap();
let server_thread = std::thread::spawn(move || {
listener.accept().unwrap();
});
let client = QuicChannel::connect(addr, "localhost").unwrap();
client.close().unwrap();
assert!(client.send(b"should fail").is_err());
server_thread.join().unwrap();
}
#[test]
fn test_quic_bidirectional() {
let listener = loopback_listener();
let addr = listener.local_addr().unwrap();
let server_thread = std::thread::spawn(move || {
let ch = listener.accept().unwrap();
let from_client = ch.receive(Duration::from_secs(5)).unwrap();
assert_eq!(from_client, Some(b"ping".to_vec()));
ch.send(b"pong").unwrap();
let from_client2 = ch.receive(Duration::from_secs(5)).unwrap();
assert_eq!(from_client2, Some(b"done".to_vec()));
});
let client = QuicChannel::connect(addr, "localhost").unwrap();
client.send(b"ping").unwrap();
let reply = client.receive(Duration::from_secs(5)).unwrap();
assert_eq!(reply, Some(b"pong".to_vec()));
client.send(b"done").unwrap();
server_thread.join().unwrap();
}
#[test]
fn test_quic_local_addr() {
let listener = loopback_listener();
let addr = listener.local_addr().unwrap();
assert_eq!(addr.ip().to_string(), "127.0.0.1");
assert_ne!(addr.port(), 0);
}
#[test]
fn test_quic_channel_implements_channel_trait() {
let listener = loopback_listener();
let addr = listener.local_addr().unwrap();
let server_thread = std::thread::spawn(move || {
let ch = listener.accept().unwrap();
let _: &dyn Channel = &ch; ch.receive(Duration::from_secs(5)).unwrap()
});
let client = QuicChannel::connect(addr, "localhost").unwrap();
let ch: Box<dyn Channel> = Box::new(client); ch.send(b"trait test").unwrap();
let msg = server_thread.join().unwrap();
assert_eq!(msg, Some(b"trait test".to_vec()));
}
#[test]
fn test_quic_rejects_oversize_frame() {
let listener = loopback_listener();
let addr = listener.local_addr().unwrap();
let server_thread = std::thread::spawn(move || {
let ch = listener.accept().unwrap();
let oversized =
vec![0u8; crate::net::channel::MAX_FRAME_PAYLOAD + 1];
let _ = ch.send(&oversized);
});
let client = QuicChannel::connect(addr, "localhost").unwrap();
let result = client.receive(Duration::from_secs(10));
let _ = client.close();
let err = result.expect_err("oversize QUIC frame must be rejected");
match err {
RepError::ProtocolError(msg) => {
assert!(
msg.contains("frame payload too large"),
"unexpected protocol-error message: {}",
msg
);
}
other => panic!("expected ProtocolError, got {:?}", other),
}
let _ = server_thread.join();
}
}