use crate::tls::build_native_cert_store;
use crate::tls::Identity;
use quinn::EndpointConfig;
use quinn::TransportConfig;
use socket2::Domain as SocketDomain;
use socket2::Protocol as SocketProtocol;
use socket2::Socket;
use socket2::Type as SocketType;
use std::fmt::Debug;
use std::fmt::Display;
use std::future::Future;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
use std::net::SocketAddr;
use std::net::SocketAddrV4;
use std::net::SocketAddrV6;
use std::net::UdpSocket;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
pub type TlsServerConfig = crate::tls::rustls::ServerConfig;
pub type TlsClientConfig = crate::tls::rustls::ClientConfig;
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub type QuicTransportConfig = crate::quinn::TransportConfig;
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub type QuicServerConfig = crate::quinn::ServerConfig;
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub type QuicClientConfig = crate::quinn::ClientConfig;
#[derive(Debug, Copy, Clone)]
pub enum IpBindConfig {
LocalV4,
LocalV6,
LocalDual,
InAddrAnyV4,
InAddrAnyV6,
InAddrAnyDual,
}
impl IpBindConfig {
fn into_ip(self) -> IpAddr {
match self {
IpBindConfig::LocalV4 => Ipv4Addr::LOCALHOST.into(),
IpBindConfig::LocalV6 => Ipv6Addr::LOCALHOST.into(),
IpBindConfig::LocalDual => Ipv6Addr::LOCALHOST.into(),
IpBindConfig::InAddrAnyV4 => Ipv4Addr::UNSPECIFIED.into(),
IpBindConfig::InAddrAnyV6 => Ipv6Addr::UNSPECIFIED.into(),
IpBindConfig::InAddrAnyDual => Ipv6Addr::UNSPECIFIED.into(),
}
}
fn into_dual_stack_config(self) -> Ipv6DualStackConfig {
match self {
IpBindConfig::LocalV4 | IpBindConfig::InAddrAnyV4 => Ipv6DualStackConfig::OsDefault,
IpBindConfig::LocalV6 | IpBindConfig::InAddrAnyV6 => Ipv6DualStackConfig::Deny,
IpBindConfig::LocalDual | IpBindConfig::InAddrAnyDual => Ipv6DualStackConfig::Allow,
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum Ipv6DualStackConfig {
OsDefault,
Deny,
Allow,
}
pub struct InvalidIdleTimeout;
#[derive(Debug)]
pub struct ServerConfig {
pub(crate) bind_address_config: BindAddressConfig,
pub(crate) endpoint_config: quinn::EndpointConfig,
pub(crate) quic_config: quinn::ServerConfig,
}
impl ServerConfig {
pub fn builder() -> ServerConfigBuilder<states::WantsBindAddress> {
ServerConfigBuilder::default()
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn quic_endpoint_config(&self) -> &quinn::EndpointConfig {
&self.endpoint_config
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn quic_endpoint_config_mut(&mut self) -> &mut quinn::EndpointConfig {
&mut self.endpoint_config
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn quic_config(&self) -> &quinn::ServerConfig {
&self.quic_config
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn quic_config_mut(&mut self) -> &mut quinn::ServerConfig {
&mut self.quic_config
}
}
#[must_use]
pub struct ServerConfigBuilder<State>(State);
impl ServerConfigBuilder<states::WantsBindAddress> {
pub fn with_bind_default(
self,
listening_port: u16,
) -> ServerConfigBuilder<states::WantsIdentity> {
self.with_bind_config(IpBindConfig::InAddrAnyDual, listening_port)
}
pub fn with_bind_config(
self,
ip_bind_config: IpBindConfig,
listening_port: u16,
) -> ServerConfigBuilder<states::WantsIdentity> {
let ip_address: IpAddr = ip_bind_config.into_ip();
match ip_address {
IpAddr::V4(ip) => self.with_bind_address(SocketAddr::new(ip.into(), listening_port)),
IpAddr::V6(ip) => self.with_bind_address_v6(
SocketAddrV6::new(ip, listening_port, 0, 0),
ip_bind_config.into_dual_stack_config(),
),
}
}
pub fn with_bind_address(
self,
address: SocketAddr,
) -> ServerConfigBuilder<states::WantsIdentity> {
ServerConfigBuilder(states::WantsIdentity {
bind_address_config: BindAddressConfig::from(address),
})
}
pub fn with_bind_address_v6(
self,
address: SocketAddrV6,
dual_stack_config: Ipv6DualStackConfig,
) -> ServerConfigBuilder<states::WantsIdentity> {
ServerConfigBuilder(states::WantsIdentity {
bind_address_config: BindAddressConfig::AddressV6(address, dual_stack_config),
})
}
pub fn with_bind_socket(self, socket: UdpSocket) -> ServerConfigBuilder<states::WantsIdentity> {
ServerConfigBuilder(states::WantsIdentity {
bind_address_config: BindAddressConfig::Socket(socket),
})
}
}
impl ServerConfigBuilder<states::WantsIdentity> {
pub fn with_identity(
self,
identity: Identity,
) -> ServerConfigBuilder<states::WantsTransportConfigServer> {
use crate::tls::server::build_default_tls_config;
let tls_config = build_default_tls_config(identity);
let quic_endpoint_config = EndpointConfig::default();
let quic_transport_config = TransportConfig::default();
self.with(tls_config, quic_endpoint_config, quic_transport_config)
}
pub fn with_custom_tls(
self,
tls_config: TlsServerConfig,
) -> ServerConfigBuilder<states::WantsTransportConfigServer> {
let quic_endpoint_config = EndpointConfig::default();
let quic_transport_config = TransportConfig::default();
self.with(tls_config, quic_endpoint_config, quic_transport_config)
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn with_custom_transport(
self,
identity: Identity,
quic_transport_config: QuicTransportConfig,
) -> ServerConfigBuilder<states::WantsTransportConfigServer> {
use crate::tls::server::build_default_tls_config;
let tls_config = build_default_tls_config(identity);
let quic_endpoint_config = EndpointConfig::default();
self.with(tls_config, quic_endpoint_config, quic_transport_config)
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn with_custom_tls_and_transport(
self,
tls_config: TlsServerConfig,
quic_transport_config: QuicTransportConfig,
) -> ServerConfigBuilder<states::WantsTransportConfigServer> {
let quic_endpoint_config = EndpointConfig::default();
self.with(tls_config, quic_endpoint_config, quic_transport_config)
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn build_with_quic_config(self, quic_config: QuicServerConfig) -> ServerConfig {
ServerConfig {
bind_address_config: self.0.bind_address_config,
endpoint_config: EndpointConfig::default(),
quic_config,
}
}
fn with(
self,
tls_config: TlsServerConfig,
endpoint_config: EndpointConfig,
transport_config: TransportConfig,
) -> ServerConfigBuilder<states::WantsTransportConfigServer> {
ServerConfigBuilder(states::WantsTransportConfigServer {
bind_address_config: self.0.bind_address_config,
tls_config,
endpoint_config,
transport_config,
migration: true,
})
}
}
impl ServerConfigBuilder<states::WantsTransportConfigServer> {
#[must_use]
pub fn build(self) -> ServerConfig {
let crypto: Arc<quinn::crypto::rustls::QuicServerConfig> = Arc::new(
quinn::crypto::rustls::QuicServerConfig::try_from(self.0.tls_config)
.expect("CipherSuite::TLS13_AES_128_GCM_SHA256 missing"),
);
let mut quic_config = quinn::ServerConfig::with_crypto(crypto);
quic_config.transport_config(Arc::new(self.0.transport_config));
quic_config.migration(self.0.migration);
ServerConfig {
bind_address_config: self.0.bind_address_config,
endpoint_config: self.0.endpoint_config,
quic_config,
}
}
pub fn max_idle_timeout(
mut self,
idle_timeout: Option<Duration>,
) -> Result<Self, InvalidIdleTimeout> {
let idle_timeout = idle_timeout
.map(quinn::IdleTimeout::try_from)
.transpose()
.map_err(|_| InvalidIdleTimeout)?;
self.0.transport_config.max_idle_timeout(idle_timeout);
Ok(self)
}
pub fn keep_alive_interval(mut self, interval: Option<Duration>) -> Self {
self.0.transport_config.keep_alive_interval(interval);
self
}
pub fn allow_migration(mut self, value: bool) -> Self {
self.0.migration = value;
self
}
}
#[derive(Debug)]
pub struct ClientConfig {
pub(crate) bind_address_config: BindAddressConfig,
pub(crate) endpoint_config: quinn::EndpointConfig,
pub(crate) quic_config: quinn::ClientConfig,
pub(crate) dns_resolver: Arc<dyn DnsResolver + Send + Sync>,
}
impl ClientConfig {
pub fn builder() -> ClientConfigBuilder<states::WantsBindAddress> {
ClientConfigBuilder::default()
}
pub fn set_dns_resolver<R>(&mut self, dns_resolver: R)
where
R: DnsResolver + Send + Sync + 'static,
{
self.dns_resolver = Arc::new(dns_resolver);
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn quic_endpoint_config(&self) -> &quinn::EndpointConfig {
&self.endpoint_config
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn quic_endpoint_config_mut(&mut self) -> &mut quinn::EndpointConfig {
&mut self.endpoint_config
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn quic_config(&self) -> &quinn::ClientConfig {
&self.quic_config
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn quic_config_mut(&mut self) -> &mut quinn::ClientConfig {
&mut self.quic_config
}
}
impl Default for ClientConfig {
fn default() -> Self {
ClientConfig::builder()
.with_bind_default()
.with_native_certs()
.build()
}
}
#[must_use]
pub struct ClientConfigBuilder<State>(State);
impl ClientConfigBuilder<states::WantsBindAddress> {
pub fn with_bind_default(self) -> ClientConfigBuilder<states::WantsRootStore> {
self.with_bind_config(IpBindConfig::InAddrAnyDual)
}
pub fn with_bind_config(
self,
ip_bind_config: IpBindConfig,
) -> ClientConfigBuilder<states::WantsRootStore> {
let ip_address: IpAddr = ip_bind_config.into_ip();
match ip_address {
IpAddr::V4(ip) => self.with_bind_address(SocketAddr::new(ip.into(), 0)),
IpAddr::V6(ip) => self.with_bind_address_v6(
SocketAddrV6::new(ip, 0, 0, 0),
ip_bind_config.into_dual_stack_config(),
),
}
}
pub fn with_bind_address(
self,
address: SocketAddr,
) -> ClientConfigBuilder<states::WantsRootStore> {
ClientConfigBuilder(states::WantsRootStore {
bind_address_config: BindAddressConfig::from(address),
})
}
pub fn with_bind_address_v6(
self,
address: SocketAddrV6,
dual_stack_config: Ipv6DualStackConfig,
) -> ClientConfigBuilder<states::WantsRootStore> {
ClientConfigBuilder(states::WantsRootStore {
bind_address_config: BindAddressConfig::AddressV6(address, dual_stack_config),
})
}
pub fn with_bind_socket(
self,
socket: UdpSocket,
) -> ClientConfigBuilder<states::WantsRootStore> {
ClientConfigBuilder(states::WantsRootStore {
bind_address_config: BindAddressConfig::Socket(socket),
})
}
}
impl ClientConfigBuilder<states::WantsRootStore> {
pub fn with_native_certs(self) -> ClientConfigBuilder<states::WantsTransportConfigClient> {
use crate::tls::client::build_default_tls_config;
let tls_config = build_default_tls_config(Arc::new(build_native_cert_store()), None);
let endpoint_config = EndpointConfig::default();
let transport_config = TransportConfig::default();
self.with(tls_config, endpoint_config, transport_config)
}
#[cfg(feature = "dangerous-configuration")]
#[cfg_attr(docsrs, doc(cfg(feature = "dangerous-configuration")))]
pub fn with_no_cert_validation(
self,
) -> ClientConfigBuilder<states::WantsTransportConfigClient> {
use crate::tls::client::build_default_tls_config;
use crate::tls::client::NoServerVerification;
use rustls::RootCertStore;
let tls_config = build_default_tls_config(
Arc::new(RootCertStore::empty()),
Some(Arc::new(NoServerVerification::new())),
);
let endpoint_config = EndpointConfig::default();
let transport_config = TransportConfig::default();
self.with(tls_config, endpoint_config, transport_config)
}
pub fn with_server_certificate_hashes<I>(
self,
hashes: I,
) -> ClientConfigBuilder<states::WantsTransportConfigClient>
where
I: IntoIterator<Item = crate::tls::Sha256Digest>,
{
use crate::tls::client::build_default_tls_config;
use crate::tls::client::ServerHashVerification;
use rustls::RootCertStore;
let tls_config = build_default_tls_config(
Arc::new(RootCertStore::empty()),
Some(Arc::new(ServerHashVerification::new(hashes))),
);
let endpoint_config = EndpointConfig::default();
let transport_config = TransportConfig::default();
self.with(tls_config, endpoint_config, transport_config)
}
pub fn with_custom_tls(
self,
tls_config: TlsClientConfig,
) -> ClientConfigBuilder<states::WantsTransportConfigClient> {
let endpoint_config = EndpointConfig::default();
let transport_config = TransportConfig::default();
self.with(tls_config, endpoint_config, transport_config)
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn with_custom_transport(
self,
quic_transport_config: QuicTransportConfig,
) -> ClientConfigBuilder<states::WantsTransportConfigClient> {
use crate::tls::client::build_default_tls_config;
let tls_config = build_default_tls_config(Arc::new(build_native_cert_store()), None);
let quic_endpoint_config = EndpointConfig::default();
self.with(tls_config, quic_endpoint_config, quic_transport_config)
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn with_custom_tls_and_transport(
self,
tls_config: TlsClientConfig,
quic_transport_config: QuicTransportConfig,
) -> ClientConfigBuilder<states::WantsTransportConfigClient> {
let quic_endpoint_config = EndpointConfig::default();
self.with(tls_config, quic_endpoint_config, quic_transport_config)
}
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn build_with_quic_config(self, quic_config: QuicClientConfig) -> ClientConfig {
ClientConfig {
bind_address_config: self.0.bind_address_config,
endpoint_config: EndpointConfig::default(),
quic_config,
dns_resolver: Arc::<TokioDnsResolver>::default(),
}
}
fn with(
self,
tls_config: TlsClientConfig,
endpoint_config: EndpointConfig,
transport_config: TransportConfig,
) -> ClientConfigBuilder<states::WantsTransportConfigClient> {
ClientConfigBuilder(states::WantsTransportConfigClient {
bind_address_config: self.0.bind_address_config,
tls_config,
endpoint_config,
transport_config,
dns_resolver: Arc::<TokioDnsResolver>::default(),
})
}
}
impl ClientConfigBuilder<states::WantsTransportConfigClient> {
#[must_use]
pub fn build(self) -> ClientConfig {
let crypto = quinn::crypto::rustls::QuicClientConfig::try_from(self.0.tls_config)
.expect("CipherSuite::TLS13_AES_128_GCM_SHA256 missing");
let mut quic_config = quinn::ClientConfig::new(Arc::new(crypto));
quic_config.transport_config(Arc::new(self.0.transport_config));
ClientConfig {
bind_address_config: self.0.bind_address_config,
endpoint_config: self.0.endpoint_config,
quic_config,
dns_resolver: self.0.dns_resolver,
}
}
pub fn max_idle_timeout(
mut self,
idle_timeout: Option<Duration>,
) -> Result<Self, InvalidIdleTimeout> {
let idle_timeout = idle_timeout
.map(quinn::IdleTimeout::try_from)
.transpose()
.map_err(|_| InvalidIdleTimeout)?;
self.0.transport_config.max_idle_timeout(idle_timeout);
Ok(self)
}
pub fn keep_alive_interval(mut self, interval: Option<Duration>) -> Self {
self.0.transport_config.keep_alive_interval(interval);
self
}
pub fn dns_resolver<R>(mut self, dns_resolver: R) -> Self
where
R: DnsResolver + Send + Sync + 'static,
{
self.0.dns_resolver = Arc::new(dns_resolver);
self
}
}
impl Default for ServerConfigBuilder<states::WantsBindAddress> {
fn default() -> Self {
Self(states::WantsBindAddress {})
}
}
impl Default for ClientConfigBuilder<states::WantsBindAddress> {
fn default() -> Self {
Self(states::WantsBindAddress {})
}
}
#[derive(Debug)]
pub(crate) enum BindAddressConfig {
AddressV4(SocketAddrV4),
AddressV6(SocketAddrV6, Ipv6DualStackConfig),
Socket(UdpSocket),
}
impl BindAddressConfig {
pub(crate) fn bind_socket(self) -> std::io::Result<UdpSocket> {
let (bind_address, dual_stack_config) = match self {
BindAddressConfig::AddressV4(address) => {
(SocketAddr::from(address), Ipv6DualStackConfig::OsDefault)
}
BindAddressConfig::AddressV6(address, ipv6_dual_stack_config) => {
(SocketAddr::from(address), ipv6_dual_stack_config)
}
BindAddressConfig::Socket(socket) => {
return Ok(socket);
}
};
let domain = match bind_address {
SocketAddr::V4(_) => SocketDomain::IPV4,
SocketAddr::V6(_) => SocketDomain::IPV6,
};
let socket = Socket::new(domain, SocketType::DGRAM, Some(SocketProtocol::UDP))?;
match dual_stack_config {
Ipv6DualStackConfig::OsDefault => {}
Ipv6DualStackConfig::Deny => socket.set_only_v6(true)?,
Ipv6DualStackConfig::Allow => socket.set_only_v6(false)?,
}
socket.bind(&bind_address.into())?;
Ok(UdpSocket::from(socket))
}
}
impl From<SocketAddr> for BindAddressConfig {
fn from(value: SocketAddr) -> Self {
match value {
SocketAddr::V4(address) => BindAddressConfig::AddressV4(address),
SocketAddr::V6(address) => {
BindAddressConfig::AddressV6(address, Ipv6DualStackConfig::OsDefault)
}
}
}
}
pub mod states {
use super::*;
pub struct WantsBindAddress {}
pub struct WantsIdentity {
pub(super) bind_address_config: BindAddressConfig,
}
pub struct WantsRootStore {
pub(super) bind_address_config: BindAddressConfig,
}
pub struct WantsTransportConfigServer {
pub(super) bind_address_config: BindAddressConfig,
pub(super) tls_config: TlsServerConfig,
pub(super) endpoint_config: quinn::EndpointConfig,
pub(super) transport_config: quinn::TransportConfig,
pub(super) migration: bool,
}
pub struct WantsTransportConfigClient {
pub(super) bind_address_config: BindAddressConfig,
pub(super) tls_config: TlsClientConfig,
pub(super) endpoint_config: quinn::EndpointConfig,
pub(super) transport_config: quinn::TransportConfig,
pub(super) dns_resolver: Arc<dyn DnsResolver + Send + Sync>,
}
}
pub trait DnsLookupFuture: Future<Output = std::io::Result<Option<SocketAddr>>> + Send {}
impl<F> DnsLookupFuture for F where F: Future<Output = std::io::Result<Option<SocketAddr>>> + Send {}
pub trait DnsResolver: Debug {
fn resolve(&self, host: &str) -> Pin<Box<dyn DnsLookupFuture>>;
}
#[derive(Default)]
pub struct TokioDnsResolver;
impl DnsResolver for TokioDnsResolver {
fn resolve(&self, host: &str) -> Pin<Box<dyn DnsLookupFuture>> {
let host = host.to_string();
Box::pin(async move { Ok(tokio::net::lookup_host(host).await?.next()) })
}
}
impl Debug for TokioDnsResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokioDnsResolver").finish()
}
}
impl std::error::Error for InvalidIdleTimeout {}
impl Debug for InvalidIdleTimeout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("idle timeout value configuration is invalid")
}
}
impl Display for InvalidIdleTimeout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(self, f)
}
}