use crate::util::{SerializeEnumAsString, ToStringForFigment};
use anyhow::anyhow;
use serde::{Deserialize, Serialize};
use serde_bare::Uint;
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::net::{IpAddr, SocketAddr};
pub const BANNER: &str = "qcp-server-2\n";
pub const OLD_BANNER: &str = "qcp-server-1\n";
pub(crate) const OUR_COMPATIBILITY_NUMERIC: u16 = 4;
pub const OUR_COMPATIBILITY_LEVEL: Compatibility = Compatibility::Level(OUR_COMPATIBILITY_NUMERIC);
mod client_msg;
pub use client_msg::*;
mod server_msg;
pub use server_msg::*;
mod greetings;
pub use greetings::*;
mod closedown;
pub use closedown::*;
use engineering_repr::EngineeringQuantity as EQ;
fn display_opt_uint(label: &str, bandwidth: Option<&Uint>) -> String {
bandwidth.map_or_else(String::new, |u| {
format!(", {label}: {}", EQ::<u64>::from(u.0))
})
}
fn display_opt<T: std::fmt::Display>(label: &str, value: Option<&T>) -> String {
value
.as_ref()
.map_or_else(String::new, |v| format!(", {label}: {v}"))
}
#[derive(Clone, Copy, Debug, Default, derive_more::Display, PartialEq, Serialize, Deserialize)]
pub enum Compatibility {
#[default]
#[serde(skip_serializing)]
Unknown,
#[serde(skip_serializing)]
Newer,
#[serde(untagged)]
Level(u16),
}
impl From<Compatibility> for u16 {
fn from(value: Compatibility) -> Self {
match value {
Compatibility::Level(v) => v,
Compatibility::Unknown | Compatibility::Newer => 0,
}
}
}
impl From<u16> for Compatibility {
fn from(value: u16) -> Self {
if value > OUR_COMPATIBILITY_NUMERIC {
Compatibility::Newer
} else {
Compatibility::Level(value)
}
}
}
#[derive(
Serialize_repr,
Deserialize_repr,
PartialEq,
Eq,
Debug,
Default,
Clone,
Copy,
strum_macros::Display,
)]
#[repr(u8)]
pub enum ConnectionType {
#[default]
Ipv4 = 4,
Ipv6 = 6,
}
impl From<IpAddr> for ConnectionType {
fn from(value: IpAddr) -> Self {
match value {
IpAddr::V4(_) => ConnectionType::Ipv4,
IpAddr::V6(_) => ConnectionType::Ipv6,
}
}
}
impl From<SocketAddr> for ConnectionType {
fn from(value: SocketAddr) -> Self {
match value {
SocketAddr::V4(_) => ConnectionType::Ipv4,
SocketAddr::V6(_) => ConnectionType::Ipv6,
}
}
}
#[derive(
Copy,
Clone,
Debug,
Default,
PartialEq,
Eq,
Serialize,
Deserialize,
strum_macros::Display,
strum_macros::EnumString,
strum_macros::FromRepr,
strum_macros::VariantNames,
strum::AsRefStr,
clap::ValueEnum,
enumscribe::TryUnscribe,
enumscribe::ScribeString,
)]
#[serde(try_from = "Uint")]
#[serde(into = "Uint")]
#[strum(ascii_case_insensitive)]
#[strum(serialize_all = "lowercase")]
#[value(rename_all = "lower")]
#[enumscribe(case_insensitive)]
pub enum CongestionController {
#[default]
Cubic,
Bbr,
NewReno,
}
impl SerializeEnumAsString for CongestionController {}
impl ToStringForFigment for CongestionController {}
impl From<CongestionController> for Uint {
fn from(value: CongestionController) -> Self {
Self(value as u64)
}
}
impl TryFrom<Uint> for CongestionController {
type Error = anyhow::Error;
fn try_from(value: Uint) -> anyhow::Result<Self> {
let v = usize::try_from(value.0)?;
CongestionController::from_repr(v).ok_or(anyhow!("invalid congestioncontroller enum"))
}
}
impl From<CongestionController> for figment::value::Value {
fn from(value: CongestionController) -> Self {
value.to_string().into()
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod test {
use std::{
io::Cursor,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
};
use pretty_assertions::assert_eq;
use serde::{Deserialize, Serialize};
use crate::protocol::{
DataTag as _, TaggedData,
common::ProtocolMessage,
control::{Compatibility, ConnectionType, CredentialsType},
};
pub(crate) fn dummy_cert() -> Vec<u8> {
vec![0, 1, 2]
}
pub(crate) fn dummy_credentials() -> TaggedData<CredentialsType> {
CredentialsType::X509.with_bytes(vec![0, 1, 2])
}
#[test]
fn convert_connection_type() {
let ip4 = IpAddr::from(Ipv4Addr::LOCALHOST);
let ct4 = ConnectionType::from(ip4);
assert_eq!(ct4, ConnectionType::Ipv4);
let ip6 = IpAddr::from(Ipv6Addr::LOCALHOST);
let ct6 = ConnectionType::from(ip6);
assert_eq!(ct6, ConnectionType::Ipv6);
let sa4: SocketAddr = "127.0.0.1:1234".parse().unwrap();
let ct4 = ConnectionType::from(sa4);
assert_eq!(ct4, ConnectionType::Ipv4);
let sa6: SocketAddr = "[::1]:4321".parse().unwrap();
let ct6 = ConnectionType::from(sa6);
assert_eq!(ct6, ConnectionType::Ipv6);
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
struct Test1 {
i: i32,
extension: u8,
}
impl ProtocolMessage for Test1 {}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
struct Test2 {
i: i32,
whatever: Option<u64>,
}
impl ProtocolMessage for Test2 {}
#[test]
fn forwards_compatibility() {
let t1 = Test1 {
i: 42,
extension: 0,
};
let mut buf = Vec::<u8>::new();
t1.to_writer_framed(&mut buf).unwrap();
let decoded = Test2::from_reader_framed(&mut Cursor::new(buf)).unwrap();
assert_eq!(decoded.i, t1.i);
assert!(decoded.whatever.is_none());
}
#[test]
fn backwards_compatibility() {
let t2 = Test2 {
i: 78,
whatever: Some(12345),
};
let mut buf = Vec::<u8>::new();
t2.to_writer_framed(&mut buf).unwrap();
let decoded = Test1::from_reader_framed(&mut Cursor::new(buf)).unwrap();
assert_eq!(decoded.i, t2.i);
assert_eq!(decoded.extension, 1);
}
#[test]
fn compat_level_from_wire() {
let cases = &[
(0u16, Compatibility::Level(0)),
(1, Compatibility::Level(1)),
(2, Compatibility::Level(2)),
(32768, Compatibility::Newer),
(65535, Compatibility::Newer),
];
for (wire, compat) in cases {
let level: Compatibility = (*wire).into();
assert_eq!(
level, *compat,
"wire {wire} should be {compat:?} but got {level}"
);
let wire2 = u16::from(*compat);
if *compat == Compatibility::Newer {
assert_eq!(wire2, 0, "compat Newer should be wire 0");
} else {
assert_eq!(
wire2, *wire,
"compat {compat:?} failed to convert back (expected {wire})"
);
}
}
}
}