array_map/
serde.rs

1use crate::{ArrayMap, ArraySet, Indexable};
2use core::marker::PhantomData;
3use core::mem::MaybeUninit;
4use serde::de::{MapAccess, SeqAccess};
5use serde::{Deserializer, Serializer};
6
7impl<K: serde::Serialize + Indexable, V: serde::Serialize, const N: usize> serde::Serialize for ArrayMap<K, V, N> {
8  fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
9  where
10    S: Serializer,
11  {
12    debug_assert_eq!(N, K::iter().count());
13    serializer.collect_map(self.iter())
14  }
15}
16
17struct ExpectingN(usize);
18
19impl serde::de::Expected for ExpectingN {
20  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
21    write!(f, "{}", self.0)
22  }
23}
24
25impl<'de, K: serde::Deserialize<'de> + Indexable, V: serde::Deserialize<'de>, const N: usize> serde::Deserialize<'de>
26  for ArrayMap<K, V, N>
27{
28  fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
29  where
30    D: Deserializer<'de>,
31  {
32    struct ArrayMapVisitor<K: Indexable, V, const N: usize> {
33      array: MaybeUninit<[V; N]>,
34      filled: [bool; N],
35      phantom: PhantomData<fn() -> K>,
36    }
37    impl<'v, K: serde::Deserialize<'v> + Indexable, V: serde::Deserialize<'v>, const N: usize> serde::de::Visitor<'v>
38      for ArrayMapVisitor<K, V, N>
39    {
40      type Value = ArrayMap<K, V, N>;
41
42      fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
43        write!(formatter, "A map of {} values", N)
44      }
45
46      #[allow(unsafe_code)]
47      fn visit_map<A>(mut self, mut map: A) -> Result<Self::Value, <A as MapAccess<'v>>::Error>
48      where
49        A: MapAccess<'v>,
50      {
51        while let Some((k, v)) = map.next_entry::<K, V>()? {
52          let index = k.index();
53          assert!(index < N);
54          // Safety: we can only write to uninit before trying to read them which we do here
55          unsafe {
56            self.array.as_mut_ptr().cast::<V>().add(index).write(v);
57          }
58          self.filled[index] = true;
59        }
60        let count = self.filled.iter().filter(|f| **f).count();
61        if count != N {
62          use serde::de::Error;
63          return Err(<A as MapAccess<'v>>::Error::invalid_length(count, &ExpectingN(N)));
64        }
65        Ok(ArrayMap {
66          // Safety we have guaranteed that all the slots have been filled
67          array: unsafe { self.array.assume_init() },
68          phantom: PhantomData,
69        })
70      }
71    }
72    debug_assert_eq!(N, K::iter().count());
73    deserializer.deserialize_map(ArrayMapVisitor {
74      array: MaybeUninit::uninit(),
75      filled: [false; N],
76      phantom: PhantomData,
77    })
78  }
79}
80
81#[test]
82fn test_array_map_serde() {
83  use crate::test::Lowercase;
84  type Map = ArrayMap<Lowercase, Option<(u8, u8)>, { Lowercase::SIZE }>;
85  let mut h = Map::default();
86  h[Lowercase('b')] = Some((50, 80));
87  h[Lowercase('c')] = Some((10, 20));
88  let s = serde_json::to_string(&h).unwrap();
89  let h_new = serde_json::from_str::<Map>(&s).unwrap();
90  assert_eq!(h, h_new);
91}
92
93impl<K: serde::Serialize + Indexable, const N: usize> serde::Serialize for ArraySet<K, N> {
94  fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
95  where
96    S: Serializer,
97  {
98    debug_assert_eq!(K::SIZE, K::iter().count());
99    serializer.collect_seq(self.keys())
100  }
101}
102
103impl<'de, K: serde::Deserialize<'de> + Indexable, const N: usize> serde::Deserialize<'de> for ArraySet<K, N> {
104  fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
105  where
106    D: Deserializer<'de>,
107  {
108    struct ArraySetVisitor<K: Indexable, const N: usize> {
109      set: ArraySet<K, N>,
110    }
111    impl<'v, K: serde::Deserialize<'v> + Indexable, const N: usize> serde::de::Visitor<'v> for ArraySetVisitor<K, N> {
112      type Value = ArraySet<K, N>;
113
114      fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
115        write!(formatter, "A sequence of values")
116      }
117
118      #[allow(unsafe_code)]
119      fn visit_seq<A>(mut self, mut seq: A) -> Result<Self::Value, <A as SeqAccess<'v>>::Error>
120      where
121        A: SeqAccess<'v>,
122      {
123        while let Some(k) = seq.next_element::<K>()? {
124          self.set.insert(k);
125        }
126        Ok(self.set)
127      }
128    }
129    debug_assert_eq!(K::SIZE, K::iter().count());
130    deserializer.deserialize_seq(ArraySetVisitor { set: ArraySet::default() })
131  }
132}
133
134#[test]
135fn test_array_set_serde() {
136  use crate::test::Lowercase;
137  type Set = ArraySet<Lowercase, { crate::set_size(Lowercase::SIZE) }>;
138  let mut h = Set::default();
139  h.insert(Lowercase('b'));
140  h.insert(Lowercase('c'));
141  let s = serde_json::to_string(&h).unwrap();
142  let h_new = serde_json::from_str::<Set>(&s).unwrap();
143  assert_eq!(h, h_new);
144}