messagepack_serde/value/
extension.rs

1use messagepack_core::{Format, encode::ExtensionEncoder, io::IoWrite};
2use serde::{
3    Deserialize, Serialize, Serializer,
4    de::Visitor,
5    ser::{self, SerializeSeq},
6};
7
8use crate::ser::{CoreError, Error};
9
10pub(crate) const EXTENSION_STRUCT_NAME: &str = "$__MSGPACK_EXTENSION_STRUCT";
11
12/// Represents `ext` format. This is also available with `no_std` to borrow data.
13#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord)]
14pub struct ExtensionRef<'a> {
15    pub kind: i8,
16    pub data: &'a [u8],
17}
18
19impl<'a> ExtensionRef<'a> {
20    pub fn new(kind: i8, data: &'a [u8]) -> Self {
21        Self { kind, data }
22    }
23}
24
25pub(crate) struct SerializeExt<'a, W> {
26    writer: &'a mut W,
27    length: &'a mut usize,
28}
29
30impl<W> AsMut<Self> for SerializeExt<'_, W> {
31    fn as_mut(&mut self) -> &mut Self {
32        self
33    }
34}
35
36impl<'a, W> SerializeExt<'a, W> {
37    pub fn new(writer: &'a mut W, length: &'a mut usize) -> Self {
38        Self { writer, length }
39    }
40}
41
42impl<W: IoWrite> SerializeExt<'_, W> {
43    fn unexpected(&self) -> Error<W::Error> {
44        ser::Error::custom("unexpected value")
45    }
46}
47
48impl<'a, 'b, W> ser::Serializer for &'a mut SerializeExt<'b, W>
49where
50    'b: 'a,
51    W: IoWrite,
52{
53    type Ok = ();
54
55    type Error = Error<W::Error>;
56
57    type SerializeSeq = SerializeExtSeq<'a, 'b, W>;
58
59    type SerializeTuple = serde::ser::Impossible<Self::Ok, Self::Error>;
60
61    type SerializeTupleStruct = serde::ser::Impossible<Self::Ok, Self::Error>;
62
63    type SerializeTupleVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
64
65    type SerializeMap = serde::ser::Impossible<Self::Ok, Self::Error>;
66
67    type SerializeStruct = serde::ser::Impossible<Self::Ok, Self::Error>;
68
69    type SerializeStructVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
70
71    fn serialize_bool(self, _: bool) -> Result<Self::Ok, Self::Error> {
72        Err(self.unexpected())
73    }
74
75    fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> {
76        self.serialize_bytes(&v.to_be_bytes())
77    }
78
79    fn serialize_i16(self, _v: i16) -> Result<Self::Ok, Self::Error> {
80        Err(self.unexpected())
81    }
82
83    fn serialize_i32(self, _v: i32) -> Result<Self::Ok, Self::Error> {
84        Err(self.unexpected())
85    }
86
87    fn serialize_i64(self, _v: i64) -> Result<Self::Ok, Self::Error> {
88        Err(self.unexpected())
89    }
90
91    fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
92        self.serialize_bytes(&v.to_be_bytes())
93    }
94
95    fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
96        self.serialize_bytes(&v.to_be_bytes())
97    }
98
99    fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> {
100        self.serialize_bytes(&v.to_be_bytes())
101    }
102
103    fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
104        self.serialize_bytes(&v.to_be_bytes())
105    }
106
107    fn serialize_f32(self, _v: f32) -> Result<Self::Ok, Self::Error> {
108        Err(self.unexpected())
109    }
110
111    fn serialize_f64(self, _v: f64) -> Result<Self::Ok, Self::Error> {
112        Err(self.unexpected())
113    }
114
115    fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
116        Err(self.unexpected())
117    }
118
119    fn serialize_str(self, _v: &str) -> Result<Self::Ok, Self::Error> {
120        Err(self.unexpected())
121    }
122
123    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
124        self.writer.write_bytes(v).map_err(CoreError::Io)?;
125        *self.length += v.len();
126        Ok(())
127    }
128
129    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
130        Err(self.unexpected())
131    }
132
133    fn serialize_some<T>(self, _value: &T) -> Result<Self::Ok, Self::Error>
134    where
135        T: ?Sized + Serialize,
136    {
137        Err(self.unexpected())
138    }
139
140    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
141        Err(self.unexpected())
142    }
143
144    fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
145        Err(self.unexpected())
146    }
147
148    fn serialize_unit_variant(
149        self,
150        _name: &'static str,
151        _variant_index: u32,
152        _variant: &'static str,
153    ) -> Result<Self::Ok, Self::Error> {
154        Err(self.unexpected())
155    }
156
157    fn serialize_newtype_struct<T>(
158        self,
159        _name: &'static str,
160        value: &T,
161    ) -> Result<Self::Ok, Self::Error>
162    where
163        T: ?Sized + Serialize,
164    {
165        value.serialize(self)
166    }
167
168    fn serialize_newtype_variant<T>(
169        self,
170        _name: &'static str,
171        _variant_index: u32,
172        _variant: &'static str,
173        _value: &T,
174    ) -> Result<Self::Ok, Self::Error>
175    where
176        T: ?Sized + Serialize,
177    {
178        Err(self.unexpected())
179    }
180
181    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
182        Ok(SerializeExtSeq::new(self))
183    }
184
185    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
186        Err(self.unexpected())
187    }
188
189    fn serialize_tuple_struct(
190        self,
191        _name: &'static str,
192        _len: usize,
193    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
194        Err(self.unexpected())
195    }
196
197    fn serialize_tuple_variant(
198        self,
199        _name: &'static str,
200        _variant_index: u32,
201        _variant: &'static str,
202        _len: usize,
203    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
204        Err(self.unexpected())
205    }
206
207    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
208        Err(self.unexpected())
209    }
210
211    fn serialize_struct(
212        self,
213        _name: &'static str,
214        _len: usize,
215    ) -> Result<Self::SerializeStruct, Self::Error> {
216        Err(self.unexpected())
217    }
218
219    fn serialize_struct_variant(
220        self,
221        _name: &'static str,
222        _variant_index: u32,
223        _variant: &'static str,
224        _len: usize,
225    ) -> Result<Self::SerializeStructVariant, Self::Error> {
226        Err(self.unexpected())
227    }
228}
229
230pub struct SerializeExtSeq<'a, 'b, W> {
231    ser: &'a mut SerializeExt<'b, W>,
232}
233
234impl<'a, 'b, W> SerializeExtSeq<'a, 'b, W> {
235    pub(crate) fn new(ser: &'a mut SerializeExt<'b, W>) -> Self {
236        Self { ser }
237    }
238}
239
240impl<'a, 'b, W> ser::SerializeSeq for SerializeExtSeq<'a, 'b, W>
241where
242    'b: 'a,
243    W: IoWrite,
244{
245    type Ok = ();
246    type Error = Error<W::Error>;
247    fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
248    where
249        T: ?Sized + Serialize,
250    {
251        value.serialize(self.ser.as_mut())
252    }
253    fn end(self) -> Result<Self::Ok, Self::Error> {
254        Ok(())
255    }
256}
257
258struct ExtInner<'a> {
259    kind: i8,
260    data: &'a [u8],
261}
262
263impl ser::Serialize for ExtInner<'_> {
264    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
265    where
266        S: Serializer,
267    {
268        let encoder = ExtensionEncoder::new(self.kind, self.data);
269        let format = encoder
270            .to_format::<()>()
271            .map_err(|_| ser::Error::custom("Invalid data length"))?;
272
273        let mut seq = serializer.serialize_seq(Some(4))?;
274
275        seq.serialize_element(serde_bytes::Bytes::new(&format.as_slice()))?;
276
277        const EMPTY: &[u8] = &[];
278
279        match format {
280            messagepack_core::Format::FixExt1 => {
281                seq.serialize_element(serde_bytes::Bytes::new(EMPTY))
282            }
283            messagepack_core::Format::FixExt2 => {
284                seq.serialize_element(serde_bytes::Bytes::new(EMPTY))
285            }
286            messagepack_core::Format::FixExt4 => {
287                seq.serialize_element(serde_bytes::Bytes::new(EMPTY))
288            }
289            messagepack_core::Format::FixExt8 => {
290                seq.serialize_element(serde_bytes::Bytes::new(EMPTY))
291            }
292            messagepack_core::Format::FixExt16 => {
293                seq.serialize_element(serde_bytes::Bytes::new(EMPTY))
294            }
295            messagepack_core::Format::Ext8 => {
296                let len = (self.data.len() as u8).to_be_bytes();
297                seq.serialize_element(serde_bytes::Bytes::new(&len))
298            }
299            messagepack_core::Format::Ext16 => {
300                let len = (self.data.len() as u16).to_be_bytes();
301                seq.serialize_element(serde_bytes::Bytes::new(&len))
302            }
303            messagepack_core::Format::Ext32 => {
304                let len = (self.data.len() as u32).to_be_bytes();
305                seq.serialize_element(serde_bytes::Bytes::new(&len))
306            }
307            _ => unreachable!(),
308        }?;
309        seq.serialize_element(serde_bytes::Bytes::new(&self.kind.to_be_bytes()))?;
310        seq.serialize_element(serde_bytes::Bytes::new(self.data))?;
311
312        seq.end()
313    }
314}
315
316impl ser::Serialize for ExtensionRef<'_> {
317    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
318    where
319        S: Serializer,
320    {
321        serializer.serialize_newtype_struct(
322            EXTENSION_STRUCT_NAME,
323            &ExtInner {
324                kind: self.kind,
325                data: self.data,
326            },
327        )
328    }
329}
330
331pub(crate) struct DeserializeExt<'de> {
332    data_len: usize,
333    pub(crate) input: &'de [u8],
334}
335
336impl AsMut<Self> for DeserializeExt<'_> {
337    fn as_mut(&mut self) -> &mut Self {
338        self
339    }
340}
341
342impl<'de> DeserializeExt<'de> {
343    pub(crate) fn new(format: Format, input: &'de [u8]) -> Result<Self, crate::de::Error> {
344        let (data_len, rest) = match format {
345            Format::FixExt1 => (1, input),
346            Format::FixExt2 => (2, input),
347            Format::FixExt4 => (4, input),
348            Format::FixExt8 => (8, input),
349            Format::FixExt16 => (16, input),
350            Format::Ext8 => {
351                let (first, rest) = input
352                    .split_first_chunk::<1>()
353                    .ok_or(messagepack_core::decode::Error::EofData)?;
354                let val = u8::from_be_bytes(*first);
355                (val.into(), rest)
356            }
357            Format::Ext16 => {
358                let (first, rest) = input
359                    .split_first_chunk::<2>()
360                    .ok_or(messagepack_core::decode::Error::EofData)?;
361                let val = u16::from_be_bytes(*first);
362                (val.into(), rest)
363            }
364            Format::Ext32 => {
365                let (first, rest) = input
366                    .split_first_chunk::<4>()
367                    .ok_or(messagepack_core::decode::Error::EofData)?;
368                let val = u32::from_be_bytes(*first);
369                (val as usize, rest)
370            }
371            _ => return Err(messagepack_core::decode::Error::UnexpectedFormat.into()),
372        };
373        Ok(DeserializeExt {
374            data_len,
375            input: rest,
376        })
377    }
378}
379
380impl<'de> serde::Deserializer<'de> for &mut DeserializeExt<'de> {
381    type Error = crate::de::Error;
382
383    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
384    where
385        V: Visitor<'de>,
386    {
387        Err(crate::de::Error::AnyIsUnsupported)
388    }
389
390    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
391    where
392        V: Visitor<'de>,
393    {
394        let (first, rest) = self
395            .input
396            .split_first_chunk::<1>()
397            .ok_or(messagepack_core::decode::Error::EofData)?;
398
399        let val = i8::from_be_bytes(*first);
400        self.input = rest;
401        visitor.visit_i8(val)
402    }
403
404    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
405    where
406        V: Visitor<'de>,
407    {
408        let (data, rest) = self
409            .input
410            .split_at_checked(self.data_len)
411            .ok_or(messagepack_core::decode::Error::EofData)?;
412        self.input = rest;
413        visitor.visit_borrowed_bytes(data)
414    }
415
416    fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
417    where
418        V: Visitor<'de>,
419    {
420        visitor.visit_seq(&mut self)
421    }
422
423    fn deserialize_newtype_struct<V>(
424        self,
425        _name: &'static str,
426        visitor: V,
427    ) -> Result<V::Value, Self::Error>
428    where
429        V: Visitor<'de>,
430    {
431        visitor.visit_newtype_struct(self)
432    }
433
434    serde::forward_to_deserialize_any! {
435        bool i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
436        byte_buf option unit unit_struct tuple
437        tuple_struct map struct enum identifier ignored_any
438    }
439}
440
441impl<'de> serde::de::SeqAccess<'de> for &mut DeserializeExt<'de> {
442    type Error = crate::de::Error;
443    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
444    where
445        T: serde::de::DeserializeSeed<'de>,
446    {
447        seed.deserialize(self.as_mut()).map(Some)
448    }
449}
450
451impl<'de> Deserialize<'de> for ExtensionRef<'de> {
452    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
453    where
454        D: serde::Deserializer<'de>,
455    {
456        struct ExtensionVisitor;
457
458        impl<'de> Visitor<'de> for ExtensionVisitor {
459            type Value = ExtensionRef<'de>;
460            fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
461                formatter.write_str("expect extension")
462            }
463
464            fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
465            where
466                D: serde::Deserializer<'de>,
467            {
468                deserializer.deserialize_seq(self)
469            }
470
471            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
472            where
473                A: serde::de::SeqAccess<'de>,
474            {
475                let kind = seq
476                    .next_element::<i8>()?
477                    .ok_or(serde::de::Error::custom("expect i8"))?;
478
479                let data = seq
480                    .next_element::<&[u8]>()?
481                    .ok_or(serde::de::Error::custom("expect [u8]"))?;
482
483                Ok(ExtensionRef::new(kind, data))
484            }
485        }
486        deserializer.deserialize_any(ExtensionVisitor)
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use messagepack_core::SliceWriter;
494    use rstest::rstest;
495
496    #[rstest]
497    fn encode_ext() {
498        let mut buf = [0_u8; 3];
499        let mut writer = SliceWriter::from_slice(&mut buf);
500        let mut length = 0;
501        let mut ser = SerializeExt::new(&mut writer, &mut length);
502
503        let kind: i8 = 123;
504
505        let ext = ExtensionRef::new(kind, &[0x12]);
506
507        ext.serialize(&mut ser).unwrap();
508
509        assert_eq!(length, 3);
510        assert_eq!(buf, [0xd4, kind.to_be_bytes()[0], 0x12]);
511    }
512
513    #[rstest]
514    fn decode_ext() {
515        let buf = [0xd6, 0xff, 0x00, 0x00, 0x00, 0x00]; // timestamp ext type
516
517        let ext = crate::from_slice::<ExtensionRef>(&buf).unwrap();
518        assert_eq!(ext.kind, -1);
519        let seconds = u32::from_be_bytes(ext.data.try_into().unwrap());
520        assert_eq!(seconds, 0);
521    }
522}