use std::{convert::TryInto, fmt, num::TryFromIntError, sync::Arc, time::Duration};
use rand::RngCore;
use thiserror::Error;
#[cfg(feature = "rustls")]
use crate::crypto::types::{Certificate, CertificateChain, PrivateKey};
use crate::{
cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator},
congestion,
crypto::{self, ClientConfig as _, HandshakeTokenKey as _, HmacKey as _, ServerConfig as _},
VarInt, VarIntBoundsExceeded, DEFAULT_SUPPORTED_VERSIONS,
};
pub struct TransportConfig {
pub(crate) max_concurrent_bidi_streams: VarInt,
pub(crate) max_concurrent_uni_streams: VarInt,
pub(crate) max_idle_timeout: Option<Duration>,
pub(crate) stream_receive_window: VarInt,
pub(crate) receive_window: VarInt,
pub(crate) send_window: u64,
pub(crate) max_tlps: u32,
pub(crate) packet_threshold: u32,
pub(crate) time_threshold: f32,
pub(crate) initial_rtt: Duration,
pub(crate) persistent_congestion_threshold: u32,
pub(crate) keep_alive_interval: Option<Duration>,
pub(crate) crypto_buffer_size: usize,
pub(crate) allow_spin: bool,
pub(crate) datagram_receive_buffer_size: Option<usize>,
pub(crate) datagram_send_buffer_size: usize,
pub(crate) congestion_controller_factory: Box<dyn congestion::ControllerFactory + Send + Sync>,
}
impl TransportConfig {
pub fn max_concurrent_bidi_streams(&mut self, value: u64) -> Result<&mut Self, ConfigError> {
self.max_concurrent_bidi_streams = value.try_into()?;
Ok(self)
}
pub fn max_concurrent_uni_streams(&mut self, value: u64) -> Result<&mut Self, ConfigError> {
self.max_concurrent_uni_streams = value.try_into()?;
Ok(self)
}
pub fn max_idle_timeout(&mut self, value: Option<Duration>) -> Result<&mut Self, ConfigError> {
if value.map_or(false, |x| x.as_millis() > VarInt::MAX.0 as u128) {
return Err(ConfigError::OutOfBounds);
}
self.max_idle_timeout = value;
Ok(self)
}
pub fn stream_receive_window(&mut self, value: u64) -> Result<&mut Self, ConfigError> {
self.stream_receive_window = value.try_into()?;
Ok(self)
}
pub fn receive_window(&mut self, value: u64) -> Result<&mut Self, ConfigError> {
self.receive_window = value.try_into()?;
Ok(self)
}
pub fn send_window(&mut self, value: u64) -> &mut Self {
self.send_window = value;
self
}
pub fn max_tlps(&mut self, value: u32) -> &mut Self {
self.max_tlps = value;
self
}
pub fn packet_threshold(&mut self, value: u32) -> &mut Self {
self.packet_threshold = value;
self
}
pub fn time_threshold(&mut self, value: f32) -> &mut Self {
self.time_threshold = value;
self
}
pub fn initial_rtt(&mut self, value: Duration) -> &mut Self {
self.initial_rtt = value;
self
}
pub fn persistent_congestion_threshold(&mut self, value: u32) -> &mut Self {
self.persistent_congestion_threshold = value;
self
}
pub fn keep_alive_interval(&mut self, value: Option<Duration>) -> &mut Self {
self.keep_alive_interval = value;
self
}
pub fn crypto_buffer_size(&mut self, value: usize) -> &mut Self {
self.crypto_buffer_size = value;
self
}
pub fn allow_spin(&mut self, value: bool) -> &mut Self {
self.allow_spin = value;
self
}
pub fn datagram_receive_buffer_size(&mut self, value: Option<usize>) -> &mut Self {
self.datagram_receive_buffer_size = value;
self
}
pub fn datagram_send_buffer_size(&mut self, value: usize) -> &mut Self {
self.datagram_send_buffer_size = value;
self
}
pub fn congestion_controller_factory(
&mut self,
factory: impl congestion::ControllerFactory + Send + Sync + 'static,
) -> &mut Self {
self.congestion_controller_factory = Box::new(factory);
self
}
}
impl Default for TransportConfig {
fn default() -> Self {
const EXPECTED_RTT: u32 = 100;
const MAX_STREAM_BANDWIDTH: u32 = 12500 * 1000;
const STREAM_RWND: u32 = MAX_STREAM_BANDWIDTH / 1000 * EXPECTED_RTT;
TransportConfig {
max_concurrent_bidi_streams: 100u32.into(),
max_concurrent_uni_streams: 100u32.into(),
max_idle_timeout: Some(Duration::from_millis(10_000)),
stream_receive_window: STREAM_RWND.into(),
receive_window: VarInt::MAX,
send_window: (8 * STREAM_RWND).into(),
max_tlps: 2,
packet_threshold: 3,
time_threshold: 9.0 / 8.0,
initial_rtt: Duration::from_millis(333),
persistent_congestion_threshold: 3,
keep_alive_interval: None,
crypto_buffer_size: 16 * 1024,
allow_spin: true,
datagram_receive_buffer_size: Some(STREAM_RWND as usize),
datagram_send_buffer_size: 1024 * 1024,
congestion_controller_factory: Box::new(Arc::new(congestion::NewRenoConfig::default())),
}
}
}
impl fmt::Debug for TransportConfig {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("TranportConfig")
.field(
"max_concurrent_bidi_streams",
&self.max_concurrent_bidi_streams,
)
.field(
"max_concurrent_uni_streams",
&self.max_concurrent_uni_streams,
)
.field("max_idle_timeout", &self.max_idle_timeout)
.field("stream_receive_window", &self.stream_receive_window)
.field("receive_window", &self.receive_window)
.field("send_window", &self.send_window)
.field("max_tlps", &self.max_tlps)
.field("packet_threshold", &self.packet_threshold)
.field("time_threshold", &self.time_threshold)
.field("initial_rtt", &self.initial_rtt)
.field(
"persistent_congestion_threshold",
&self.persistent_congestion_threshold,
)
.field("keep_alive_interval", &self.keep_alive_interval)
.field("crypto_buffer_size", &self.crypto_buffer_size)
.field("allow_spin", &self.allow_spin)
.field(
"datagram_receive_buffer_size",
&self.datagram_receive_buffer_size,
)
.field("datagram_send_buffer_size", &self.datagram_send_buffer_size)
.field("congestion_controller_factory", &"[ opaque ]")
.finish()
}
}
pub struct EndpointConfig<S>
where
S: crypto::Session,
{
pub(crate) reset_key: Arc<S::HmacKey>,
pub(crate) max_udp_payload_size: VarInt,
pub(crate) connection_id_generator_factory:
Arc<dyn Fn() -> Box<dyn ConnectionIdGenerator> + Send + Sync>,
pub(crate) supported_versions: Vec<u32>,
pub(crate) initial_version: u32,
}
impl<S> EndpointConfig<S>
where
S: crypto::Session,
{
pub fn new(reset_key: S::HmacKey) -> Self {
let cid_factory: fn() -> Box<dyn ConnectionIdGenerator> =
|| Box::new(RandomConnectionIdGenerator::default());
Self {
reset_key: Arc::new(reset_key),
max_udp_payload_size: 1480u32.into(),
connection_id_generator_factory: Arc::new(cid_factory),
initial_version: DEFAULT_SUPPORTED_VERSIONS[0],
supported_versions: DEFAULT_SUPPORTED_VERSIONS.to_vec(),
}
}
pub fn cid_generator<F: Fn() -> Box<dyn ConnectionIdGenerator> + Send + Sync + 'static>(
&mut self,
factory: F,
) -> &mut Self {
self.connection_id_generator_factory = Arc::new(factory);
self
}
pub fn reset_key(&mut self, value: &[u8]) -> Result<&mut Self, ConfigError> {
self.reset_key = Arc::new(S::HmacKey::new(value)?);
Ok(self)
}
pub fn max_udp_payload_size(&mut self, value: u64) -> Result<&mut Self, ConfigError> {
self.max_udp_payload_size = value.try_into()?;
Ok(self)
}
#[doc(hidden)]
pub fn get_max_udp_payload_size(&self) -> u64 {
self.max_udp_payload_size.into()
}
pub fn supported_versions(
&mut self,
supported_versions: Vec<u32>,
initial_version: u32,
) -> Result<&mut Self, ConfigError> {
if !supported_versions.contains(&initial_version) {
return Err(ConfigError::OutOfBounds);
}
self.supported_versions = supported_versions;
self.initial_version = initial_version;
Ok(self)
}
}
impl<S: crypto::Session> fmt::Debug for EndpointConfig<S> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("EndpointConfig")
.field("reset_key", &"[ elided ]")
.field("max_udp_payload_size", &self.max_udp_payload_size)
.field("cid_generator_factory", &"[ elided ]")
.field("supported_versions", &self.supported_versions)
.field("initial_version", &self.initial_version)
.finish()
}
}
impl<S: crypto::Session> Default for EndpointConfig<S> {
fn default() -> Self {
let mut reset_key = vec![0; S::HmacKey::KEY_LEN];
rand::thread_rng().fill_bytes(&mut reset_key);
Self::new(
S::HmacKey::new(&reset_key)
.expect("HMAC key rejected random bytes; use EndpointConfig::new instead"),
)
}
}
impl<S: crypto::Session> Clone for EndpointConfig<S> {
fn clone(&self) -> Self {
Self {
reset_key: self.reset_key.clone(),
max_udp_payload_size: self.max_udp_payload_size,
connection_id_generator_factory: self.connection_id_generator_factory.clone(),
supported_versions: self.supported_versions.clone(),
initial_version: self.initial_version,
}
}
}
pub struct ServerConfig<S>
where
S: crypto::Session,
{
pub transport: Arc<TransportConfig>,
pub crypto: S::ServerConfig,
pub(crate) token_key: Arc<S::HandshakeTokenKey>,
pub(crate) use_stateless_retry: bool,
pub(crate) retry_token_lifetime: u64,
pub(crate) concurrent_connections: u32,
pub(crate) migration: bool,
}
impl<S> ServerConfig<S>
where
S: crypto::Session,
{
pub fn new(prk: S::HandshakeTokenKey) -> Self {
Self {
transport: Arc::new(TransportConfig::default()),
crypto: S::ServerConfig::new(),
token_key: Arc::new(prk),
use_stateless_retry: false,
retry_token_lifetime: 15_000_000,
concurrent_connections: 100_000,
migration: true,
}
}
pub fn token_key(&mut self, master_key: &[u8]) -> Result<&mut Self, ConfigError> {
self.token_key = Arc::new(S::HandshakeTokenKey::from_secret(&master_key));
Ok(self)
}
pub fn use_stateless_retry(&mut self, value: bool) -> &mut Self {
self.use_stateless_retry = value;
self
}
pub fn retry_token_lifetime(&mut self, value: u64) -> &mut Self {
self.retry_token_lifetime = value;
self
}
pub fn concurrent_connections(&mut self, value: u32) -> &mut Self {
self.concurrent_connections = value;
self
}
pub fn migration(&mut self, value: bool) -> &mut Self {
self.migration = value;
self
}
}
#[cfg(feature = "rustls")]
impl ServerConfig<crypto::rustls::TlsSession> {
pub fn certificate(
&mut self,
cert_chain: CertificateChain,
key: PrivateKey,
) -> Result<&mut Self, rustls::TLSError> {
Arc::make_mut(&mut self.crypto).set_single_cert(cert_chain.certs, key.inner)?;
Ok(self)
}
}
impl<S> fmt::Debug for ServerConfig<S>
where
S: crypto::Session,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("ServerConfig<T>")
.field("transport", &self.transport)
.field("crypto", &"ServerConfig { elided }")
.field("token_key", &"[ elided ]")
.field("use_stateless_retry", &self.use_stateless_retry)
.field("retry_token_lifetime", &self.retry_token_lifetime)
.field("concurrent_connections", &self.concurrent_connections)
.field("migration", &self.migration)
.finish()
}
}
impl<S> Default for ServerConfig<S>
where
S: crypto::Session,
{
fn default() -> Self {
let rng = &mut rand::thread_rng();
let mut master_key = [0u8; 64];
rng.fill_bytes(&mut master_key);
Self::new(S::HandshakeTokenKey::from_secret(&master_key))
}
}
impl<S> Clone for ServerConfig<S>
where
S: crypto::Session,
S::ServerConfig: Clone,
{
fn clone(&self) -> Self {
Self {
transport: self.transport.clone(),
crypto: self.crypto.clone(),
token_key: self.token_key.clone(),
use_stateless_retry: self.use_stateless_retry,
retry_token_lifetime: self.retry_token_lifetime,
concurrent_connections: self.concurrent_connections,
migration: self.migration,
}
}
}
pub struct ClientConfig<S>
where
S: crypto::Session,
{
pub transport: Arc<TransportConfig>,
pub crypto: S::ClientConfig,
}
#[cfg(feature = "rustls")]
impl ClientConfig<crypto::rustls::TlsSession> {
pub fn add_certificate_authority(
&mut self,
cert: Certificate,
) -> Result<&mut Self, webpki::Error> {
let anchor = webpki::trust_anchor_util::cert_der_as_trust_anchor(&cert.inner.0)?;
Arc::make_mut(&mut self.crypto)
.root_store
.add_server_trust_anchors(&webpki::TLSServerTrustAnchors(&[anchor]));
Ok(self)
}
}
impl<S> Default for ClientConfig<S>
where
S: crypto::Session,
{
fn default() -> Self {
Self {
transport: Default::default(),
crypto: S::ClientConfig::new(),
}
}
}
impl<S> Clone for ClientConfig<S>
where
S: crypto::Session,
S::ClientConfig: Clone,
{
fn clone(&self) -> Self {
Self {
transport: self.transport.clone(),
crypto: self.crypto.clone(),
}
}
}
impl<S> fmt::Debug for ClientConfig<S>
where
S: crypto::Session,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("ClientConfig<T>")
.field("transport", &self.transport)
.field("crypto", &"ClientConfig { elided }")
.finish()
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ConfigError {
#[error("value exceeds supported bounds")]
OutOfBounds,
}
impl From<TryFromIntError> for ConfigError {
fn from(_: TryFromIntError) -> Self {
ConfigError::OutOfBounds
}
}
impl From<VarIntBoundsExceeded> for ConfigError {
fn from(_: VarIntBoundsExceeded) -> Self {
ConfigError::OutOfBounds
}
}