1use crate::firestore::{value::ValueType, ArrayValue, MapValue, Value};
2pub use error::{DeserializationError, Result};
3use prost::Message;
4use serde::{
5 de::{EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess},
6 Deserializer,
7};
8use std::convert::TryFrom;
9
10use crate::{TYPE, VALUE, VALUES};
11
12use self::{
13 plain_byte_deserializer::PlainByteDeserializer,
14 plain_string_deserializer::PlainStringDeserializer,
15};
16
17mod error;
18mod plain_byte_deserializer;
19mod plain_string_deserializer;
20
21pub struct ValueDeserializer<'de>(pub &'de Value);
22
23struct ArrayValueSeq<'de> {
24 values: std::slice::Iter<'de, Value>,
25}
26
27impl<'de> ArrayValueSeq<'de> {
28 pub fn new(values: std::slice::Iter<'de, Value>) -> Self {
29 ArrayValueSeq { values }
30 }
31}
32
33impl<'de> SeqAccess<'de> for ArrayValueSeq<'de> {
34 type Error = DeserializationError;
35
36 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
37 where
38 T: serde::de::DeserializeSeed<'de>,
39 {
40 if let Some(v) = self.values.next() {
41 seed.deserialize(&mut ValueDeserializer(v)).map(Some)
42 } else {
43 Ok(None)
44 }
45 }
46}
47
48struct BytesSeq<'de> {
49 bytes: core::slice::Iter<'de, u8>,
50}
51
52impl<'de> BytesSeq<'de> {
53 pub fn new(bytes: core::slice::Iter<'de, u8>) -> Self {
54 BytesSeq { bytes }
55 }
56}
57
58impl<'de> SeqAccess<'de> for BytesSeq<'de> {
59 type Error = DeserializationError;
60
61 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
62 where
63 T: serde::de::DeserializeSeed<'de>,
64 {
65 if let Some(v) = self.bytes.next() {
66 seed.deserialize(PlainByteDeserializer(*v)).map(Some)
67 } else {
68 Ok(None)
69 }
70 }
71}
72
73struct MapValueSeq<'de> {
74 values: std::collections::hash_map::Iter<'de, String, Value>,
75 next_value: Option<&'de Value>,
76}
77
78impl<'de> MapValueSeq<'de> {
79 pub fn new(values: std::collections::hash_map::Iter<'de, String, Value>) -> Self {
80 MapValueSeq {
81 values,
82 next_value: None,
83 }
84 }
85}
86
87impl<'de> MapAccess<'de> for MapValueSeq<'de> {
88 type Error = DeserializationError;
89
90 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
91 where
92 K: serde::de::DeserializeSeed<'de>,
93 {
94 if let Some((k, v)) = self.values.next() {
95 self.next_value = Some(v);
96
97 Ok(Some(seed.deserialize(PlainStringDeserializer(k))?))
98 } else {
99 Ok(None)
100 }
101 }
102
103 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
104 where
105 V: serde::de::DeserializeSeed<'de>,
106 {
107 let value = self
108 .next_value
109 .take()
110 .expect("Shouldn't visit value before key.");
111 seed.deserialize(&mut ValueDeserializer(value))
112 }
113}
114
115impl<'de, 'a> Deserializer<'de> for &'a mut ValueDeserializer<'de> {
116 type Error = DeserializationError;
117
118 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
119 where
120 V: serde::de::Visitor<'de>,
121 {
122 Err(DeserializationError::Unrepresentable("any"))
123 }
124
125 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
126 where
127 V: serde::de::Visitor<'de>,
128 {
129 if let Value {
130 value_type: Some(ValueType::BooleanValue(v)),
131 } = self.0
132 {
133 visitor.visit_bool(*v)
134 } else {
135 Err(DeserializationError::WrongType("bool", self.0.clone()))
136 }
137 }
138
139 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
140 where
141 V: serde::de::Visitor<'de>,
142 {
143 if let Value {
144 value_type: Some(ValueType::IntegerValue(v)),
145 } = self.0
146 {
147 visitor
148 .visit_i8(i8::try_from(*v).map_err(|_| DeserializationError::IntRange("i8", *v))?)
149 } else {
150 Err(DeserializationError::WrongType("i8", self.0.clone()))
151 }
152 }
153
154 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
155 where
156 V: serde::de::Visitor<'de>,
157 {
158 if let Value {
159 value_type: Some(ValueType::IntegerValue(v)),
160 } = self.0
161 {
162 visitor.visit_i16(
163 i16::try_from(*v).map_err(|_| DeserializationError::IntRange("i16", *v))?,
164 )
165 } else {
166 Err(DeserializationError::WrongType("i16", self.0.clone()))
167 }
168 }
169
170 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
171 where
172 V: serde::de::Visitor<'de>,
173 {
174 if let Value {
175 value_type: Some(ValueType::IntegerValue(v)),
176 } = self.0
177 {
178 visitor.visit_i32(
179 i32::try_from(*v).map_err(|_| DeserializationError::IntRange("i32", *v))?,
180 )
181 } else {
182 Err(DeserializationError::WrongType("i32", self.0.clone()))
183 }
184 }
185
186 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
187 where
188 V: serde::de::Visitor<'de>,
189 {
190 if let Value {
191 value_type: Some(ValueType::IntegerValue(v)),
192 } = self.0
193 {
194 visitor.visit_i64(*v)
195 } else {
196 Err(DeserializationError::WrongType("i64", self.0.clone()))
197 }
198 }
199
200 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
201 where
202 V: serde::de::Visitor<'de>,
203 {
204 if let Value {
205 value_type: Some(ValueType::IntegerValue(v)),
206 } = self.0
207 {
208 visitor
209 .visit_u8(u8::try_from(*v).map_err(|_| DeserializationError::IntRange("u8", *v))?)
210 } else {
211 Err(DeserializationError::WrongType("i8", self.0.clone()))
212 }
213 }
214
215 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
216 where
217 V: serde::de::Visitor<'de>,
218 {
219 if let Value {
220 value_type: Some(ValueType::IntegerValue(v)),
221 } = self.0
222 {
223 visitor.visit_u16(
224 u16::try_from(*v).map_err(|_| DeserializationError::IntRange("u16", *v))?,
225 )
226 } else {
227 Err(DeserializationError::WrongType("u16", self.0.clone()))
228 }
229 }
230
231 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
232 where
233 V: serde::de::Visitor<'de>,
234 {
235 if let Value {
236 value_type: Some(ValueType::IntegerValue(v)),
237 } = self.0
238 {
239 visitor.visit_u32(
240 u32::try_from(*v).map_err(|_| DeserializationError::IntRange("u32", *v))?,
241 )
242 } else {
243 Err(DeserializationError::WrongType("u32", self.0.clone()))
244 }
245 }
246
247 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
248 where
249 V: serde::de::Visitor<'de>,
250 {
251 if let Value {
252 value_type: Some(ValueType::IntegerValue(v)),
253 } = self.0
254 {
255 visitor.visit_u64(
256 u64::try_from(*v).map_err(|_| DeserializationError::IntRange("u64", *v))?,
257 )
258 } else {
259 Err(DeserializationError::WrongType("u64", self.0.clone()))
260 }
261 }
262
263 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
264 where
265 V: serde::de::Visitor<'de>,
266 {
267 if let Value {
268 value_type: Some(ValueType::DoubleValue(v)),
269 } = self.0
270 {
271 #[allow(clippy::cast_possible_truncation)]
272 visitor.visit_f32(*v as f32)
273 } else {
274 Err(DeserializationError::WrongType("f32", self.0.clone()))
275 }
276 }
277
278 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
279 where
280 V: serde::de::Visitor<'de>,
281 {
282 if let Value {
283 value_type: Some(ValueType::DoubleValue(v)),
284 } = self.0
285 {
286 visitor.visit_f64(*v)
287 } else {
288 Err(DeserializationError::WrongType("f64", self.0.clone()))
289 }
290 }
291
292 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
293 where
294 V: serde::de::Visitor<'de>,
295 {
296 if let Value {
297 value_type: Some(ValueType::StringValue(v)),
298 } = self.0
299 {
300 if v.len() == 1 {
301 visitor.visit_char(
302 v.chars()
303 .next()
304 .expect("Already checked that string has exactly one char."),
305 )
306 } else {
307 Err(DeserializationError::WrongType("char", self.0.clone()))
308 }
309 } else {
310 Err(DeserializationError::WrongType("char", self.0.clone()))
311 }
312 }
313
314 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
315 where
316 V: serde::de::Visitor<'de>,
317 {
318 if let Value {
319 value_type: Some(ValueType::StringValue(v)),
320 } = self.0
321 {
322 visitor.visit_str(v)
323 } else {
324 Err(DeserializationError::WrongType("str", self.0.clone()))
325 }
326 }
327
328 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
329 where
330 V: serde::de::Visitor<'de>,
331 {
332 if let Value {
333 value_type: Some(ValueType::StringValue(v)),
334 } = self.0
335 {
336 visitor.visit_string(v.clone())
337 } else {
338 Err(DeserializationError::WrongType("string", self.0.clone()))
339 }
340 }
341
342 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
343 where
344 V: serde::de::Visitor<'de>,
345 {
346 if let Value {
347 value_type: Some(ValueType::BytesValue(bytes)),
348 } = self.0
349 {
350 visitor.visit_bytes(bytes)
351 } else {
352 Err(DeserializationError::WrongType("bytes", self.0.clone()))
353 }
354 }
355
356 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
357 where
358 V: serde::de::Visitor<'de>,
359 {
360 if let Value {
361 value_type: Some(ValueType::BytesValue(bytes)),
362 } = self.0
363 {
364 visitor.visit_byte_buf(bytes.clone())
365 } else if let Value {
366 value_type: Some(ValueType::TimestampValue(timestamp)),
367 } = self.0
368 {
369 let bytes = timestamp.encode_to_vec();
370 visitor.visit_byte_buf(bytes)
371 } else {
372 Err(DeserializationError::WrongType("byte_buf", self.0.clone()))
373 }
374 }
375
376 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
377 where
378 V: serde::de::Visitor<'de>,
379 {
380 if let Value {
381 value_type: Some(ValueType::NullValue(_)),
382 } = self.0
383 {
384 visitor.visit_none()
385 } else {
386 visitor.visit_some(self)
387 }
388 }
389
390 fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value>
391 where
392 V: serde::de::Visitor<'de>,
393 {
394 Err(DeserializationError::Unrepresentable("unit"))
395 }
396
397 fn deserialize_unit_struct<V>(self, _name: &'static str, _visitor: V) -> Result<V::Value>
398 where
399 V: serde::de::Visitor<'de>,
400 {
401 Err(DeserializationError::Unrepresentable("unit_struct"))
402 }
403
404 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
405 where
406 V: serde::de::Visitor<'de>,
407 {
408 visitor.visit_newtype_struct(self)
409 }
410
411 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
412 where
413 V: serde::de::Visitor<'de>,
414 {
415 if let Value {
416 value_type: Some(ValueType::ArrayValue(ArrayValue { values })),
417 } = self.0
418 {
419 visitor.visit_seq(ArrayValueSeq::new(values.iter()))
420 } else if let Value {
421 value_type: Some(ValueType::BytesValue(bytes)),
422 } = self.0
423 {
424 visitor.visit_seq(BytesSeq::new(bytes.iter()))
425 } else {
426 Err(DeserializationError::WrongType("seq", self.0.clone()))
427 }
428 }
429
430 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
431 where
432 V: serde::de::Visitor<'de>,
433 {
434 if let Value {
435 value_type: Some(ValueType::ArrayValue(ArrayValue { values })),
436 } = self.0
437 {
438 visitor.visit_seq(ArrayValueSeq::new(values.iter()))
439 } else {
440 Err(DeserializationError::WrongType("tuple", self.0.clone()))
441 }
442 }
443
444 fn deserialize_tuple_struct<V>(
445 self,
446 _name: &'static str,
447 _len: usize,
448 visitor: V,
449 ) -> Result<V::Value>
450 where
451 V: serde::de::Visitor<'de>,
452 {
453 self.deserialize_seq(visitor)
454 }
455
456 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
457 where
458 V: serde::de::Visitor<'de>,
459 {
460 if let Value {
461 value_type: Some(ValueType::MapValue(MapValue { fields })),
462 } = self.0
463 {
464 visitor.visit_map(MapValueSeq::new(fields.iter()))
465 } else {
466 Err(DeserializationError::WrongType("map", self.0.clone()))
467 }
468 }
469
470 fn deserialize_struct<V>(
471 self,
472 _name: &'static str,
473 _fields: &'static [&'static str],
474 visitor: V,
475 ) -> Result<V::Value>
476 where
477 V: serde::de::Visitor<'de>,
478 {
479 self.deserialize_map(visitor)
480 }
481
482 fn deserialize_enum<V>(
483 self,
484 _name: &'static str,
485 _variants: &'static [&'static str],
486 visitor: V,
487 ) -> Result<V::Value>
488 where
489 V: serde::de::Visitor<'de>,
490 {
491 match &self.0.value_type {
492 Some(ValueType::StringValue(v)) => visitor.visit_enum(v.clone().into_deserializer()),
493 Some(ValueType::MapValue(MapValue { fields })) => {
494 let mut typ: Option<&String> = None;
495 let mut value: Option<&Value> = None;
496
497 for (k, v) in fields {
498 if k == TYPE {
499 if let Value {
500 value_type: Some(ValueType::StringValue(v)),
501 } = v
502 {
503 typ = Some(v);
504 } else {
505 return Err(DeserializationError::WrongType("string", v.clone()));
506 }
507 } else if k == VALUE || k == VALUES {
508 value = Some(v);
509 }
510 }
511
512 let typ = if let Some(typ) = typ {
513 typ
514 } else {
515 return Err(DeserializationError::MissingField(TYPE));
516 };
517
518 if let Some(value) = value {
519 visitor.visit_enum(Enum::new(typ, value))
520 } else {
521 Err(DeserializationError::MissingField(VALUE))
522 }
523 }
524 _ => Err(DeserializationError::WrongType("enum", self.0.clone())),
525 }
526 }
527
528 fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
529 where
530 V: serde::de::Visitor<'de>,
531 {
532 Err(DeserializationError::Unrepresentable("identifier"))
533 }
534
535 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
536 where
537 V: serde::de::Visitor<'de>,
538 {
539 visitor.visit_unit()
540 }
541}
542
543struct Enum<'de> {
544 typ: &'de str,
545 value: &'de Value,
546}
547
548impl<'de> Enum<'de> {
549 pub fn new(typ: &'de str, value: &'de Value) -> Self {
550 Enum { typ, value }
551 }
552}
553
554impl<'de> EnumAccess<'de> for Enum<'de> {
555 type Error = DeserializationError;
556
557 type Variant = Enum<'de>;
558
559 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
560 where
561 V: serde::de::DeserializeSeed<'de>,
562 {
563 let val = seed.deserialize(PlainStringDeserializer(self.typ))?;
564
565 Ok((val, self))
566 }
567}
568
569impl<'de> VariantAccess<'de> for Enum<'de> {
570 type Error = DeserializationError;
571
572 fn unit_variant(self) -> Result<()> {
573 panic!("Unit variant was already handled.")
574 }
575
576 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
577 where
578 T: serde::de::DeserializeSeed<'de>,
579 {
580 seed.deserialize(&mut ValueDeserializer(self.value))
581 }
582
583 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
584 where
585 V: serde::de::Visitor<'de>,
586 {
587 ValueDeserializer(self.value).deserialize_seq(visitor)
588 }
589
590 fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
591 where
592 V: serde::de::Visitor<'de>,
593 {
594 ValueDeserializer(self.value).deserialize_map(visitor)
595 }
596}