messagepack_core/
extension.rs

1//! MessagePack extension helpers.
2
3use crate::decode::{self, DecodeBorrowed, NbyteReader};
4use crate::encode;
5use crate::{
6    Encode,
7    formats::Format,
8    io::{IoRead, IoWrite},
9};
10
11const U8_MAX: usize = u8::MAX as usize;
12const U16_MAX: usize = u16::MAX as usize;
13const U32_MAX: usize = u32::MAX as usize;
14const U8_MAX_PLUS_ONE: usize = U8_MAX + 1;
15const U16_MAX_PLUS_ONE: usize = U16_MAX + 1;
16
17/// A borrowed view of a MessagePack extension value.
18///
19/// Note that the MessagePack header (FixExt vs Ext8/16/32) is determined by the
20/// payload length when encoding. See [`ExtensionRef::to_format`].
21#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
22pub struct ExtensionRef<'a> {
23    /// Application‑defined extension type code.
24    pub r#type: i8,
25    /// Borrowed payload bytes.
26    pub data: &'a [u8],
27}
28
29impl<'a> ExtensionRef<'a> {
30    /// Create a borrowed reference to extension data with the given type code.
31    pub fn new(r#type: i8, data: &'a [u8]) -> Self {
32        Self { r#type, data }
33    }
34
35    /// Decide the MessagePack format to use given the payload length.
36    ///
37    /// - If `data.len()` is exactly 1, 2, 4, 8 or 16, `FixExtN` is selected.
38    /// - Otherwise, `Ext8`/`Ext16`/`Ext32` is selected based on the byte length.
39    pub fn to_format<E>(&self) -> core::result::Result<Format, encode::Error<E>> {
40        let format = match self.data.len() {
41            1 => Format::FixExt1,
42            2 => Format::FixExt2,
43            4 => Format::FixExt4,
44            8 => Format::FixExt8,
45            16 => Format::FixExt16,
46            0..=U8_MAX => Format::Ext8,
47            U8_MAX_PLUS_ONE..=U16_MAX => Format::Ext16,
48            U16_MAX_PLUS_ONE..=U32_MAX => Format::Ext32,
49            _ => return Err(encode::Error::InvalidFormat),
50        };
51        Ok(format)
52    }
53}
54
55impl<'a, W: IoWrite> Encode<W> for ExtensionRef<'a> {
56    fn encode(&self, writer: &mut W) -> core::result::Result<usize, encode::Error<W::Error>> {
57        let data_len = self.data.len();
58        let type_byte = self.r#type.to_be_bytes()[0];
59
60        match data_len {
61            1 => {
62                writer.write(&[Format::FixExt1.as_byte(), type_byte])?;
63                writer.write(self.data)?;
64                Ok(2 + data_len)
65            }
66            2 => {
67                writer.write(&[Format::FixExt2.as_byte(), type_byte])?;
68                writer.write(self.data)?;
69                Ok(2 + data_len)
70            }
71            4 => {
72                writer.write(&[Format::FixExt4.as_byte(), type_byte])?;
73                writer.write(self.data)?;
74                Ok(2 + data_len)
75            }
76            8 => {
77                writer.write(&[Format::FixExt8.as_byte(), type_byte])?;
78                writer.write(self.data)?;
79                Ok(2 + data_len)
80            }
81            16 => {
82                writer.write(&[Format::FixExt16.as_byte(), type_byte])?;
83                writer.write(self.data)?;
84                Ok(2 + data_len)
85            }
86            0..=0xff => {
87                let cast = data_len as u8;
88                writer.write(&[Format::Ext8.as_byte(), cast, type_byte])?;
89                writer.write(self.data)?;
90                Ok(3 + data_len)
91            }
92            0x100..=U16_MAX => {
93                let cast = (data_len as u16).to_be_bytes();
94                writer.write(&[Format::Ext16.as_byte(), cast[0], cast[1], type_byte])?;
95                writer.write(self.data)?;
96                Ok(4 + data_len)
97            }
98            0x10000..=U32_MAX => {
99                let cast = (data_len as u32).to_be_bytes();
100                writer.write(&[
101                    Format::Ext32.as_byte(),
102                    cast[0],
103                    cast[1],
104                    cast[2],
105                    cast[3],
106                    type_byte,
107                ])?;
108                writer.write(self.data)?;
109                Ok(6 + data_len)
110            }
111            _ => Err(encode::Error::InvalidFormat),
112        }
113    }
114}
115
116impl<'de> DecodeBorrowed<'de> for ExtensionRef<'de> {
117    type Value = ExtensionRef<'de>;
118
119    fn decode_borrowed_with_format<R>(
120        format: Format,
121        reader: &mut R,
122    ) -> core::result::Result<Self::Value, decode::Error<R::Error>>
123    where
124        R: IoRead<'de>,
125    {
126        let len = match format {
127            Format::FixExt1 => 1,
128            Format::FixExt2 => 2,
129            Format::FixExt4 => 4,
130            Format::FixExt8 => 8,
131            Format::FixExt16 => 16,
132            Format::Ext8 => NbyteReader::<1>::read(reader)?,
133            Format::Ext16 => NbyteReader::<2>::read(reader)?,
134            Format::Ext32 => NbyteReader::<4>::read(reader)?,
135            _ => return Err(decode::Error::UnexpectedFormat),
136        };
137        let ext_type: [u8; 1] = reader
138            .read_slice(1)
139            .map_err(decode::Error::Io)?
140            .as_bytes()
141            .try_into()
142            .map_err(|_| decode::Error::UnexpectedEof)?;
143        let ext_type = ext_type[0] as i8;
144
145        let data_ref = reader.read_slice(len).map_err(decode::Error::Io)?;
146        let data = match data_ref {
147            crate::io::Reference::Borrowed(b) => b,
148            crate::io::Reference::Copied(_) => return Err(decode::Error::InvalidData),
149        };
150        Ok(ExtensionRef {
151            r#type: ext_type,
152            data,
153        })
154    }
155}
156
157/// A fixed-capacity container for extension payloads of up to `N` bytes.
158///
159/// This type name refers to the fixed-size backing buffer, not the MessagePack
160/// header kind. The actual header used at encode-time depends on the current
161/// payload length:
162/// - `len == 1, 2, 4, 8, 16` → `FixExtN`
163/// - otherwise (0..=255, 256..=65535, 65536..=u32::MAX) → `Ext8/16/32`
164#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
165pub struct FixedExtension<const N: usize> {
166    /// Application‑defined extension type code.
167    pub r#type: i8,
168    len: usize,
169    data: [u8; N],
170}
171
172impl<const N: usize> FixedExtension<N> {
173    /// Construct from a slice whose length must be `<= N`.
174    ///
175    /// The chosen MessagePack format when encoding still follows the rules
176    /// described in the type-level documentation above.
177    pub fn new(r#type: i8, data: &[u8]) -> Option<Self> {
178        if data.len() > N {
179            return None;
180        }
181        let mut buf = [0u8; N];
182        buf[..data.len()].copy_from_slice(data);
183        Some(Self {
184            r#type,
185            len: data.len(),
186            data: buf,
187        })
188    }
189
190    /// Construct with an exact `N`-byte payload.
191    ///
192    /// Note: Even when constructed with a fixed-size buffer, the encoder will
193    /// emit `FixExtN` only if `N` is one of {1, 2, 4, 8, 16}. For any other
194    /// `N`, the encoder uses `Ext8/16/32` as appropriate.
195    pub fn new_fixed(r#type: i8, len: usize, data: [u8; N]) -> Self {
196        Self { r#type, len, data }
197    }
198
199    /// Borrow as [`ExtensionRef`] for encoding.
200    pub fn as_ref(&self) -> ExtensionRef<'_> {
201        ExtensionRef {
202            r#type: self.r#type,
203            data: &self.data[..self.len],
204        }
205    }
206
207    /// Current payload length in bytes.
208    pub fn len(&self) -> usize {
209        self.len
210    }
211
212    /// Returns `true` if the payload is empty.
213    pub fn is_empty(&self) -> bool {
214        self.len == 0
215    }
216
217    /// Extract a slice
218    pub fn as_slice(&self) -> &[u8] {
219        &self.data[..self.len]
220    }
221
222    /// Extract a mutable slice
223    pub fn as_mut_slice(&mut self) -> &mut [u8] {
224        &mut self.data[..self.len]
225    }
226}
227
228/// The error type returned when a checked conversion from [`ExtensionRef`] fails
229#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
230pub struct TryFromExtensionRefError(());
231
232impl core::fmt::Display for TryFromExtensionRefError {
233    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
234        write!(f, "extension data exceeds capacity")
235    }
236}
237
238impl core::error::Error for TryFromExtensionRefError {}
239
240impl<const N: usize> TryFrom<ExtensionRef<'_>> for FixedExtension<N> {
241    type Error = TryFromExtensionRefError;
242
243    fn try_from(value: ExtensionRef<'_>) -> Result<Self, Self::Error> {
244        if value.data.len() > N {
245            return Err(TryFromExtensionRefError(()));
246        }
247        let mut buf = [0u8; N];
248        buf[..value.data.len()].copy_from_slice(value.data);
249        Ok(Self {
250            r#type: value.r#type,
251            len: value.data.len(),
252            data: buf,
253        })
254    }
255}
256
257impl<const N: usize, W: IoWrite> Encode<W> for FixedExtension<N> {
258    fn encode(&self, writer: &mut W) -> core::result::Result<usize, encode::Error<W::Error>> {
259        self.as_ref().encode(writer)
260    }
261}
262
263impl<'de, const N: usize> DecodeBorrowed<'de> for FixedExtension<N> {
264    type Value = FixedExtension<N>;
265
266    fn decode_borrowed_with_format<R>(
267        format: Format,
268        reader: &mut R,
269    ) -> core::result::Result<Self::Value, decode::Error<R::Error>>
270    where
271        R: IoRead<'de>,
272    {
273        let ext = ExtensionRef::decode_borrowed_with_format(format, reader)?;
274        if ext.data.len() > N {
275            return Err(decode::Error::InvalidData);
276        }
277        let mut buf_arr = [0u8; N];
278        buf_arr[..ext.data.len()].copy_from_slice(ext.data);
279        Ok(FixedExtension {
280            r#type: ext.r#type,
281            len: ext.data.len(),
282            data: buf_arr,
283        })
284    }
285}
286
287#[cfg(feature = "alloc")]
288mod owned {
289    use super::*;
290
291    /// An owned container for extension payloads.
292    #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
293    pub struct ExtensionOwned {
294        /// Application‑defined extension type code.
295        pub r#type: i8,
296        /// payload bytes.
297        pub data: alloc::vec::Vec<u8>,
298    }
299
300    impl ExtensionOwned {
301        /// Create an owned extension value with the given type code and payload.
302        pub fn new(r#type: i8, data: alloc::vec::Vec<u8>) -> Self {
303            Self { r#type, data }
304        }
305
306        /// Borrow as [`ExtensionRef`] for encoding.
307        pub fn as_ref(&self) -> ExtensionRef<'_> {
308            ExtensionRef {
309                r#type: self.r#type,
310                data: &self.data,
311            }
312        }
313    }
314
315    impl<'a> From<ExtensionRef<'a>> for ExtensionOwned {
316        fn from(value: ExtensionRef<'a>) -> Self {
317            Self {
318                r#type: value.r#type,
319                data: value.data.to_vec(),
320            }
321        }
322    }
323
324    impl<const N: usize> From<FixedExtension<N>> for ExtensionOwned {
325        fn from(value: FixedExtension<N>) -> Self {
326            Self {
327                r#type: value.r#type,
328                data: value.as_slice().to_vec(),
329            }
330        }
331    }
332
333    impl<W: IoWrite> Encode<W> for ExtensionOwned {
334        fn encode(&self, writer: &mut W) -> core::result::Result<usize, encode::Error<W::Error>> {
335            self.as_ref().encode(writer)
336        }
337    }
338
339    impl<'de> DecodeBorrowed<'de> for ExtensionOwned {
340        type Value = ExtensionOwned;
341
342        fn decode_borrowed_with_format<R>(
343            format: Format,
344            reader: &mut R,
345        ) -> core::result::Result<Self::Value, decode::Error<R::Error>>
346        where
347            R: crate::io::IoRead<'de>,
348        {
349            let ext = ExtensionRef::decode_borrowed_with_format(format, reader)?;
350            Ok(ExtensionOwned {
351                r#type: ext.r#type,
352                data: ext.data.to_vec(),
353            })
354        }
355    }
356}
357
358#[cfg(feature = "alloc")]
359pub use owned::ExtensionOwned;
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use crate::decode::Decode;
365    use rstest::rstest;
366
367    #[rstest]
368    #[case(0xd4,123,[0x12])]
369    #[case(0xd5,123,[0x12,0x34])]
370    #[case(0xd6,123,[0x12,0x34,0x56,0x78])]
371    #[case(0xd7,123,[0x12;8])]
372    #[case(0xd8,123,[0x12;16])]
373    fn encode_ext_fixed<D: AsRef<[u8]>>(#[case] marker: u8, #[case] ty: i8, #[case] data: D) {
374        let expected = marker
375            .to_be_bytes()
376            .iter()
377            .chain(ty.to_be_bytes().iter())
378            .chain(data.as_ref())
379            .cloned()
380            .collect::<Vec<_>>();
381
382        let encoder = ExtensionRef::new(ty, data.as_ref());
383
384        let mut buf = vec![];
385        let n = encoder.encode(&mut buf).unwrap();
386
387        assert_eq!(&buf, &expected);
388        assert_eq!(n, expected.len());
389    }
390
391    #[rstest]
392    #[case(0xc7_u8.to_be_bytes(),123,5u8.to_be_bytes(),[0x12;5])]
393    #[case(0xc8_u8.to_be_bytes(),123,65535_u16.to_be_bytes(),[0x34;65535])]
394    #[case(0xc9_u8.to_be_bytes(),123,65536_u32.to_be_bytes(),[0x56;65536])]
395    fn encode_ext_sized<M: AsRef<[u8]>, S: AsRef<[u8]>, D: AsRef<[u8]>>(
396        #[case] marker: M,
397        #[case] ty: i8,
398        #[case] size: S,
399        #[case] data: D,
400    ) {
401        let expected = marker
402            .as_ref()
403            .iter()
404            .chain(size.as_ref())
405            .chain(ty.to_be_bytes().iter())
406            .chain(data.as_ref())
407            .cloned()
408            .collect::<Vec<_>>();
409
410        let encoder = ExtensionRef::new(ty, data.as_ref());
411
412        let mut buf = vec![];
413        let n = encoder.encode(&mut buf).unwrap();
414
415        assert_eq!(&buf, &expected);
416        assert_eq!(n, expected.len());
417    }
418
419    #[rstest]
420    #[case(Format::FixExt1.as_byte(),  5_i8, [0x12])]
421    #[case(Format::FixExt2.as_byte(), -1_i8, [0x34, 0x56])]
422    #[case(Format::FixExt4.as_byte(), 42_i8, [0xde, 0xad, 0xbe, 0xef])]
423    #[case(Format::FixExt8.as_byte(), -7_i8, [0xAA; 8])]
424    #[case(Format::FixExt16.as_byte(), 7_i8, [0x55; 16])]
425    fn decode_ext_fixed<E: AsRef<[u8]>>(#[case] marker: u8, #[case] ty: i8, #[case] data: E) {
426        // Buffer: [FixExtN marker][type][data..]
427        let buf = core::iter::once(marker)
428            .chain(core::iter::once(ty as u8))
429            .chain(data.as_ref().iter().cloned())
430            .collect::<Vec<u8>>();
431
432        let mut r = crate::io::SliceReader::new(&buf);
433        let ext = ExtensionRef::decode(&mut r).unwrap();
434        assert_eq!(ext.r#type, ty);
435        assert_eq!(ext.data, data.as_ref());
436        assert!(r.rest().is_empty());
437    }
438
439    #[rstest]
440    #[case(Format::Ext8, 42_i8, 5u8.to_be_bytes(), [0x11;5])] // small: Ext8
441    #[case(Format::Ext16, -7_i8,   300u16.to_be_bytes(), [0xAA;300])] // medium: Ext16 (>255)
442    #[case(Format::Ext32, 7_i8, 70000u32.to_be_bytes(), [0x55;70000])] // large: Ext32 (>65535)
443    fn decode_ext_sized<S: AsRef<[u8]>, D: AsRef<[u8]>>(
444        #[case] format: Format,
445        #[case] ty: i8,
446        #[case] size: S,
447        #[case] data: D,
448    ) {
449        // MessagePack ext variable-length layout: [format][length][type][data]
450        let buf = format
451            .as_slice()
452            .iter()
453            .chain(size.as_ref())
454            .chain(ty.to_be_bytes().iter())
455            .chain(data.as_ref())
456            .cloned()
457            .collect::<Vec<_>>();
458
459        let mut r = crate::io::SliceReader::new(&buf);
460        let ext = ExtensionRef::decode(&mut r).unwrap();
461        assert_eq!(ext.r#type, ty);
462        assert_eq!(ext.data, data.as_ref());
463        assert!(r.rest().is_empty());
464    }
465
466    #[rstest]
467    fn fixed_extension_roundtrip() {
468        let data = [1u8, 2, 3, 4];
469        let ext = FixedExtension::<8>::new(5, &data).unwrap();
470        let mut buf = vec![];
471        ext.encode(&mut buf).unwrap();
472        let mut r = crate::io::SliceReader::new(&buf);
473        let decoded = FixedExtension::<8>::decode(&mut r).unwrap();
474        assert_eq!(decoded.r#type, 5);
475        assert_eq!(decoded.as_slice(), &data);
476        assert!(r.rest().is_empty());
477    }
478}