Skip to main content

liminal/protocol/
version.rs

1use super::error::ProtocolError;
2
3/// Protocol version negotiated during the connection handshake.
4#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
5pub struct ProtocolVersion {
6    /// Major protocol version component.
7    pub major: u16,
8    /// Minor protocol version component.
9    pub minor: u16,
10}
11
12impl ProtocolVersion {
13    /// Number of bytes used by a serialized protocol version.
14    pub const WIRE_LEN: usize = 4;
15
16    /// Create a protocol version from its major and minor components.
17    #[must_use]
18    pub const fn new(major: u16, minor: u16) -> Self {
19        Self { major, minor }
20    }
21
22    /// Serialize this version as big-endian `major` followed by big-endian `minor`.
23    #[must_use]
24    pub const fn serialize(self) -> [u8; Self::WIRE_LEN] {
25        let major = self.major.to_be_bytes();
26        let minor = self.minor.to_be_bytes();
27        [major[0], major[1], minor[0], minor[1]]
28    }
29
30    /// Serialize this version for wire transport.
31    #[must_use]
32    pub const fn to_wire_bytes(self) -> [u8; Self::WIRE_LEN] {
33        self.serialize()
34    }
35
36    /// Deserialize a protocol version from wire bytes.
37    ///
38    /// # Errors
39    ///
40    /// Returns [`ProtocolError::CodecError`] when `bytes` is not exactly four bytes long.
41    pub fn deserialize(bytes: &[u8]) -> Result<Self, ProtocolError> {
42        Self::from_wire_bytes(bytes)
43    }
44
45    /// Deserialize a protocol version from wire bytes.
46    ///
47    /// # Errors
48    ///
49    /// Returns [`ProtocolError::CodecError`] when `bytes` is not exactly four bytes long.
50    pub fn from_wire_bytes(bytes: &[u8]) -> Result<Self, ProtocolError> {
51        match bytes {
52            [major_high, major_low, minor_high, minor_low] => Ok(Self::new(
53                u16::from_be_bytes([*major_high, *major_low]),
54                u16::from_be_bytes([*minor_high, *minor_low]),
55            )),
56            _ => Err(ProtocolError::codec(
57                "protocol version must be exactly 4 bytes",
58            )),
59        }
60    }
61}
62
63/// Select the highest server-supported protocol version within the client's range.
64///
65/// # Errors
66///
67/// Returns [`ProtocolError::VersionMismatch`] when `supported_versions` contains no
68/// version in the inclusive range from `min_version` through `max_version`.
69pub fn negotiate_version(
70    min_version: ProtocolVersion,
71    max_version: ProtocolVersion,
72    supported_versions: &[ProtocolVersion],
73) -> Result<ProtocolVersion, ProtocolError> {
74    supported_versions
75        .iter()
76        .copied()
77        .filter(|version| min_version <= *version && *version <= max_version)
78        .max()
79        .ok_or_else(|| ProtocolError::VersionMismatch {
80            message: Some("no mutually supported protocol version".to_owned()),
81        })
82}
83
84#[cfg(test)]
85mod tests {
86    use std::fmt::Debug;
87
88    use super::{ProtocolVersion, negotiate_version};
89    use crate::protocol::ProtocolError;
90
91    #[test]
92    fn version_trait_bounds_are_available() {
93        fn assert_traits<T: Debug + Clone + Copy + PartialEq + Eq + PartialOrd + Ord>() {}
94
95        assert_traits::<ProtocolVersion>();
96    }
97
98    #[test]
99    fn versions_order_by_major_then_minor() {
100        assert!(ProtocolVersion::new(1, 0) < ProtocolVersion::new(1, 1));
101        assert!(ProtocolVersion::new(1, 1) < ProtocolVersion::new(2, 0));
102    }
103
104    #[test]
105    fn version_serializes_to_exactly_four_bytes() {
106        let bytes = ProtocolVersion::new(0x0102, 0x0304).serialize();
107
108        assert_eq!(bytes.len(), ProtocolVersion::WIRE_LEN);
109        assert_eq!(bytes, [0x01, 0x02, 0x03, 0x04]);
110    }
111
112    #[test]
113    fn version_round_trips_through_wire_bytes() -> Result<(), ProtocolError> {
114        let version = ProtocolVersion::new(2, 7);
115        let bytes = version.to_wire_bytes();
116
117        assert_eq!(ProtocolVersion::deserialize(&bytes)?, version);
118        assert_eq!(ProtocolVersion::from_wire_bytes(&bytes)?, version);
119        Ok(())
120    }
121
122    #[test]
123    fn deserialize_rejects_wrong_length() {
124        assert!(matches!(
125            ProtocolVersion::deserialize(&[0, 1, 0]),
126            Err(ProtocolError::CodecError { .. })
127        ));
128    }
129
130    #[test]
131    fn negotiation_selects_highest_mutual_version() -> Result<(), ProtocolError> {
132        let supported = [
133            ProtocolVersion::new(1, 0),
134            ProtocolVersion::new(1, 1),
135            ProtocolVersion::new(2, 0),
136            ProtocolVersion::new(3, 0),
137        ];
138
139        let selected = negotiate_version(
140            ProtocolVersion::new(1, 0),
141            ProtocolVersion::new(2, 0),
142            &supported,
143        )?;
144
145        assert_eq!(selected, ProtocolVersion::new(2, 0));
146        Ok(())
147    }
148
149    #[test]
150    fn negotiation_reports_version_mismatch() {
151        let supported = [ProtocolVersion::new(2, 0), ProtocolVersion::new(3, 0)];
152
153        let result = negotiate_version(
154            ProtocolVersion::new(1, 0),
155            ProtocolVersion::new(1, 5),
156            &supported,
157        );
158
159        assert!(matches!(result, Err(ProtocolError::VersionMismatch { .. })));
160    }
161
162    #[test]
163    fn negotiation_selects_exact_single_version() -> Result<(), ProtocolError> {
164        let supported = [ProtocolVersion::new(1, 0)];
165
166        let selected = negotiate_version(
167            ProtocolVersion::new(1, 0),
168            ProtocolVersion::new(1, 0),
169            &supported,
170        )?;
171
172        assert_eq!(selected, ProtocolVersion::new(1, 0));
173        Ok(())
174    }
175}