nodedb_cluster/wire_version/
negotiation.rs1use super::error::WireVersionError;
6use super::types::WireVersion;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct VersionRange {
12 pub min: WireVersion,
13 pub max: WireVersion,
14}
15
16impl VersionRange {
17 pub fn new(min: WireVersion, max: WireVersion) -> Self {
19 debug_assert!(
20 min <= max,
21 "VersionRange: min ({min}) must be <= max ({max})"
22 );
23 Self { min, max }
24 }
25
26 pub fn contains(&self, version: WireVersion) -> bool {
28 version >= self.min && version <= self.max
29 }
30}
31
32pub fn negotiate(
39 local: VersionRange,
40 remote: VersionRange,
41) -> Result<WireVersion, WireVersionError> {
42 let intersect_min = local.min.max(remote.min);
44 let intersect_max = local.max.min(remote.max);
45
46 if intersect_min > intersect_max {
47 return Err(WireVersionError::NegotiationFailed {
48 local_min: local.min,
49 local_max: local.max,
50 remote_min: remote.min,
51 remote_max: remote.max,
52 });
53 }
54
55 Ok(intersect_max)
57}
58
59#[derive(
75 Debug,
76 Clone,
77 PartialEq,
78 Eq,
79 serde::Serialize,
80 serde::Deserialize,
81 zerompk::ToMessagePack,
82 zerompk::FromMessagePack,
83)]
84pub struct VersionHandshake {
85 pub range: (u16, u16),
86 #[serde(default)]
90 pub capabilities: u64,
91}
92
93#[derive(
95 Debug,
96 Clone,
97 PartialEq,
98 Eq,
99 serde::Serialize,
100 serde::Deserialize,
101 zerompk::ToMessagePack,
102 zerompk::FromMessagePack,
103)]
104pub struct VersionHandshakeAck {
105 pub agreed: u16,
106 #[serde(default)]
109 pub capabilities: u64,
110}
111
112impl VersionHandshake {
113 pub fn from_range(range: VersionRange) -> Self {
115 Self {
116 range: (range.min.0, range.max.0),
117 capabilities: 0,
118 }
119 }
120
121 pub fn to_range(&self) -> VersionRange {
123 VersionRange::new(WireVersion(self.range.0), WireVersion(self.range.1))
124 }
125}
126
127impl VersionHandshakeAck {
128 pub fn new(agreed: WireVersion) -> Self {
130 Self {
131 agreed: agreed.0,
132 capabilities: 0,
133 }
134 }
135
136 pub fn agreed_version(&self) -> WireVersion {
138 WireVersion(self.agreed)
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 fn v(n: u16) -> WireVersion {
147 WireVersion(n)
148 }
149
150 fn range(min: u16, max: u16) -> VersionRange {
151 VersionRange::new(v(min), v(max))
152 }
153
154 #[test]
155 fn overlapping_ranges_return_highest_common() {
156 let result = negotiate(range(1, 3), range(2, 5)).unwrap();
158 assert_eq!(result, v(3));
159 }
160
161 #[test]
162 fn disjoint_ranges_return_negotiation_failed() {
163 let err = negotiate(range(1, 2), range(3, 5)).unwrap_err();
165 assert!(
166 matches!(err, WireVersionError::NegotiationFailed { .. }),
167 "expected NegotiationFailed, got: {err}"
168 );
169 }
170
171 #[test]
172 fn equal_single_version_succeeds() {
173 let result = negotiate(range(2, 2), range(2, 2)).unwrap();
174 assert_eq!(result, v(2));
175 }
176
177 #[test]
178 fn equal_min_equal_max_succeeds_with_that_version() {
179 let result = negotiate(range(1, 4), range(4, 6)).unwrap();
180 assert_eq!(result, v(4));
181 }
182
183 #[test]
184 fn handshake_roundtrip() {
185 let r = range(1, 2);
186 let hs = VersionHandshake::from_range(r);
187 let bytes = zerompk::to_msgpack_vec(&hs).unwrap();
188 let decoded: VersionHandshake = zerompk::from_msgpack(&bytes).unwrap();
189 assert_eq!(decoded.to_range(), r);
190 }
191}