1use serde::de::{self, DeserializeSeed, EnumAccess, MapAccess, SeqAccess, VariantAccess, Visitor};
2
3use minicbor::data::Type;
4use minicbor::decode::{Decoder, Error};
5
6use crate::error::DecodeError;
7
8const BREAK: u8 = 0xff;
9
10pub fn from_slice<'de, T: de::Deserialize<'de>>(b: &'de [u8]) -> Result<T, DecodeError> {
12 T::deserialize(&mut Deserializer::new(b))
13}
14
15#[derive(Debug, Clone)]
17pub struct Deserializer<'de> {
18 decoder: Decoder<'de>
19}
20
21impl<'de> Deserializer<'de> {
22 pub fn new(b: &'de [u8]) -> Self {
23 Self::from(Decoder::new(b))
24 }
25
26 pub fn decoder(&self) -> &Decoder<'de> {
27 &self.decoder
28 }
29
30 pub fn decoder_mut(&mut self) -> &mut Decoder<'de> {
31 &mut self.decoder
32 }
33
34 pub fn into_decoder(self) -> Decoder<'de> {
35 self.decoder
36 }
37
38 fn current(&self) -> Result<u8, Error> {
40 if let Some(b) = self.decoder.input().get(self.decoder.position()) {
41 return Ok(*b)
42 }
43 Err(Error::end_of_input())
44 }
45
46 fn read(&mut self) -> Result<u8, Error> {
48 let p = self.decoder.position();
49 if let Some(b) = self.decoder.input().get(p) {
50 self.decoder.set_position(p + 1);
51 return Ok(*b)
52 }
53 Err(Error::end_of_input())
54 }
55}
56
57impl<'de> From<Decoder<'de>> for Deserializer<'de> {
58 fn from(d: Decoder<'de>) -> Self {
59 Self { decoder: d }
60 }
61}
62
63impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
64 type Error = DecodeError;
65
66 fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
67 match self.decoder.datatype()? {
68 Type::Bool => self.deserialize_bool(visitor),
69 Type::U8 => self.deserialize_u8(visitor),
70 Type::U16 => self.deserialize_u16(visitor),
71 Type::U32 => self.deserialize_u32(visitor),
72 Type::U64 => self.deserialize_u64(visitor),
73 Type::I8 => self.deserialize_i8(visitor),
74 Type::I16 => self.deserialize_i16(visitor),
75 Type::I32 => self.deserialize_i32(visitor),
76 Type::I64 => self.deserialize_i64(visitor),
77 Type::F32 => self.deserialize_f32(visitor),
78 Type::F64 => self.deserialize_f64(visitor),
79 Type::Bytes => visitor.visit_borrowed_bytes(self.decoder.bytes()?),
80 Type::String => visitor.visit_borrowed_str(self.decoder.str()?),
81 Type::Null => { self.decoder.skip()?; visitor.visit_none() }
82 Type::Array |
83 Type::ArrayIndef => self.deserialize_seq(visitor),
84 Type::Map |
85 Type::MapIndef => self.deserialize_map(visitor),
86
87 #[cfg(feature = "half")]
88 Type::F16 => visitor.visit_f32(self.decoder.f16()?),
89
90 #[cfg(not(feature = "half"))]
91 Type::F16 => Err(Error::type_mismatch(Type::F16)
92 .with_message("unexpected type")
93 .at(self.decoder.position())
94 .into()),
95
96 #[cfg(feature = "alloc")]
97 Type::BytesIndef => {
98 let mut buf = alloc::vec::Vec::new();
99 for b in self.decoder.bytes_iter()? {
100 buf.extend_from_slice(b?)
101 }
102 visitor.visit_byte_buf(buf)
103 }
104
105 #[cfg(feature = "alloc")]
106 Type::StringIndef => {
107 let mut buf = alloc::string::String::new();
108 for b in self.decoder.str_iter()? {
109 buf += b?
110 }
111 visitor.visit_string(buf)
112 }
113
114 #[cfg(not(feature = "alloc"))]
115 t @ (Type::BytesIndef | Type::StringIndef) =>
116 Err(Error::type_mismatch(t).with_message("unexpected type").at(self.decoder.position()).into()),
117
118 t @ (
119 | Type::Undefined
120 | Type::Tag
121 | Type::Int
122 | Type::Simple
123 | Type::Break
124 | Type::Unknown(_)
125 ) => Err(Error::type_mismatch(t).with_message("unexpected type").at(self.decoder.position()).into())
126 }
127 }
128
129 fn deserialize_bool<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
130 visitor.visit_bool(self.decoder.bool()?)
131 }
132
133 fn deserialize_i8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
134 visitor.visit_i8(self.decoder.i8()?)
135 }
136
137 fn deserialize_i16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
138 visitor.visit_i16(self.decoder.i16()?)
139 }
140
141 fn deserialize_i32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
142 visitor.visit_i32(self.decoder.i32()?)
143 }
144
145 fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
146 visitor.visit_i64(self.decoder.i64()?)
147 }
148
149 fn deserialize_u8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
150 visitor.visit_u8(self.decoder.u8()?)
151 }
152
153 fn deserialize_u16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
154 visitor.visit_u16(self.decoder.u16()?)
155 }
156
157 fn deserialize_u32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
158 visitor.visit_u32(self.decoder.u32()?)
159 }
160
161 fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
162 visitor.visit_u64(self.decoder.u64()?)
163 }
164
165 fn deserialize_f32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
166 visitor.visit_f32(self.decoder.f32()?)
167 }
168
169 fn deserialize_f64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
170 visitor.visit_f64(self.decoder.f64()?)
171 }
172
173 fn deserialize_char<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
174 visitor.visit_char(self.decoder.char()?)
175 }
176
177 fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
178 visitor.visit_borrowed_str(self.decoder.str()?)
179 }
180
181 fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
182 visitor.visit_str(self.decoder.str()?)
183 }
184
185 fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
186 visitor.visit_borrowed_bytes(self.decoder.bytes()?)
187 }
188
189 fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
190 visitor.visit_bytes(self.decoder.bytes()?)
191 }
192
193 fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
194 if Type::Null == self.decoder.datatype()? {
195 self.decoder.skip()?;
196 visitor.visit_none()
197 } else {
198 visitor.visit_some(self)
199 }
200 }
201
202 fn deserialize_unit<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
203 self.decoder.decode::<()>()?;
204 visitor.visit_unit()
205 }
206
207 fn deserialize_unit_struct<V>(self, _name: &'static str, v: V) -> Result<V::Value, Self::Error>
208 where
209 V: Visitor<'de>
210 {
211 self.deserialize_unit(v)
212 }
213
214 fn deserialize_newtype_struct<V>(self, _name: &'static str, v: V) -> Result<V::Value, Self::Error>
215 where
216 V: Visitor<'de>
217 {
218 v.visit_newtype_struct(self)
219 }
220
221 fn deserialize_seq<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
222 let len = self.decoder.array()?;
223 visitor.visit_seq(Seq::new(self, len))
224 }
225
226 fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
227 where
228 V: Visitor<'de>
229 {
230 let p = self.decoder.position();
231 let n = self.decoder.array()?;
232 if Some(len as u64) != n {
233 #[cfg(feature = "alloc")]
234 return Err(Error::message(alloc::format!("invalid length {n:?}, was expecting: {len}")).at(p).into());
235 #[cfg(not(feature = "alloc"))]
236 return Err(Error::message("invalid length").at(p).into());
237 }
238 visitor.visit_seq(Seq::new(self, n))
239 }
240
241 fn deserialize_tuple_struct<V>
242 ( self
243 , _name: &'static str
244 , len: usize
245 , visitor: V
246 ) -> Result<V::Value, Self::Error>
247 where
248 V: Visitor<'de>
249 {
250 self.deserialize_tuple(len, visitor)
251 }
252
253 fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
254 let len = self.decoder.map()?;
255 visitor.visit_map(Seq::new(self, len))
256 }
257
258 fn deserialize_struct<V>
259 ( self
260 , _name: &'static str
261 , _fields: &'static [&'static str]
262 , visitor: V
263 ) -> Result<V::Value, Self::Error>
264 where
265 V: Visitor<'de>
266 {
267 self.deserialize_map(visitor)
268 }
269
270 fn deserialize_enum<V>
271 ( self
272 , _name: &'static str
273 , _variants: &'static [&'static str]
274 , visitor: V
275 ) -> Result<V::Value, Self::Error>
276 where
277 V: Visitor<'de>
278 {
279 let p = self.decoder.position();
280 if Type::Map == self.decoder.datatype()? {
281 let m = self.decoder.map()?;
282 if m != Some(1) {
283 return Err(Error::message("invalid enum map length").at(p).into())
284 }
285 }
286 visitor.visit_enum(Enum::new(self))
287 }
288
289 fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
290 self.deserialize_str(visitor)
291 }
292
293 fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
294 self.decoder.skip()?;
295 visitor.visit_unit() }
297
298 fn is_human_readable(&self) -> bool {
299 false
300 }
301}
302
303struct Seq<'a, 'de> {
304 deserializer: &'a mut Deserializer<'de>,
305 len: Option<u64>
306}
307
308impl<'a, 'de> Seq<'a, 'de> {
309 fn new(d: &'a mut Deserializer<'de>, len: Option<u64>) -> Self {
310 Self { deserializer: d, len }
311 }
312}
313
314impl<'a, 'de> SeqAccess<'de> for Seq<'a, 'de> {
315 type Error = DecodeError;
316
317 fn size_hint(&self) -> Option<usize> {
318 self.len.and_then(|n| n.try_into().ok())
319 }
320
321 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
322 where
323 T: DeserializeSeed<'de>
324 {
325 match self.len {
326 None => if BREAK == self.deserializer.current()? {
327 self.deserializer.read()?;
328 Ok(None)
329 } else {
330 seed.deserialize(&mut *self.deserializer).map(Some)
331 }
332 Some(0) => Ok(None),
333 Some(n) => {
334 let x = seed.deserialize(&mut *self.deserializer)?;
335 self.len = Some(n - 1);
336 Ok(Some(x))
337 }
338 }
339 }
340}
341
342impl<'a, 'de> MapAccess<'de> for Seq<'a, 'de> {
343 type Error = DecodeError;
344
345 fn size_hint(&self) -> Option<usize> {
346 self.len.and_then(|n| n.try_into().ok())
347 }
348
349 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
350 where
351 K: DeserializeSeed<'de>
352 {
353 match self.len {
354 None => if BREAK == self.deserializer.current()? {
355 self.deserializer.read()?;
356 Ok(None)
357 } else {
358 seed.deserialize(&mut *self.deserializer).map(Some)
359 }
360 Some(0) => Ok(None),
361 Some(_) => seed.deserialize(&mut *self.deserializer).map(Some)
362 }
363 }
364
365 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
366 where
367 V: DeserializeSeed<'de>
368 {
369 if let Some(n) = self.len {
370 let x = seed.deserialize(&mut *self.deserializer)?;
371 self.len = Some(n - 1);
372 Ok(x)
373 } else {
374 seed.deserialize(&mut *self.deserializer)
375 }
376 }
377}
378
379struct Enum<'a, 'de: 'a> {
380 deserializer: &'a mut Deserializer<'de>
381}
382
383impl<'a, 'de> Enum<'a, 'de> {
384 fn new(d: &'a mut Deserializer<'de>) -> Self {
385 Self { deserializer: d }
386 }
387}
388
389impl<'a, 'de> EnumAccess<'de> for Enum<'a, 'de> {
390 type Error = DecodeError;
391 type Variant = Self;
392
393 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
394 where
395 V: DeserializeSeed<'de>
396 {
397 seed.deserialize(&mut *self.deserializer).map(|v| (v, self))
398 }
399}
400
401impl<'a, 'de> VariantAccess<'de> for Enum<'a, 'de> {
402 type Error = DecodeError;
403
404 fn unit_variant(self) -> Result<(), Self::Error> {
405 Ok(())
406 }
407
408 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
409 where
410 T: DeserializeSeed<'de>
411 {
412 seed.deserialize(self.deserializer)
413 }
414
415 fn tuple_variant<V>(self, len: usize, v: V) -> Result<V::Value, Self::Error>
416 where
417 V: Visitor<'de>
418 {
419 de::Deserializer::deserialize_tuple(self.deserializer, len, v)
420 }
421
422 fn struct_variant<V>(self, _fields: &'static [&'static str], v: V) -> Result<V::Value, Self::Error>
423 where
424 V: Visitor<'de>
425 {
426 de::Deserializer::deserialize_map(self.deserializer, v)
427 }
428}