1use std::collections::HashMap;
2
3use parsely_rs::*;
4
5#[derive(Debug, PartialEq)]
45pub struct OneByteHeaderExtension {
46 id: u4,
47 data: Bits,
48}
49
50impl OneByteHeaderExtension {
51 pub const TYPE: u16 = 0xBEDE;
52
53 pub fn type_matches(ext_type: u16) -> bool {
54 ext_type == Self::TYPE
55 }
56
57 pub fn new(id: u4, data: Bits) -> Self {
58 Self { id, data }
59 }
60
61 pub fn id(&self) -> u4 {
62 self.id
63 }
64
65 pub fn data(&self) -> &[u8] {
66 self.data.chunk_bytes()
67 }
68}
69
70impl From<OneByteHeaderExtension> for SomeHeaderExtension {
71 fn from(value: OneByteHeaderExtension) -> Self {
72 SomeHeaderExtension::OneByteHeaderExtension(value)
73 }
74}
75
76impl<B: BitBuf> ParselyRead<B> for OneByteHeaderExtension {
77 type Ctx = ();
78
79 fn read<T: ByteOrder>(buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<Self> {
80 let id = buf.get_u4().context("id")?;
81
82 let data_length_bytes = match id {
84 i if i == 0 => {
86 let _ = buf.get_u4();
88 0
89 }
90 _ => {
93 let length: usize = buf.get_u4().context("length")?.into();
94 length + 1
97 }
98 };
99
100 if buf.remaining_bytes() < data_length_bytes {
101 bail!(
102 "Header extension length was {data_length_bytes} but buffer only has {} bytes remaining",
103 buf.remaining_bytes()
104 );
105 }
106 let data = Bits::copy_from_bytes(&buf.chunk_bytes()[..data_length_bytes]);
107 buf.advance_bytes(data_length_bytes);
108 Ok(OneByteHeaderExtension { id, data })
109 }
110}
111
112impl<B: BitBufMut> ParselyWrite<B> for OneByteHeaderExtension {
113 type Ctx = ();
114
115 fn write<T: ByteOrder>(&self, buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<()> {
116 buf.put_u4(self.id).context("Writing field 'id'")?;
117 let data_length_bytes = self.data.len_bytes();
118 let length_field = u4::try_from(data_length_bytes - 1).context("fitting length in u4")?;
119 buf.put_u4(length_field).context("Writing field 'length'")?;
120 buf.try_put_slice_bytes(self.data())
121 .context("Writing field 'data'")?;
122
123 Ok(())
124 }
125}
126
127impl_stateless_sync!(OneByteHeaderExtension);
128
129#[derive(Debug, PartialEq)]
173pub struct TwoByteHeaderExtension {
174 id: u8,
175 data: Bits,
176}
177
178impl TwoByteHeaderExtension {
179 const TYPE_MASK: u16 = 0xFFF0;
180 pub const TYPE: u16 = 0x1000;
181
182 pub fn type_matches(ext_type: u16) -> bool {
183 (ext_type & Self::TYPE_MASK) == Self::TYPE
184 }
185
186 pub fn id(&self) -> u8 {
187 self.id
188 }
189
190 pub fn data(&self) -> &[u8] {
191 self.data.chunk_bytes()
192 }
193}
194
195impl From<TwoByteHeaderExtension> for SomeHeaderExtension {
196 fn from(value: TwoByteHeaderExtension) -> Self {
197 SomeHeaderExtension::TwoByteHeaderExtension(value)
198 }
199}
200
201impl<B: BitBuf> ParselyRead<B> for TwoByteHeaderExtension {
202 type Ctx = ();
203
204 fn read<T: ByteOrder>(buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<Self> {
206 let id = buf.get_u8().context("id")?;
207 let data_length_bytes = match id {
208 0 => 0,
209 _ => buf.get_u8().context("length")?,
210 } as usize;
211 if buf.remaining_bytes() < data_length_bytes {
212 bail!(
213 "Header extension length was {data_length_bytes} but buffer only has {} bytes remaining",
214 buf.remaining_bytes()
215 );
216 }
217 let data = Bits::copy_from_bytes(&buf.chunk_bytes()[..data_length_bytes]);
218 buf.advance_bytes(data_length_bytes);
219 Ok(TwoByteHeaderExtension { id, data })
220 }
221}
222
223impl<B: BitBufMut> ParselyWrite<B> for TwoByteHeaderExtension {
224 type Ctx = ();
225
226 fn write<T: ByteOrder>(&self, buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<()> {
227 buf.put_u8(self.id()).context("Writing field 'id'")?;
228 let data_length_bytes = self.data().len();
229 buf.put_u8(data_length_bytes as u8)
230 .context("Writing field 'length'")?;
231 buf.try_put_slice_bytes(self.data())
232 .context("Writing field 'data'")?;
233
234 Ok(())
235 }
236}
237
238impl_stateless_sync!(TwoByteHeaderExtension);
239
240#[derive(Debug, PartialEq)]
241pub enum SomeHeaderExtension {
242 OneByteHeaderExtension(OneByteHeaderExtension),
243 TwoByteHeaderExtension(TwoByteHeaderExtension),
244}
245
246impl SomeHeaderExtension {
247 pub fn id(&self) -> u8 {
248 match self {
249 SomeHeaderExtension::OneByteHeaderExtension(e) => e.id().into(),
250 SomeHeaderExtension::TwoByteHeaderExtension(e) => e.id(),
251 }
252 }
253
254 pub fn data(&self) -> &[u8] {
255 match self {
256 SomeHeaderExtension::OneByteHeaderExtension(e) => e.data(),
257 SomeHeaderExtension::TwoByteHeaderExtension(e) => e.data(),
258 }
259 }
260}
261
262impl<B: BitBufMut> ParselyWrite<B> for SomeHeaderExtension {
263 type Ctx = ();
264
265 fn write<T: ByteOrder>(&self, buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<()> {
266 match self {
267 SomeHeaderExtension::OneByteHeaderExtension(he) => he.write::<T>(buf, ()),
268 SomeHeaderExtension::TwoByteHeaderExtension(he) => he.write::<T>(buf, ()),
269 }
270 }
271}
272
273impl_stateless_sync!(SomeHeaderExtension);
274
275#[derive(Debug, Default, PartialEq)]
276pub struct HeaderExtensions(HashMap<u8, SomeHeaderExtension>);
277
278impl HeaderExtensions {
279 pub fn len(&self) -> usize {
281 self.0.len()
282 }
283
284 pub fn is_empty(&self) -> bool {
286 self.len() == 0
287 }
288
289 pub fn has_one_byte(&self) -> bool {
291 self.0
292 .iter()
293 .any(|(_, he)| matches!(he, SomeHeaderExtension::OneByteHeaderExtension(_)))
294 }
295
296 pub fn has_two_byte(&self) -> bool {
298 self.0
299 .iter()
300 .any(|(_, he)| matches!(he, SomeHeaderExtension::TwoByteHeaderExtension(_)))
301 }
302
303 pub fn add_extension<T: Into<SomeHeaderExtension>>(
305 &mut self,
306 ext: T,
307 ) -> Option<SomeHeaderExtension> {
308 let ext: SomeHeaderExtension = ext.into();
309 self.0.insert(ext.id(), ext)
310 }
311
312 pub fn remove_extension_by_id(&mut self, id: u8) -> Option<SomeHeaderExtension> {
315 self.0.remove(&id)
316 }
317
318 pub fn get_by_id(&self, id: u8) -> Option<&SomeHeaderExtension> {
319 self.0.get(&id)
320 }
321}
322
323impl<'a> IntoIterator for &'a HeaderExtensions {
324 type Item = (&'a u8, &'a SomeHeaderExtension);
325
326 type IntoIter = std::collections::hash_map::Iter<'a, u8, SomeHeaderExtension>;
327
328 fn into_iter(self) -> Self::IntoIter {
329 self.0.iter()
330 }
331}
332
333impl<B: BitBuf> ParselyRead<B> for HeaderExtensions {
340 type Ctx = ();
341
342 fn read<T: ByteOrder>(buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<Self> {
343 let mut header_extensions = HashMap::new();
344
345 let ext_type = buf
346 .get_u16::<NetworkOrder>()
347 .context("Reading header extensions profile")?;
348 let ext_length = buf
349 .get_u16::<NetworkOrder>()
350 .context("Reading header extensions length")?;
351
352 let ext_length_bytes = (ext_length * 4) as usize;
354 let mut extensions_buf = buf.take_bytes(ext_length_bytes);
355
356 while extensions_buf.has_remaining_bytes() {
357 let extension = if OneByteHeaderExtension::type_matches(ext_type) {
358 let id = (&extensions_buf.chunk_bits()[..4]).as_u4();
361 if id == 0xF {
362 let _ = extensions_buf.get_u8();
366 let he = TwoByteHeaderExtension::read::<T>(&mut extensions_buf, ())
367 .context("One-byte header extension")?;
368 SomeHeaderExtension::TwoByteHeaderExtension(he)
369 } else {
370 let he = OneByteHeaderExtension::read::<T>(&mut extensions_buf, ())
371 .context("One-byte header extension")?;
372 SomeHeaderExtension::OneByteHeaderExtension(he)
373 }
374 } else if TwoByteHeaderExtension::type_matches(ext_type) {
375 let he = TwoByteHeaderExtension::read::<T>(&mut extensions_buf, ())
376 .context("One-byte header extension")?;
377 SomeHeaderExtension::TwoByteHeaderExtension(he)
378 } else {
379 bail!("Encountered invalid header extension block type: {ext_type:x}");
380 };
381 if extension.id() != 0 {
382 header_extensions.insert(extension.id(), extension);
383 }
384 }
385
386 Ok(HeaderExtensions(header_extensions))
387 }
388}
389
390impl<B: BitBufMut> ParselyWrite<B> for HeaderExtensions {
391 type Ctx = ();
392
393 fn write<T: ByteOrder>(&self, buf: &mut B, _ctx: Self::Ctx) -> ParselyResult<()> {
394 let len_start = buf.remaining_mut_bytes();
395 self.0
396 .values()
397 .map(|he| he.write::<T>(buf, ()))
398 .collect::<ParselyResult<Vec<_>>>()
399 .context("Writing header extensions")?;
400
401 while (len_start - buf.remaining_mut_bytes()) % 4 != 0 {
402 buf.put_u8(0).context("Padding")?;
403 }
404
405 Ok(())
406 }
407}
408
409impl_stateless_sync!(HeaderExtensions);
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_one_byte_header_extension_parse() {
417 #[rustfmt::skip]
418 let mut bits = Bits::from_static_bytes(&[
419 0x10, 0xFF, 0x00, 0x00
420 ]);
421
422 let he = OneByteHeaderExtension::read::<NetworkOrder>(&mut bits, ()).unwrap();
423 assert_eq!(he.id(), 1);
424 assert_eq!(he.data(), &[0xFF]);
425 }
426
427 #[test]
428 fn test_two_byte_header_extension_parse() {
429 #[rustfmt::skip]
430 let mut bits = Bits::from_static_bytes(&[
431 0x01, 0x01, 0xFF, 0x00, 0x00
432 ]);
433 let he = TwoByteHeaderExtension::read::<NetworkOrder>(&mut bits, ()).unwrap();
434 assert_eq!(he.id(), 1);
435 assert_eq!(he.data(), &[0xFF]);
436 }
437
438 #[test]
439 fn test_header_extensions_parse_all_one_byte() {
440 #[rustfmt::skip]
441 let mut bits = Bits::from_static_bytes(&[
442 0xBE, 0xDE, 0x00, 0x02,
443 0x10, 0xFF, 0x00, 0x00,
444 0x21, 0xDE, 0xAD, 0x00
445 ]);
446
447 let exts = HeaderExtensions::read::<NetworkOrder>(&mut bits, ()).unwrap();
448 assert_eq!(exts.len(), 2);
449 let ext1 = exts.get_by_id(1).unwrap();
450 assert_eq!(ext1.data(), &[0xFF]);
451
452 let ext2 = exts.get_by_id(2).unwrap();
453 assert_eq!(ext2.data(), &[0xDE, 0xAD]);
454 }
455
456 #[test]
457 fn test_header_extensions_parse_all_two_byte() {
458 #[rustfmt::skip]
459 let mut bits = Bits::from_static_bytes(&[
460 0x10, 0x00, 0x00, 0x03,
461 0x07, 0x04, 0xDE, 0xAD,
462 0xBE, 0xEF, 0x04, 0x01,
463 0x42, 0x00, 0x00, 0x00,
464 ]);
465 let exts = HeaderExtensions::read::<NetworkOrder>(&mut bits, ()).unwrap();
466 assert_eq!(exts.len(), 2);
467 let ext7 = exts.get_by_id(7).unwrap();
468 assert_eq!(ext7.data(), &[0xDE, 0xAD, 0xBE, 0xEF]);
469 let ext4 = exts.get_by_id(4).unwrap();
470 assert_eq!(ext4.data(), &[0x42]);
471 }
472
473 #[test]
474 fn test_header_extensions_parse_mixed() {
475 #[rustfmt::skip]
476 let mut bits = Bits::from_static_bytes(&[
477 0xBE, 0xDE, 0x00, 0x04,
478 0x10, 0xFF, 0x00, 0x00,
480 0xF0, 0x07, 0x04, 0xDE,
482 0xAD, 0xBE, 0xEF, 0xF0,
484 0x04, 0x01, 0x42, 0x00,
485 ]);
486 let exts = HeaderExtensions::read::<NetworkOrder>(&mut bits, ()).unwrap();
487 assert_eq!(exts.len(), 3);
488 let ext = exts.get_by_id(1).unwrap();
489 assert_eq!(ext.data(), &[0xFF]);
490 let ext = exts.get_by_id(7).unwrap();
491 assert_eq!(ext.data(), &[0xDE, 0xAD, 0xBE, 0xEF]);
492 let ext = exts.get_by_id(4).unwrap();
493 assert_eq!(ext.data(), &[0x42]);
494 }
495}