use super::config::{ClientAuth, QuicConfig};
use super::connection::QuicConnection;
use super::error::QuicError;
use crate::cx::Cx;
use std::net::SocketAddr;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct QuicEndpoint {
inner: quinn::Endpoint,
}
impl QuicEndpoint {
#[allow(clippy::option_if_let_else)] pub fn client(cx: &Cx, config: &QuicConfig) -> Result<Self, QuicError> {
cx.checkpoint()?;
let root_certs = if let Some(store) = config.root_certs.clone() {
store
} else {
let mut store = crate::tls::RootCertStore::empty();
#[cfg(feature = "tls-native-roots")]
{
store
.extend_from_native_roots()
.map_err(|e| QuicError::TlsConfig(e.to_string()))?;
}
store.extend_from_webpki_roots();
store
};
let mut crypto = if config.insecure_skip_verify {
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
.with_no_client_auth()
} else {
if root_certs.is_empty() {
return Err(QuicError::Config(
"no root certificates configured; enable tls-native-roots/tls-webpki-roots or provide RootCertStore".into(),
));
}
rustls::ClientConfig::builder()
.with_root_certificates(root_certs.into_inner())
.with_no_client_auth()
};
if !config.alpn_protocols.is_empty() {
crypto.alpn_protocols.clone_from(&config.alpn_protocols);
}
let transport = config.to_transport_config();
let mut client_config = quinn::ClientConfig::new(Arc::new(
quinn::crypto::rustls::QuicClientConfig::try_from(crypto)
.map_err(|e| QuicError::TlsConfig(e.to_string()))?,
));
client_config.transport_config(Arc::new(transport));
let bind_addr = SocketAddr::from(([0, 0, 0, 0], 0));
let mut endpoint = quinn::Endpoint::client(bind_addr)?;
endpoint.set_default_client_config(client_config);
Ok(Self { inner: endpoint })
}
pub fn server(cx: &Cx, addr: SocketAddr, config: &QuicConfig) -> Result<Self, QuicError> {
cx.checkpoint()?;
let (cert_chain_raw, private_key_raw) = match (&config.cert_chain, &config.private_key) {
(Some(c), Some(k)) if config.is_valid_for_server() => (c, k),
_ => {
return Err(QuicError::Config(
"server requires cert_chain and private_key; client_auth_roots must also be provided if client_auth is enabled".into(),
));
}
};
let cert_chain = cert_chain_raw
.iter()
.map(|c| rustls::pki_types::CertificateDer::from(c.clone()))
.collect::<Vec<_>>();
let private_key = rustls::pki_types::PrivateKeyDer::try_from(private_key_raw.clone())
.map_err(|e| QuicError::TlsConfig(format!("invalid private key: {e}")))?;
let builder = rustls::ServerConfig::builder();
let builder = match config.client_auth {
ClientAuth::None => builder.with_no_client_auth(),
ClientAuth::Optional | ClientAuth::Required => {
let roots = config.client_auth_roots.clone().ok_or_else(|| {
QuicError::Config(
"client_auth_roots required when client_auth is Optional/Required".into(),
)
})?;
if roots.is_empty() {
return Err(QuicError::Config(
"client_auth_roots must be non-empty when client_auth is enabled".into(),
));
}
let verifier = match config.client_auth {
ClientAuth::Optional => {
rustls::server::WebPkiClientVerifier::builder(Arc::new(roots.into_inner()))
.allow_unauthenticated()
.build()
.map_err(|e| QuicError::TlsConfig(e.to_string()))?
}
ClientAuth::Required => {
rustls::server::WebPkiClientVerifier::builder(Arc::new(roots.into_inner()))
.build()
.map_err(|e| QuicError::TlsConfig(e.to_string()))?
}
ClientAuth::None => unreachable!("handled above"),
};
builder.with_client_cert_verifier(verifier)
}
};
let mut crypto = builder
.with_single_cert(cert_chain, private_key)
.map_err(|e| QuicError::TlsConfig(e.to_string()))?;
if !config.alpn_protocols.is_empty() {
crypto.alpn_protocols.clone_from(&config.alpn_protocols);
}
let transport = config.to_transport_config();
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
quinn::crypto::rustls::QuicServerConfig::try_from(crypto)
.map_err(|e| QuicError::TlsConfig(e.to_string()))?,
));
server_config.transport_config(Arc::new(transport));
let endpoint = quinn::Endpoint::server(server_config, addr)?;
Ok(Self { inner: endpoint })
}
pub async fn connect(
&self,
cx: &Cx,
addr: SocketAddr,
server_name: &str,
) -> Result<QuicConnection, QuicError> {
cx.checkpoint()?;
let connecting = self.inner.connect(addr, server_name)?;
let connection = connecting.await?;
Ok(QuicConnection::new(connection))
}
pub async fn connect_with(
&self,
cx: &Cx,
addr: SocketAddr,
server_name: &str,
config: quinn::ClientConfig,
) -> Result<QuicConnection, QuicError> {
cx.checkpoint()?;
let connecting = self.inner.connect_with(config, addr, server_name)?;
let connection = connecting.await?;
Ok(QuicConnection::new(connection))
}
pub async fn accept(&self, cx: &Cx) -> Result<QuicIncoming, QuicError> {
cx.checkpoint()?;
let incoming = self.inner.accept().await.ok_or(QuicError::EndpointClosed)?;
Ok(QuicIncoming { inner: incoming })
}
pub fn local_addr(&self) -> Result<SocketAddr, QuicError> {
self.inner.local_addr().map_err(QuicError::from)
}
pub fn close(&self, code: u32, reason: &[u8]) {
self.inner.close(code.into(), reason);
}
pub async fn wait_idle(&self) {
self.inner.wait_idle().await;
}
#[must_use]
pub fn inner(&self) -> &quinn::Endpoint {
&self.inner
}
}
#[derive(Debug)]
pub struct QuicIncoming {
inner: quinn::Connecting,
}
impl QuicIncoming {
pub async fn handshake(self) -> Result<QuicConnection, QuicError> {
let connection = self.inner.await?;
Ok(QuicConnection::new(connection))
}
pub fn remote_address(&self) -> SocketAddr {
self.inner.remote_address()
}
}
#[derive(Debug)]
struct SkipServerVerification;
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
]
}
}