messagepack_core/
extension.rs

1//! MessagePack extension helpers.
2
3use crate::decode::{self, NbyteReader};
4use crate::encode;
5use crate::{Decode, Encode, formats::Format, io::IoWrite};
6
7const U8_MAX: usize = u8::MAX as usize;
8const U16_MAX: usize = u16::MAX as usize;
9const U32_MAX: usize = u32::MAX as usize;
10const U8_MAX_PLUS_ONE: usize = U8_MAX + 1;
11const U16_MAX_PLUS_ONE: usize = U16_MAX + 1;
12
13/// A borrowed view of a MessagePack extension value.
14///
15/// Note that the MessagePack header (FixExt vs Ext8/16/32) is determined by the
16/// payload length when encoding. See [`ExtensionRef::to_format`].
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
18pub struct ExtensionRef<'a> {
19    /// Application‑defined extension type code.
20    pub r#type: i8,
21    /// Borrowed payload bytes.
22    pub data: &'a [u8],
23}
24
25impl<'a> ExtensionRef<'a> {
26    /// Create a borrowed reference to extension data with the given type code.
27    pub fn new(r#type: i8, data: &'a [u8]) -> Self {
28        Self { r#type, data }
29    }
30
31    /// Decide the MessagePack format to use given the payload length.
32    ///
33    /// - If `data.len()` is exactly 1, 2, 4, 8 or 16, `FixExtN` is selected.
34    /// - Otherwise, `Ext8`/`Ext16`/`Ext32` is selected based on the byte length.
35    pub fn to_format<E>(&self) -> core::result::Result<Format, encode::Error<E>> {
36        let format = match self.data.len() {
37            1 => Format::FixExt1,
38            2 => Format::FixExt2,
39            4 => Format::FixExt4,
40            8 => Format::FixExt8,
41            16 => Format::FixExt16,
42            0..=U8_MAX => Format::Ext8,
43            U8_MAX_PLUS_ONE..=U16_MAX => Format::Ext16,
44            U16_MAX_PLUS_ONE..=U32_MAX => Format::Ext32,
45            _ => return Err(encode::Error::InvalidFormat),
46        };
47        Ok(format)
48    }
49}
50
51impl<'a, W: IoWrite> Encode<W> for ExtensionRef<'a> {
52    fn encode(&self, writer: &mut W) -> core::result::Result<usize, encode::Error<W::Error>> {
53        let data_len = self.data.len();
54        let type_byte = self.r#type.to_be_bytes()[0];
55
56        match data_len {
57            1 => {
58                writer.write(&[Format::FixExt1.as_byte(), type_byte])?;
59                writer.write(self.data)?;
60                Ok(2 + data_len)
61            }
62            2 => {
63                writer.write(&[Format::FixExt2.as_byte(), type_byte])?;
64                writer.write(self.data)?;
65                Ok(2 + data_len)
66            }
67            4 => {
68                writer.write(&[Format::FixExt4.as_byte(), type_byte])?;
69                writer.write(self.data)?;
70                Ok(2 + data_len)
71            }
72            8 => {
73                writer.write(&[Format::FixExt8.as_byte(), type_byte])?;
74                writer.write(self.data)?;
75                Ok(2 + data_len)
76            }
77            16 => {
78                writer.write(&[Format::FixExt16.as_byte(), type_byte])?;
79                writer.write(self.data)?;
80                Ok(2 + data_len)
81            }
82            0..=0xff => {
83                let cast = data_len as u8;
84                writer.write(&[Format::Ext8.as_byte(), cast, type_byte])?;
85                writer.write(self.data)?;
86                Ok(3 + data_len)
87            }
88            0x100..=U16_MAX => {
89                let cast = (data_len as u16).to_be_bytes();
90                writer.write(&[Format::Ext16.as_byte(), cast[0], cast[1], type_byte])?;
91                writer.write(self.data)?;
92                Ok(4 + data_len)
93            }
94            0x10000..=U32_MAX => {
95                let cast = (data_len as u32).to_be_bytes();
96                writer.write(&[
97                    Format::Ext32.as_byte(),
98                    cast[0],
99                    cast[1],
100                    cast[2],
101                    cast[3],
102                    type_byte,
103                ])?;
104                writer.write(self.data)?;
105                Ok(6 + data_len)
106            }
107            _ => Err(encode::Error::InvalidFormat),
108        }
109    }
110}
111
112impl<'a> Decode<'a> for ExtensionRef<'a> {
113    type Value = ExtensionRef<'a>;
114
115    fn decode(buf: &'a [u8]) -> core::result::Result<(Self::Value, &'a [u8]), decode::Error> {
116        let (format, buf) = Format::decode(buf)?;
117        match format {
118            Format::FixExt1
119            | Format::FixExt2
120            | Format::FixExt4
121            | Format::FixExt8
122            | Format::FixExt16
123            | Format::Ext8
124            | Format::Ext16
125            | Format::Ext32 => Self::decode_with_format(format, buf),
126            _ => Err(decode::Error::UnexpectedFormat),
127        }
128    }
129
130    fn decode_with_format(
131        format: Format,
132        buf: &'a [u8],
133    ) -> core::result::Result<(Self::Value, &'a [u8]), decode::Error> {
134        let (len, buf) = match format {
135            Format::FixExt1 => (1, buf),
136            Format::FixExt2 => (2, buf),
137            Format::FixExt4 => (4, buf),
138            Format::FixExt8 => (8, buf),
139            Format::FixExt16 => (16, buf),
140            Format::Ext8 => NbyteReader::<1>::read(buf)?,
141            Format::Ext16 => NbyteReader::<2>::read(buf)?,
142            Format::Ext32 => NbyteReader::<4>::read(buf)?,
143            _ => return Err(decode::Error::UnexpectedFormat),
144        };
145        let (ext_type, buf) = buf.split_first().ok_or(decode::Error::EofData)?;
146        let (data, rest) = buf.split_at_checked(len).ok_or(decode::Error::EofData)?;
147        let ext = ExtensionRef {
148            r#type: (*ext_type) as i8,
149            data,
150        };
151        Ok((ext, rest))
152    }
153}
154
155/// A fixed-capacity container for extension payloads of up to `N` bytes.
156///
157/// This type name refers to the fixed-size backing buffer, not the MessagePack
158/// header kind. The actual header used at encode-time depends on the current
159/// payload length:
160/// - `len == 1, 2, 4, 8, 16` → `FixExtN`
161/// - otherwise (0..=255, 256..=65535, 65536..=u32::MAX) → `Ext8/16/32`
162#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
163pub struct FixedExtension<const N: usize> {
164    /// Application‑defined extension type code.
165    pub r#type: i8,
166    len: usize,
167    data: [u8; N],
168}
169
170impl<const N: usize> FixedExtension<N> {
171    /// Construct from a slice whose length must be `<= N`.
172    ///
173    /// The chosen MessagePack format when encoding still follows the rules
174    /// described in the type-level documentation above.
175    pub fn new(r#type: i8, data: &[u8]) -> Option<Self> {
176        if data.len() > N {
177            return None;
178        }
179        let mut buf = [0u8; N];
180        buf[..data.len()].copy_from_slice(data);
181        Some(Self {
182            r#type,
183            len: data.len(),
184            data: buf,
185        })
186    }
187
188    /// Construct with an exact `N`-byte payload.
189    ///
190    /// Note: Even when constructed with a fixed-size buffer, the encoder will
191    /// emit `FixExtN` only if `N` is one of {1, 2, 4, 8, 16}. For any other
192    /// `N`, the encoder uses `Ext8/16/32` as appropriate.
193    pub fn new_fixed(r#type: i8, data: [u8; N]) -> Self {
194        Self {
195            r#type,
196            len: N,
197            data,
198        }
199    }
200
201    /// Borrow as [`ExtensionRef`] for encoding.
202    pub fn as_ref(&self) -> ExtensionRef<'_> {
203        ExtensionRef {
204            r#type: self.r#type,
205            data: &self.data[..self.len],
206        }
207    }
208
209    /// Current payload length in bytes.
210    pub fn len(&self) -> usize {
211        self.len
212    }
213
214    /// Returns `true` if the payload is empty.
215    pub fn is_empty(&self) -> bool {
216        self.len == 0
217    }
218
219    /// Borrow the payload slice.
220    pub fn data(&self) -> &[u8] {
221        &self.data[..self.len]
222    }
223}
224
225impl<const N: usize, W: IoWrite> Encode<W> for FixedExtension<N> {
226    fn encode(&self, writer: &mut W) -> core::result::Result<usize, encode::Error<W::Error>> {
227        self.as_ref().encode(writer)
228    }
229}
230
231impl<'a, const N: usize> Decode<'a> for FixedExtension<N> {
232    type Value = FixedExtension<N>;
233
234    fn decode(buf: &'a [u8]) -> core::result::Result<(Self::Value, &'a [u8]), decode::Error> {
235        let (ext, rest) = ExtensionRef::decode(buf)?;
236        if ext.data.len() > N {
237            return Err(decode::Error::InvalidData);
238        }
239        let mut buf_arr = [0u8; N];
240        buf_arr[..ext.data.len()].copy_from_slice(ext.data);
241        Ok((
242            FixedExtension {
243                r#type: ext.r#type,
244                len: ext.data.len(),
245                data: buf_arr,
246            },
247            rest,
248        ))
249    }
250
251    fn decode_with_format(
252        format: Format,
253        buf: &'a [u8],
254    ) -> core::result::Result<(Self::Value, &'a [u8]), decode::Error> {
255        let (ext, rest) = ExtensionRef::decode_with_format(format, buf)?;
256        if ext.data.len() > N {
257            return Err(decode::Error::InvalidData);
258        }
259        let mut buf_arr = [0u8; N];
260        buf_arr[..ext.data.len()].copy_from_slice(ext.data);
261        Ok((
262            FixedExtension {
263                r#type: ext.r#type,
264                len: ext.data.len(),
265                data: buf_arr,
266            },
267            rest,
268        ))
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use rstest::rstest;
276
277    #[rstest]
278    #[case(0xd4,123,[0x12])]
279    #[case(0xd5,123,[0x12,0x34])]
280    #[case(0xd6,123,[0x12,0x34,0x56,0x78])]
281    #[case(0xd7,123,[0x12;8])]
282    #[case(0xd8,123,[0x12;16])]
283    fn encode_ext_fixed<D: AsRef<[u8]>>(#[case] marker: u8, #[case] ty: i8, #[case] data: D) {
284        let expected = marker
285            .to_be_bytes()
286            .iter()
287            .chain(ty.to_be_bytes().iter())
288            .chain(data.as_ref())
289            .cloned()
290            .collect::<Vec<_>>();
291
292        let encoder = ExtensionRef::new(ty, data.as_ref());
293
294        let mut buf = vec![];
295        let n = encoder.encode(&mut buf).unwrap();
296
297        assert_eq!(&buf, &expected);
298        assert_eq!(n, expected.len());
299    }
300
301    #[rstest]
302    #[case(0xc7_u8.to_be_bytes(),123,5u8.to_be_bytes(),[0x12;5])]
303    #[case(0xc8_u8.to_be_bytes(),123,65535_u16.to_be_bytes(),[0x34;65535])]
304    #[case(0xc9_u8.to_be_bytes(),123,65536_u32.to_be_bytes(),[0x56;65536])]
305    fn encode_ext_sized<M: AsRef<[u8]>, S: AsRef<[u8]>, D: AsRef<[u8]>>(
306        #[case] marker: M,
307        #[case] ty: i8,
308        #[case] size: S,
309        #[case] data: D,
310    ) {
311        let expected = marker
312            .as_ref()
313            .iter()
314            .chain(size.as_ref())
315            .chain(ty.to_be_bytes().iter())
316            .chain(data.as_ref())
317            .cloned()
318            .collect::<Vec<_>>();
319
320        let encoder = ExtensionRef::new(ty, data.as_ref());
321
322        let mut buf = vec![];
323        let n = encoder.encode(&mut buf).unwrap();
324
325        assert_eq!(&buf, &expected);
326        assert_eq!(n, expected.len());
327    }
328
329    #[rstest]
330    #[case(Format::FixExt1.as_byte(),  5_i8, [0x12])]
331    #[case(Format::FixExt2.as_byte(), -1_i8, [0x34, 0x56])]
332    #[case(Format::FixExt4.as_byte(), 42_i8, [0xde, 0xad, 0xbe, 0xef])]
333    #[case(Format::FixExt8.as_byte(), -7_i8, [0xAA; 8])]
334    #[case(Format::FixExt16.as_byte(), 7_i8, [0x55; 16])]
335    fn decode_ext_fixed<E: AsRef<[u8]>>(#[case] marker: u8, #[case] ty: i8, #[case] data: E) {
336        // Buffer: [FixExtN marker][type][data..]
337        let buf = core::iter::once(marker)
338            .chain(core::iter::once(ty as u8))
339            .chain(data.as_ref().iter().cloned())
340            .collect::<Vec<u8>>();
341
342        let (ext, rest) = ExtensionRef::decode(&buf).unwrap();
343        assert_eq!(ext.r#type, ty);
344        assert_eq!(ext.data, data.as_ref());
345        assert!(rest.is_empty());
346    }
347
348    #[rstest]
349    #[case(Format::Ext8, 42_i8, 5u8.to_be_bytes(), [0x11;5])] // small: Ext8
350    #[case(Format::Ext16, -7_i8,   300u16.to_be_bytes(), [0xAA;300])] // medium: Ext16 (>255)
351    #[case(Format::Ext32, 7_i8, 70000u32.to_be_bytes(), [0x55;70000])] // large: Ext32 (>65535)
352    fn decode_ext_sized<S: AsRef<[u8]>, D: AsRef<[u8]>>(
353        #[case] format: Format,
354        #[case] ty: i8,
355        #[case] size: S,
356        #[case] data: D,
357    ) {
358        // MessagePack ext variable-length layout: [format][length][type][data]
359        let buf = format
360            .as_slice()
361            .iter()
362            .chain(size.as_ref())
363            .chain(ty.to_be_bytes().iter())
364            .chain(data.as_ref())
365            .cloned()
366            .collect::<Vec<_>>();
367
368        let (ext, rest) = ExtensionRef::decode(&buf).unwrap();
369        assert_eq!(ext.r#type, ty);
370        assert_eq!(ext.data, data.as_ref());
371        assert!(rest.is_empty());
372    }
373
374    #[rstest]
375    fn fixed_extension_roundtrip() {
376        let data = [1u8, 2, 3, 4];
377        let ext = FixedExtension::<8>::new(5, &data).unwrap();
378        let mut buf = vec![];
379        ext.encode(&mut buf).unwrap();
380        let (decoded, rest) = FixedExtension::<8>::decode(&buf).unwrap();
381        assert_eq!(decoded.r#type, 5);
382        assert_eq!(decoded.data(), &data);
383        assert!(rest.is_empty());
384    }
385}