use super::error::ProtocolError;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct ProtocolVersion {
pub major: u16,
pub minor: u16,
}
impl ProtocolVersion {
pub const WIRE_LEN: usize = 4;
#[must_use]
pub const fn new(major: u16, minor: u16) -> Self {
Self { major, minor }
}
#[must_use]
pub const fn serialize(self) -> [u8; Self::WIRE_LEN] {
let major = self.major.to_be_bytes();
let minor = self.minor.to_be_bytes();
[major[0], major[1], minor[0], minor[1]]
}
#[must_use]
pub const fn to_wire_bytes(self) -> [u8; Self::WIRE_LEN] {
self.serialize()
}
pub fn deserialize(bytes: &[u8]) -> Result<Self, ProtocolError> {
Self::from_wire_bytes(bytes)
}
pub fn from_wire_bytes(bytes: &[u8]) -> Result<Self, ProtocolError> {
match bytes {
[major_high, major_low, minor_high, minor_low] => Ok(Self::new(
u16::from_be_bytes([*major_high, *major_low]),
u16::from_be_bytes([*minor_high, *minor_low]),
)),
_ => Err(ProtocolError::codec(
"protocol version must be exactly 4 bytes",
)),
}
}
}
pub fn negotiate_version(
min_version: ProtocolVersion,
max_version: ProtocolVersion,
supported_versions: &[ProtocolVersion],
) -> Result<ProtocolVersion, ProtocolError> {
supported_versions
.iter()
.copied()
.filter(|version| min_version <= *version && *version <= max_version)
.max()
.ok_or_else(|| ProtocolError::VersionMismatch {
message: Some("no mutually supported protocol version".to_owned()),
})
}
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use super::{ProtocolVersion, negotiate_version};
use crate::protocol::ProtocolError;
#[test]
fn version_trait_bounds_are_available() {
fn assert_traits<T: Debug + Clone + Copy + PartialEq + Eq + PartialOrd + Ord>() {}
assert_traits::<ProtocolVersion>();
}
#[test]
fn versions_order_by_major_then_minor() {
assert!(ProtocolVersion::new(1, 0) < ProtocolVersion::new(1, 1));
assert!(ProtocolVersion::new(1, 1) < ProtocolVersion::new(2, 0));
}
#[test]
fn version_serializes_to_exactly_four_bytes() {
let bytes = ProtocolVersion::new(0x0102, 0x0304).serialize();
assert_eq!(bytes.len(), ProtocolVersion::WIRE_LEN);
assert_eq!(bytes, [0x01, 0x02, 0x03, 0x04]);
}
#[test]
fn version_round_trips_through_wire_bytes() -> Result<(), ProtocolError> {
let version = ProtocolVersion::new(2, 7);
let bytes = version.to_wire_bytes();
assert_eq!(ProtocolVersion::deserialize(&bytes)?, version);
assert_eq!(ProtocolVersion::from_wire_bytes(&bytes)?, version);
Ok(())
}
#[test]
fn deserialize_rejects_wrong_length() {
assert!(matches!(
ProtocolVersion::deserialize(&[0, 1, 0]),
Err(ProtocolError::CodecError { .. })
));
}
#[test]
fn negotiation_selects_highest_mutual_version() -> Result<(), ProtocolError> {
let supported = [
ProtocolVersion::new(1, 0),
ProtocolVersion::new(1, 1),
ProtocolVersion::new(2, 0),
ProtocolVersion::new(3, 0),
];
let selected = negotiate_version(
ProtocolVersion::new(1, 0),
ProtocolVersion::new(2, 0),
&supported,
)?;
assert_eq!(selected, ProtocolVersion::new(2, 0));
Ok(())
}
#[test]
fn negotiation_reports_version_mismatch() {
let supported = [ProtocolVersion::new(2, 0), ProtocolVersion::new(3, 0)];
let result = negotiate_version(
ProtocolVersion::new(1, 0),
ProtocolVersion::new(1, 5),
&supported,
);
assert!(matches!(result, Err(ProtocolError::VersionMismatch { .. })));
}
#[test]
fn negotiation_selects_exact_single_version() -> Result<(), ProtocolError> {
let supported = [ProtocolVersion::new(1, 0)];
let selected = negotiate_version(
ProtocolVersion::new(1, 0),
ProtocolVersion::new(1, 0),
&supported,
)?;
assert_eq!(selected, ProtocolVersion::new(1, 0));
Ok(())
}
}