use super::{CongestionController, ConnectionType, display_opt, display_opt_uint};
use crate::protocol::prelude::*;
use crate::{
config::Configuration_Optional,
transport::ThroughputMode,
util::{PortRange, serialization::SerializeEnumAsString},
};
use engineering_repr::EngineeringRepr as _;
use int_enum::IntEnum;
use num_traits::AsPrimitive;
#[derive(
Clone,
Serialize,
Deserialize,
PartialEq,
Debug,
Default,
derive_more::Display,
derive_more::From,
)]
pub enum ClientMessage {
#[default]
#[serde(skip_serializing)]
ToFollow, V1(ClientMessageV1),
V2(ClientMessageV2),
}
impl ProtocolMessage for ClientMessage {}
#[derive(
Clone, Serialize, Deserialize, PartialEq, Default, derive_more::Debug, derive_more::Display,
)]
#[display(
"{connection_type}{}{}{}{}{}{}{}, attributes {}",
display_opt("remote port", port.as_ref()),
display_opt_uint("bw to client", bandwidth_to_client.as_ref()),
display_opt_uint("bw to server", bandwidth_to_server.as_ref()),
display_opt("RTT", rtt.as_ref()),
display_opt("congestion algorithm ", congestion.as_ref()),
display_opt_uint("cwnd ", initial_congestion_window.as_ref()),
display_opt("timeout", timeout.as_ref()),
display_vec_td(attributes),
)]
pub struct ClientMessageV1 {
#[debug(ignore)]
pub cert: Vec<u8>,
pub connection_type: ConnectionType,
pub port: Option<PortRange>,
pub show_config: bool,
pub bandwidth_to_server: Option<Uint>,
pub bandwidth_to_client: Option<Uint>,
pub rtt: Option<u16>,
pub congestion: Option<CongestionController>,
pub initial_congestion_window: Option<Uint>,
pub timeout: Option<u16>,
pub attributes: Vec<TaggedData<ClientMessageAttributes>>,
}
#[derive(
strum_macros::Display,
Clone,
Copy,
Debug,
Default,
IntEnum,
PartialEq,
Serialize,
Deserialize,
clap::ValueEnum,
strum::AsRefStr,
strum_macros::EnumString,
strum_macros::VariantNames,
enumscribe::TryUnscribe,
)]
#[non_exhaustive]
#[repr(u64)]
#[strum(ascii_case_insensitive)]
#[value(rename_all = "lower")]
#[enumscribe(case_insensitive)]
pub enum CredentialsType {
Any = 0,
X509,
#[default]
RawPublicKey,
}
impl DataTag for CredentialsType {}
impl SerializeEnumAsString for CredentialsType {}
#[derive(
Clone, Default, Serialize, Deserialize, PartialEq, derive_more::Debug, derive_more::Display,
)]
#[display("{connection_type}, attributes {}", display_vec_td(attributes))]
pub struct ClientMessageV2 {
#[debug(ignore)]
pub credentials: TaggedData<CredentialsType>,
pub connection_type: ConnectionType,
pub attributes: Vec<TaggedData<ClientMessage2Attributes>>,
pub extension: u8,
}
impl ClientMessageV2 {
fn new(credentials: TaggedData<CredentialsType>, connection_type: ConnectionType) -> Self {
Self {
credentials,
connection_type,
attributes: Vec::new(),
extension: 0,
}
}
}
impl ClientMessageV2 {
pub(crate) fn apply_config_attributes(
&mut self,
remote_config: bool,
our_config: &Configuration_Optional,
) {
if remote_config {
self.attributes
.push(ClientMessage2Attributes::OutputConfig.into());
}
if let Some(pr) = our_config.remote_port {
self.attributes
.push(ClientMessage2Attributes::PortRangeStart.with_unsigned(pr.begin));
self.attributes
.push(ClientMessage2Attributes::PortRangeEnd.with_unsigned(pr.end));
}
let rx_bw = match our_config.tx {
None | Some(0) => our_config.rx,
Some(i) => Some(i),
};
if let Some(rx) = rx_bw {
self.attributes
.push(ClientMessage2Attributes::BandwidthToServer.with_unsigned(rx));
}
if let Some(eq) = our_config.rx {
self.attributes
.push(ClientMessage2Attributes::BandwidthToClient.with_unsigned(eq));
}
if let Some(rtt) = our_config.rtt {
self.attributes
.push(ClientMessage2Attributes::RoundTripTime.with_unsigned(rtt));
}
if let Some(cc) = our_config.congestion {
self.attributes
.push(ClientMessage2Attributes::CongestionControllerType.with_unsigned(cc as u64));
}
if let Some(icw) = our_config.initial_congestion_window {
self.attributes
.push(ClientMessage2Attributes::InitialCongestionWindow.with_unsigned(icw));
}
if let Some(t) = our_config.timeout {
self.attributes
.push(ClientMessage2Attributes::QuicTimeout.with_unsigned(t));
}
}
}
impl From<ClientMessageV1> for ClientMessageV2 {
fn from(v1: ClientMessageV1) -> Self {
let mut attributes = Vec::new();
if let Some(pr) = v1.port {
attributes.push(ClientMessage2Attributes::PortRangeStart.with_unsigned(pr.begin));
attributes.push(ClientMessage2Attributes::PortRangeEnd.with_unsigned(pr.end));
}
if v1.show_config {
attributes.push(ClientMessage2Attributes::OutputConfig.into());
}
if let Some(Uint(bw)) = v1.bandwidth_to_server {
attributes.push(ClientMessage2Attributes::BandwidthToServer.with_unsigned(bw));
}
if let Some(Uint(bw)) = v1.bandwidth_to_client {
attributes.push(ClientMessage2Attributes::BandwidthToClient.with_unsigned(bw));
}
if let Some(rtt) = v1.rtt {
attributes.push(ClientMessage2Attributes::RoundTripTime.with_unsigned(rtt));
}
if let Some(cc) = v1.congestion {
attributes
.push(ClientMessage2Attributes::CongestionControllerType.with_unsigned(cc as u64));
}
if let Some(Uint(icw)) = v1.initial_congestion_window {
attributes.push(ClientMessage2Attributes::InitialCongestionWindow.with_unsigned(icw));
}
if let Some(t) = v1.timeout {
attributes.push(ClientMessage2Attributes::QuicTimeout.with_unsigned(t));
}
if let Some(v) = v1
.attributes
.find_tag(ClientMessageAttributes::DirectionOfTravel)
{
attributes.push(ClientMessage2Attributes::DirectionOfTravel.with_variant(v.clone()));
}
Self {
credentials: CredentialsType::X509.with_bytes(v1.cert),
connection_type: v1.connection_type,
attributes,
extension: 0,
}
}
}
#[cfg(test)]
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug, Default)]
pub(super) enum OriginalClientMessage {
#[default]
#[serde(skip_serializing)]
ToFollow,
V1(OriginalClientMessageV1),
}
#[cfg(test)]
impl ProtocolMessage for OriginalClientMessage {}
#[cfg(test)]
#[derive(Clone, Serialize, Deserialize, PartialEq, Default, derive_more::Debug)]
pub(super) struct OriginalClientMessageV1 {
pub(super) cert: Vec<u8>,
pub(super) connection_type: ConnectionType,
pub(super) port: Option<PortRange>,
pub(super) show_config: bool,
pub(super) bandwidth_to_server: Option<Uint>,
pub(super) bandwidth_to_client: Option<Uint>,
pub(super) rtt: Option<u16>,
pub(super) congestion: Option<CongestionController>,
pub(super) initial_congestion_window: Option<Uint>,
pub(super) timeout: Option<u16>,
pub(super) extension: u8,
}
#[derive(strum_macros::Display, Clone, Copy, Debug, IntEnum, PartialEq)]
#[non_exhaustive]
#[repr(u64)]
pub enum ClientMessageAttributes {
Invalid = 0,
DirectionOfTravel,
}
impl DataTag for ClientMessageAttributes {
fn debug_data(&self, data: &Variant) -> String {
match self {
ClientMessageAttributes::DirectionOfTravel => {
Direction::from_repr(data.coerce_unsigned().as_())
.unwrap_or(Direction::Both)
.to_string()
}
_ => format!("{data:?}"),
}
}
}
#[derive(strum_macros::Display, Clone, Copy, Debug, IntEnum, PartialEq)]
#[non_exhaustive]
#[repr(u64)]
pub enum ClientMessage2Attributes {
Invalid = 0,
DirectionOfTravel,
PortRangeStart,
PortRangeEnd,
OutputConfig,
BandwidthToServer,
BandwidthToClient,
RoundTripTime,
CongestionControllerType,
InitialCongestionWindow,
QuicTimeout,
}
impl DataTag for ClientMessage2Attributes {
fn debug_data(&self, data: &Variant) -> String {
match self {
ClientMessage2Attributes::DirectionOfTravel => {
Direction::from_repr(data.coerce_unsigned().as_())
.unwrap_or(Direction::Both)
.to_string()
}
ClientMessage2Attributes::CongestionControllerType => {
CongestionController::from_repr(data.coerce_unsigned().as_())
.unwrap_or_default()
.to_string()
}
ClientMessage2Attributes::BandwidthToClient
| ClientMessage2Attributes::BandwidthToServer => {
data.coerce_unsigned().to_eng(4).to_string()
}
_ => format!("{data:?}"),
}
}
}
#[derive(
strum_macros::Display, Clone, Copy, Debug, PartialEq, Eq, strum_macros::FromRepr, Default,
)]
#[allow(missing_docs)]
pub enum Direction {
#[default]
Both,
ClientToServer,
ServerToClient,
}
impl From<Direction> for Variant {
fn from(value: Direction) -> Self {
Variant::unsigned(value as u64)
}
}
impl From<&Variant> for Direction {
fn from(value: &Variant) -> Self {
Direction::from_repr(value.coerce_unsigned().as_()).unwrap_or_default()
}
}
impl From<Option<&Variant>> for Direction {
fn from(value: Option<&Variant>) -> Self {
value.map_or(Direction::default(), Direction::from)
}
}
impl Direction {
pub(crate) fn server_mode(self) -> ThroughputMode {
match self {
Direction::ClientToServer => ThroughputMode::Rx,
Direction::ServerToClient => ThroughputMode::Tx,
Direction::Both => ThroughputMode::Both,
}
}
pub(crate) fn client_mode(self) -> ThroughputMode {
match self {
Direction::ClientToServer => ThroughputMode::Tx,
Direction::ServerToClient => ThroughputMode::Rx,
Direction::Both => ThroughputMode::Both,
}
}
}
impl ClientMessage {
pub(crate) fn new(
compat: Compatibility,
cert: TaggedData<CredentialsType>,
connection_type: ConnectionType,
remote_config: bool,
my_config: &Configuration_Optional,
) -> Self {
assert!(cert.data.is_bytes());
if compat.supports(Feature::CMSG_SMSG_2) {
let mut msg = ClientMessageV2::new(cert, connection_type);
msg.apply_config_attributes(remote_config, my_config);
msg.into()
} else {
let cert_bytes = cert.data.into_bytes().unwrap_or_default();
ClientMessageV1::new(&cert_bytes, connection_type, remote_config, my_config).into()
}
}
pub(crate) fn set_direction(&mut self, direction: Direction) {
match self {
ClientMessage::ToFollow => (),
ClientMessage::V1(msg) => msg
.attributes
.push(ClientMessageAttributes::DirectionOfTravel.with_unsigned(direction as u64)),
ClientMessage::V2(msg) => msg
.attributes
.push(ClientMessage2Attributes::DirectionOfTravel.with_unsigned(direction as u64)),
}
}
}
impl ClientMessageV1 {
pub(super) fn new(
cert: &[u8],
connection_type: ConnectionType,
remote_config: bool,
my_config: &Configuration_Optional,
) -> Self {
let rx: &Option<u64> = &my_config.rx;
let icw: &Option<u64> = &my_config.initial_congestion_window;
Self {
cert: cert.to_vec(),
connection_type,
port: my_config.remote_port,
show_config: remote_config,
bandwidth_to_server: match my_config.tx {
None | Some(0) => None,
Some(v) => Some(Uint(v)),
},
bandwidth_to_client: rx.map(Uint),
rtt: my_config.rtt,
congestion: my_config.congestion,
initial_congestion_window: icw.map(Uint),
timeout: my_config.timeout,
attributes: vec![],
}
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod test {
use crate::protocol::prelude::*;
use assertables::{assert_contains, assert_matches};
use pretty_assertions::{assert_eq, assert_str_eq};
use serde_bare::Uint;
use crate::{
config::{Configuration_Optional, Manager},
protocol::control::{
ClientMessage2Attributes, ClientMessageAttributes, ClientMessageV1, ClientMessageV2,
Compatibility, CongestionController, ConnectionType, CredentialsType, Direction,
OriginalClientMessage, OriginalClientMessageV1,
test::{dummy_cert, dummy_credentials},
},
util::PortRange,
};
use super::ClientMessage;
#[test]
fn serialize_client_message() {
let config = Configuration_Optional {
tx: Some(42),
rx: Some(89),
rtt: Some(1234),
congestion: Some(CongestionController::Bbr),
udp_buffer: Some(456_789),
initial_congestion_window: Some(12345),
port: Some(PortRange { begin: 17, end: 98 }),
remote_port: Some(PortRange {
begin: 123,
end: 456,
}),
remote_user: None,
timeout: Some(432),
packet_threshold: None,
time_threshold: None,
initial_mtu: None,
min_mtu: None,
max_mtu: None,
address_family: None,
ssh: None,
remote_qcp_binary: None,
ssh_options: None,
time_format: None,
ssh_config: None,
ssh_subsystem: None,
color: None,
tls_auth_type: None,
aes256: None,
io_buffer_size: None,
};
let cmsg = {
let cert = CredentialsType::X509.with_bytes(dummy_cert());
let mut manager = Manager::without_default(None);
manager.merge_provider(&config);
let cfg = manager.get::<Configuration_Optional>().unwrap();
ClientMessage::new(
Compatibility::Level(1),
cert,
ConnectionType::Ipv4,
false,
&cfg,
)
};
let ser = cmsg.to_vec().unwrap();
println!("{cmsg:#?}");
println!("vec: {ser:?}");
let deser = ClientMessage::from_slice(&ser).unwrap();
let disp = format!("{cmsg}");
eprintln!("{disp}");
assert!(disp.contains("123-456"));
let _empty: Vec<TaggedData<ClientMessageAttributes>> = vec![];
assert_matches!(
deser,
ClientMessage::V1(ClientMessageV1 {
cert: _,
connection_type: ConnectionType::Ipv4,
port: Some(PortRange {
begin: 123,
end: 456
}),
show_config: false,
bandwidth_to_server: Some(Uint(42)),
bandwidth_to_client: Some(Uint(89)),
rtt: Some(1234),
congestion: Some(CongestionController::Bbr),
initial_congestion_window: Some(Uint(12345)),
timeout: Some(432),
attributes: _empty,
})
);
}
#[test]
fn construct_client_message() {
let cert = CredentialsType::X509.with_bytes(dummy_cert());
let mut manager = Manager::without_default(None);
let config = Configuration_Optional::default();
manager.merge_provider(&config);
let cfg = manager.get::<Configuration_Optional>().unwrap();
let cmsg = ClientMessage::new(
Compatibility::Level(1),
cert,
ConnectionType::Ipv4,
false,
&cfg,
);
assert_matches!(
cmsg,
ClientMessage::V1(ClientMessageV1 {
bandwidth_to_server: None,
..
})
);
}
#[test]
fn wire_marshalling_client_message_v1() {
let cert = dummy_cert();
let msg = ClientMessage::V1(ClientMessageV1::new(
&cert,
ConnectionType::Ipv4,
false,
&Configuration_Optional::default(),
));
let wire = msg.to_vec().unwrap();
let expected = b"\x01\x03\x00\x01\x02\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00".to_vec();
assert_eq!(wire, expected);
}
#[test]
fn wire_marshalling_client_message_v2() {
let cert = dummy_credentials();
let mut msg2 = ClientMessageV2::new(cert, ConnectionType::Ipv6);
msg2.attributes = vec![
ClientMessage2Attributes::DirectionOfTravel
.with_unsigned(Direction::ClientToServer as u64),
ClientMessage2Attributes::BandwidthToClient.with_unsigned(123_456u32),
];
let msg = ClientMessage::V2(msg2);
let wire = msg.to_vec().unwrap();
let expected =
b"\x02\x01\x05\x03\x00\x01\x02\x06\x02\x01\x03\x01\x06\x03\xc0\xc4\x07\x00".to_vec();
assert_eq!(wire, expected);
}
#[test]
fn display_clientmessage_attrs() {
let d = ClientMessageAttributes::DirectionOfTravel
.with_variant(Direction::ClientToServer.into());
let cm: ClientMessage = ClientMessage::V1(ClientMessageV1 {
attributes: vec![d.clone()],
..Default::default()
});
let s = format!("{d:?}");
eprintln!("{s}");
assert_str_eq!(
s,
"TaggedData { tag: ClientMessageAttributes::DirectionOfTravel, data: ClientToServer, .. }"
);
let s = format!("{cm:?}");
eprintln!("{s}");
assert!(s.contains("ClientMessageAttributes::DirectionOfTravel, data: ClientToServer"));
let s = display_vec_td(&vec![d.clone()]);
eprintln!("{s}");
assert_str_eq!(s, "[DirectionOfTravel:ClientToServer]");
let s = format!("{d}");
eprintln!("{s}");
assert_str_eq!(s, "(DirectionOfTravel, ClientToServer)");
let s = format!("{cm}");
eprintln!("{s}");
assert!(s.contains("[DirectionOfTravel:ClientToServer]"));
}
#[test]
fn clientmessagev1_attrs_backwards_compat() {
let d = ClientMessageAttributes::DirectionOfTravel
.with_variant(Direction::ClientToServer.into());
let cm = ClientMessage::V1(ClientMessageV1 {
attributes: vec![d.clone()],
..Default::default()
});
let wire = cm.to_vec().unwrap();
let decode = OriginalClientMessage::from_slice(&wire).unwrap();
assert_eq!(
decode,
OriginalClientMessage::V1(OriginalClientMessageV1 {
cert: vec![],
connection_type: ConnectionType::Ipv4,
port: None,
show_config: false,
bandwidth_to_server: None,
bandwidth_to_client: None,
rtt: None,
congestion: None,
initial_congestion_window: None,
timeout: None,
extension: 1, })
);
}
#[test]
fn client_message_2_debug_attrs() {
let msg = ClientMessageV2 {
attributes: vec![
ClientMessage2Attributes::Invalid.into(),
ClientMessage2Attributes::DirectionOfTravel
.with_unsigned(Direction::ClientToServer as u64),
ClientMessage2Attributes::CongestionControllerType
.with_unsigned(CongestionController::NewReno as u64),
ClientMessage2Attributes::OutputConfig.into(),
],
..Default::default()
};
let s = format!("{msg:?}");
assert!(s.contains("Invalid, data: Empty"));
assert!(s.contains("DirectionOfTravel, data: ClientToServer"));
assert!(s.contains("CongestionControllerType, data: newreno"));
assert!(s.contains("OutputConfig, data: Empty"));
}
#[test]
fn client_message_2_display() {
let msg = ClientMessageV2 {
attributes: vec![
ClientMessage2Attributes::BandwidthToClient.with_unsigned(123_456_789u32),
ClientMessage2Attributes::BandwidthToServer.with_unsigned(32_768u32),
ClientMessage2Attributes::Invalid.into(),
ClientMessage2Attributes::DirectionOfTravel
.with_unsigned(Direction::ClientToServer as u64),
ClientMessage2Attributes::CongestionControllerType
.with_unsigned(CongestionController::NewReno as u64),
ClientMessage2Attributes::OutputConfig.into(),
],
..Default::default()
};
let s = format!("{msg}");
eprintln!("{s}");
assert_contains!(s, "Invalid:Empty");
assert_contains!(s, "BandwidthToClient:123.4M");
assert_contains!(s, "BandwidthToServer:32.76k");
assert_contains!(s, "DirectionOfTravel:ClientToServer");
assert_contains!(s, "CongestionControllerType:newreno");
assert_contains!(s, "OutputConfig:Empty");
}
}