Skip to main content

nodedb_cluster/wire_version/
negotiation.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Wire-version range negotiation.
4
5use super::error::WireVersionError;
6use super::types::WireVersion;
7
8/// A contiguous inclusive range of wire-protocol versions supported by one
9/// endpoint.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct VersionRange {
12    pub min: WireVersion,
13    pub max: WireVersion,
14}
15
16impl VersionRange {
17    /// Construct a new `VersionRange`. Panics in debug mode if `min > max`.
18    pub fn new(min: WireVersion, max: WireVersion) -> Self {
19        debug_assert!(
20            min <= max,
21            "VersionRange: min ({min}) must be <= max ({max})"
22        );
23        Self { min, max }
24    }
25
26    /// Returns `true` if `version` falls within `[min, max]`.
27    pub fn contains(&self, version: WireVersion) -> bool {
28        version >= self.min && version <= self.max
29    }
30}
31
32/// Negotiate the highest wire version both sides support.
33///
34/// Returns the highest `WireVersion` in the intersection of `local` and
35/// `remote`. If the ranges are disjoint, returns
36/// `WireVersionError::NegotiationFailed` with both ranges captured for
37/// operator diagnostics.
38pub fn negotiate(
39    local: VersionRange,
40    remote: VersionRange,
41) -> Result<WireVersion, WireVersionError> {
42    // Intersection: [max(local.min, remote.min), min(local.max, remote.max)]
43    let intersect_min = local.min.max(remote.min);
44    let intersect_max = local.max.min(remote.max);
45
46    if intersect_min > intersect_max {
47        return Err(WireVersionError::NegotiationFailed {
48            local_min: local.min,
49            local_max: local.max,
50            remote_min: remote.min,
51            remote_max: remote.max,
52        });
53    }
54
55    // Highest common version — prefer new capabilities over old.
56    Ok(intersect_max)
57}
58
59/// Wire handshake types. Both sides exchange their `VersionRange` at the start
60/// of every nexar/QUIC connection. The receiving side calls [`negotiate`] to
61/// compute the agreed version and rejects the connection if ranges are
62/// disjoint.
63///
64/// # Transport wiring
65///
66/// The actual injection into the QUIC connection accept loop lives in
67/// `transport/client` (outbound) and `transport/server` (inbound).
68/// On connect, the client opens a dedicated bidi stream and sends
69/// `VersionHandshake { range: local_range }` before any RPC frames.
70/// The server reads the handshake, negotiates, and replies with
71/// `VersionHandshakeAck { agreed }`. If ranges are disjoint the server
72/// closes the QUIC connection with application error code 0x01.
73/// See `handshake_io` and `negotiation::negotiate` for implementation.
74#[derive(
75    Debug,
76    Clone,
77    PartialEq,
78    Eq,
79    serde::Serialize,
80    serde::Deserialize,
81    zerompk::ToMessagePack,
82    zerompk::FromMessagePack,
83)]
84pub struct VersionHandshake {
85    pub range: (u16, u16),
86    /// Optional capability bitmask for forward-compatible feature advertisement.
87    /// Unknown bits are ignored by the receiver.  Defaults to `0` (no extra
88    /// capabilities) so older peers that do not set this field remain compatible.
89    #[serde(default)]
90    pub capabilities: u64,
91}
92
93/// Server-side acknowledgement returned after negotiation succeeds.
94#[derive(
95    Debug,
96    Clone,
97    PartialEq,
98    Eq,
99    serde::Serialize,
100    serde::Deserialize,
101    zerompk::ToMessagePack,
102    zerompk::FromMessagePack,
103)]
104pub struct VersionHandshakeAck {
105    pub agreed: u16,
106    /// Capability bitmask echoed (or narrowed) by the server.
107    /// Unknown bits are ignored by the receiver.  Defaults to `0`.
108    #[serde(default)]
109    pub capabilities: u64,
110}
111
112impl VersionHandshake {
113    /// Build a handshake from a [`VersionRange`].
114    pub fn from_range(range: VersionRange) -> Self {
115        Self {
116            range: (range.min.0, range.max.0),
117            capabilities: 0,
118        }
119    }
120
121    /// Recover the [`VersionRange`] from wire fields.
122    pub fn to_range(&self) -> VersionRange {
123        VersionRange::new(WireVersion(self.range.0), WireVersion(self.range.1))
124    }
125}
126
127impl VersionHandshakeAck {
128    /// Construct an ack for the given agreed wire version.
129    pub fn new(agreed: WireVersion) -> Self {
130        Self {
131            agreed: agreed.0,
132            capabilities: 0,
133        }
134    }
135
136    /// The agreed wire version as a typed [`WireVersion`].
137    pub fn agreed_version(&self) -> WireVersion {
138        WireVersion(self.agreed)
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    fn v(n: u16) -> WireVersion {
147        WireVersion(n)
148    }
149
150    fn range(min: u16, max: u16) -> VersionRange {
151        VersionRange::new(v(min), v(max))
152    }
153
154    #[test]
155    fn overlapping_ranges_return_highest_common() {
156        // local: [1,3]  remote: [2,5]  → intersection [2,3] → pick 3
157        let result = negotiate(range(1, 3), range(2, 5)).unwrap();
158        assert_eq!(result, v(3));
159    }
160
161    #[test]
162    fn disjoint_ranges_return_negotiation_failed() {
163        // local: [1,2]  remote: [3,5]  → no overlap
164        let err = negotiate(range(1, 2), range(3, 5)).unwrap_err();
165        assert!(
166            matches!(err, WireVersionError::NegotiationFailed { .. }),
167            "expected NegotiationFailed, got: {err}"
168        );
169    }
170
171    #[test]
172    fn equal_single_version_succeeds() {
173        let result = negotiate(range(2, 2), range(2, 2)).unwrap();
174        assert_eq!(result, v(2));
175    }
176
177    #[test]
178    fn equal_min_equal_max_succeeds_with_that_version() {
179        let result = negotiate(range(1, 4), range(4, 6)).unwrap();
180        assert_eq!(result, v(4));
181    }
182
183    #[test]
184    fn handshake_roundtrip() {
185        let r = range(1, 2);
186        let hs = VersionHandshake::from_range(r);
187        let bytes = zerompk::to_msgpack_vec(&hs).unwrap();
188        let decoded: VersionHandshake = zerompk::from_msgpack(&bytes).unwrap();
189        assert_eq!(decoded.to_range(), r);
190    }
191}