nomad_protocol/extensions/
negotiation.rs1use thiserror::Error;
7
8pub mod ext_type {
15 pub const COMPRESSION: u16 = 0x0001;
17 pub const PRIORITY: u16 = 0x0002;
19 pub const BATCHING: u16 = 0x0003;
21 pub const RATE_HINTS: u16 = 0x0004;
23 pub const SELECTIVE_SYNC: u16 = 0x0005;
25 pub const CHECKPOINT: u16 = 0x0006;
27 pub const METADATA: u16 = 0x0007;
29}
30
31#[derive(Debug, Error, Clone, PartialEq, Eq)]
33pub enum NegotiationError {
34 #[error("buffer too short: expected {expected}, got {actual}")]
36 TooShort {
37 expected: usize,
39 actual: usize,
41 },
42
43 #[error("invalid extension data")]
45 InvalidData,
46
47 #[error("extension not supported: 0x{0:04x}")]
49 NotSupported(u16),
50
51 #[error("buffer too small for encoding")]
53 BufferTooSmall,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct Extension {
66 pub ext_type: u16,
68 pub data: Vec<u8>,
70}
71
72pub const EXTENSION_HEADER_SIZE: usize = 4;
74
75impl Extension {
76 pub fn new(ext_type: u16, data: Vec<u8>) -> Self {
78 Self { ext_type, data }
79 }
80
81 pub fn empty(ext_type: u16) -> Self {
83 Self {
84 ext_type,
85 data: Vec::new(),
86 }
87 }
88
89 pub fn compression(level: u8) -> Self {
91 Self {
92 ext_type: ext_type::COMPRESSION,
93 data: vec![level],
94 }
95 }
96
97 pub fn compression_level(&self) -> Option<u8> {
99 if self.ext_type == ext_type::COMPRESSION && !self.data.is_empty() {
100 Some(self.data[0])
101 } else {
102 None
103 }
104 }
105
106 pub fn wire_size(&self) -> usize {
108 EXTENSION_HEADER_SIZE + self.data.len()
109 }
110
111 pub fn encode(&self) -> Vec<u8> {
113 let mut buf = Vec::with_capacity(self.wire_size());
114 buf.extend_from_slice(&self.ext_type.to_le_bytes());
115 buf.extend_from_slice(&(self.data.len() as u16).to_le_bytes());
116 buf.extend_from_slice(&self.data);
117 buf
118 }
119
120 pub fn encode_into(&self, buf: &mut [u8]) -> Result<usize, NegotiationError> {
122 let size = self.wire_size();
123 if buf.len() < size {
124 return Err(NegotiationError::BufferTooSmall);
125 }
126
127 buf[0..2].copy_from_slice(&self.ext_type.to_le_bytes());
128 buf[2..4].copy_from_slice(&(self.data.len() as u16).to_le_bytes());
129 buf[4..size].copy_from_slice(&self.data);
130
131 Ok(size)
132 }
133
134 pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
136 if data.len() < EXTENSION_HEADER_SIZE {
137 return Err(NegotiationError::TooShort {
138 expected: EXTENSION_HEADER_SIZE,
139 actual: data.len(),
140 });
141 }
142
143 let ext_type = u16::from_le_bytes(data[0..2].try_into().unwrap());
144 let ext_len = u16::from_le_bytes(data[2..4].try_into().unwrap()) as usize;
145
146 if data.len() < EXTENSION_HEADER_SIZE + ext_len {
147 return Err(NegotiationError::TooShort {
148 expected: EXTENSION_HEADER_SIZE + ext_len,
149 actual: data.len(),
150 });
151 }
152
153 let ext_data = data[EXTENSION_HEADER_SIZE..EXTENSION_HEADER_SIZE + ext_len].to_vec();
154
155 Ok(Self {
156 ext_type,
157 data: ext_data,
158 })
159 }
160
161 pub fn decode_with_length(data: &[u8]) -> Result<(Self, usize), NegotiationError> {
163 let ext = Self::decode(data)?;
164 let consumed = ext.wire_size();
165 Ok((ext, consumed))
166 }
167}
168
169#[derive(Debug, Clone, Default)]
171pub struct ExtensionSet {
172 extensions: Vec<Extension>,
173}
174
175impl ExtensionSet {
176 pub fn new() -> Self {
178 Self::default()
179 }
180
181 pub fn add(&mut self, ext: Extension) {
183 if let Some(existing) = self.extensions.iter_mut().find(|e| e.ext_type == ext.ext_type) {
185 *existing = ext;
186 } else {
187 self.extensions.push(ext);
188 }
189 }
190
191 pub fn add_compression(&mut self, level: u8) {
193 self.add(Extension::compression(level));
194 }
195
196 pub fn get(&self, ext_type: u16) -> Option<&Extension> {
198 self.extensions.iter().find(|e| e.ext_type == ext_type)
199 }
200
201 pub fn has(&self, ext_type: u16) -> bool {
203 self.extensions.iter().any(|e| e.ext_type == ext_type)
204 }
205
206 pub fn has_compression(&self) -> bool {
208 self.has(ext_type::COMPRESSION)
209 }
210
211 pub fn compression_level(&self) -> Option<u8> {
213 self.get(ext_type::COMPRESSION)
214 .and_then(|e| e.compression_level())
215 }
216
217 pub fn iter(&self) -> impl Iterator<Item = &Extension> {
219 self.extensions.iter()
220 }
221
222 pub fn len(&self) -> usize {
224 self.extensions.len()
225 }
226
227 pub fn is_empty(&self) -> bool {
229 self.extensions.is_empty()
230 }
231
232 pub fn wire_size(&self) -> usize {
234 self.extensions.iter().map(|e| e.wire_size()).sum()
235 }
236
237 pub fn encode(&self) -> Vec<u8> {
239 let mut buf = Vec::with_capacity(self.wire_size());
240 for ext in &self.extensions {
241 buf.extend_from_slice(&ext.encode());
242 }
243 buf
244 }
245
246 pub fn decode(mut data: &[u8]) -> Result<Self, NegotiationError> {
248 let mut set = Self::new();
249
250 while !data.is_empty() {
251 let (ext, consumed) = Extension::decode_with_length(data)?;
252 set.add(ext);
253 data = &data[consumed..];
254 }
255
256 Ok(set)
257 }
258
259 pub fn remove(&mut self, ext_type: u16) -> Option<Extension> {
261 if let Some(pos) = self.extensions.iter().position(|e| e.ext_type == ext_type) {
262 Some(self.extensions.remove(pos))
263 } else {
264 None
265 }
266 }
267
268 pub fn clear(&mut self) {
270 self.extensions.clear();
271 }
272}
273
274pub fn negotiate(offered: &ExtensionSet, supported: &ExtensionSet) -> ExtensionSet {
278 let mut result = ExtensionSet::new();
279
280 for ext in offered.iter() {
281 if let Some(supported_ext) = supported.get(ext.ext_type) {
282 if ext.ext_type == ext_type::COMPRESSION {
284 let offered_level = ext.compression_level().unwrap_or(3);
285 let supported_level = supported_ext.compression_level().unwrap_or(3);
286 result.add(Extension::compression(offered_level.min(supported_level)));
287 } else {
288 result.add(ext.clone());
290 }
291 }
292 }
293
294 result
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn test_extension_encode_decode() {
303 let ext = Extension::new(0x1234, vec![1, 2, 3, 4]);
304
305 let encoded = ext.encode();
306 assert_eq!(encoded.len(), EXTENSION_HEADER_SIZE + 4);
307
308 let decoded = Extension::decode(&encoded).unwrap();
309 assert_eq!(decoded, ext);
310 }
311
312 #[test]
313 fn test_compression_extension() {
314 let ext = Extension::compression(5);
315
316 assert_eq!(ext.ext_type, ext_type::COMPRESSION);
317 assert_eq!(ext.compression_level(), Some(5));
318
319 let encoded = ext.encode();
321 let decoded = Extension::decode(&encoded).unwrap();
322 assert_eq!(decoded.compression_level(), Some(5));
323 }
324
325 #[test]
326 fn test_empty_extension() {
327 let ext = Extension::empty(0xFFFF);
328
329 assert_eq!(ext.wire_size(), EXTENSION_HEADER_SIZE);
330 assert!(ext.data.is_empty());
331
332 let encoded = ext.encode();
333 let decoded = Extension::decode(&encoded).unwrap();
334 assert_eq!(decoded, ext);
335 }
336
337 #[test]
338 fn test_decode_too_short() {
339 let data = [0u8; 2];
340 let result = Extension::decode(&data);
341 assert!(matches!(result, Err(NegotiationError::TooShort { .. })));
342 }
343
344 #[test]
345 fn test_decode_data_truncated() {
346 let data = [0x01, 0x00, 0x0A, 0x00, 0x01, 0x02];
348 let result = Extension::decode(&data);
349 assert!(matches!(result, Err(NegotiationError::TooShort { .. })));
350 }
351
352 #[test]
353 fn test_extension_set() {
354 let mut set = ExtensionSet::new();
355
356 set.add_compression(5);
357 set.add(Extension::empty(0x1234));
358
359 assert_eq!(set.len(), 2);
360 assert!(set.has_compression());
361 assert!(set.has(0x1234));
362 assert!(!set.has(0x9999));
363
364 assert_eq!(set.compression_level(), Some(5));
365 }
366
367 #[test]
368 fn test_extension_set_encode_decode() {
369 let mut set = ExtensionSet::new();
370 set.add_compression(3);
371 set.add(Extension::new(0x0100, vec![0xAA, 0xBB]));
372
373 let encoded = set.encode();
374 let decoded = ExtensionSet::decode(&encoded).unwrap();
375
376 assert_eq!(decoded.len(), 2);
377 assert_eq!(decoded.compression_level(), Some(3));
378 assert!(decoded.has(0x0100));
379 }
380
381 #[test]
382 fn test_extension_set_replace() {
383 let mut set = ExtensionSet::new();
384
385 set.add_compression(3);
386 assert_eq!(set.compression_level(), Some(3));
387
388 set.add_compression(10);
389 assert_eq!(set.compression_level(), Some(10));
390 assert_eq!(set.len(), 1); }
392
393 #[test]
394 fn test_negotiate_extensions() {
395 let mut client = ExtensionSet::new();
396 client.add_compression(10);
397 client.add(Extension::empty(0x1234));
398
399 let mut server = ExtensionSet::new();
400 server.add_compression(5);
401 let result = negotiate(&client, &server);
404
405 assert_eq!(result.len(), 1);
406 assert!(result.has_compression());
407 assert_eq!(result.compression_level(), Some(5)); assert!(!result.has(0x1234)); }
410
411 #[test]
412 fn test_negotiate_no_overlap() {
413 let mut client = ExtensionSet::new();
414 client.add(Extension::empty(0x1111));
415
416 let mut server = ExtensionSet::new();
417 server.add(Extension::empty(0x2222));
418
419 let result = negotiate(&client, &server);
420 assert!(result.is_empty());
421 }
422
423 #[test]
424 fn test_extension_set_remove() {
425 let mut set = ExtensionSet::new();
426 set.add_compression(5);
427 set.add(Extension::empty(0x1234));
428
429 assert_eq!(set.len(), 2);
430
431 let removed = set.remove(ext_type::COMPRESSION);
432 assert!(removed.is_some());
433 assert_eq!(set.len(), 1);
434 assert!(!set.has_compression());
435 }
436
437 #[test]
438 fn test_encode_into() {
439 let ext = Extension::new(0x1234, vec![1, 2, 3]);
440 let mut buf = [0u8; 100];
441
442 let written = ext.encode_into(&mut buf).unwrap();
443 assert_eq!(written, EXTENSION_HEADER_SIZE + 3);
444
445 let decoded = Extension::decode(&buf[..written]).unwrap();
446 assert_eq!(decoded, ext);
447 }
448
449 #[test]
450 fn test_encode_into_too_small() {
451 let ext = Extension::new(0x1234, vec![1, 2, 3, 4, 5]);
452 let mut buf = [0u8; 4]; let result = ext.encode_into(&mut buf);
455 assert!(matches!(result, Err(NegotiationError::BufferTooSmall)));
456 }
457}