firestore_serde/deserialize/
mod.rs

1use crate::firestore::{value::ValueType, ArrayValue, MapValue, Value};
2pub use error::{DeserializationError, Result};
3use prost::Message;
4use serde::{
5    de::{EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess},
6    Deserializer,
7};
8use std::convert::TryFrom;
9
10use crate::{TYPE, VALUE, VALUES};
11
12use self::{
13    plain_byte_deserializer::PlainByteDeserializer,
14    plain_string_deserializer::PlainStringDeserializer,
15};
16
17mod error;
18mod plain_byte_deserializer;
19mod plain_string_deserializer;
20
21pub struct ValueDeserializer<'de>(pub &'de Value);
22
23struct ArrayValueSeq<'de> {
24    values: std::slice::Iter<'de, Value>,
25}
26
27impl<'de> ArrayValueSeq<'de> {
28    pub fn new(values: std::slice::Iter<'de, Value>) -> Self {
29        ArrayValueSeq { values }
30    }
31}
32
33impl<'de> SeqAccess<'de> for ArrayValueSeq<'de> {
34    type Error = DeserializationError;
35
36    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
37    where
38        T: serde::de::DeserializeSeed<'de>,
39    {
40        if let Some(v) = self.values.next() {
41            seed.deserialize(&mut ValueDeserializer(v)).map(Some)
42        } else {
43            Ok(None)
44        }
45    }
46}
47
48struct BytesSeq<'de> {
49    bytes: core::slice::Iter<'de, u8>,
50}
51
52impl<'de> BytesSeq<'de> {
53    pub fn new(bytes: core::slice::Iter<'de, u8>) -> Self {
54        BytesSeq { bytes }
55    }
56}
57
58impl<'de> SeqAccess<'de> for BytesSeq<'de> {
59    type Error = DeserializationError;
60
61    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
62    where
63        T: serde::de::DeserializeSeed<'de>,
64    {
65        if let Some(v) = self.bytes.next() {
66            seed.deserialize(PlainByteDeserializer(*v)).map(Some)
67        } else {
68            Ok(None)
69        }
70    }
71}
72
73struct MapValueSeq<'de> {
74    values: std::collections::hash_map::Iter<'de, String, Value>,
75    next_value: Option<&'de Value>,
76}
77
78impl<'de> MapValueSeq<'de> {
79    pub fn new(values: std::collections::hash_map::Iter<'de, String, Value>) -> Self {
80        MapValueSeq {
81            values,
82            next_value: None,
83        }
84    }
85}
86
87impl<'de> MapAccess<'de> for MapValueSeq<'de> {
88    type Error = DeserializationError;
89
90    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
91    where
92        K: serde::de::DeserializeSeed<'de>,
93    {
94        if let Some((k, v)) = self.values.next() {
95            self.next_value = Some(v);
96
97            Ok(Some(seed.deserialize(PlainStringDeserializer(k))?))
98        } else {
99            Ok(None)
100        }
101    }
102
103    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
104    where
105        V: serde::de::DeserializeSeed<'de>,
106    {
107        let value = self
108            .next_value
109            .take()
110            .expect("Shouldn't visit value before key.");
111        seed.deserialize(&mut ValueDeserializer(value))
112    }
113}
114
115impl<'de, 'a> Deserializer<'de> for &'a mut ValueDeserializer<'de> {
116    type Error = DeserializationError;
117
118    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
119    where
120        V: serde::de::Visitor<'de>,
121    {
122        Err(DeserializationError::Unrepresentable("any"))
123    }
124
125    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
126    where
127        V: serde::de::Visitor<'de>,
128    {
129        if let Value {
130            value_type: Some(ValueType::BooleanValue(v)),
131        } = self.0
132        {
133            visitor.visit_bool(*v)
134        } else {
135            Err(DeserializationError::WrongType("bool", self.0.clone()))
136        }
137    }
138
139    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
140    where
141        V: serde::de::Visitor<'de>,
142    {
143        if let Value {
144            value_type: Some(ValueType::IntegerValue(v)),
145        } = self.0
146        {
147            visitor
148                .visit_i8(i8::try_from(*v).map_err(|_| DeserializationError::IntRange("i8", *v))?)
149        } else {
150            Err(DeserializationError::WrongType("i8", self.0.clone()))
151        }
152    }
153
154    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
155    where
156        V: serde::de::Visitor<'de>,
157    {
158        if let Value {
159            value_type: Some(ValueType::IntegerValue(v)),
160        } = self.0
161        {
162            visitor.visit_i16(
163                i16::try_from(*v).map_err(|_| DeserializationError::IntRange("i16", *v))?,
164            )
165        } else {
166            Err(DeserializationError::WrongType("i16", self.0.clone()))
167        }
168    }
169
170    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
171    where
172        V: serde::de::Visitor<'de>,
173    {
174        if let Value {
175            value_type: Some(ValueType::IntegerValue(v)),
176        } = self.0
177        {
178            visitor.visit_i32(
179                i32::try_from(*v).map_err(|_| DeserializationError::IntRange("i32", *v))?,
180            )
181        } else {
182            Err(DeserializationError::WrongType("i32", self.0.clone()))
183        }
184    }
185
186    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
187    where
188        V: serde::de::Visitor<'de>,
189    {
190        if let Value {
191            value_type: Some(ValueType::IntegerValue(v)),
192        } = self.0
193        {
194            visitor.visit_i64(*v)
195        } else {
196            Err(DeserializationError::WrongType("i64", self.0.clone()))
197        }
198    }
199
200    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
201    where
202        V: serde::de::Visitor<'de>,
203    {
204        if let Value {
205            value_type: Some(ValueType::IntegerValue(v)),
206        } = self.0
207        {
208            visitor
209                .visit_u8(u8::try_from(*v).map_err(|_| DeserializationError::IntRange("u8", *v))?)
210        } else {
211            Err(DeserializationError::WrongType("i8", self.0.clone()))
212        }
213    }
214
215    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
216    where
217        V: serde::de::Visitor<'de>,
218    {
219        if let Value {
220            value_type: Some(ValueType::IntegerValue(v)),
221        } = self.0
222        {
223            visitor.visit_u16(
224                u16::try_from(*v).map_err(|_| DeserializationError::IntRange("u16", *v))?,
225            )
226        } else {
227            Err(DeserializationError::WrongType("u16", self.0.clone()))
228        }
229    }
230
231    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
232    where
233        V: serde::de::Visitor<'de>,
234    {
235        if let Value {
236            value_type: Some(ValueType::IntegerValue(v)),
237        } = self.0
238        {
239            visitor.visit_u32(
240                u32::try_from(*v).map_err(|_| DeserializationError::IntRange("u32", *v))?,
241            )
242        } else {
243            Err(DeserializationError::WrongType("u32", self.0.clone()))
244        }
245    }
246
247    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
248    where
249        V: serde::de::Visitor<'de>,
250    {
251        if let Value {
252            value_type: Some(ValueType::IntegerValue(v)),
253        } = self.0
254        {
255            visitor.visit_u64(
256                u64::try_from(*v).map_err(|_| DeserializationError::IntRange("u64", *v))?,
257            )
258        } else {
259            Err(DeserializationError::WrongType("u64", self.0.clone()))
260        }
261    }
262
263    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
264    where
265        V: serde::de::Visitor<'de>,
266    {
267        if let Value {
268            value_type: Some(ValueType::DoubleValue(v)),
269        } = self.0
270        {
271            #[allow(clippy::cast_possible_truncation)]
272            visitor.visit_f32(*v as f32)
273        } else {
274            Err(DeserializationError::WrongType("f32", self.0.clone()))
275        }
276    }
277
278    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
279    where
280        V: serde::de::Visitor<'de>,
281    {
282        if let Value {
283            value_type: Some(ValueType::DoubleValue(v)),
284        } = self.0
285        {
286            visitor.visit_f64(*v)
287        } else {
288            Err(DeserializationError::WrongType("f64", self.0.clone()))
289        }
290    }
291
292    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
293    where
294        V: serde::de::Visitor<'de>,
295    {
296        if let Value {
297            value_type: Some(ValueType::StringValue(v)),
298        } = self.0
299        {
300            if v.len() == 1 {
301                visitor.visit_char(
302                    v.chars()
303                        .next()
304                        .expect("Already checked that string has exactly one char."),
305                )
306            } else {
307                Err(DeserializationError::WrongType("char", self.0.clone()))
308            }
309        } else {
310            Err(DeserializationError::WrongType("char", self.0.clone()))
311        }
312    }
313
314    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
315    where
316        V: serde::de::Visitor<'de>,
317    {
318        if let Value {
319            value_type: Some(ValueType::StringValue(v)),
320        } = self.0
321        {
322            visitor.visit_str(v)
323        } else {
324            Err(DeserializationError::WrongType("str", self.0.clone()))
325        }
326    }
327
328    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
329    where
330        V: serde::de::Visitor<'de>,
331    {
332        if let Value {
333            value_type: Some(ValueType::StringValue(v)),
334        } = self.0
335        {
336            visitor.visit_string(v.clone())
337        } else {
338            Err(DeserializationError::WrongType("string", self.0.clone()))
339        }
340    }
341
342    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
343    where
344        V: serde::de::Visitor<'de>,
345    {
346        if let Value {
347            value_type: Some(ValueType::BytesValue(bytes)),
348        } = self.0
349        {
350            visitor.visit_bytes(bytes)
351        } else {
352            Err(DeserializationError::WrongType("bytes", self.0.clone()))
353        }
354    }
355
356    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
357    where
358        V: serde::de::Visitor<'de>,
359    {
360        if let Value {
361            value_type: Some(ValueType::BytesValue(bytes)),
362        } = self.0
363        {
364            visitor.visit_byte_buf(bytes.clone())
365        } else if let Value {
366            value_type: Some(ValueType::TimestampValue(timestamp)),
367        } = self.0
368        {
369            let bytes = timestamp.encode_to_vec();
370            visitor.visit_byte_buf(bytes)
371        } else {
372            Err(DeserializationError::WrongType("byte_buf", self.0.clone()))
373        }
374    }
375
376    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
377    where
378        V: serde::de::Visitor<'de>,
379    {
380        if let Value {
381            value_type: Some(ValueType::NullValue(_)),
382        } = self.0
383        {
384            visitor.visit_none()
385        } else {
386            visitor.visit_some(self)
387        }
388    }
389
390    fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value>
391    where
392        V: serde::de::Visitor<'de>,
393    {
394        Err(DeserializationError::Unrepresentable("unit"))
395    }
396
397    fn deserialize_unit_struct<V>(self, _name: &'static str, _visitor: V) -> Result<V::Value>
398    where
399        V: serde::de::Visitor<'de>,
400    {
401        Err(DeserializationError::Unrepresentable("unit_struct"))
402    }
403
404    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
405    where
406        V: serde::de::Visitor<'de>,
407    {
408        visitor.visit_newtype_struct(self)
409    }
410
411    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
412    where
413        V: serde::de::Visitor<'de>,
414    {
415        if let Value {
416            value_type: Some(ValueType::ArrayValue(ArrayValue { values })),
417        } = self.0
418        {
419            visitor.visit_seq(ArrayValueSeq::new(values.iter()))
420        } else if let Value {
421            value_type: Some(ValueType::BytesValue(bytes)),
422        } = self.0
423        {
424            visitor.visit_seq(BytesSeq::new(bytes.iter()))
425        } else {
426            Err(DeserializationError::WrongType("seq", self.0.clone()))
427        }
428    }
429
430    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
431    where
432        V: serde::de::Visitor<'de>,
433    {
434        if let Value {
435            value_type: Some(ValueType::ArrayValue(ArrayValue { values })),
436        } = self.0
437        {
438            visitor.visit_seq(ArrayValueSeq::new(values.iter()))
439        } else {
440            Err(DeserializationError::WrongType("tuple", self.0.clone()))
441        }
442    }
443
444    fn deserialize_tuple_struct<V>(
445        self,
446        _name: &'static str,
447        _len: usize,
448        visitor: V,
449    ) -> Result<V::Value>
450    where
451        V: serde::de::Visitor<'de>,
452    {
453        self.deserialize_seq(visitor)
454    }
455
456    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
457    where
458        V: serde::de::Visitor<'de>,
459    {
460        if let Value {
461            value_type: Some(ValueType::MapValue(MapValue { fields })),
462        } = self.0
463        {
464            visitor.visit_map(MapValueSeq::new(fields.iter()))
465        } else {
466            Err(DeserializationError::WrongType("map", self.0.clone()))
467        }
468    }
469
470    fn deserialize_struct<V>(
471        self,
472        _name: &'static str,
473        _fields: &'static [&'static str],
474        visitor: V,
475    ) -> Result<V::Value>
476    where
477        V: serde::de::Visitor<'de>,
478    {
479        self.deserialize_map(visitor)
480    }
481
482    fn deserialize_enum<V>(
483        self,
484        _name: &'static str,
485        _variants: &'static [&'static str],
486        visitor: V,
487    ) -> Result<V::Value>
488    where
489        V: serde::de::Visitor<'de>,
490    {
491        match &self.0.value_type {
492            Some(ValueType::StringValue(v)) => visitor.visit_enum(v.clone().into_deserializer()),
493            Some(ValueType::MapValue(MapValue { fields })) => {
494                let mut typ: Option<&String> = None;
495                let mut value: Option<&Value> = None;
496
497                for (k, v) in fields {
498                    if k == TYPE {
499                        if let Value {
500                            value_type: Some(ValueType::StringValue(v)),
501                        } = v
502                        {
503                            typ = Some(v);
504                        } else {
505                            return Err(DeserializationError::WrongType("string", v.clone()));
506                        }
507                    } else if k == VALUE || k == VALUES {
508                        value = Some(v);
509                    }
510                }
511
512                let typ = if let Some(typ) = typ {
513                    typ
514                } else {
515                    return Err(DeserializationError::MissingField(TYPE));
516                };
517
518                if let Some(value) = value {
519                    visitor.visit_enum(Enum::new(typ, value))
520                } else {
521                    Err(DeserializationError::MissingField(VALUE))
522                }
523            }
524            _ => Err(DeserializationError::WrongType("enum", self.0.clone())),
525        }
526    }
527
528    fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
529    where
530        V: serde::de::Visitor<'de>,
531    {
532        Err(DeserializationError::Unrepresentable("identifier"))
533    }
534
535    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
536    where
537        V: serde::de::Visitor<'de>,
538    {
539        visitor.visit_unit()
540    }
541}
542
543struct Enum<'de> {
544    typ: &'de str,
545    value: &'de Value,
546}
547
548impl<'de> Enum<'de> {
549    pub fn new(typ: &'de str, value: &'de Value) -> Self {
550        Enum { typ, value }
551    }
552}
553
554impl<'de> EnumAccess<'de> for Enum<'de> {
555    type Error = DeserializationError;
556
557    type Variant = Enum<'de>;
558
559    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
560    where
561        V: serde::de::DeserializeSeed<'de>,
562    {
563        let val = seed.deserialize(PlainStringDeserializer(self.typ))?;
564
565        Ok((val, self))
566    }
567}
568
569impl<'de> VariantAccess<'de> for Enum<'de> {
570    type Error = DeserializationError;
571
572    fn unit_variant(self) -> Result<()> {
573        panic!("Unit variant was already handled.")
574    }
575
576    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
577    where
578        T: serde::de::DeserializeSeed<'de>,
579    {
580        seed.deserialize(&mut ValueDeserializer(self.value))
581    }
582
583    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
584    where
585        V: serde::de::Visitor<'de>,
586    {
587        ValueDeserializer(self.value).deserialize_seq(visitor)
588    }
589
590    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
591    where
592        V: serde::de::Visitor<'de>,
593    {
594        ValueDeserializer(self.value).deserialize_map(visitor)
595    }
596}