messagepack_serde/value/
extension.rs

1use messagepack_core::{Format, extension::ExtensionRef as CoreExtensionRef, io::IoWrite};
2use serde::{
3    Serialize, Serializer,
4    de::Visitor,
5    ser::{self, SerializeSeq},
6};
7
8use crate::ser::Error;
9
10pub(crate) const EXTENSION_STRUCT_NAME: &str = "$__MSGPACK_EXTENSION_STRUCT";
11
12pub(crate) struct SerializeExt<'a, W> {
13    writer: &'a mut W,
14    length: usize,
15}
16
17impl<W> AsMut<Self> for SerializeExt<'_, W> {
18    fn as_mut(&mut self) -> &mut Self {
19        self
20    }
21}
22
23impl<'a, W> SerializeExt<'a, W> {
24    pub fn new(writer: &'a mut W) -> Self {
25        Self { writer, length: 0 }
26    }
27
28    pub(crate) fn length(&self) -> usize {
29        self.length
30    }
31}
32
33impl<W: IoWrite> SerializeExt<'_, W> {
34    fn unexpected(&self) -> Error<W::Error> {
35        ser::Error::custom("unexpected value")
36    }
37}
38
39impl<'a, 'b, W> ser::Serializer for &'a mut SerializeExt<'b, W>
40where
41    'b: 'a,
42    W: IoWrite,
43{
44    type Ok = ();
45
46    type Error = Error<W::Error>;
47
48    type SerializeSeq = SerializeExtSeq<'a, 'b, W>;
49
50    type SerializeTuple = serde::ser::Impossible<Self::Ok, Self::Error>;
51
52    type SerializeTupleStruct = serde::ser::Impossible<Self::Ok, Self::Error>;
53
54    type SerializeTupleVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
55
56    type SerializeMap = serde::ser::Impossible<Self::Ok, Self::Error>;
57
58    type SerializeStruct = serde::ser::Impossible<Self::Ok, Self::Error>;
59
60    type SerializeStructVariant = serde::ser::Impossible<Self::Ok, Self::Error>;
61
62    fn serialize_bool(self, _: bool) -> Result<Self::Ok, Self::Error> {
63        Err(self.unexpected())
64    }
65
66    fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> {
67        self.serialize_bytes(&v.to_be_bytes())
68    }
69
70    fn serialize_i16(self, _v: i16) -> Result<Self::Ok, Self::Error> {
71        Err(self.unexpected())
72    }
73
74    fn serialize_i32(self, _v: i32) -> Result<Self::Ok, Self::Error> {
75        Err(self.unexpected())
76    }
77
78    fn serialize_i64(self, _v: i64) -> Result<Self::Ok, Self::Error> {
79        Err(self.unexpected())
80    }
81
82    fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
83        self.serialize_bytes(&v.to_be_bytes())
84    }
85
86    fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
87        self.serialize_bytes(&v.to_be_bytes())
88    }
89
90    fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> {
91        self.serialize_bytes(&v.to_be_bytes())
92    }
93
94    fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
95        self.serialize_bytes(&v.to_be_bytes())
96    }
97
98    fn serialize_f32(self, _v: f32) -> Result<Self::Ok, Self::Error> {
99        Err(self.unexpected())
100    }
101
102    fn serialize_f64(self, _v: f64) -> Result<Self::Ok, Self::Error> {
103        Err(self.unexpected())
104    }
105
106    fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
107        Err(self.unexpected())
108    }
109
110    fn serialize_str(self, _v: &str) -> Result<Self::Ok, Self::Error> {
111        Err(self.unexpected())
112    }
113
114    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
115        self.writer
116            .write(v)
117            .map_err(messagepack_core::encode::Error::Io)?;
118        self.length += v.len();
119        Ok(())
120    }
121
122    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
123        Err(self.unexpected())
124    }
125
126    fn serialize_some<T>(self, _value: &T) -> Result<Self::Ok, Self::Error>
127    where
128        T: ?Sized + Serialize,
129    {
130        Err(self.unexpected())
131    }
132
133    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
134        Err(self.unexpected())
135    }
136
137    fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
138        Err(self.unexpected())
139    }
140
141    fn serialize_unit_variant(
142        self,
143        _name: &'static str,
144        _variant_index: u32,
145        _variant: &'static str,
146    ) -> Result<Self::Ok, Self::Error> {
147        Err(self.unexpected())
148    }
149
150    fn serialize_newtype_struct<T>(
151        self,
152        _name: &'static str,
153        value: &T,
154    ) -> Result<Self::Ok, Self::Error>
155    where
156        T: ?Sized + Serialize,
157    {
158        value.serialize(self)
159    }
160
161    fn serialize_newtype_variant<T>(
162        self,
163        _name: &'static str,
164        _variant_index: u32,
165        _variant: &'static str,
166        _value: &T,
167    ) -> Result<Self::Ok, Self::Error>
168    where
169        T: ?Sized + Serialize,
170    {
171        Err(self.unexpected())
172    }
173
174    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
175        Ok(SerializeExtSeq::new(self))
176    }
177
178    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
179        Err(self.unexpected())
180    }
181
182    fn serialize_tuple_struct(
183        self,
184        _name: &'static str,
185        _len: usize,
186    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
187        Err(self.unexpected())
188    }
189
190    fn serialize_tuple_variant(
191        self,
192        _name: &'static str,
193        _variant_index: u32,
194        _variant: &'static str,
195        _len: usize,
196    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
197        Err(self.unexpected())
198    }
199
200    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
201        Err(self.unexpected())
202    }
203
204    fn serialize_struct(
205        self,
206        _name: &'static str,
207        _len: usize,
208    ) -> Result<Self::SerializeStruct, Self::Error> {
209        Err(self.unexpected())
210    }
211
212    fn serialize_struct_variant(
213        self,
214        _name: &'static str,
215        _variant_index: u32,
216        _variant: &'static str,
217        _len: usize,
218    ) -> Result<Self::SerializeStructVariant, Self::Error> {
219        Err(self.unexpected())
220    }
221
222    fn collect_str<T>(self, _value: &T) -> Result<Self::Ok, Self::Error>
223    where
224        T: ?Sized + core::fmt::Display,
225    {
226        Err(self.unexpected())
227    }
228}
229
230pub struct SerializeExtSeq<'a, 'b, W> {
231    ser: &'a mut SerializeExt<'b, W>,
232}
233
234impl<'a, 'b, W> SerializeExtSeq<'a, 'b, W> {
235    pub(crate) fn new(ser: &'a mut SerializeExt<'b, W>) -> Self {
236        Self { ser }
237    }
238}
239
240impl<'a, 'b, W> ser::SerializeSeq for SerializeExtSeq<'a, 'b, W>
241where
242    'b: 'a,
243    W: IoWrite,
244{
245    type Ok = ();
246    type Error = Error<W::Error>;
247    fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
248    where
249        T: ?Sized + Serialize,
250    {
251        value.serialize(self.ser.as_mut())
252    }
253    fn end(self) -> Result<Self::Ok, Self::Error> {
254        Ok(())
255    }
256}
257
258struct Bytes<'a>(pub &'a [u8]);
259impl ser::Serialize for Bytes<'_> {
260    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
261    where
262        S: Serializer,
263    {
264        serializer.serialize_bytes(self.0)
265    }
266}
267
268struct ExtInner<'a> {
269    kind: i8,
270    data: &'a [u8],
271}
272
273impl ser::Serialize for ExtInner<'_> {
274    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
275    where
276        S: Serializer,
277    {
278        let encoder = CoreExtensionRef::new(self.kind, self.data);
279        let format = encoder
280            .to_format::<core::convert::Infallible>()
281            .map_err(|_| ser::Error::custom("Invalid data length"))?;
282
283        let mut seq = serializer.serialize_seq(Some(4))?;
284
285        seq.serialize_element(&Bytes(&format.as_slice()))?;
286
287        match format {
288            messagepack_core::Format::FixExt1
289            | messagepack_core::Format::FixExt2
290            | messagepack_core::Format::FixExt4
291            | messagepack_core::Format::FixExt8
292            | messagepack_core::Format::FixExt16 => {}
293
294            messagepack_core::Format::Ext8 => {
295                let len = (self.data.len() as u8).to_be_bytes();
296                seq.serialize_element(&Bytes(&len))?;
297            }
298            messagepack_core::Format::Ext16 => {
299                let len = (self.data.len() as u16).to_be_bytes();
300                seq.serialize_element(&Bytes(&len))?;
301            }
302            messagepack_core::Format::Ext32 => {
303                let len = (self.data.len() as u32).to_be_bytes();
304                seq.serialize_element(&Bytes(&len))?;
305            }
306            _ => return Err(ser::Error::custom("unexpected format")),
307        };
308        seq.serialize_element(&Bytes(&self.kind.to_be_bytes()))?;
309        seq.serialize_element(&Bytes(self.data))?;
310
311        seq.end()
312    }
313}
314
315pub(crate) struct DeserializeExt<'de> {
316    data_len: usize,
317    pub(crate) input: &'de [u8],
318}
319
320impl AsMut<Self> for DeserializeExt<'_> {
321    fn as_mut(&mut self) -> &mut Self {
322        self
323    }
324}
325
326impl<'de> DeserializeExt<'de> {
327    pub(crate) fn new(format: Format, input: &'de [u8]) -> Result<Self, crate::de::Error> {
328        let (data_len, rest) = match format {
329            Format::FixExt1 => (1, input),
330            Format::FixExt2 => (2, input),
331            Format::FixExt4 => (4, input),
332            Format::FixExt8 => (8, input),
333            Format::FixExt16 => (16, input),
334            Format::Ext8 => {
335                let (first, rest) = input
336                    .split_first_chunk::<1>()
337                    .ok_or(messagepack_core::decode::Error::EofData)?;
338                let val = u8::from_be_bytes(*first);
339                (val.into(), rest)
340            }
341            Format::Ext16 => {
342                let (first, rest) = input
343                    .split_first_chunk::<2>()
344                    .ok_or(messagepack_core::decode::Error::EofData)?;
345                let val = u16::from_be_bytes(*first);
346                (val.into(), rest)
347            }
348            Format::Ext32 => {
349                let (first, rest) = input
350                    .split_first_chunk::<4>()
351                    .ok_or(messagepack_core::decode::Error::EofData)?;
352                let val = u32::from_be_bytes(*first);
353                (val as usize, rest)
354            }
355            _ => return Err(messagepack_core::decode::Error::UnexpectedFormat.into()),
356        };
357        Ok(DeserializeExt {
358            data_len,
359            input: rest,
360        })
361    }
362}
363
364impl<'de> serde::Deserializer<'de> for &mut DeserializeExt<'de> {
365    type Error = crate::de::Error;
366
367    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
368    where
369        V: Visitor<'de>,
370    {
371        Err(serde::de::Error::custom(
372            "any when deserialize extension is not supported",
373        ))
374    }
375
376    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
377    where
378        V: Visitor<'de>,
379    {
380        let (first, rest) = self
381            .input
382            .split_first_chunk::<1>()
383            .ok_or(messagepack_core::decode::Error::EofData)?;
384
385        let val = i8::from_be_bytes(*first);
386        self.input = rest;
387        visitor.visit_i8(val)
388    }
389
390    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
391    where
392        V: Visitor<'de>,
393    {
394        let (data, rest) = self
395            .input
396            .split_at_checked(self.data_len)
397            .ok_or(messagepack_core::decode::Error::EofData)?;
398        self.input = rest;
399        visitor.visit_borrowed_bytes(data)
400    }
401
402    fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
403    where
404        V: Visitor<'de>,
405    {
406        visitor.visit_seq(&mut self)
407    }
408
409    fn deserialize_newtype_struct<V>(
410        self,
411        _name: &'static str,
412        visitor: V,
413    ) -> Result<V::Value, Self::Error>
414    where
415        V: Visitor<'de>,
416    {
417        visitor.visit_newtype_struct(self)
418    }
419
420    serde::forward_to_deserialize_any! {
421        bool i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
422        byte_buf option unit unit_struct tuple
423        tuple_struct map struct enum identifier ignored_any
424    }
425}
426
427impl<'de> serde::de::SeqAccess<'de> for &mut DeserializeExt<'de> {
428    type Error = crate::de::Error;
429    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
430    where
431        T: serde::de::DeserializeSeed<'de>,
432    {
433        seed.deserialize(self.as_mut()).map(Some)
434    }
435}
436
437/// De/Serialize [messagepack_core::extension::ExtensionRef]
438///
439/// ## Example
440///
441/// ```rust
442/// use serde::{Serialize,Deserialize};
443/// use messagepack_core::extension::ExtensionRef;
444///
445/// #[derive(Debug, Serialize, Deserialize, PartialEq)]
446/// #[serde(transparent)]
447/// struct WrapRef<'a>(
448///     #[serde(with = "messagepack_serde::ext_ref", borrow)] ExtensionRef<'a>,
449/// );
450///
451/// # fn main() {
452///
453/// let ext = WrapRef(
454///     ExtensionRef::new(10,&[0,1,2,3,4,5])
455/// );
456/// let mut buf = [0u8; 9];
457/// messagepack_serde::to_slice(&ext, &mut buf).unwrap();
458///
459/// let result = messagepack_serde::from_slice::<WrapRef<'_>>(&buf).unwrap();
460/// assert_eq!(ext,result);
461///
462/// # }
463/// ```
464pub mod ext_ref {
465    use super::*;
466
467    /// Serialize [messagepack_core::extension::ExtensionRef]
468    pub fn serialize<S>(
469        ext: &messagepack_core::extension::ExtensionRef<'_>,
470        serializer: S,
471    ) -> Result<S::Ok, S::Error>
472    where
473        S: serde::Serializer,
474    {
475        serializer.serialize_newtype_struct(
476            EXTENSION_STRUCT_NAME,
477            &ExtInner {
478                kind: ext.r#type,
479                data: ext.data,
480            },
481        )
482    }
483
484    /// Deserialize [messagepack_core::extension::ExtensionRef]
485    pub fn deserialize<'de, D>(
486        deserializer: D,
487    ) -> Result<messagepack_core::extension::ExtensionRef<'de>, D::Error>
488    where
489        D: serde::Deserializer<'de>,
490    {
491        struct ExtensionVisitor;
492
493        impl<'de> Visitor<'de> for ExtensionVisitor {
494            type Value = messagepack_core::extension::ExtensionRef<'de>;
495            fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
496                formatter.write_str("expect extension")
497            }
498
499            fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
500            where
501                D: serde::Deserializer<'de>,
502            {
503                deserializer.deserialize_seq(self)
504            }
505
506            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
507            where
508                A: serde::de::SeqAccess<'de>,
509            {
510                let kind = seq
511                    .next_element::<i8>()?
512                    .ok_or(serde::de::Error::custom("expect i8"))?;
513
514                let data = seq
515                    .next_element::<&[u8]>()?
516                    .ok_or(serde::de::Error::custom("expect [u8]"))?;
517
518                Ok(messagepack_core::extension::ExtensionRef::new(kind, data))
519            }
520        }
521        deserializer.deserialize_seq(ExtensionVisitor)
522    }
523}
524
525/// De/Serialize [messagepack_core::extension::FixedExtension]
526///
527/// ## Example
528///
529/// ```rust
530/// use serde::{Serialize,Deserialize};
531/// use messagepack_core::extension::FixedExtension;
532///
533/// #[derive(Debug, Serialize, Deserialize, PartialEq)]
534/// #[serde(transparent)]
535/// struct WrapRef(
536///     #[serde(with = "messagepack_serde::ext_fixed")] FixedExtension<16>,
537/// );
538///
539/// # fn main() {
540///
541/// let ext = WrapRef(
542///     FixedExtension::new(10,&[0,1,2,3,4,5]).unwrap()
543/// );
544/// let mut buf = [0u8; 9];
545/// messagepack_serde::to_slice(&ext, &mut buf).unwrap();
546///
547/// let result = messagepack_serde::from_slice::<WrapRef>(&buf).unwrap();
548/// assert_eq!(ext,result);
549///
550/// # }
551/// ```
552pub mod ext_fixed {
553    use serde::de;
554
555    /// Serialize [messagepack_core::extension::FixedExtension]
556    pub fn serialize<const N: usize, S>(
557        ext: &messagepack_core::extension::FixedExtension<N>,
558        serializer: S,
559    ) -> Result<S::Ok, S::Error>
560    where
561        S: serde::Serializer,
562    {
563        super::ext_ref::serialize(&ext.as_ref(), serializer)
564    }
565
566    /// Deserialize [messagepack_core::extension::FixedExtension]
567    pub fn deserialize<'de, const N: usize, D>(
568        deserializer: D,
569    ) -> Result<messagepack_core::extension::FixedExtension<N>, D::Error>
570    where
571        D: serde::Deserializer<'de>,
572    {
573        let r = super::ext_ref::deserialize(deserializer)?;
574
575        let ext = messagepack_core::extension::FixedExtension::new(r.r#type, r.data)
576            .ok_or_else(|| de::Error::custom("extension length is too long"))?;
577        Ok(ext)
578    }
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584    use messagepack_core::extension::{ExtensionRef, FixedExtension};
585    use rstest::rstest;
586    use serde::{Deserialize, Serialize};
587
588    #[derive(Debug, Serialize, Deserialize)]
589    struct WrapRef<'a>(
590        #[serde(with = "ext_ref", borrow)] messagepack_core::extension::ExtensionRef<'a>,
591    );
592
593    #[rstest]
594    fn encode_ext_ref() {
595        let mut buf = [0_u8; 3];
596
597        let kind: i8 = 123;
598
599        let ext = WrapRef(ExtensionRef::new(kind, &[0x12]));
600        let length = crate::to_slice(&ext, &mut buf).unwrap();
601
602        assert_eq!(length, 3);
603        assert_eq!(buf, [0xd4, kind.to_be_bytes()[0], 0x12]);
604    }
605
606    #[rstest]
607    fn decode_ext_ref() {
608        let buf = [0xd6, 0xff, 0x00, 0x00, 0x00, 0x00]; // timestamp ext type
609
610        let ext = crate::from_slice::<WrapRef<'_>>(&buf).unwrap().0;
611        assert_eq!(ext.r#type, -1);
612        let seconds = u32::from_be_bytes(ext.data.try_into().unwrap());
613        assert_eq!(seconds, 0);
614    }
615
616    #[derive(Debug, Serialize, Deserialize)]
617    struct WrapFixed<const N: usize>(
618        #[serde(with = "ext_fixed")] messagepack_core::extension::FixedExtension<N>,
619    );
620
621    #[rstest]
622    fn encode_ext_fixed() {
623        let mut buf = [0u8; 3];
624        let kind: i8 = 123;
625
626        let ext = WrapFixed(FixedExtension::new_fixed(kind, [0x12]));
627        let length = crate::to_slice(&ext, &mut buf).unwrap();
628
629        assert_eq!(length, 3);
630        assert_eq!(buf, [0xd4, kind.to_be_bytes()[0], 0x12]);
631    }
632
633    const TIMESTAMP32: &[u8] = &[0xd6, 0xff, 0x00, 0x00, 0x00, 0x00];
634
635    #[rstest]
636    fn decode_ext_fixed_bigger_will_success() {
637        let ext = crate::from_slice::<WrapFixed<6>>(TIMESTAMP32).unwrap().0;
638        assert_eq!(ext.r#type, -1);
639        assert_eq!(ext.data(), &TIMESTAMP32[2..])
640    }
641
642    #[rstest]
643    #[should_panic]
644    fn decode_ext_fixed_smaller_will_failed() {
645        let _ = crate::from_slice::<WrapFixed<3>>(TIMESTAMP32).unwrap();
646    }
647}