messagepack_core/encode/
extension.rs

1use super::{Encode, Error, Result};
2use crate::{
3    formats::{Format, U8_MAX, U16_MAX, U32_MAX},
4    io::IoWrite,
5};
6
7#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
8pub struct ExtensionEncoder<'data> {
9    r#type: i8,
10    data: &'data [u8],
11}
12
13impl<'data> ExtensionEncoder<'data> {
14    pub fn new(r#type: i8, data: &'data [u8]) -> Self {
15        Self { r#type, data }
16    }
17
18    pub fn to_format<E>(&self) -> Result<Format, E> {
19        let format = match self.data.len() {
20            1 => Format::FixExt1,
21            2 => Format::FixExt2,
22            4 => Format::FixExt4,
23            8 => Format::FixExt8,
24            16 => Format::FixExt16,
25            0..U8_MAX => Format::Ext8,
26            U8_MAX..U16_MAX => Format::Ext16,
27            U16_MAX..U32_MAX => Format::Ext32,
28            _ => return Err(Error::InvalidFormat),
29        };
30        Ok(format)
31    }
32}
33
34impl<W: IoWrite> Encode<W> for ExtensionEncoder<'_> {
35    fn encode(&self, writer: &mut W) -> Result<usize, W::Error> {
36        let data_len = self.data.len();
37        let type_byte = self.r#type.to_be_bytes()[0];
38
39        match data_len {
40            1 => {
41                writer.write_bytes(&[Format::FixExt1.as_byte(), type_byte])?;
42                writer.write_bytes(self.data)?;
43
44                Ok(2 + data_len)
45            }
46            2 => {
47                writer.write_bytes(&[Format::FixExt2.as_byte(), type_byte])?;
48                writer.write_bytes(self.data)?;
49
50                Ok(2 + data_len)
51            }
52            4 => {
53                writer.write_bytes(&[Format::FixExt4.as_byte(), type_byte])?;
54                writer.write_bytes(self.data)?;
55                Ok(2 + data_len)
56            }
57            8 => {
58                writer.write_bytes(&[Format::FixExt8.as_byte(), type_byte])?;
59                writer.write_bytes(self.data)?;
60
61                Ok(2 + data_len)
62            }
63            16 => {
64                writer.write_bytes(&[Format::FixExt16.as_byte(), type_byte])?;
65                writer.write_bytes(self.data)?;
66
67                Ok(2 + data_len)
68            }
69            0..=0xff => {
70                let cast = data_len as u8;
71                writer.write_bytes(&[Format::Ext8.as_byte(), cast, type_byte])?;
72                writer.write_bytes(self.data)?;
73
74                Ok(3 + data_len)
75            }
76            0x100..=U16_MAX => {
77                let cast = (data_len as u16).to_be_bytes();
78                writer.write_bytes(&[Format::Ext16.as_byte(), cast[0], cast[1], type_byte])?;
79                writer.write_bytes(self.data)?;
80
81                Ok(4 + data_len)
82            }
83            0x10000..=U32_MAX => {
84                let cast = (data_len as u32).to_be_bytes();
85                writer.write_bytes(&[
86                    Format::Ext32.as_byte(),
87                    cast[0],
88                    cast[1],
89                    cast[2],
90                    cast[3],
91                    type_byte,
92                ])?;
93                writer.write_bytes(self.data)?;
94
95                Ok(6 + data_len)
96            }
97            _ => Err(Error::InvalidFormat),
98        }
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use rstest::rstest;
106
107    #[rstest]
108    #[case(0xd4,123,[0x12])]
109    #[case(0xd5,123,[0x12,0x34])]
110    #[case(0xd6,123,[0x12,0x34,0x56,0x78])]
111    #[case(0xd7,123,[0x12;8])]
112    #[case(0xd8,123,[0x12;16])]
113    fn encode_ext_fixed<D: AsRef<[u8]>>(#[case] marker: u8, #[case] ty: i8, #[case] data: D) {
114        let expected = marker
115            .to_be_bytes()
116            .iter()
117            .chain(ty.to_be_bytes().iter())
118            .chain(data.as_ref())
119            .cloned()
120            .collect::<Vec<_>>();
121
122        let encoder = ExtensionEncoder::new(ty, data.as_ref());
123
124        let mut buf = vec![];
125        let n = encoder.encode(&mut buf).unwrap();
126
127        assert_eq!(&buf, &expected);
128        assert_eq!(n, expected.len());
129    }
130
131    #[rstest]
132    #[case(0xc7_u8.to_be_bytes(),123,5u8.to_be_bytes(),[0x12;5])]
133    #[case(0xc8_u8.to_be_bytes(),123,65535_u16.to_be_bytes(),[0x34;65535])]
134    #[case(0xc9_u8.to_be_bytes(),123,65536_u32.to_be_bytes(),[0x56;65536])]
135    fn encode_ext_sized<M: AsRef<[u8]>, S: AsRef<[u8]>, D: AsRef<[u8]>>(
136        #[case] marker: M,
137        #[case] ty: i8,
138        #[case] size: S,
139        #[case] data: D,
140    ) {
141        let expected = marker
142            .as_ref()
143            .iter()
144            .chain(size.as_ref())
145            .chain(ty.to_be_bytes().iter())
146            .chain(data.as_ref())
147            .cloned()
148            .collect::<Vec<_>>();
149
150        let encoder = ExtensionEncoder::new(ty, data.as_ref());
151
152        let mut buf = vec![];
153        let n = encoder.encode(&mut buf).unwrap();
154
155        assert_eq!(&buf, &expected);
156        assert_eq!(n, expected.len());
157    }
158}