use super::error::WireVersionError;
use super::types::WireVersion;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct VersionRange {
pub min: WireVersion,
pub max: WireVersion,
}
impl VersionRange {
pub fn new(min: WireVersion, max: WireVersion) -> Self {
debug_assert!(
min <= max,
"VersionRange: min ({min}) must be <= max ({max})"
);
Self { min, max }
}
pub fn contains(&self, version: WireVersion) -> bool {
version >= self.min && version <= self.max
}
}
pub fn negotiate(
local: VersionRange,
remote: VersionRange,
) -> Result<WireVersion, WireVersionError> {
let intersect_min = local.min.max(remote.min);
let intersect_max = local.max.min(remote.max);
if intersect_min > intersect_max {
return Err(WireVersionError::NegotiationFailed {
local_min: local.min,
local_max: local.max,
remote_min: remote.min,
remote_max: remote.max,
});
}
Ok(intersect_max)
}
#[derive(
Debug,
Clone,
PartialEq,
Eq,
serde::Serialize,
serde::Deserialize,
zerompk::ToMessagePack,
zerompk::FromMessagePack,
)]
pub struct VersionHandshake {
pub range: (u16, u16),
#[serde(default)]
pub capabilities: u64,
}
#[derive(
Debug,
Clone,
PartialEq,
Eq,
serde::Serialize,
serde::Deserialize,
zerompk::ToMessagePack,
zerompk::FromMessagePack,
)]
pub struct VersionHandshakeAck {
pub agreed: u16,
#[serde(default)]
pub capabilities: u64,
}
impl VersionHandshake {
pub fn from_range(range: VersionRange) -> Self {
Self {
range: (range.min.0, range.max.0),
capabilities: 0,
}
}
pub fn to_range(&self) -> VersionRange {
VersionRange::new(WireVersion(self.range.0), WireVersion(self.range.1))
}
}
impl VersionHandshakeAck {
pub fn new(agreed: WireVersion) -> Self {
Self {
agreed: agreed.0,
capabilities: 0,
}
}
pub fn agreed_version(&self) -> WireVersion {
WireVersion(self.agreed)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn v(n: u16) -> WireVersion {
WireVersion(n)
}
fn range(min: u16, max: u16) -> VersionRange {
VersionRange::new(v(min), v(max))
}
#[test]
fn overlapping_ranges_return_highest_common() {
let result = negotiate(range(1, 3), range(2, 5)).unwrap();
assert_eq!(result, v(3));
}
#[test]
fn disjoint_ranges_return_negotiation_failed() {
let err = negotiate(range(1, 2), range(3, 5)).unwrap_err();
assert!(
matches!(err, WireVersionError::NegotiationFailed { .. }),
"expected NegotiationFailed, got: {err}"
);
}
#[test]
fn equal_single_version_succeeds() {
let result = negotiate(range(2, 2), range(2, 2)).unwrap();
assert_eq!(result, v(2));
}
#[test]
fn equal_min_equal_max_succeeds_with_that_version() {
let result = negotiate(range(1, 4), range(4, 6)).unwrap();
assert_eq!(result, v(4));
}
#[test]
fn handshake_roundtrip() {
let r = range(1, 2);
let hs = VersionHandshake::from_range(r);
let bytes = zerompk::to_msgpack_vec(&hs).unwrap();
let decoded: VersionHandshake = zerompk::from_msgpack(&bytes).unwrap();
assert_eq!(decoded.to_range(), r);
}
}