1use std::cmp::Ord;
2use std::collections::{BTreeMap, HashMap};
3use std::fmt;
4use std::hash::Hash;
5use std::marker::PhantomData;
6
7use serde::de::{DeserializeSeed, Deserializer, MapAccess, Visitor};
8
9use super::util::SerializeVecMap;
10use super::{Context, Presto, PrestoMapKey, PrestoTy};
11
12macro_rules! gen_map {
13 ($ty:ident < $($bound:ident ),* >, $seed:ident) => {
14 impl<K: PrestoMapKey + $($bound+)*, V: Presto> Presto for $ty<K, V> {
15 type ValueType<'a> = SerializeVecMap<K::ValueType<'a>, V::ValueType<'a>> where K: 'a, V: 'a;
18 type Seed<'a, 'de> = $seed<'a, K, V>;
19
20 fn value(&self) -> Self::ValueType<'_> {
29 SerializeVecMap {
30 iter: self.iter().map(|(k, v)| (k.value(), v.value())).collect()
31 }
32 }
33
34 fn ty() -> PrestoTy {
35 PrestoTy::Map(Box::new(K::ty()), Box::new(V::ty()))
36 }
37
38 fn seed<'a, 'de>(ctx: &'a Context<'a>) -> Self::Seed<'a, 'de> {
39 if let PrestoTy::Map(t1, t2) = ctx.ty() {
40 $seed {
41 ctx,
42 key_ty: &*t1,
43 value_ty: &*t2,
44 _marker: PhantomData,
45 }
46 } else {
47 unreachable!()
48 }
49 }
50
51 fn empty() -> Self {
52 Default::default()
53 }
54 }
55
56 pub struct $seed<'a, K, V> {
57 ctx: &'a Context<'a>,
58 key_ty: &'a PrestoTy,
59 value_ty: &'a PrestoTy,
60 _marker: PhantomData<(K, V)>,
61 }
62
63 impl<'a, 'de, K: PrestoMapKey + $($bound+)*, V: Presto> Visitor<'de> for $seed<'a, K, V> {
64 type Value = $ty<K, V>;
65 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
66 formatter.write_str("map")
67 }
68 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
69 where
70 A: MapAccess<'de>,
71 {
72 let mut ret: Self::Value = Default::default();
73 let key_ctx = self.ctx.with_ty(self.key_ty);
74 let value_ctx = self.ctx.with_ty(self.value_ty);
75 while let Some((k, v)) =
76 map.next_entry_seed(K::seed(&key_ctx), V::seed(&value_ctx))?
77 {
78 ret.insert(k, v);
79 }
80 Ok(ret)
81 }
82 }
83
84 impl<'a, 'de, K: PrestoMapKey + $($bound+)*, V: Presto> DeserializeSeed<'de>
85 for $seed<'a, K, V>
86 {
87 type Value = $ty<K, V>;
88 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
89 where
90 D: Deserializer<'de>,
91 {
92 deserializer.deserialize_map(self)
93 }
94 }
95 };
96}
97
98gen_map!(HashMap<Eq, Hash>, HashMapSeed);
99gen_map!(BTreeMap<Ord>, BTreeMapSeed);