1use crate::decode::{self, DecodeBorrowed, NbyteReader};
4use crate::encode;
5use crate::{
6 Encode,
7 formats::Format,
8 io::{IoRead, IoWrite},
9};
10
11const U8_MAX: usize = u8::MAX as usize;
12const U16_MAX: usize = u16::MAX as usize;
13const U32_MAX: usize = u32::MAX as usize;
14const U8_MAX_PLUS_ONE: usize = U8_MAX + 1;
15const U16_MAX_PLUS_ONE: usize = U16_MAX + 1;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
22pub struct ExtensionRef<'a> {
23 pub r#type: i8,
25 pub data: &'a [u8],
27}
28
29impl<'a> ExtensionRef<'a> {
30 pub fn new(r#type: i8, data: &'a [u8]) -> Self {
32 Self { r#type, data }
33 }
34
35 pub fn to_format<E>(&self) -> core::result::Result<Format, encode::Error<E>> {
40 let format = match self.data.len() {
41 1 => Format::FixExt1,
42 2 => Format::FixExt2,
43 4 => Format::FixExt4,
44 8 => Format::FixExt8,
45 16 => Format::FixExt16,
46 0..=U8_MAX => Format::Ext8,
47 U8_MAX_PLUS_ONE..=U16_MAX => Format::Ext16,
48 U16_MAX_PLUS_ONE..=U32_MAX => Format::Ext32,
49 _ => return Err(encode::Error::InvalidFormat),
50 };
51 Ok(format)
52 }
53}
54
55impl<'a, W: IoWrite> Encode<W> for ExtensionRef<'a> {
56 fn encode(&self, writer: &mut W) -> core::result::Result<usize, encode::Error<W::Error>> {
57 let data_len = self.data.len();
58 let type_byte = self.r#type.to_be_bytes()[0];
59
60 match data_len {
61 1 => {
62 writer.write(&[Format::FixExt1.as_byte(), type_byte])?;
63 writer.write(self.data)?;
64 Ok(2 + data_len)
65 }
66 2 => {
67 writer.write(&[Format::FixExt2.as_byte(), type_byte])?;
68 writer.write(self.data)?;
69 Ok(2 + data_len)
70 }
71 4 => {
72 writer.write(&[Format::FixExt4.as_byte(), type_byte])?;
73 writer.write(self.data)?;
74 Ok(2 + data_len)
75 }
76 8 => {
77 writer.write(&[Format::FixExt8.as_byte(), type_byte])?;
78 writer.write(self.data)?;
79 Ok(2 + data_len)
80 }
81 16 => {
82 writer.write(&[Format::FixExt16.as_byte(), type_byte])?;
83 writer.write(self.data)?;
84 Ok(2 + data_len)
85 }
86 0..=0xff => {
87 let cast = data_len as u8;
88 writer.write(&[Format::Ext8.as_byte(), cast, type_byte])?;
89 writer.write(self.data)?;
90 Ok(3 + data_len)
91 }
92 0x100..=U16_MAX => {
93 let cast = (data_len as u16).to_be_bytes();
94 writer.write(&[Format::Ext16.as_byte(), cast[0], cast[1], type_byte])?;
95 writer.write(self.data)?;
96 Ok(4 + data_len)
97 }
98 0x10000..=U32_MAX => {
99 let cast = (data_len as u32).to_be_bytes();
100 writer.write(&[
101 Format::Ext32.as_byte(),
102 cast[0],
103 cast[1],
104 cast[2],
105 cast[3],
106 type_byte,
107 ])?;
108 writer.write(self.data)?;
109 Ok(6 + data_len)
110 }
111 _ => Err(encode::Error::InvalidFormat),
112 }
113 }
114}
115
116impl<'de> DecodeBorrowed<'de> for ExtensionRef<'de> {
117 type Value = ExtensionRef<'de>;
118
119 fn decode_borrowed_with_format<R>(
120 format: Format,
121 reader: &mut R,
122 ) -> core::result::Result<Self::Value, decode::Error<R::Error>>
123 where
124 R: IoRead<'de>,
125 {
126 let len = match format {
127 Format::FixExt1 => 1,
128 Format::FixExt2 => 2,
129 Format::FixExt4 => 4,
130 Format::FixExt8 => 8,
131 Format::FixExt16 => 16,
132 Format::Ext8 => NbyteReader::<1>::read(reader)?,
133 Format::Ext16 => NbyteReader::<2>::read(reader)?,
134 Format::Ext32 => NbyteReader::<4>::read(reader)?,
135 _ => return Err(decode::Error::UnexpectedFormat),
136 };
137 let ext_type: [u8; 1] = reader
138 .read_slice(1)
139 .map_err(decode::Error::Io)?
140 .as_bytes()
141 .try_into()
142 .map_err(|_| decode::Error::UnexpectedEof)?;
143 let ext_type = ext_type[0] as i8;
144
145 let data_ref = reader.read_slice(len).map_err(decode::Error::Io)?;
146 let data = match data_ref {
147 crate::io::Reference::Borrowed(b) => b,
148 crate::io::Reference::Copied(_) => return Err(decode::Error::InvalidData),
149 };
150 Ok(ExtensionRef {
151 r#type: ext_type,
152 data,
153 })
154 }
155}
156
157#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
165pub struct FixedExtension<const N: usize> {
166 pub r#type: i8,
168 len: usize,
169 data: [u8; N],
170}
171
172impl<const N: usize> FixedExtension<N> {
173 pub fn new(r#type: i8, data: &[u8]) -> Option<Self> {
178 if data.len() > N {
179 return None;
180 }
181 let mut buf = [0u8; N];
182 buf[..data.len()].copy_from_slice(data);
183 Some(Self {
184 r#type,
185 len: data.len(),
186 data: buf,
187 })
188 }
189
190 pub fn new_fixed(r#type: i8, len: usize, data: [u8; N]) -> Self {
196 Self { r#type, len, data }
197 }
198
199 pub fn as_ref(&self) -> ExtensionRef<'_> {
201 ExtensionRef {
202 r#type: self.r#type,
203 data: &self.data[..self.len],
204 }
205 }
206
207 pub fn len(&self) -> usize {
209 self.len
210 }
211
212 pub fn is_empty(&self) -> bool {
214 self.len == 0
215 }
216
217 pub fn as_slice(&self) -> &[u8] {
219 &self.data[..self.len]
220 }
221
222 pub fn as_mut_slice(&mut self) -> &mut [u8] {
224 &mut self.data[..self.len]
225 }
226}
227
228#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
230pub struct TryFromExtensionRefError(());
231
232impl core::fmt::Display for TryFromExtensionRefError {
233 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
234 write!(f, "extension data exceeds capacity")
235 }
236}
237
238impl core::error::Error for TryFromExtensionRefError {}
239
240impl<const N: usize> TryFrom<ExtensionRef<'_>> for FixedExtension<N> {
241 type Error = TryFromExtensionRefError;
242
243 fn try_from(value: ExtensionRef<'_>) -> Result<Self, Self::Error> {
244 if value.data.len() > N {
245 return Err(TryFromExtensionRefError(()));
246 }
247 let mut buf = [0u8; N];
248 buf[..value.data.len()].copy_from_slice(value.data);
249 Ok(Self {
250 r#type: value.r#type,
251 len: value.data.len(),
252 data: buf,
253 })
254 }
255}
256
257impl<const N: usize, W: IoWrite> Encode<W> for FixedExtension<N> {
258 fn encode(&self, writer: &mut W) -> core::result::Result<usize, encode::Error<W::Error>> {
259 self.as_ref().encode(writer)
260 }
261}
262
263impl<'de, const N: usize> DecodeBorrowed<'de> for FixedExtension<N> {
264 type Value = FixedExtension<N>;
265
266 fn decode_borrowed_with_format<R>(
267 format: Format,
268 reader: &mut R,
269 ) -> core::result::Result<Self::Value, decode::Error<R::Error>>
270 where
271 R: IoRead<'de>,
272 {
273 let ext = ExtensionRef::decode_borrowed_with_format(format, reader)?;
274 if ext.data.len() > N {
275 return Err(decode::Error::InvalidData);
276 }
277 let mut buf_arr = [0u8; N];
278 buf_arr[..ext.data.len()].copy_from_slice(ext.data);
279 Ok(FixedExtension {
280 r#type: ext.r#type,
281 len: ext.data.len(),
282 data: buf_arr,
283 })
284 }
285}
286
287#[cfg(feature = "alloc")]
288mod owned {
289 use super::*;
290
291 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
293 pub struct ExtensionOwned {
294 pub r#type: i8,
296 pub data: alloc::vec::Vec<u8>,
298 }
299
300 impl ExtensionOwned {
301 pub fn new(r#type: i8, data: alloc::vec::Vec<u8>) -> Self {
303 Self { r#type, data }
304 }
305
306 pub fn as_ref(&self) -> ExtensionRef<'_> {
308 ExtensionRef {
309 r#type: self.r#type,
310 data: &self.data,
311 }
312 }
313 }
314
315 impl<'a> From<ExtensionRef<'a>> for ExtensionOwned {
316 fn from(value: ExtensionRef<'a>) -> Self {
317 Self {
318 r#type: value.r#type,
319 data: value.data.to_vec(),
320 }
321 }
322 }
323
324 impl<const N: usize> From<FixedExtension<N>> for ExtensionOwned {
325 fn from(value: FixedExtension<N>) -> Self {
326 Self {
327 r#type: value.r#type,
328 data: value.as_slice().to_vec(),
329 }
330 }
331 }
332
333 impl<W: IoWrite> Encode<W> for ExtensionOwned {
334 fn encode(&self, writer: &mut W) -> core::result::Result<usize, encode::Error<W::Error>> {
335 self.as_ref().encode(writer)
336 }
337 }
338
339 impl<'de> DecodeBorrowed<'de> for ExtensionOwned {
340 type Value = ExtensionOwned;
341
342 fn decode_borrowed_with_format<R>(
343 format: Format,
344 reader: &mut R,
345 ) -> core::result::Result<Self::Value, decode::Error<R::Error>>
346 where
347 R: crate::io::IoRead<'de>,
348 {
349 let ext = ExtensionRef::decode_borrowed_with_format(format, reader)?;
350 Ok(ExtensionOwned {
351 r#type: ext.r#type,
352 data: ext.data.to_vec(),
353 })
354 }
355 }
356}
357
358#[cfg(feature = "alloc")]
359pub use owned::ExtensionOwned;
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use crate::decode::Decode;
365 use rstest::rstest;
366
367 #[rstest]
368 #[case(0xd4,123,[0x12])]
369 #[case(0xd5,123,[0x12,0x34])]
370 #[case(0xd6,123,[0x12,0x34,0x56,0x78])]
371 #[case(0xd7,123,[0x12;8])]
372 #[case(0xd8,123,[0x12;16])]
373 fn encode_ext_fixed<D: AsRef<[u8]>>(#[case] marker: u8, #[case] ty: i8, #[case] data: D) {
374 let expected = marker
375 .to_be_bytes()
376 .iter()
377 .chain(ty.to_be_bytes().iter())
378 .chain(data.as_ref())
379 .cloned()
380 .collect::<Vec<_>>();
381
382 let encoder = ExtensionRef::new(ty, data.as_ref());
383
384 let mut buf = vec![];
385 let n = encoder.encode(&mut buf).unwrap();
386
387 assert_eq!(&buf, &expected);
388 assert_eq!(n, expected.len());
389 }
390
391 #[rstest]
392 #[case(0xc7_u8.to_be_bytes(),123,5u8.to_be_bytes(),[0x12;5])]
393 #[case(0xc8_u8.to_be_bytes(),123,65535_u16.to_be_bytes(),[0x34;65535])]
394 #[case(0xc9_u8.to_be_bytes(),123,65536_u32.to_be_bytes(),[0x56;65536])]
395 fn encode_ext_sized<M: AsRef<[u8]>, S: AsRef<[u8]>, D: AsRef<[u8]>>(
396 #[case] marker: M,
397 #[case] ty: i8,
398 #[case] size: S,
399 #[case] data: D,
400 ) {
401 let expected = marker
402 .as_ref()
403 .iter()
404 .chain(size.as_ref())
405 .chain(ty.to_be_bytes().iter())
406 .chain(data.as_ref())
407 .cloned()
408 .collect::<Vec<_>>();
409
410 let encoder = ExtensionRef::new(ty, data.as_ref());
411
412 let mut buf = vec![];
413 let n = encoder.encode(&mut buf).unwrap();
414
415 assert_eq!(&buf, &expected);
416 assert_eq!(n, expected.len());
417 }
418
419 #[rstest]
420 #[case(Format::FixExt1.as_byte(), 5_i8, [0x12])]
421 #[case(Format::FixExt2.as_byte(), -1_i8, [0x34, 0x56])]
422 #[case(Format::FixExt4.as_byte(), 42_i8, [0xde, 0xad, 0xbe, 0xef])]
423 #[case(Format::FixExt8.as_byte(), -7_i8, [0xAA; 8])]
424 #[case(Format::FixExt16.as_byte(), 7_i8, [0x55; 16])]
425 fn decode_ext_fixed<E: AsRef<[u8]>>(#[case] marker: u8, #[case] ty: i8, #[case] data: E) {
426 let buf = core::iter::once(marker)
428 .chain(core::iter::once(ty as u8))
429 .chain(data.as_ref().iter().cloned())
430 .collect::<Vec<u8>>();
431
432 let mut r = crate::io::SliceReader::new(&buf);
433 let ext = ExtensionRef::decode(&mut r).unwrap();
434 assert_eq!(ext.r#type, ty);
435 assert_eq!(ext.data, data.as_ref());
436 assert!(r.rest().is_empty());
437 }
438
439 #[rstest]
440 #[case(Format::Ext8, 42_i8, 5u8.to_be_bytes(), [0x11;5])] #[case(Format::Ext16, -7_i8, 300u16.to_be_bytes(), [0xAA;300])] #[case(Format::Ext32, 7_i8, 70000u32.to_be_bytes(), [0x55;70000])] fn decode_ext_sized<S: AsRef<[u8]>, D: AsRef<[u8]>>(
444 #[case] format: Format,
445 #[case] ty: i8,
446 #[case] size: S,
447 #[case] data: D,
448 ) {
449 let buf = format
451 .as_slice()
452 .iter()
453 .chain(size.as_ref())
454 .chain(ty.to_be_bytes().iter())
455 .chain(data.as_ref())
456 .cloned()
457 .collect::<Vec<_>>();
458
459 let mut r = crate::io::SliceReader::new(&buf);
460 let ext = ExtensionRef::decode(&mut r).unwrap();
461 assert_eq!(ext.r#type, ty);
462 assert_eq!(ext.data, data.as_ref());
463 assert!(r.rest().is_empty());
464 }
465
466 #[rstest]
467 fn fixed_extension_roundtrip() {
468 let data = [1u8, 2, 3, 4];
469 let ext = FixedExtension::<8>::new(5, &data).unwrap();
470 let mut buf = vec![];
471 ext.encode(&mut buf).unwrap();
472 let mut r = crate::io::SliceReader::new(&buf);
473 let decoded = FixedExtension::<8>::decode(&mut r).unwrap();
474 assert_eq!(decoded.r#type, 5);
475 assert_eq!(decoded.as_slice(), &data);
476 assert!(r.rest().is_empty());
477 }
478}