use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
use rkyv::rancor;
pub const MAX_LOCAL_PROTOCOL_VERSION: u16 = 1;
pub const MIN_LOCAL_PROTOCOL_VERSION: u16 = 1;
#[derive(
Archive,
RkyvSerialize,
RkyvDeserialize,
Debug,
Clone,
Copy,
PartialEq,
Eq,
)]
#[rkyv(derive(Debug))]
pub struct VersionHandshake {
pub max_version: u16,
pub min_version: u16,
pub build_id: [u8; 32],
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NegotiatedVersion {
pub version: u16,
}
impl VersionHandshake {
#[must_use]
pub fn local(build_id: [u8; 32]) -> Self {
Self {
max_version: MAX_LOCAL_PROTOCOL_VERSION,
min_version: MIN_LOCAL_PROTOCOL_VERSION,
build_id,
}
}
#[must_use]
pub fn negotiate(&self, peer: &VersionHandshake) -> Option<NegotiatedVersion> {
let candidate = self.max_version.min(peer.max_version);
if candidate >= self.min_version && candidate >= peer.min_version {
Some(NegotiatedVersion { version: candidate })
} else {
None
}
}
pub fn access(bytes: &[u8]) -> Result<&ArchivedVersionHandshake, rancor::Error> {
rkyv::access::<ArchivedVersionHandshake, rancor::Error>(bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn build(id: u8) -> [u8; 32] {
let mut a = [0u8; 32];
a[0] = id;
a
}
#[test]
fn equal_windows_pick_the_max() {
let a = VersionHandshake::local(build(1));
let b = VersionHandshake::local(build(2));
let neg = a.negotiate(&b).unwrap();
assert_eq!(neg.version, MAX_LOCAL_PROTOCOL_VERSION);
}
#[test]
fn overlapping_window_picks_the_intersection_max() {
let a = VersionHandshake {
max_version: 5,
min_version: 3,
build_id: build(1),
};
let b = VersionHandshake {
max_version: 4,
min_version: 2,
build_id: build(2),
};
let neg = a.negotiate(&b).unwrap();
assert_eq!(neg.version, 4);
}
#[test]
fn disjoint_windows_refuse() {
let old = VersionHandshake {
max_version: 2,
min_version: 1,
build_id: build(1),
};
let modern = VersionHandshake {
max_version: 5,
min_version: 4,
build_id: build(2),
};
assert!(old.negotiate(&modern).is_none());
assert!(modern.negotiate(&old).is_none());
}
#[test]
fn handshake_roundtrips_via_rkyv() {
let h = VersionHandshake::local(build(7));
let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&h).unwrap();
let arc = VersionHandshake::access(&bytes).unwrap();
assert_eq!(arc.max_version, h.max_version);
assert_eq!(arc.min_version, h.min_version);
assert_eq!(arc.build_id, h.build_id);
}
}