minicbor_serde/
de.rs

1use serde::de::{self, DeserializeSeed, EnumAccess, MapAccess, SeqAccess, VariantAccess, Visitor};
2
3use minicbor::data::Type;
4use minicbor::decode::{Decoder, Error};
5
6use crate::error::DecodeError;
7
8const BREAK: u8 = 0xff;
9
10/// Deserialise a type implementing [`serde::Deserialize`] from the given byte slice.
11pub fn from_slice<'de, T: de::Deserialize<'de>>(b: &'de [u8]) -> Result<T, DecodeError> {
12    T::deserialize(&mut Deserializer::new(b))
13}
14
15/// An implementation of [`serde::Deserializer`] using a [`minicbor::Decoder`].
16#[derive(Debug, Clone)]
17pub struct Deserializer<'de> {
18    decoder: Decoder<'de>
19}
20
21impl<'de> Deserializer<'de> {
22    pub fn new(b: &'de [u8]) -> Self {
23        Self::from(Decoder::new(b))
24    }
25
26    pub fn decoder(&self) -> &Decoder<'de> {
27        &self.decoder
28    }
29
30    pub fn decoder_mut(&mut self) -> &mut Decoder<'de> {
31        &mut self.decoder
32    }
33
34    pub fn into_decoder(self) -> Decoder<'de> {
35        self.decoder
36    }
37
38    // Cf. `Decoder::current`
39    fn current(&self) -> Result<u8, Error> {
40        if let Some(b) = self.decoder.input().get(self.decoder.position()) {
41            return Ok(*b)
42        }
43        Err(Error::end_of_input())
44    }
45
46    // Cf. `Decoder::read`
47    fn read(&mut self) -> Result<u8, Error> {
48        let p = self.decoder.position();
49        if let Some(b) = self.decoder.input().get(p) {
50            self.decoder.set_position(p + 1);
51            return Ok(*b)
52        }
53        Err(Error::end_of_input())
54    }
55}
56
57impl<'de> From<Decoder<'de>> for Deserializer<'de> {
58    fn from(d: Decoder<'de>) -> Self {
59        Self { decoder: d }
60    }
61}
62
63impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
64    type Error = DecodeError;
65
66    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
67        match self.decoder.datatype()? {
68            Type::Bool       => self.deserialize_bool(visitor),
69            Type::U8         => self.deserialize_u8(visitor),
70            Type::U16        => self.deserialize_u16(visitor),
71            Type::U32        => self.deserialize_u32(visitor),
72            Type::U64        => self.deserialize_u64(visitor),
73            Type::I8         => self.deserialize_i8(visitor),
74            Type::I16        => self.deserialize_i16(visitor),
75            Type::I32        => self.deserialize_i32(visitor),
76            Type::I64        => self.deserialize_i64(visitor),
77            Type::F32        => self.deserialize_f32(visitor),
78            Type::F64        => self.deserialize_f64(visitor),
79            Type::Bytes      => visitor.visit_borrowed_bytes(self.decoder.bytes()?),
80            Type::String     => visitor.visit_borrowed_str(self.decoder.str()?),
81            Type::Null       => { self.decoder.skip()?; visitor.visit_none() }
82            Type::Array |
83            Type::ArrayIndef => self.deserialize_seq(visitor),
84            Type::Map |
85            Type::MapIndef   => self.deserialize_map(visitor),
86
87            #[cfg(feature = "half")]
88            Type::F16  => visitor.visit_f32(self.decoder.f16()?),
89
90            #[cfg(not(feature = "half"))]
91            Type::F16  => Err(Error::type_mismatch(Type::F16)
92                .with_message("unexpected type")
93                .at(self.decoder.position())
94                .into()),
95
96            #[cfg(feature = "alloc")]
97            Type::BytesIndef => {
98                let mut buf = alloc::vec::Vec::new();
99                for b in self.decoder.bytes_iter()? {
100                    buf.extend_from_slice(b?)
101                }
102                visitor.visit_byte_buf(buf)
103            }
104
105            #[cfg(feature = "alloc")]
106            Type::StringIndef => {
107                let mut buf = alloc::string::String::new();
108                for b in self.decoder.str_iter()? {
109                    buf += b?
110                }
111                visitor.visit_string(buf)
112            }
113
114            #[cfg(not(feature = "alloc"))]
115            t @ (Type::BytesIndef | Type::StringIndef) =>
116                Err(Error::type_mismatch(t).with_message("unexpected type").at(self.decoder.position()).into()),
117
118            t @ (
119                | Type::Undefined
120                | Type::Tag
121                | Type::Int
122                | Type::Simple
123                | Type::Break
124                | Type::Unknown(_)
125            ) => Err(Error::type_mismatch(t).with_message("unexpected type").at(self.decoder.position()).into())
126        }
127    }
128
129    fn deserialize_bool<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
130        visitor.visit_bool(self.decoder.bool()?)
131    }
132
133    fn deserialize_i8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
134        visitor.visit_i8(self.decoder.i8()?)
135    }
136
137    fn deserialize_i16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
138        visitor.visit_i16(self.decoder.i16()?)
139    }
140
141    fn deserialize_i32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
142        visitor.visit_i32(self.decoder.i32()?)
143    }
144
145    fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
146        visitor.visit_i64(self.decoder.i64()?)
147    }
148
149    fn deserialize_u8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
150        visitor.visit_u8(self.decoder.u8()?)
151    }
152
153    fn deserialize_u16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
154        visitor.visit_u16(self.decoder.u16()?)
155    }
156
157    fn deserialize_u32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
158        visitor.visit_u32(self.decoder.u32()?)
159    }
160
161    fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
162        visitor.visit_u64(self.decoder.u64()?)
163    }
164
165    fn deserialize_f32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
166        visitor.visit_f32(self.decoder.f32()?)
167    }
168
169    fn deserialize_f64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
170        visitor.visit_f64(self.decoder.f64()?)
171    }
172
173    fn deserialize_char<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
174        visitor.visit_char(self.decoder.char()?)
175    }
176
177    fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
178        visitor.visit_borrowed_str(self.decoder.str()?)
179    }
180
181    fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
182        visitor.visit_str(self.decoder.str()?)
183    }
184
185    fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
186        visitor.visit_borrowed_bytes(self.decoder.bytes()?)
187    }
188
189    fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
190        visitor.visit_bytes(self.decoder.bytes()?)
191    }
192
193    fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
194        if Type::Null == self.decoder.datatype()? {
195            self.decoder.skip()?;
196            visitor.visit_none()
197        } else {
198            visitor.visit_some(self)
199        }
200    }
201
202    fn deserialize_unit<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
203        self.decoder.decode::<()>()?;
204        visitor.visit_unit()
205    }
206
207    fn deserialize_unit_struct<V>(self, _name: &'static str, v: V) -> Result<V::Value, Self::Error>
208    where
209        V: Visitor<'de>
210    {
211        self.deserialize_unit(v)
212    }
213
214    fn deserialize_newtype_struct<V>(self, _name: &'static str, v: V) -> Result<V::Value, Self::Error>
215    where
216        V: Visitor<'de>
217    {
218        v.visit_newtype_struct(self)
219    }
220
221    fn deserialize_seq<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
222        let len = self.decoder.array()?;
223        visitor.visit_seq(Seq::new(self, len))
224    }
225
226    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
227    where
228        V: Visitor<'de>
229    {
230        let p = self.decoder.position();
231        let n = self.decoder.array()?;
232        if Some(len as u64) != n {
233            #[cfg(feature = "alloc")]
234            return Err(Error::message(alloc::format!("invalid length {n:?}, was expecting: {len}")).at(p).into());
235            #[cfg(not(feature = "alloc"))]
236            return Err(Error::message("invalid length").at(p).into());
237        }
238        visitor.visit_seq(Seq::new(self, n))
239    }
240
241    fn deserialize_tuple_struct<V>
242        ( self
243        , _name: &'static str
244        , len: usize
245        , visitor: V
246        ) -> Result<V::Value, Self::Error>
247    where
248        V: Visitor<'de>
249    {
250        self.deserialize_tuple(len, visitor)
251    }
252
253    fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
254        let len = self.decoder.map()?;
255        visitor.visit_map(Seq::new(self, len))
256    }
257
258    fn deserialize_struct<V>
259        ( self
260        , _name: &'static str
261        , _fields: &'static [&'static str]
262        , visitor: V
263        ) -> Result<V::Value, Self::Error>
264    where
265        V: Visitor<'de>
266    {
267        self.deserialize_map(visitor)
268    }
269
270    fn deserialize_enum<V>
271        ( self
272        , _name: &'static str
273        , _variants: &'static [&'static str]
274        , visitor: V
275        ) -> Result<V::Value, Self::Error>
276    where
277        V: Visitor<'de>
278    {
279        let p = self.decoder.position();
280        if Type::Map == self.decoder.datatype()? {
281            let m = self.decoder.map()?;
282            if m != Some(1) {
283                return Err(Error::message("invalid enum map length").at(p).into())
284            }
285        }
286        visitor.visit_enum(Enum::new(self))
287    }
288
289    fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
290        self.deserialize_str(visitor)
291    }
292
293    fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
294        self.decoder.skip()?;
295        visitor.visit_unit() // ignored
296    }
297
298    fn is_human_readable(&self) -> bool {
299        false
300    }
301}
302
303struct Seq<'a, 'de> {
304    deserializer: &'a mut Deserializer<'de>,
305    len: Option<u64>
306}
307
308impl<'a, 'de> Seq<'a, 'de> {
309    fn new(d: &'a mut Deserializer<'de>, len: Option<u64>) -> Self {
310        Self { deserializer: d, len }
311    }
312}
313
314impl<'a, 'de> SeqAccess<'de> for Seq<'a, 'de> {
315    type Error = DecodeError;
316
317    fn size_hint(&self) -> Option<usize> {
318        self.len.and_then(|n| n.try_into().ok())
319    }
320
321    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
322    where
323        T: DeserializeSeed<'de>
324    {
325        match self.len {
326            None => if BREAK == self.deserializer.current()? {
327                self.deserializer.read()?;
328                Ok(None)
329            } else {
330                seed.deserialize(&mut *self.deserializer).map(Some)
331            }
332            Some(0) => Ok(None),
333            Some(n) => {
334                let x = seed.deserialize(&mut *self.deserializer)?;
335                self.len = Some(n - 1);
336                Ok(Some(x))
337            }
338        }
339    }
340}
341
342impl<'a, 'de> MapAccess<'de> for Seq<'a, 'de> {
343    type Error = DecodeError;
344
345    fn size_hint(&self) -> Option<usize> {
346        self.len.and_then(|n| n.try_into().ok())
347    }
348
349    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
350    where
351        K: DeserializeSeed<'de>
352    {
353        match self.len {
354            None => if BREAK == self.deserializer.current()? {
355                self.deserializer.read()?;
356                Ok(None)
357            } else {
358                seed.deserialize(&mut *self.deserializer).map(Some)
359            }
360            Some(0) => Ok(None),
361            Some(_) => seed.deserialize(&mut *self.deserializer).map(Some)
362        }
363    }
364
365    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
366    where
367        V: DeserializeSeed<'de>
368    {
369        if let Some(n) = self.len {
370            let x = seed.deserialize(&mut *self.deserializer)?;
371            self.len = Some(n - 1);
372            Ok(x)
373        } else {
374            seed.deserialize(&mut *self.deserializer)
375        }
376    }
377}
378
379struct Enum<'a, 'de: 'a> {
380    deserializer: &'a mut Deserializer<'de>
381}
382
383impl<'a, 'de> Enum<'a, 'de> {
384    fn new(d: &'a mut Deserializer<'de>) -> Self {
385        Self { deserializer: d }
386    }
387}
388
389impl<'a, 'de> EnumAccess<'de> for Enum<'a, 'de> {
390    type Error = DecodeError;
391    type Variant = Self;
392
393    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
394    where
395        V: DeserializeSeed<'de>
396    {
397        seed.deserialize(&mut *self.deserializer).map(|v| (v, self))
398    }
399}
400
401impl<'a, 'de> VariantAccess<'de> for Enum<'a, 'de> {
402    type Error = DecodeError;
403
404    fn unit_variant(self) -> Result<(), Self::Error> {
405        Ok(())
406    }
407
408    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
409    where
410        T: DeserializeSeed<'de>
411    {
412        seed.deserialize(self.deserializer)
413    }
414
415    fn tuple_variant<V>(self, len: usize, v: V) -> Result<V::Value, Self::Error>
416    where
417        V: Visitor<'de>
418    {
419        de::Deserializer::deserialize_tuple(self.deserializer, len, v)
420    }
421
422    fn struct_variant<V>(self, _fields: &'static [&'static str], v: V) -> Result<V::Value, Self::Error>
423    where
424        V: Visitor<'de>
425    {
426        de::Deserializer::deserialize_map(self.deserializer, v)
427    }
428}