nomad_protocol/extensions/
negotiation.rs

1//! Extension negotiation
2//!
3//! Implements TLV-based extension negotiation during handshake.
4//! See 4-EXTENSIONS.md for specification.
5
6use thiserror::Error;
7
8/// Extension type identifiers
9///
10/// Reserved ranges:
11/// - 0x0001-0x00FF: Core protocol extensions
12/// - 0x0100-0x0FFF: Application-specific extensions
13/// - 0xF000-0xFFFF: Experimental/private extensions
14pub mod ext_type {
15    /// Compression extension (zstd)
16    pub const COMPRESSION: u16 = 0x0001;
17    /// Priority levels for updates (critical vs cosmetic)
18    pub const PRIORITY: u16 = 0x0002;
19    /// Batch multiple small updates into single frame
20    pub const BATCHING: u16 = 0x0003;
21    /// Server hints for acceptable update frequency
22    pub const RATE_HINTS: u16 = 0x0004;
23    /// Selective sync - sync only specific state regions
24    pub const SELECTIVE_SYNC: u16 = 0x0005;
25    /// Full state checkpoint for recovery/initial sync
26    pub const CHECKPOINT: u16 = 0x0006;
27    /// Metadata attachment (timestamps, causality, user info)
28    pub const METADATA: u16 = 0x0007;
29}
30
31/// Errors from extension negotiation.
32#[derive(Debug, Error, Clone, PartialEq, Eq)]
33pub enum NegotiationError {
34    /// Input buffer is too short to contain valid extension data.
35    #[error("buffer too short: expected {expected}, got {actual}")]
36    TooShort {
37        /// Minimum bytes required.
38        expected: usize,
39        /// Actual bytes available.
40        actual: usize,
41    },
42
43    /// Extension data is malformed or invalid.
44    #[error("invalid extension data")]
45    InvalidData,
46
47    /// Requested extension type is not supported by this implementation.
48    #[error("extension not supported: 0x{0:04x}")]
49    NotSupported(u16),
50
51    /// Output buffer is too small to hold encoded extension.
52    #[error("buffer too small for encoding")]
53    BufferTooSmall,
54}
55
56/// Extension TLV (Type-Length-Value) format
57///
58/// Wire format:
59/// ```text
60/// +0   Extension Type (2 bytes LE16)
61/// +2   Extension Length (2 bytes LE16)
62/// +4   Extension Data (variable)
63/// ```
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct Extension {
66    /// Extension type identifier
67    pub ext_type: u16,
68    /// Extension data
69    pub data: Vec<u8>,
70}
71
72/// Header size for extension TLV
73pub const EXTENSION_HEADER_SIZE: usize = 4;
74
75impl Extension {
76    /// Create a new extension
77    pub fn new(ext_type: u16, data: Vec<u8>) -> Self {
78        Self { ext_type, data }
79    }
80
81    /// Create an empty extension (no data)
82    pub fn empty(ext_type: u16) -> Self {
83        Self {
84            ext_type,
85            data: Vec::new(),
86        }
87    }
88
89    /// Create compression extension with level
90    pub fn compression(level: u8) -> Self {
91        Self {
92            ext_type: ext_type::COMPRESSION,
93            data: vec![level],
94        }
95    }
96
97    /// Get compression level if this is a compression extension
98    pub fn compression_level(&self) -> Option<u8> {
99        if self.ext_type == ext_type::COMPRESSION && !self.data.is_empty() {
100            Some(self.data[0])
101        } else {
102            None
103        }
104    }
105
106    /// Total wire size
107    pub fn wire_size(&self) -> usize {
108        EXTENSION_HEADER_SIZE + self.data.len()
109    }
110
111    /// Encode to bytes
112    pub fn encode(&self) -> Vec<u8> {
113        let mut buf = Vec::with_capacity(self.wire_size());
114        buf.extend_from_slice(&self.ext_type.to_le_bytes());
115        buf.extend_from_slice(&(self.data.len() as u16).to_le_bytes());
116        buf.extend_from_slice(&self.data);
117        buf
118    }
119
120    /// Encode into buffer, returns bytes written
121    pub fn encode_into(&self, buf: &mut [u8]) -> Result<usize, NegotiationError> {
122        let size = self.wire_size();
123        if buf.len() < size {
124            return Err(NegotiationError::BufferTooSmall);
125        }
126
127        buf[0..2].copy_from_slice(&self.ext_type.to_le_bytes());
128        buf[2..4].copy_from_slice(&(self.data.len() as u16).to_le_bytes());
129        buf[4..size].copy_from_slice(&self.data);
130
131        Ok(size)
132    }
133
134    /// Decode from bytes
135    pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
136        if data.len() < EXTENSION_HEADER_SIZE {
137            return Err(NegotiationError::TooShort {
138                expected: EXTENSION_HEADER_SIZE,
139                actual: data.len(),
140            });
141        }
142
143        let ext_type = u16::from_le_bytes(data[0..2].try_into().unwrap());
144        let ext_len = u16::from_le_bytes(data[2..4].try_into().unwrap()) as usize;
145
146        if data.len() < EXTENSION_HEADER_SIZE + ext_len {
147            return Err(NegotiationError::TooShort {
148                expected: EXTENSION_HEADER_SIZE + ext_len,
149                actual: data.len(),
150            });
151        }
152
153        let ext_data = data[EXTENSION_HEADER_SIZE..EXTENSION_HEADER_SIZE + ext_len].to_vec();
154
155        Ok(Self {
156            ext_type,
157            data: ext_data,
158        })
159    }
160
161    /// Decode from bytes, returning extension and bytes consumed
162    pub fn decode_with_length(data: &[u8]) -> Result<(Self, usize), NegotiationError> {
163        let ext = Self::decode(data)?;
164        let consumed = ext.wire_size();
165        Ok((ext, consumed))
166    }
167}
168
169/// Extension set for negotiation
170#[derive(Debug, Clone, Default)]
171pub struct ExtensionSet {
172    extensions: Vec<Extension>,
173}
174
175impl ExtensionSet {
176    /// Create an empty extension set
177    pub fn new() -> Self {
178        Self::default()
179    }
180
181    /// Add an extension
182    pub fn add(&mut self, ext: Extension) {
183        // Replace if already exists
184        if let Some(existing) = self.extensions.iter_mut().find(|e| e.ext_type == ext.ext_type) {
185            *existing = ext;
186        } else {
187            self.extensions.push(ext);
188        }
189    }
190
191    /// Add compression extension
192    pub fn add_compression(&mut self, level: u8) {
193        self.add(Extension::compression(level));
194    }
195
196    /// Get extension by type
197    pub fn get(&self, ext_type: u16) -> Option<&Extension> {
198        self.extensions.iter().find(|e| e.ext_type == ext_type)
199    }
200
201    /// Check if extension is present
202    pub fn has(&self, ext_type: u16) -> bool {
203        self.extensions.iter().any(|e| e.ext_type == ext_type)
204    }
205
206    /// Check if compression is enabled
207    pub fn has_compression(&self) -> bool {
208        self.has(ext_type::COMPRESSION)
209    }
210
211    /// Get compression level if enabled
212    pub fn compression_level(&self) -> Option<u8> {
213        self.get(ext_type::COMPRESSION)
214            .and_then(|e| e.compression_level())
215    }
216
217    /// Get all extensions
218    pub fn iter(&self) -> impl Iterator<Item = &Extension> {
219        self.extensions.iter()
220    }
221
222    /// Number of extensions
223    pub fn len(&self) -> usize {
224        self.extensions.len()
225    }
226
227    /// Check if empty
228    pub fn is_empty(&self) -> bool {
229        self.extensions.is_empty()
230    }
231
232    /// Total wire size
233    pub fn wire_size(&self) -> usize {
234        self.extensions.iter().map(|e| e.wire_size()).sum()
235    }
236
237    /// Encode all extensions
238    pub fn encode(&self) -> Vec<u8> {
239        let mut buf = Vec::with_capacity(self.wire_size());
240        for ext in &self.extensions {
241            buf.extend_from_slice(&ext.encode());
242        }
243        buf
244    }
245
246    /// Decode all extensions from buffer
247    pub fn decode(mut data: &[u8]) -> Result<Self, NegotiationError> {
248        let mut set = Self::new();
249
250        while !data.is_empty() {
251            let (ext, consumed) = Extension::decode_with_length(data)?;
252            set.add(ext);
253            data = &data[consumed..];
254        }
255
256        Ok(set)
257    }
258
259    /// Remove an extension by type
260    pub fn remove(&mut self, ext_type: u16) -> Option<Extension> {
261        if let Some(pos) = self.extensions.iter().position(|e| e.ext_type == ext_type) {
262            Some(self.extensions.remove(pos))
263        } else {
264            None
265        }
266    }
267
268    /// Clear all extensions
269    pub fn clear(&mut self) {
270        self.extensions.clear();
271    }
272}
273
274/// Negotiate extensions between client and server offers
275///
276/// Returns the intersection of supported extensions.
277pub fn negotiate(offered: &ExtensionSet, supported: &ExtensionSet) -> ExtensionSet {
278    let mut result = ExtensionSet::new();
279
280    for ext in offered.iter() {
281        if let Some(supported_ext) = supported.get(ext.ext_type) {
282            // For compression, use the lower level
283            if ext.ext_type == ext_type::COMPRESSION {
284                let offered_level = ext.compression_level().unwrap_or(3);
285                let supported_level = supported_ext.compression_level().unwrap_or(3);
286                result.add(Extension::compression(offered_level.min(supported_level)));
287            } else {
288                // For other extensions, use the offered version
289                result.add(ext.clone());
290            }
291        }
292    }
293
294    result
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_extension_encode_decode() {
303        let ext = Extension::new(0x1234, vec![1, 2, 3, 4]);
304
305        let encoded = ext.encode();
306        assert_eq!(encoded.len(), EXTENSION_HEADER_SIZE + 4);
307
308        let decoded = Extension::decode(&encoded).unwrap();
309        assert_eq!(decoded, ext);
310    }
311
312    #[test]
313    fn test_compression_extension() {
314        let ext = Extension::compression(5);
315
316        assert_eq!(ext.ext_type, ext_type::COMPRESSION);
317        assert_eq!(ext.compression_level(), Some(5));
318
319        // Roundtrip
320        let encoded = ext.encode();
321        let decoded = Extension::decode(&encoded).unwrap();
322        assert_eq!(decoded.compression_level(), Some(5));
323    }
324
325    #[test]
326    fn test_empty_extension() {
327        let ext = Extension::empty(0xFFFF);
328
329        assert_eq!(ext.wire_size(), EXTENSION_HEADER_SIZE);
330        assert!(ext.data.is_empty());
331
332        let encoded = ext.encode();
333        let decoded = Extension::decode(&encoded).unwrap();
334        assert_eq!(decoded, ext);
335    }
336
337    #[test]
338    fn test_decode_too_short() {
339        let data = [0u8; 2];
340        let result = Extension::decode(&data);
341        assert!(matches!(result, Err(NegotiationError::TooShort { .. })));
342    }
343
344    #[test]
345    fn test_decode_data_truncated() {
346        // Header says 10 bytes of data, but only 2 provided
347        let data = [0x01, 0x00, 0x0A, 0x00, 0x01, 0x02];
348        let result = Extension::decode(&data);
349        assert!(matches!(result, Err(NegotiationError::TooShort { .. })));
350    }
351
352    #[test]
353    fn test_extension_set() {
354        let mut set = ExtensionSet::new();
355
356        set.add_compression(5);
357        set.add(Extension::empty(0x1234));
358
359        assert_eq!(set.len(), 2);
360        assert!(set.has_compression());
361        assert!(set.has(0x1234));
362        assert!(!set.has(0x9999));
363
364        assert_eq!(set.compression_level(), Some(5));
365    }
366
367    #[test]
368    fn test_extension_set_encode_decode() {
369        let mut set = ExtensionSet::new();
370        set.add_compression(3);
371        set.add(Extension::new(0x0100, vec![0xAA, 0xBB]));
372
373        let encoded = set.encode();
374        let decoded = ExtensionSet::decode(&encoded).unwrap();
375
376        assert_eq!(decoded.len(), 2);
377        assert_eq!(decoded.compression_level(), Some(3));
378        assert!(decoded.has(0x0100));
379    }
380
381    #[test]
382    fn test_extension_set_replace() {
383        let mut set = ExtensionSet::new();
384
385        set.add_compression(3);
386        assert_eq!(set.compression_level(), Some(3));
387
388        set.add_compression(10);
389        assert_eq!(set.compression_level(), Some(10));
390        assert_eq!(set.len(), 1); // Still only one compression ext
391    }
392
393    #[test]
394    fn test_negotiate_extensions() {
395        let mut client = ExtensionSet::new();
396        client.add_compression(10);
397        client.add(Extension::empty(0x1234));
398
399        let mut server = ExtensionSet::new();
400        server.add_compression(5);
401        // Server doesn't support 0x1234
402
403        let result = negotiate(&client, &server);
404
405        assert_eq!(result.len(), 1);
406        assert!(result.has_compression());
407        assert_eq!(result.compression_level(), Some(5)); // Lower of 10 and 5
408        assert!(!result.has(0x1234)); // Not supported by server
409    }
410
411    #[test]
412    fn test_negotiate_no_overlap() {
413        let mut client = ExtensionSet::new();
414        client.add(Extension::empty(0x1111));
415
416        let mut server = ExtensionSet::new();
417        server.add(Extension::empty(0x2222));
418
419        let result = negotiate(&client, &server);
420        assert!(result.is_empty());
421    }
422
423    #[test]
424    fn test_extension_set_remove() {
425        let mut set = ExtensionSet::new();
426        set.add_compression(5);
427        set.add(Extension::empty(0x1234));
428
429        assert_eq!(set.len(), 2);
430
431        let removed = set.remove(ext_type::COMPRESSION);
432        assert!(removed.is_some());
433        assert_eq!(set.len(), 1);
434        assert!(!set.has_compression());
435    }
436
437    #[test]
438    fn test_encode_into() {
439        let ext = Extension::new(0x1234, vec![1, 2, 3]);
440        let mut buf = [0u8; 100];
441
442        let written = ext.encode_into(&mut buf).unwrap();
443        assert_eq!(written, EXTENSION_HEADER_SIZE + 3);
444
445        let decoded = Extension::decode(&buf[..written]).unwrap();
446        assert_eq!(decoded, ext);
447    }
448
449    #[test]
450    fn test_encode_into_too_small() {
451        let ext = Extension::new(0x1234, vec![1, 2, 3, 4, 5]);
452        let mut buf = [0u8; 4]; // Too small
453
454        let result = ext.encode_into(&mut buf);
455        assert!(matches!(result, Err(NegotiationError::BufferTooSmall)));
456    }
457}