embedded_msgpack/decode/serde/
mod.rs

1use crate::marker::Marker;
2use core::fmt;
3use paste::paste;
4use serde::de::{self, Visitor};
5
6use self::{enum_::UnitVariantAccess, map::MapAccess, seq::SeqAccess};
7
8mod enum_;
9mod map;
10mod seq;
11
12use super::Error;
13
14type Result<T> = core::result::Result<T, Error>;
15
16#[cfg(test)]
17fn print_debug<T>(prefix: &str, function_name: &str, de: &Deserializer) {
18    #[cfg(not(feature = "std"))]
19    extern crate std;
20    #[cfg(not(feature = "std"))]
21    use std::println;
22    println!(
23        "{}{}<{}> ({:02x?})",
24        prefix,
25        function_name,
26        core::any::type_name::<T>(),
27        &de.slice[de.index..core::cmp::min(de.slice.len(), de.index + 10)]
28    );
29}
30
31#[cfg(test)]
32fn print_debug_value<T, V: core::fmt::Debug>(function_name: &str, de: &Deserializer, value: &V) {
33    #[cfg(not(feature = "std"))]
34    extern crate std;
35    #[cfg(not(feature = "std"))]
36    use std::println;
37    println!(
38        "{}<{}> => {:?}   ({:02x?})",
39        function_name,
40        core::any::type_name::<T>(),
41        value,
42        &de.slice[de.index..core::cmp::min(de.slice.len(), de.index + 10)]
43    );
44}
45#[cfg(not(test))]
46fn print_debug<T>(_prefix: &str, _function_name: &str, _de: &Deserializer) {}
47#[cfg(not(test))]
48fn print_debug_value<T, V: core::fmt::Debug>(_function_name: &str, _de: &Deserializer, _value: &V) {}
49
50pub(crate) struct Deserializer<'b> {
51    slice: &'b [u8],
52    index: usize,
53    state: State,
54}
55
56enum State {
57    Normal,
58    Ext(usize),
59}
60
61impl<'a> Deserializer<'a> {
62    pub const fn new(slice: &'a [u8]) -> Deserializer<'_> {
63        Deserializer {
64            slice,
65            index: 0,
66            state: State::Normal,
67        }
68    }
69
70    fn eat_byte(&mut self) { self.index += 1; }
71
72    fn peek(&mut self) -> Option<Marker> { Some(Marker::from_u8(*self.slice.get(self.index)?)) }
73}
74
75// NOTE(deserialize_*signed) we avoid parsing into u64 and then casting to a smaller integer, which
76// is what upstream does, to avoid pulling in 64-bit compiler intrinsics, which waste a few KBs of
77// Flash, when targeting non 64-bit architectures
78macro_rules! deserialize_primitive {
79    ($ty:ident) => {
80        paste! {
81            fn [<deserialize_ $ty>]<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value>
82        {
83            print_debug::<V>("Deserializer::deserialize_", stringify!($ty), &self);
84            let (value, len) = paste! { super::[<read_ $ty>](&self.slice[self.index..])? };
85            self.index += len;
86            print_debug_value::<$ty, $ty>(stringify!(concat_idents!(Deserializer::deserialize_, $ty)), &self, &value);
87            paste! { visitor.[<visit_ $ty>](value) }
88        }}
89    };
90}
91macro_rules! deserialize_primitives {
92    ($($ty:ident),*) => { $( deserialize_primitive!($ty); )* };
93}
94
95impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> {
96    type Error = Error;
97
98    deserialize_primitives!(bool, u8, u16, u32, u64, i16, i32, i64, f32, f64);
99
100    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
101    where V: Visitor<'de> {
102        print_debug::<V>("Deserializer::deserialize_", "i8", &self);
103        let (value, len) = match self.state {
104            State::Normal => super::read_i8(&self.slice[self.index..])?,
105            // read the ext type as raw byte and not encoded as a normal i8
106            #[cfg(feature = "ext")]
107            State::Ext(_) => (self.slice[self.index] as i8, 1),
108        };
109        self.index += len;
110        print_debug_value::<i8, i8>("Deserializer::deserialize_i8", &self, &value);
111        visitor.visit_i8(value)
112    }
113
114    fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
115        print_debug::<V>("Deserializer::deserialize_", "str", &self);
116        let (s, len) = super::read_str(&self.slice[self.index..])?;
117        self.index += len;
118        visitor.visit_borrowed_str(s)
119    }
120
121    fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
122        print_debug::<V>("Deserializer::deserialize_", "bytes", &self);
123        let (value, len) = match self.state {
124            State::Normal => super::read_bin(&self.slice[self.index..])?,
125            // read the ext type as raw byte and not encoded as a normal i8
126            #[cfg(feature = "ext")]
127            State::Ext(len) => {
128                self.state = State::Normal;
129                (&self.slice[self.index..self.index + len], len)
130            }
131        };
132        self.index += len;
133        visitor.visit_borrowed_bytes(value)
134    }
135
136    fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
137        print_debug::<V>("Deserializer::deserialize_", "byte_buf", &self);
138        self.deserialize_bytes(visitor)
139    }
140
141    fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
142        print_debug::<V>("Deserializer::deserialize_", "option", &self);
143        let marker = self.peek().ok_or(Error::EndOfBuffer)?;
144        match marker {
145            Marker::Null => {
146                self.eat_byte();
147                visitor.visit_none()
148            }
149            _ => visitor.visit_some(self),
150        }
151    }
152
153    fn deserialize_seq<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
154        print_debug::<V>("Deserializer::deserialize_", "seq", &self);
155        let (len, header_len) = crate::decode::read_array_len(&self.slice[self.index..])?;
156        self.index += header_len;
157        visitor.visit_seq(SeqAccess::new(self, len))
158    }
159
160    fn deserialize_tuple<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value> {
161        print_debug::<V>("Deserializer::deserialize_", "tuple", &self);
162        self.deserialize_seq(visitor)
163    }
164
165    fn deserialize_tuple_struct<V: Visitor<'de>>(self, _name: &'static str, _len: usize, visitor: V) -> Result<V::Value> {
166        print_debug::<V>("Deserializer::deserialize_", "tuple_struct", &self);
167        self.deserialize_seq(visitor)
168    }
169
170    fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
171        print_debug::<V>("Deserializer::deserialize_", "map", &self);
172        let (len, header_len) = crate::decode::read_map_len(&self.slice[self.index..])?;
173        self.index += header_len;
174        visitor.visit_map(MapAccess::new(self, len))
175    }
176
177    fn deserialize_struct<V: Visitor<'de>>(self, name: &'static str, _fields: &'static [&'static str], visitor: V) -> Result<V::Value> {
178        print_debug::<V>("Deserializer::deserialize_", "struct", &self);
179        match name {
180            #[cfg(feature = "ext")]
181            crate::ext::TYPE_NAME | crate::timestamp::TYPE_NAME => {
182                if let Some(marker) = self.peek() {
183                    match marker {
184                        Marker::FixExt1
185                        | Marker::FixExt2
186                        | Marker::FixExt4
187                        | Marker::FixExt8
188                        | Marker::FixExt16
189                        | Marker::Ext8
190                        | Marker::Ext16
191                        | Marker::Ext32 => {
192                            let (header_len, data_len) = crate::ext::read_ext_len(&self.slice[self.index..])?;
193                            self.index += header_len - 1; // move forward minus 1 byte for the ext type (header_len includes the type byte)
194                            self.state = State::Ext(data_len);
195                            visitor.visit_seq(SeqAccess::new(self, 2))
196                        }
197                        _ => Err(Error::InvalidType),
198                    }
199                } else {
200                    Err(Error::EndOfBuffer)
201                }
202            }
203            _ => self.deserialize_map(visitor),
204        }
205    }
206
207    fn deserialize_enum<V: Visitor<'de>>(self, _name: &'static str, _variants: &'static [&'static str], visitor: V) -> Result<V::Value> {
208        print_debug::<V>("Deserializer::deserialize_", "enum", &self);
209        visitor.visit_enum(UnitVariantAccess::new(self))
210    }
211
212    fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
213        print_debug::<V>("Deserializer::deserialize_", "identifier", &self);
214        self.deserialize_str(visitor)
215    }
216
217    /// Unsupported. Can’t parse a value without knowing its expected type.
218    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
219        print_debug::<V>("Deserializer::deserialize_", "any", &self);
220        let (_, n) = super::skip_any(&self.slice[self.index..])?;
221        self.index += n;
222        visitor.visit_unit()
223    }
224
225    /// Used to throw out fields that we don’t want to keep in our structs.
226    fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
227        print_debug::<V>("Deserializer::deserialize_", "ignored_any", &self);
228        self.deserialize_any(visitor)
229    }
230
231    /// Unsupported. Use a more specific deserialize_* method
232    fn deserialize_unit<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
233        print_debug::<V>("Deserializer::deserialize_", "unit", &self);
234        let marker = self.peek().ok_or(Error::EndOfBuffer)?;
235        match marker {
236            Marker::Null | Marker::FixArray(0) => {
237                self.eat_byte();
238                visitor.visit_unit()
239            }
240            _ => Err(Error::InvalidType),
241        }
242    }
243
244    /// Unsupported. Use a more specific deserialize_* method
245    fn deserialize_unit_struct<V: Visitor<'de>>(self, _name: &'static str, visitor: V) -> Result<V::Value> {
246        print_debug::<V>("Deserializer::deserialize_", "unit_struct", &self);
247        self.deserialize_unit(visitor)
248    }
249
250    fn deserialize_char<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
251        print_debug::<V>("Deserializer::deserialize_", "char", &self);
252        //TODO Need to decide how to encode this. Probably as a str?
253        self.deserialize_str(visitor)
254    }
255
256    fn deserialize_newtype_struct<V: Visitor<'de>>(self, _name: &'static str, visitor: V) -> Result<V::Value> {
257        print_debug::<V>("Deserializer::deserialize_", "newtype_struct", &self);
258        visitor.visit_newtype_struct(self)
259    }
260
261    fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
262        print_debug::<V>("Deserializer::deserialize_", "string", &self);
263        self.deserialize_str(visitor)
264    }
265}
266
267impl ::serde::de::StdError for Error {}
268impl de::Error for Error {
269    #[cfg_attr(not(feature = "custom-error-messages"), allow(unused_variables))]
270    fn custom<T>(msg: T) -> Self
271    where T: fmt::Display {
272        #[cfg(not(feature = "custom-error-messages"))]
273        {
274            Error::CustomError
275        }
276        #[cfg(feature = "custom-error-messages")]
277        {
278            use core::fmt::Write;
279
280            let mut string = heapless::String::new();
281            write!(string, "{:.64}", msg).unwrap();
282            Error::CustomErrorWithMessage(string)
283        }
284    }
285}
286
287impl fmt::Display for Error {
288    #[cfg(debug_assertions)]
289    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290        write!(
291            f,
292            "{}",
293            match self {
294                Error::InvalidType => "Unexpected type encountered.",
295                Error::OutOfBounds => "Index out of bounds.",
296                Error::EndOfBuffer => "EOF while parsing.",
297                Error::CustomError => "Did not match deserializer's expected format.",
298                #[cfg(feature = "custom-error-messages")]
299                Error::CustomErrorWithMessage(msg) => msg.as_str(),
300                Error::NotAscii => "String contains non-ascii chars.",
301                Error::InvalidBoolean => "Invalid boolean marker.",
302                Error::InvalidBinType => "Invalid binary marker.",
303                Error::InvalidStringType => "Invalid string marker.",
304                Error::InvalidArrayType => "Invalid array marker.",
305                Error::InvalidMapType => "Invalid map marker.",
306                Error::InvalidNewTypeLength => "Invalid array length for newtype.",
307            }
308        )
309    }
310    #[cfg(not(debug_assertions))]
311    fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { Ok(()) }
312}