use std::{fmt::Debug, sync::Arc, time::Duration};
use anyhow::Result;
use human_repr::HumanCount as _;
use num_traits::ToPrimitive as _;
use quinn::{
MtuDiscoveryConfig, TransportConfig, VarInt,
congestion::{BbrConfig, CubicConfig, NewRenoConfig},
};
use tracing::{debug, trace};
use crate::protocol::Variant;
use crate::{
config::{self, Configuration, Configuration_Optional, Manager},
protocol::{
FindTag as _,
compat::Feature,
control::{ClientMessage2Attributes, ClientMessageV2, Compatibility, CongestionController},
},
util::PortRange,
};
pub trait DebugControllerFactory: quinn::congestion::ControllerFactory + std::fmt::Debug {}
impl DebugControllerFactory for BbrConfig {}
impl DebugControllerFactory for CubicConfig {}
impl DebugControllerFactory for NewRenoConfig {}
pub(crate) const PROTOCOL_KEEPALIVE: Duration = Duration::from_secs(5);
const META_CLIENT: &str = "requested by client";
const META_NEGOTIATED: &str = "config resolution logic";
#[derive(Copy, Clone, Debug, PartialEq, strum_macros::Display, Default)]
pub enum ThroughputMode {
Tx,
Rx,
#[default]
Both,
}
pub fn create_config(
params: &Configuration,
mode: ThroughputMode,
compat: Compatibility,
) -> Result<(Arc<TransportConfig>, Arc<dyn DebugControllerFactory>)> {
let mut mtu_cfg = MtuDiscoveryConfig::default();
let _ = mtu_cfg.upper_bound(params.max_mtu);
let mut config = TransportConfig::default();
let _ = config
.max_concurrent_bidi_streams(1u8.into())
.max_concurrent_uni_streams(0u8.into())
.keep_alive_interval(Some(PROTOCOL_KEEPALIVE))
.allow_spin(true)
.initial_rtt(2 * params.rtt_duration())
.packet_threshold(params.packet_threshold)
.time_threshold(params.time_threshold)
.min_mtu(params.min_mtu)
.initial_mtu(params.initial_mtu)
.mtu_discovery_config(Some(mtu_cfg));
let udp_buf = params
.udp_buffer
.to_usize()
.ok_or(anyhow::anyhow!("udp_buffer size overflowed usize"))?;
match mode {
ThroughputMode::Tx | ThroughputMode::Both => {
let _ = config
.send_window(params.send_window())
.datagram_send_buffer_size(udp_buf);
}
ThroughputMode::Rx => (),
}
match mode {
ThroughputMode::Rx | ThroughputMode::Both => {
let rwnd: VarInt = params.recv_window().try_into()?;
let _ = config
.receive_window(rwnd) .stream_receive_window(rwnd) .datagram_receive_buffer_size(Some(udp_buf));
}
ThroughputMode::Tx => (),
}
let window = params.initial_congestion_window;
let congestion: Arc<dyn DebugControllerFactory> = match params.congestion {
CongestionController::Cubic => {
let mut cubic = CubicConfig::default();
if window != 0 {
let _ = cubic.initial_window(window);
}
let factory = Arc::new(cubic);
let _ = config.congestion_controller_factory(factory.clone());
factory
}
CongestionController::Bbr => {
let mut bbr = BbrConfig::default();
if window != 0 {
let _ = bbr.initial_window(window);
}
let factory = Arc::new(bbr);
let _ = config.congestion_controller_factory(factory.clone());
factory
}
CongestionController::NewReno => {
anyhow::ensure!(
compat.supports(Feature::NEW_RENO),
"Remote host does not support NewReno"
);
let mut newreno = NewRenoConfig::default();
if window != 0 {
let _ = newreno.initial_window(window);
}
let factory = Arc::new(newreno);
let _ = config.congestion_controller_factory(factory.clone());
factory
}
};
debug!(
"Final network configuration: {}",
params.format_transport_config()
);
trace!("Quinn network configuration: {config:?}");
let send_data = if mode == ThroughputMode::Rx {
""
} else {
&format!(
"; send window {}, send buffer {}",
params.send_window().human_count_bytes(),
udp_buf.human_count_bytes()
)
};
let recv_data = if mode == ThroughputMode::Tx {
""
} else {
&format!(
"; recv window {}, recv buffer {}",
params.recv_window().human_count_bytes(),
udp_buf.human_count_bytes()
)
};
debug!("Buffer configuration: mode {mode}{send_data}{recv_data}");
Ok((config.into(), congestion))
}
#[derive(Debug)]
enum CombinationResponse<T> {
Server,
Client,
Combined(T),
Failure(anyhow::Error),
}
fn negotiate_v3<ClientType, ServerType, BaseType>(
client: Option<ClientType>,
server: Option<ServerType>,
resolve_conflict: fn(BaseType, BaseType) -> CombinationResponse<BaseType>,
client_out: &mut config::Source,
resolved_out: &mut config::Source,
key: &str,
) -> Result<()>
where
BaseType: From<ClientType> + From<ServerType>,
ClientType: Clone + Into<figment::value::Value> + Into<BaseType> + Into<ServerType>,
figment::value::Value: From<BaseType>,
ServerType: std::cmp::PartialEq,
{
match (client, server) {
(None, None) => return Ok(()),
(Some(cc), None) => {
client_out.add(key, cc.into());
}
(None, Some(_)) => (), (Some(cc), Some(ss)) => {
if <ClientType as Into<ServerType>>::into(cc.clone()) == ss {
return Ok(());
}
match resolve_conflict(cc.clone().into(), ss.into()) {
CombinationResponse::Server => (),
CombinationResponse::Client => {
client_out.add(key, cc.into());
}
CombinationResponse::Combined(val) => {
resolved_out.add(key, val.into());
}
CombinationResponse::Failure(err) => return Err(err),
}
}
}
Ok(())
}
fn min_ignoring_zero(cli: u64, srv: u64) -> CombinationResponse<u64> {
match (cli, srv) {
(0, _) => CombinationResponse::Server,
(_, 0) => CombinationResponse::Client,
(cc, ss) => CombinationResponse::Combined(std::cmp::min(cc, ss)),
}
}
#[cfg(doc)]
use crate::protocol::control::ClientMessageV1;
pub fn combine_bandwidth_configurations(
manager: &mut Manager,
client: &ClientMessageV2,
) -> Result<Configuration> {
use num_traits::AsPrimitive as _;
let server: Configuration_Optional = manager.get::<Configuration_Optional>()?;
let mut client_picks = config::Source::new(META_CLIENT);
let mut negotiated = config::Source::new(META_NEGOTIATED);
macro_rules! negotiate {
($cli:expr, $ser:expr, $resolve:expr, $key:expr) => {
negotiate_v3(
$cli,
$ser,
$resolve,
&mut client_picks,
&mut negotiated,
$key,
)
};
}
let ca = &client.attributes;
negotiate!(
ca.find_tag(ClientMessage2Attributes::BandwidthToServer)
.map(Variant::coerce_unsigned),
server.rx,
min_ignoring_zero,
"rx"
)?;
negotiate!(
ca.find_tag(ClientMessage2Attributes::BandwidthToClient)
.map(Variant::coerce_unsigned),
server.tx,
min_ignoring_zero,
"tx"
)?;
negotiate!(
ca.find_tag(ClientMessage2Attributes::RoundTripTime)
.map(|v| (v.coerce_unsigned() & 0xffff) as u16),
server.rtt,
|_: u16, _| CombinationResponse::Client,
"rtt"
)?;
let cctrl = ca
.find_tag(ClientMessage2Attributes::CongestionControllerType)
.map(Variant::coerce_unsigned)
.and_then(|v| CongestionController::from_repr(v.as_()));
negotiate!(
cctrl,
server.congestion,
|_: CongestionController, _| CombinationResponse::Failure(anyhow::anyhow!(
"server and client have incompatible congestion algorithm requirements"
)),
"congestion"
)?;
negotiate!(
ca.find_tag(ClientMessage2Attributes::InitialCongestionWindow)
.map(Variant::coerce_unsigned),
server.initial_congestion_window,
|_: u64, _| CombinationResponse::Server,
"initial_congestion_window"
)?;
let client_pr = match (
ca.find_tag(ClientMessage2Attributes::PortRangeStart)
.map(|v| v.coerce_unsigned().as_()),
ca.find_tag(ClientMessage2Attributes::PortRangeEnd)
.map(|v| v.coerce_unsigned().as_()),
) {
(None, None) => None,
(Some(it), None) | (None, Some(it)) => Some(PortRange { begin: it, end: it }),
(Some(begin), Some(end)) => Some(PortRange { begin, end }),
};
negotiate!(
client_pr,
server.port,
|a, b| crate::util::PortRange::combine(a, b)
.map_or_else(CombinationResponse::Failure, CombinationResponse::Combined),
"port"
)?;
negotiate!(
ca.find_tag(ClientMessage2Attributes::QuicTimeout)
.map(|v| (v.coerce_unsigned() & 0xffff) as u16),
server.timeout,
|_: u16, _| CombinationResponse::Client,
"timeout"
)?;
make_dict_human_friendly(client_picks.borrow());
make_dict_human_friendly(negotiated.borrow());
manager.merge_provider(client_picks);
manager.merge_provider(negotiated);
manager.apply_system_default();
manager.get::<Configuration>()?.validate()
}
fn make_entry_human_friendly(
entry: std::collections::btree_map::Entry<'_, String, figment::value::Value>,
) {
use engineering_repr::EngineeringRepr as _;
use figment::value::Value;
let _ = entry.and_modify(|v| {
if let Value::Num(_tag, num) = v
&& let Some(u) = num.to_u128()
{
*v = Value::from(u.to_eng(0).to_string());
}
});
}
fn make_dict_human_friendly(dict: &mut figment::value::Dict) {
make_entry_human_friendly(dict.entry("rx".into()));
make_entry_human_friendly(dict.entry("tx".into()));
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use assertables::{assert_contains, assert_matches};
use pretty_assertions::assert_eq;
use crate::{
Configuration,
config::Manager,
protocol::{
DataTag, FindTag,
control::{
ClientMessage, ClientMessage2Attributes, Compatibility, ConnectionType,
OUR_COMPATIBILITY_LEVEL, ServerMessage,
},
},
transport::{Configuration_Optional, combine_bandwidth_configurations},
util::{Credentials, PortRange},
};
use super::{ThroughputMode, create_config};
fn process_config(cfg: &Configuration, mode: ThroughputMode) -> (String, String) {
let (tc, congestion) = create_config(cfg, mode, Compatibility::Level(2)).unwrap();
(format!("{tc:#?}"), format!("{congestion:#?}"))
}
#[test]
fn modes() {
let mut cfg = Configuration::system_default().clone();
cfg.tx = 1_000_000;
cfg.rx = cfg.tx;
cfg.rtt = 456;
let (str, _) = process_config(&cfg, ThroughputMode::Rx);
assert_contains!(str, "stream_receive_window: 456000");
assert_contains!(str, "send_window: 10000000");
let (str, _) = process_config(&cfg, ThroughputMode::Tx);
assert_contains!(str, "stream_receive_window: 1250000");
assert_contains!(str, "send_window: 3648000"); }
#[test]
fn congestion_config() {
use crate::protocol::control::CongestionController::*;
let mut cfg = Configuration::system_default().clone();
cfg.congestion = Cubic;
cfg.initial_congestion_window = 1000;
let (_, str) = process_config(&cfg, ThroughputMode::Both);
assert_contains!(str, "CubicConfig");
assert_contains!(str, "initial_window: 1000");
cfg.congestion = Bbr;
let (_, str) = process_config(&cfg, ThroughputMode::Both);
assert_contains!(str, "BbrConfig");
assert_contains!(str, "initial_window: 1000");
cfg.congestion = NewReno;
let (_, str) = process_config(&cfg, ThroughputMode::Both);
assert_contains!(str, "NewRenoConfig");
assert_contains!(str, "initial_window: 1000");
}
#[test]
fn congestion_config_compat() {
let mut cfg = Configuration::system_default().clone();
cfg.congestion = crate::protocol::control::CongestionController::NewReno;
let e = create_config(&cfg, ThroughputMode::Both, Compatibility::Level(1)).unwrap_err();
eprintln!("{e}");
assert_contains!(e.to_string(), "Remote host does not support NewReno");
}
#[test]
fn congestion_config_incompat() {
let server_cfg = Configuration_Optional {
congestion: Some(crate::protocol::control::CongestionController::NewReno),
..Default::default()
};
let mut mgr = Manager::without_files(None);
mgr.merge_provider(server_cfg);
let attributes = vec![
ClientMessage2Attributes::CongestionControllerType
.with_unsigned(crate::protocol::control::CongestionController::Cubic as u64),
];
let mp = crate::protocol::control::ClientMessageV2 {
attributes,
..Default::default()
};
let e = combine_bandwidth_configurations(&mut mgr, &mp).unwrap_err();
assert_contains!(
e.to_string(),
"server and client have incompatible congestion algorithm requirements"
);
}
#[test]
fn negotiation() {
let server_cfg = Configuration_Optional {
port: Some(PortRange {
begin: 1000,
end: 2000,
}),
initial_congestion_window: Some(1000),
..Default::default()
};
let mut mgr = Manager::new(None, false, false);
mgr.merge_provider(server_cfg);
let attributes = vec![
ClientMessage2Attributes::PortRangeStart.with_unsigned(500u64),
ClientMessage2Attributes::PortRangeEnd.with_unsigned(1500u64),
ClientMessage2Attributes::InitialCongestionWindow.with_unsigned(500u64),
ClientMessage2Attributes::RoundTripTime.with_unsigned(1234u64),
];
let mp = crate::protocol::control::ClientMessageV2 {
attributes,
..Default::default()
};
let c = combine_bandwidth_configurations(&mut mgr, &mp).unwrap();
assert_matches!(
c,
Configuration {
port: PortRange {
begin: 1000,
end: 1500
},
rtt: 1234,
initial_congestion_window: 1000,
..
}
);
}
#[test]
fn test_min_ignoring_zero() {
use super::{CombinationResponse, min_ignoring_zero};
assert_matches!(min_ignoring_zero(0, 0), CombinationResponse::Server);
assert_matches!(min_ignoring_zero(0, 100), CombinationResponse::Server);
assert_matches!(min_ignoring_zero(100, 0), CombinationResponse::Client);
assert_matches!(
min_ignoring_zero(100, 200),
CombinationResponse::Combined(100)
);
assert_matches!(
min_ignoring_zero(200, 100),
CombinationResponse::Combined(100)
);
}
#[test]
fn issue169_bandwidth_config() {
let server_cfg = Configuration_Optional {
rx: Some(123_456),
..Default::default()
};
let mut mgr = Manager::without_files(None);
mgr.merge_provider(server_cfg);
let cfg = mgr.get::<Configuration>().unwrap();
assert_eq!(cfg.rx(), 123_456);
assert_eq!(cfg.tx(), 123_456);
let creds = Credentials::generate().unwrap();
let cert = creds.to_tagged_data(OUR_COMPATIBILITY_LEVEL, None).unwrap();
let cfg_o = mgr.get::<Configuration_Optional>().unwrap();
let cmsg = ClientMessage::new(
OUR_COMPATIBILITY_LEVEL,
cert.clone(),
ConnectionType::Ipv4,
false,
&cfg_o,
);
if let ClientMessage::V2(cmsg) = cmsg {
let tx = cmsg
.attributes
.find_tag(ClientMessage2Attributes::BandwidthToServer)
.unwrap();
let rx = cmsg
.attributes
.find_tag(ClientMessage2Attributes::BandwidthToClient)
.unwrap();
assert_eq!(rx.coerce_unsigned(), 123_456);
assert_eq!(tx.coerce_unsigned(), 123_456);
} else {
panic!();
}
let smsg = ServerMessage::new(
OUR_COMPATIBILITY_LEVEL,
&cfg,
1234,
cert,
"test".into(),
String::new(),
);
if let ServerMessage::V2(smsg) = smsg {
let tx = smsg.bandwidth_to_client.0;
let rx = smsg.bandwidth_to_server.0;
assert_eq!(rx, 123_456);
assert_eq!(tx, 123_456);
} else {
panic!();
}
}
#[test]
fn issue169_dont_send_default_bandwidth_in_client_message() {
let mgr = Manager::without_default(None);
let creds = Credentials::generate().unwrap();
let cert = creds.to_tagged_data(OUR_COMPATIBILITY_LEVEL, None).unwrap();
let cfg_o = mgr.get::<Configuration_Optional>().unwrap();
let cmsg = ClientMessage::new(
OUR_COMPATIBILITY_LEVEL,
cert.clone(),
ConnectionType::Ipv4,
false,
&cfg_o,
);
if let ClientMessage::V2(cmsg) = cmsg {
assert!(
cmsg.attributes
.find_tag(ClientMessage2Attributes::BandwidthToServer)
.is_none()
);
assert!(
cmsg.attributes
.find_tag(ClientMessage2Attributes::BandwidthToClient)
.is_none()
);
} else {
panic!();
}
}
#[test]
fn issue169_curveball_do_send_default_bandwidth_if_explicitly_set() {
let def = Configuration::system_default();
let cfg = Configuration_Optional {
rx: Some(def.rx()),
tx: Some(def.tx()),
..Default::default()
};
let mut mgr = Manager::without_files(None);
mgr.merge_provider(cfg);
let creds = Credentials::generate().unwrap();
let cert = creds.to_tagged_data(OUR_COMPATIBILITY_LEVEL, None).unwrap();
let cfg_o = mgr.get::<Configuration_Optional>().unwrap();
let cmsg = ClientMessage::new(
OUR_COMPATIBILITY_LEVEL,
cert.clone(),
ConnectionType::Ipv4,
false,
&cfg_o,
);
if let ClientMessage::V2(cmsg) = cmsg {
assert_eq!(
cmsg.attributes
.find_tag(ClientMessage2Attributes::BandwidthToServer)
.unwrap()
.coerce_unsigned(),
def.tx(),
);
assert_eq!(
cmsg.attributes
.find_tag(ClientMessage2Attributes::BandwidthToClient)
.unwrap()
.coerce_unsigned(),
def.rx(),
);
} else {
panic!();
}
}
#[test]
fn bandwidth_negotiation() {
let server_cfg = Configuration_Optional {
rx: Some(222_111),
tx: Some(333_444),
..Default::default()
};
let mut mgr = Manager::new(None, false, false);
mgr.merge_provider(server_cfg);
let attributes = vec![
ClientMessage2Attributes::BandwidthToClient.with_unsigned(987_654u32), ClientMessage2Attributes::BandwidthToServer.with_unsigned(123_456u32), ];
let cmsg = crate::protocol::control::ClientMessageV2 {
attributes,
..Default::default()
};
let c = combine_bandwidth_configurations(&mut mgr, &cmsg).unwrap();
assert_eq!(c.tx, 333_444);
assert_eq!(c.rx, 123_456);
}
}