Skip to main content

messagepack_serde/extension/
mod.rs

1//! Extension family helper
2
3pub(crate) mod de;
4pub(crate) mod ser;
5
6use serde::{Serialize, Serializer, de::Visitor};
7pub(crate) const EXTENSION_STRUCT_NAME: &str = "$__MSGPACK_EXTENSION_STRUCT";
8
9#[cfg(feature = "alloc")]
10mod owned;
11#[cfg(feature = "alloc")]
12pub use owned::ext_owned;
13
14mod timestamp;
15pub use timestamp::{timestamp32, timestamp64, timestamp96};
16
17struct Bytes<'a>(pub &'a [u8]);
18impl Serialize for Bytes<'_> {
19    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
20    where
21        S: Serializer,
22    {
23        serializer.serialize_bytes(self.0)
24    }
25}
26
27struct ExtInner<'a> {
28    kind: i8,
29    data: &'a [u8],
30}
31
32impl Serialize for ExtInner<'_> {
33    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34    where
35        S: Serializer,
36    {
37        use messagepack_core::extension::ExtensionRef;
38        use serde::ser::{self, SerializeSeq};
39        let encoder = ExtensionRef::new(self.kind, self.data);
40        let format = encoder
41            .to_format::<core::convert::Infallible>()
42            .map_err(|_| ser::Error::custom("Invalid data length"))?;
43
44        let mut seq = serializer.serialize_seq(None)?;
45
46        seq.serialize_element(&Bytes(&format.as_slice()))?;
47
48        match format {
49            messagepack_core::Format::FixExt1
50            | messagepack_core::Format::FixExt2
51            | messagepack_core::Format::FixExt4
52            | messagepack_core::Format::FixExt8
53            | messagepack_core::Format::FixExt16 => {}
54
55            messagepack_core::Format::Ext8 => {
56                let len = self.data.len() as u8;
57                seq.serialize_element(&len)?;
58            }
59            messagepack_core::Format::Ext16 => {
60                let len = self.data.len() as u16;
61                seq.serialize_element(&len)?;
62            }
63            messagepack_core::Format::Ext32 => {
64                let len = self.data.len() as u32;
65                seq.serialize_element(&len)?;
66            }
67            _ => return Err(ser::Error::custom("unexpected format")),
68        };
69        seq.serialize_element(&self.kind)?;
70        seq.serialize_element(&Bytes(self.data))?;
71
72        seq.end()
73    }
74}
75
76/// De/Serialize [messagepack_core::extension::ExtensionRef]
77///
78/// ## Example
79///
80/// ```rust
81/// use serde::{Serialize,Deserialize};
82/// use messagepack_core::extension::ExtensionRef;
83///
84/// #[derive(Debug, Serialize, Deserialize, PartialEq)]
85/// #[serde(transparent)]
86/// struct WrapRef<'a>(
87///     #[serde(with = "messagepack_serde::extension::ext_ref", borrow)] ExtensionRef<'a>,
88/// );
89///
90/// # fn main() {
91///
92/// let ext = WrapRef(
93///     ExtensionRef::new(10,&[0,1,2,3,4,5])
94/// );
95/// let mut buf = [0u8; 9];
96/// messagepack_serde::to_slice(&ext, &mut buf).unwrap();
97///
98/// let result = messagepack_serde::from_slice::<WrapRef<'_>>(&buf).unwrap();
99/// assert_eq!(ext,result);
100///
101/// # }
102/// ```
103pub mod ext_ref {
104    use super::*;
105    use serde::de;
106
107    /// Serialize [messagepack_core::extension::ExtensionRef]
108    pub fn serialize<S>(
109        ext: &messagepack_core::extension::ExtensionRef<'_>,
110        serializer: S,
111    ) -> Result<S::Ok, S::Error>
112    where
113        S: serde::Serializer,
114    {
115        serializer.serialize_newtype_struct(
116            EXTENSION_STRUCT_NAME,
117            &ExtInner {
118                kind: ext.r#type,
119                data: ext.data,
120            },
121        )
122    }
123
124    /// Deserialize [messagepack_core::extension::ExtensionRef]
125    pub fn deserialize<'de, D>(
126        deserializer: D,
127    ) -> Result<messagepack_core::extension::ExtensionRef<'de>, D::Error>
128    where
129        D: serde::Deserializer<'de>,
130    {
131        struct ExtensionVisitor;
132
133        impl<'de> Visitor<'de> for ExtensionVisitor {
134            type Value = messagepack_core::extension::ExtensionRef<'de>;
135            fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
136                formatter.write_str("expect extension")
137            }
138
139            fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
140            where
141                D: de::Deserializer<'de>,
142            {
143                deserializer.deserialize_seq(self)
144            }
145
146            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
147            where
148                A: serde::de::SeqAccess<'de>,
149            {
150                let kind = seq
151                    .next_element::<i8>()?
152                    .ok_or(de::Error::missing_field("extension type missing"))?;
153
154                let data = seq
155                    .next_element::<&[u8]>()?
156                    .ok_or(de::Error::missing_field("extension data missing"))?;
157
158                Ok(messagepack_core::extension::ExtensionRef::new(kind, data))
159            }
160        }
161        deserializer.deserialize_seq(ExtensionVisitor)
162    }
163}
164
165/// De/Serialize [messagepack_core::extension::FixedExtension]
166///
167/// ## Example
168///
169/// ```rust
170/// use serde::{Serialize,Deserialize};
171/// use messagepack_core::extension::FixedExtension;
172///
173/// #[derive(Debug, Serialize, Deserialize, PartialEq)]
174/// #[serde(transparent)]
175/// struct WrapRef(
176///     #[serde(with = "messagepack_serde::extension::ext_fixed")] FixedExtension<16>,
177/// );
178///
179/// # fn main() {
180///
181/// let ext = WrapRef(
182///     FixedExtension::new(10,&[0,1,2,3,4,5]).unwrap()
183/// );
184/// let mut buf = [0u8; 9];
185/// messagepack_serde::to_slice(&ext, &mut buf).unwrap();
186///
187/// let result = messagepack_serde::from_slice::<WrapRef>(&buf).unwrap();
188/// assert_eq!(ext,result);
189///
190/// # }
191/// ```
192pub mod ext_fixed {
193    use super::*;
194    use serde::{Deserialize, de};
195
196    /// Serialize [messagepack_core::extension::FixedExtension]
197    pub fn serialize<const N: usize, S>(
198        ext: &messagepack_core::extension::FixedExtension<N>,
199        serializer: S,
200    ) -> Result<S::Ok, S::Error>
201    where
202        S: serde::Serializer,
203    {
204        super::ext_ref::serialize(&ext.as_ref(), serializer)
205    }
206
207    /// Deserialize [messagepack_core::extension::FixedExtension]
208    pub fn deserialize<'de, const N: usize, D>(
209        deserializer: D,
210    ) -> Result<messagepack_core::extension::FixedExtension<N>, D::Error>
211    where
212        D: serde::Deserializer<'de>,
213    {
214        struct Data<const N: usize> {
215            len: usize,
216            buf: [u8; N],
217        }
218        impl<'de, const N: usize> Deserialize<'de> for Data<N> {
219            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
220            where
221                D: de::Deserializer<'de>,
222            {
223                struct DataVisitor<const N: usize>;
224                impl<'de, const N: usize> Visitor<'de> for DataVisitor<N> {
225                    type Value = Data<N>;
226                    fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
227                        formatter.write_str("expect extension")
228                    }
229
230                    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
231                    where
232                        E: de::Error,
233                    {
234                        let len = v.len();
235
236                        if len > N {
237                            return Err(de::Error::invalid_length(len, &self));
238                        }
239
240                        let mut buf = [0; N];
241                        buf[..len].copy_from_slice(v);
242                        Ok(Data { len, buf })
243                    }
244                }
245                deserializer.deserialize_bytes(DataVisitor)
246            }
247        }
248
249        struct ExtensionVisitor<const N: usize>;
250        impl<'de, const N: usize> Visitor<'de> for ExtensionVisitor<N> {
251            type Value = messagepack_core::extension::FixedExtension<N>;
252            fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
253                formatter.write_str("expect extension")
254            }
255
256            fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
257            where
258                D: de::Deserializer<'de>,
259            {
260                deserializer.deserialize_seq(self)
261            }
262
263            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
264            where
265                A: de::SeqAccess<'de>,
266            {
267                let kind = seq
268                    .next_element::<i8>()?
269                    .ok_or(serde::de::Error::missing_field("extension type missing"))?;
270                let data = seq
271                    .next_element::<Data<N>>()?
272                    .ok_or(de::Error::missing_field("extension data missing"))?;
273
274                let ext = messagepack_core::extension::FixedExtension::new_fixed_with_prefix(
275                    kind, data.len, data.buf,
276                )
277                .map_err(|_| de::Error::invalid_length(data.len, &"length is too long"))?;
278                Ok(ext)
279            }
280        }
281
282        deserializer.deserialize_seq(ExtensionVisitor)
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    use messagepack_core::extension::{ExtensionRef, FixedExtension};
291    use rstest::rstest;
292    use serde::{Deserialize, Serialize};
293
294    #[derive(Debug, Serialize, Deserialize)]
295    struct WrapRef<'a>(
296        #[serde(with = "ext_ref", borrow)] messagepack_core::extension::ExtensionRef<'a>,
297    );
298
299    #[rstest]
300    fn encode_ext_ref() {
301        let mut buf = [0_u8; 3];
302
303        let kind: i8 = 123;
304
305        let ext = WrapRef(ExtensionRef::new(kind, &[0x12]));
306        let length = crate::to_slice(&ext, &mut buf).unwrap();
307
308        assert_eq!(length, 3);
309        assert_eq!(buf, [0xd4, kind.to_be_bytes()[0], 0x12]);
310    }
311
312    #[rstest]
313    fn decode_ext_ref() {
314        let buf = [0xd6, 0xff, 0x00, 0x00, 0x00, 0x00]; // timestamp ext type
315
316        let ext = crate::from_slice::<WrapRef<'_>>(&buf).unwrap().0;
317        assert_eq!(ext.r#type, -1);
318        let seconds = u32::from_be_bytes(ext.data.try_into().unwrap());
319        assert_eq!(seconds, 0);
320    }
321
322    #[derive(Debug, Serialize, Deserialize)]
323    struct WrapFixed<const N: usize>(
324        #[serde(with = "ext_fixed")] messagepack_core::extension::FixedExtension<N>,
325    );
326
327    #[rstest]
328    fn encode_ext_fixed() {
329        let mut buf = [0u8; 3];
330        let kind: i8 = 123;
331
332        let ext = WrapFixed(FixedExtension::new_fixed(kind, [0x12]));
333        let length = crate::to_slice(&ext, &mut buf).unwrap();
334
335        assert_eq!(length, 3);
336        assert_eq!(buf, [0xd4, kind.to_be_bytes()[0], 0x12]);
337    }
338
339    const TIMESTAMP32: &[u8] = &[0xd6, 0xff, 0x00, 0x00, 0x00, 0x00];
340
341    #[rstest]
342    fn decode_ext_fixed_bigger_will_success() {
343        let ext = crate::from_slice::<WrapFixed<6>>(TIMESTAMP32).unwrap().0;
344        assert_eq!(ext.r#type, -1);
345        assert_eq!(ext.as_slice(), &TIMESTAMP32[2..])
346    }
347
348    #[rstest]
349    #[should_panic]
350    fn decode_ext_fixed_smaller_will_failed() {
351        let _ = crate::from_slice::<WrapFixed<3>>(TIMESTAMP32).unwrap();
352    }
353}