use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use bytes::Bytes;
use quinn::crypto::rustls::{QuicClientConfig, QuicServerConfig};
use quinn::{ClientConfig, Connection, Endpoint, ServerConfig};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use crate::server::ZoneServer;
#[derive(Debug)]
pub enum QuicError {
BadCert,
Tls(rustls::Error),
Bind(std::io::Error),
Connect(quinn::ConnectError),
Connection(quinn::ConnectionError),
}
impl std::fmt::Display for QuicError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::BadCert => write!(f, "invalid certificate or key DER"),
Self::Tls(e) => write!(f, "TLS error: {e}"),
Self::Bind(e) => write!(f, "bind error: {e}"),
Self::Connect(e) => write!(f, "connect error: {e}"),
Self::Connection(e) => write!(f, "connection error: {e}"),
}
}
}
impl std::error::Error for QuicError {}
impl From<rustls::Error> for QuicError {
fn from(e: rustls::Error) -> Self {
Self::Tls(e)
}
}
pub struct QuicZoneServer {
inner: Arc<ZoneServer>,
cert_der: Vec<u8>,
key_der: Vec<u8>,
connections: Arc<RwLock<HashMap<usize, Connection>>>,
}
impl QuicZoneServer {
pub fn new(server: Arc<ZoneServer>, cert_der: Vec<u8>, key_der: Vec<u8>) -> Self {
Self {
inner: server,
cert_der,
key_der,
connections: Arc::new(RwLock::new(HashMap::new())),
}
}
fn build_server_config(&self) -> Result<ServerConfig, QuicError> {
let cert = CertificateDer::from(self.cert_der.clone());
let key = PrivateKeyDer::try_from(self.key_der.clone())
.map_err(|_| QuicError::BadCert)?;
let tls_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert], key)?;
let quic_config = QuicServerConfig::try_from(tls_config)
.map_err(|_| QuicError::BadCert)?;
Ok(ServerConfig::with_crypto(Arc::new(quic_config)))
}
pub async fn listen(&self, addr: &str) -> Result<(), QuicError> {
let server_config = self.build_server_config()?;
let socket_addr: SocketAddr = addr
.parse()
.map_err(|e| QuicError::Bind(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
let endpoint = Endpoint::server(server_config, socket_addr).map_err(QuicError::Bind)?;
let connections = self.connections.clone();
while let Some(incoming) = endpoint.accept().await {
let conns = connections.clone();
tokio::spawn(async move {
if let Ok(conn) = incoming.await {
let id = conn.stable_id();
conns.write().unwrap().insert(id, conn.clone());
loop {
match conn.accept_bi().await {
Ok((_send, mut recv)) => {
let mut buf = vec![0u8; 4096];
let _ = recv.read(&mut buf).await;
}
Err(_) => {
conns.write().unwrap().remove(&id);
break;
}
}
}
}
});
}
Ok(())
}
pub fn broadcast_datagram(&self, payload: &[u8]) {
let data: Bytes = payload.to_vec().into();
let conns = self.connections.read().unwrap();
for conn in conns.values() {
let _ = conn.send_datagram(data.clone());
}
}
pub fn connection_count(&self) -> usize {
self.connections.read().unwrap().len()
}
pub fn inner(&self) -> &Arc<ZoneServer> {
&self.inner
}
}
pub async fn quic_connect(addr: &str, server_name: &str) -> Result<Connection, QuicError> {
let socket_addr: SocketAddr = addr
.parse()
.map_err(|e| QuicError::Bind(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
let client_config = build_client_config();
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()).map_err(QuicError::Bind)?;
endpoint.set_default_client_config(client_config);
let conn = endpoint
.connect(socket_addr, server_name)
.map_err(QuicError::Connect)?
.await
.map_err(QuicError::Connection)?;
Ok(conn)
}
pub async fn quic_connect_insecure(addr: &str) -> Result<Connection, QuicError> {
let socket_addr: SocketAddr = addr
.parse()
.map_err(|e| QuicError::Bind(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))?;
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
.with_no_client_auth();
let quic_config = QuicClientConfig::try_from(tls_config)
.map_err(|_| QuicError::BadCert)?;
let client_config = ClientConfig::new(Arc::new(quic_config));
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap()).map_err(QuicError::Bind)?;
endpoint.set_default_client_config(client_config);
let conn = endpoint
.connect(socket_addr, "localhost")
.map_err(QuicError::Connect)?
.await
.map_err(QuicError::Connection)?;
Ok(conn)
}
fn build_client_config() -> ClientConfig {
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
.with_no_client_auth();
let quic_config = QuicClientConfig::try_from(tls_config).expect("valid TLS client config");
ClientConfig::new(Arc::new(quic_config))
}
#[derive(Debug)]
struct SkipServerVerification;
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
]
}
}
#[cfg(test)]
fn generate_self_signed_cert(
subject: &str,
) -> Result<(Vec<u8>, Vec<u8>), Box<dyn std::error::Error + Send + Sync>> {
let cert = rcgen::generate_simple_self_signed(vec![subject.to_string()])?;
let cert_der = cert.cert.der().to_vec();
let key_der = cert.signing_key.serialize_der();
Ok((cert_der, key_der))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn broadcast_datagram_with_no_clients_is_ok() {
use crate::coord::EnuConverter;
use crate::octree::OctreeNode;
use crate::store::ZoneStore;
let conv = Arc::new(EnuConverter::new(0.0, 0.0, 0.0));
let store = Arc::new(RwLock::new(ZoneStore::from_entries(&[], &*conv)));
let octree = Arc::new(RwLock::new(OctreeNode::new([0.0; 3], 50.0)));
let srv = Arc::new(ZoneServer::new(store, octree, conv));
let (cert, key) = generate_self_signed_cert("localhost").unwrap();
let qs = QuicZoneServer::new(srv, cert, key);
assert_eq!(qs.connection_count(), 0);
qs.broadcast_datagram(&[1, 2, 3]);
assert_eq!(qs.connection_count(), 0);
}
#[test]
fn cert_generation_works() {
let (cert, key) = generate_self_signed_cert("localhost").unwrap();
assert!(!cert.is_empty(), "cert DER should be non-empty");
assert!(!key.is_empty(), "key DER should be non-empty");
}
#[test]
fn server_config_builds_from_valid_cert() {
use crate::coord::EnuConverter;
use crate::octree::OctreeNode;
use crate::store::ZoneStore;
let conv = Arc::new(EnuConverter::new(0.0, 0.0, 0.0));
let store = Arc::new(RwLock::new(ZoneStore::from_entries(&[], &*conv)));
let octree = Arc::new(RwLock::new(OctreeNode::new([0.0; 3], 50.0)));
let srv = Arc::new(ZoneServer::new(store, octree, conv));
let (cert, key) = generate_self_signed_cert("localhost").unwrap();
let qs = QuicZoneServer::new(srv, cert, key);
qs.build_server_config()
.expect("valid cert should build server config");
}
#[test]
fn bad_cert_produces_error() {
use crate::coord::EnuConverter;
use crate::octree::OctreeNode;
use crate::store::ZoneStore;
let conv = Arc::new(EnuConverter::new(0.0, 0.0, 0.0));
let store = Arc::new(RwLock::new(ZoneStore::from_entries(&[], &*conv)));
let octree = Arc::new(RwLock::new(OctreeNode::new([0.0; 3], 50.0)));
let srv = Arc::new(ZoneServer::new(store, octree, conv));
let qs = QuicZoneServer::new(srv, vec![0xFF; 10], vec![0xFF; 10]);
assert!(qs.build_server_config().is_err(), "garbage DER should fail");
}
#[tokio::test]
async fn quic_server_accepts_connection() {
use crate::coord::EnuConverter;
use crate::octree::OctreeNode;
use crate::store::ZoneStore;
let conv = Arc::new(EnuConverter::new(0.0, 0.0, 0.0));
let store = Arc::new(RwLock::new(ZoneStore::from_entries(&[], &*conv)));
let octree = Arc::new(RwLock::new(OctreeNode::new([0.0; 3], 50.0)));
let srv = Arc::new(ZoneServer::new(store, octree, conv));
let (cert, key) = generate_self_signed_cert("localhost").unwrap();
let qs = Arc::new(QuicZoneServer::new(srv, cert, key));
let qs2 = qs.clone();
let server_task = tokio::spawn(async move {
let _ = qs2.listen("127.0.0.1:0").await;
});
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
server_task.abort();
}
#[tokio::test]
async fn quic_connect_and_datagram_roundtrip() {
use crate::coord::EnuConverter;
use crate::octree::OctreeNode;
use crate::store::ZoneStore;
let conv = Arc::new(EnuConverter::new(0.0, 0.0, 0.0));
let store = Arc::new(RwLock::new(ZoneStore::from_entries(&[], &*conv)));
let octree = Arc::new(RwLock::new(OctreeNode::new([0.0; 3], 50.0)));
let srv = Arc::new(ZoneServer::new(store, octree, conv));
let (cert_der, key_der) = generate_self_signed_cert("localhost").unwrap();
let cert = CertificateDer::from(cert_der.clone());
let key = PrivateKeyDer::try_from(key_der.clone()).unwrap();
let tls_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert], key)
.unwrap();
let quic_config = QuicServerConfig::try_from(tls_config).unwrap();
let server_config = ServerConfig::with_crypto(Arc::new(quic_config));
let endpoint = Endpoint::server(server_config, "127.0.0.1:0".parse().unwrap()).unwrap();
let bound_addr = endpoint.local_addr().unwrap();
let conns: Arc<RwLock<HashMap<usize, Connection>>> =
Arc::new(RwLock::new(HashMap::new()));
let conns2 = conns.clone();
tokio::spawn(async move {
while let Some(incoming) = endpoint.accept().await {
if let Ok(conn) = incoming.await {
let id = conn.stable_id();
conns2.write().unwrap().insert(id, conn.clone());
let c = conns2.clone();
tokio::spawn(async move {
loop {
match conn.accept_bi().await {
Ok(_) => {}
Err(_) => {
c.write().unwrap().remove(&id);
break;
}
}
}
});
}
}
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let conn = quic_connect_insecure(&bound_addr.to_string())
.await
.expect("should connect");
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(conns.read().unwrap().len(), 1, "server should have 1 client");
let data: Bytes = b"hello quic".to_vec().into();
for c in conns.read().unwrap().values() {
let _ = c.send_datagram(data.clone());
}
let received = conn
.read_datagram()
.await
.expect("should receive datagram");
assert_eq!(received.as_ref(), b"hello quic");
let qs = QuicZoneServer::new(srv.clone(), cert_der, key_der);
assert_eq!(qs.connection_count(), 0);
qs.broadcast_datagram(b"test");
assert_eq!(qs.connection_count(), 0);
}
}