liminal/protocol/
version.rs1use super::error::ProtocolError;
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
5pub struct ProtocolVersion {
6 pub major: u16,
8 pub minor: u16,
10}
11
12impl ProtocolVersion {
13 pub const WIRE_LEN: usize = 4;
15
16 #[must_use]
18 pub const fn new(major: u16, minor: u16) -> Self {
19 Self { major, minor }
20 }
21
22 #[must_use]
24 pub const fn serialize(self) -> [u8; Self::WIRE_LEN] {
25 let major = self.major.to_be_bytes();
26 let minor = self.minor.to_be_bytes();
27 [major[0], major[1], minor[0], minor[1]]
28 }
29
30 #[must_use]
32 pub const fn to_wire_bytes(self) -> [u8; Self::WIRE_LEN] {
33 self.serialize()
34 }
35
36 pub fn deserialize(bytes: &[u8]) -> Result<Self, ProtocolError> {
42 Self::from_wire_bytes(bytes)
43 }
44
45 pub fn from_wire_bytes(bytes: &[u8]) -> Result<Self, ProtocolError> {
51 match bytes {
52 [major_high, major_low, minor_high, minor_low] => Ok(Self::new(
53 u16::from_be_bytes([*major_high, *major_low]),
54 u16::from_be_bytes([*minor_high, *minor_low]),
55 )),
56 _ => Err(ProtocolError::codec(
57 "protocol version must be exactly 4 bytes",
58 )),
59 }
60 }
61}
62
63pub fn negotiate_version(
70 min_version: ProtocolVersion,
71 max_version: ProtocolVersion,
72 supported_versions: &[ProtocolVersion],
73) -> Result<ProtocolVersion, ProtocolError> {
74 supported_versions
75 .iter()
76 .copied()
77 .filter(|version| min_version <= *version && *version <= max_version)
78 .max()
79 .ok_or_else(|| ProtocolError::VersionMismatch {
80 message: Some("no mutually supported protocol version".to_owned()),
81 })
82}
83
84#[cfg(test)]
85mod tests {
86 use std::fmt::Debug;
87
88 use super::{ProtocolVersion, negotiate_version};
89 use crate::protocol::ProtocolError;
90
91 #[test]
92 fn version_trait_bounds_are_available() {
93 fn assert_traits<T: Debug + Clone + Copy + PartialEq + Eq + PartialOrd + Ord>() {}
94
95 assert_traits::<ProtocolVersion>();
96 }
97
98 #[test]
99 fn versions_order_by_major_then_minor() {
100 assert!(ProtocolVersion::new(1, 0) < ProtocolVersion::new(1, 1));
101 assert!(ProtocolVersion::new(1, 1) < ProtocolVersion::new(2, 0));
102 }
103
104 #[test]
105 fn version_serializes_to_exactly_four_bytes() {
106 let bytes = ProtocolVersion::new(0x0102, 0x0304).serialize();
107
108 assert_eq!(bytes.len(), ProtocolVersion::WIRE_LEN);
109 assert_eq!(bytes, [0x01, 0x02, 0x03, 0x04]);
110 }
111
112 #[test]
113 fn version_round_trips_through_wire_bytes() -> Result<(), ProtocolError> {
114 let version = ProtocolVersion::new(2, 7);
115 let bytes = version.to_wire_bytes();
116
117 assert_eq!(ProtocolVersion::deserialize(&bytes)?, version);
118 assert_eq!(ProtocolVersion::from_wire_bytes(&bytes)?, version);
119 Ok(())
120 }
121
122 #[test]
123 fn deserialize_rejects_wrong_length() {
124 assert!(matches!(
125 ProtocolVersion::deserialize(&[0, 1, 0]),
126 Err(ProtocolError::CodecError { .. })
127 ));
128 }
129
130 #[test]
131 fn negotiation_selects_highest_mutual_version() -> Result<(), ProtocolError> {
132 let supported = [
133 ProtocolVersion::new(1, 0),
134 ProtocolVersion::new(1, 1),
135 ProtocolVersion::new(2, 0),
136 ProtocolVersion::new(3, 0),
137 ];
138
139 let selected = negotiate_version(
140 ProtocolVersion::new(1, 0),
141 ProtocolVersion::new(2, 0),
142 &supported,
143 )?;
144
145 assert_eq!(selected, ProtocolVersion::new(2, 0));
146 Ok(())
147 }
148
149 #[test]
150 fn negotiation_reports_version_mismatch() {
151 let supported = [ProtocolVersion::new(2, 0), ProtocolVersion::new(3, 0)];
152
153 let result = negotiate_version(
154 ProtocolVersion::new(1, 0),
155 ProtocolVersion::new(1, 5),
156 &supported,
157 );
158
159 assert!(matches!(result, Err(ProtocolError::VersionMismatch { .. })));
160 }
161
162 #[test]
163 fn negotiation_selects_exact_single_version() -> Result<(), ProtocolError> {
164 let supported = [ProtocolVersion::new(1, 0)];
165
166 let selected = negotiate_version(
167 ProtocolVersion::new(1, 0),
168 ProtocolVersion::new(1, 0),
169 &supported,
170 )?;
171
172 assert_eq!(selected, ProtocolVersion::new(1, 0));
173 Ok(())
174 }
175}