1use crate::decode::{self, NbyteReader};
4use crate::encode;
5use crate::{Decode, Encode, formats::Format, io::IoWrite};
6
7const U8_MAX: usize = u8::MAX as usize;
8const U16_MAX: usize = u16::MAX as usize;
9const U32_MAX: usize = u32::MAX as usize;
10const U8_MAX_PLUS_ONE: usize = U8_MAX + 1;
11const U16_MAX_PLUS_ONE: usize = U16_MAX + 1;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
18pub struct ExtensionRef<'a> {
19 pub r#type: i8,
21 pub data: &'a [u8],
23}
24
25impl<'a> ExtensionRef<'a> {
26 pub fn new(r#type: i8, data: &'a [u8]) -> Self {
28 Self { r#type, data }
29 }
30
31 pub fn to_format<E>(&self) -> core::result::Result<Format, encode::Error<E>> {
36 let format = match self.data.len() {
37 1 => Format::FixExt1,
38 2 => Format::FixExt2,
39 4 => Format::FixExt4,
40 8 => Format::FixExt8,
41 16 => Format::FixExt16,
42 0..=U8_MAX => Format::Ext8,
43 U8_MAX_PLUS_ONE..=U16_MAX => Format::Ext16,
44 U16_MAX_PLUS_ONE..=U32_MAX => Format::Ext32,
45 _ => return Err(encode::Error::InvalidFormat),
46 };
47 Ok(format)
48 }
49}
50
51impl<'a, W: IoWrite> Encode<W> for ExtensionRef<'a> {
52 fn encode(&self, writer: &mut W) -> core::result::Result<usize, encode::Error<W::Error>> {
53 let data_len = self.data.len();
54 let type_byte = self.r#type.to_be_bytes()[0];
55
56 match data_len {
57 1 => {
58 writer.write(&[Format::FixExt1.as_byte(), type_byte])?;
59 writer.write(self.data)?;
60 Ok(2 + data_len)
61 }
62 2 => {
63 writer.write(&[Format::FixExt2.as_byte(), type_byte])?;
64 writer.write(self.data)?;
65 Ok(2 + data_len)
66 }
67 4 => {
68 writer.write(&[Format::FixExt4.as_byte(), type_byte])?;
69 writer.write(self.data)?;
70 Ok(2 + data_len)
71 }
72 8 => {
73 writer.write(&[Format::FixExt8.as_byte(), type_byte])?;
74 writer.write(self.data)?;
75 Ok(2 + data_len)
76 }
77 16 => {
78 writer.write(&[Format::FixExt16.as_byte(), type_byte])?;
79 writer.write(self.data)?;
80 Ok(2 + data_len)
81 }
82 0..=0xff => {
83 let cast = data_len as u8;
84 writer.write(&[Format::Ext8.as_byte(), cast, type_byte])?;
85 writer.write(self.data)?;
86 Ok(3 + data_len)
87 }
88 0x100..=U16_MAX => {
89 let cast = (data_len as u16).to_be_bytes();
90 writer.write(&[Format::Ext16.as_byte(), cast[0], cast[1], type_byte])?;
91 writer.write(self.data)?;
92 Ok(4 + data_len)
93 }
94 0x10000..=U32_MAX => {
95 let cast = (data_len as u32).to_be_bytes();
96 writer.write(&[
97 Format::Ext32.as_byte(),
98 cast[0],
99 cast[1],
100 cast[2],
101 cast[3],
102 type_byte,
103 ])?;
104 writer.write(self.data)?;
105 Ok(6 + data_len)
106 }
107 _ => Err(encode::Error::InvalidFormat),
108 }
109 }
110}
111
112impl<'a> Decode<'a> for ExtensionRef<'a> {
113 type Value = ExtensionRef<'a>;
114
115 fn decode(buf: &'a [u8]) -> core::result::Result<(Self::Value, &'a [u8]), decode::Error> {
116 let (format, buf) = Format::decode(buf)?;
117 match format {
118 Format::FixExt1
119 | Format::FixExt2
120 | Format::FixExt4
121 | Format::FixExt8
122 | Format::FixExt16
123 | Format::Ext8
124 | Format::Ext16
125 | Format::Ext32 => Self::decode_with_format(format, buf),
126 _ => Err(decode::Error::UnexpectedFormat),
127 }
128 }
129
130 fn decode_with_format(
131 format: Format,
132 buf: &'a [u8],
133 ) -> core::result::Result<(Self::Value, &'a [u8]), decode::Error> {
134 let (len, buf) = match format {
135 Format::FixExt1 => (1, buf),
136 Format::FixExt2 => (2, buf),
137 Format::FixExt4 => (4, buf),
138 Format::FixExt8 => (8, buf),
139 Format::FixExt16 => (16, buf),
140 Format::Ext8 => NbyteReader::<1>::read(buf)?,
141 Format::Ext16 => NbyteReader::<2>::read(buf)?,
142 Format::Ext32 => NbyteReader::<4>::read(buf)?,
143 _ => return Err(decode::Error::UnexpectedFormat),
144 };
145 let (ext_type, buf) = buf.split_first().ok_or(decode::Error::EofData)?;
146 let (data, rest) = buf.split_at_checked(len).ok_or(decode::Error::EofData)?;
147 let ext = ExtensionRef {
148 r#type: (*ext_type) as i8,
149 data,
150 };
151 Ok((ext, rest))
152 }
153}
154
155#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
163pub struct FixedExtension<const N: usize> {
164 pub r#type: i8,
166 len: usize,
167 data: [u8; N],
168}
169
170impl<const N: usize> FixedExtension<N> {
171 pub fn new(r#type: i8, data: &[u8]) -> Option<Self> {
176 if data.len() > N {
177 return None;
178 }
179 let mut buf = [0u8; N];
180 buf[..data.len()].copy_from_slice(data);
181 Some(Self {
182 r#type,
183 len: data.len(),
184 data: buf,
185 })
186 }
187
188 pub fn new_fixed(r#type: i8, data: [u8; N]) -> Self {
194 Self {
195 r#type,
196 len: N,
197 data,
198 }
199 }
200
201 pub fn as_ref(&self) -> ExtensionRef<'_> {
203 ExtensionRef {
204 r#type: self.r#type,
205 data: &self.data[..self.len],
206 }
207 }
208
209 pub fn len(&self) -> usize {
211 self.len
212 }
213
214 pub fn is_empty(&self) -> bool {
216 self.len == 0
217 }
218
219 pub fn data(&self) -> &[u8] {
221 &self.data[..self.len]
222 }
223}
224
225impl<const N: usize, W: IoWrite> Encode<W> for FixedExtension<N> {
226 fn encode(&self, writer: &mut W) -> core::result::Result<usize, encode::Error<W::Error>> {
227 self.as_ref().encode(writer)
228 }
229}
230
231impl<'a, const N: usize> Decode<'a> for FixedExtension<N> {
232 type Value = FixedExtension<N>;
233
234 fn decode(buf: &'a [u8]) -> core::result::Result<(Self::Value, &'a [u8]), decode::Error> {
235 let (ext, rest) = ExtensionRef::decode(buf)?;
236 if ext.data.len() > N {
237 return Err(decode::Error::InvalidData);
238 }
239 let mut buf_arr = [0u8; N];
240 buf_arr[..ext.data.len()].copy_from_slice(ext.data);
241 Ok((
242 FixedExtension {
243 r#type: ext.r#type,
244 len: ext.data.len(),
245 data: buf_arr,
246 },
247 rest,
248 ))
249 }
250
251 fn decode_with_format(
252 format: Format,
253 buf: &'a [u8],
254 ) -> core::result::Result<(Self::Value, &'a [u8]), decode::Error> {
255 let (ext, rest) = ExtensionRef::decode_with_format(format, buf)?;
256 if ext.data.len() > N {
257 return Err(decode::Error::InvalidData);
258 }
259 let mut buf_arr = [0u8; N];
260 buf_arr[..ext.data.len()].copy_from_slice(ext.data);
261 Ok((
262 FixedExtension {
263 r#type: ext.r#type,
264 len: ext.data.len(),
265 data: buf_arr,
266 },
267 rest,
268 ))
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use rstest::rstest;
276
277 #[rstest]
278 #[case(0xd4,123,[0x12])]
279 #[case(0xd5,123,[0x12,0x34])]
280 #[case(0xd6,123,[0x12,0x34,0x56,0x78])]
281 #[case(0xd7,123,[0x12;8])]
282 #[case(0xd8,123,[0x12;16])]
283 fn encode_ext_fixed<D: AsRef<[u8]>>(#[case] marker: u8, #[case] ty: i8, #[case] data: D) {
284 let expected = marker
285 .to_be_bytes()
286 .iter()
287 .chain(ty.to_be_bytes().iter())
288 .chain(data.as_ref())
289 .cloned()
290 .collect::<Vec<_>>();
291
292 let encoder = ExtensionRef::new(ty, data.as_ref());
293
294 let mut buf = vec![];
295 let n = encoder.encode(&mut buf).unwrap();
296
297 assert_eq!(&buf, &expected);
298 assert_eq!(n, expected.len());
299 }
300
301 #[rstest]
302 #[case(0xc7_u8.to_be_bytes(),123,5u8.to_be_bytes(),[0x12;5])]
303 #[case(0xc8_u8.to_be_bytes(),123,65535_u16.to_be_bytes(),[0x34;65535])]
304 #[case(0xc9_u8.to_be_bytes(),123,65536_u32.to_be_bytes(),[0x56;65536])]
305 fn encode_ext_sized<M: AsRef<[u8]>, S: AsRef<[u8]>, D: AsRef<[u8]>>(
306 #[case] marker: M,
307 #[case] ty: i8,
308 #[case] size: S,
309 #[case] data: D,
310 ) {
311 let expected = marker
312 .as_ref()
313 .iter()
314 .chain(size.as_ref())
315 .chain(ty.to_be_bytes().iter())
316 .chain(data.as_ref())
317 .cloned()
318 .collect::<Vec<_>>();
319
320 let encoder = ExtensionRef::new(ty, data.as_ref());
321
322 let mut buf = vec![];
323 let n = encoder.encode(&mut buf).unwrap();
324
325 assert_eq!(&buf, &expected);
326 assert_eq!(n, expected.len());
327 }
328
329 #[rstest]
330 #[case(Format::FixExt1.as_byte(), 5_i8, [0x12])]
331 #[case(Format::FixExt2.as_byte(), -1_i8, [0x34, 0x56])]
332 #[case(Format::FixExt4.as_byte(), 42_i8, [0xde, 0xad, 0xbe, 0xef])]
333 #[case(Format::FixExt8.as_byte(), -7_i8, [0xAA; 8])]
334 #[case(Format::FixExt16.as_byte(), 7_i8, [0x55; 16])]
335 fn decode_ext_fixed<E: AsRef<[u8]>>(#[case] marker: u8, #[case] ty: i8, #[case] data: E) {
336 let buf = core::iter::once(marker)
338 .chain(core::iter::once(ty as u8))
339 .chain(data.as_ref().iter().cloned())
340 .collect::<Vec<u8>>();
341
342 let (ext, rest) = ExtensionRef::decode(&buf).unwrap();
343 assert_eq!(ext.r#type, ty);
344 assert_eq!(ext.data, data.as_ref());
345 assert!(rest.is_empty());
346 }
347
348 #[rstest]
349 #[case(Format::Ext8, 42_i8, 5u8.to_be_bytes(), [0x11;5])] #[case(Format::Ext16, -7_i8, 300u16.to_be_bytes(), [0xAA;300])] #[case(Format::Ext32, 7_i8, 70000u32.to_be_bytes(), [0x55;70000])] fn decode_ext_sized<S: AsRef<[u8]>, D: AsRef<[u8]>>(
353 #[case] format: Format,
354 #[case] ty: i8,
355 #[case] size: S,
356 #[case] data: D,
357 ) {
358 let buf = format
360 .as_slice()
361 .iter()
362 .chain(size.as_ref())
363 .chain(ty.to_be_bytes().iter())
364 .chain(data.as_ref())
365 .cloned()
366 .collect::<Vec<_>>();
367
368 let (ext, rest) = ExtensionRef::decode(&buf).unwrap();
369 assert_eq!(ext.r#type, ty);
370 assert_eq!(ext.data, data.as_ref());
371 assert!(rest.is_empty());
372 }
373
374 #[rstest]
375 fn fixed_extension_roundtrip() {
376 let data = [1u8, 2, 3, 4];
377 let ext = FixedExtension::<8>::new(5, &data).unwrap();
378 let mut buf = vec![];
379 ext.encode(&mut buf).unwrap();
380 let (decoded, rest) = FixedExtension::<8>::decode(&buf).unwrap();
381 assert_eq!(decoded.r#type, 5);
382 assert_eq!(decoded.data(), &data);
383 assert!(rest.is_empty());
384 }
385}