enum_table/
impls.rs

1use std::ops::{Index, IndexMut};
2
3use crate::{EnumTable, Enumable};
4
5impl<K: Enumable + core::fmt::Debug, V: core::fmt::Debug, const N: usize> core::fmt::Debug
6    for EnumTable<K, V, N>
7{
8    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
9        f.debug_map().entries(self.iter()).finish()
10    }
11}
12
13impl<K: Enumable, V: Clone, const N: usize> Clone for EnumTable<K, V, N> {
14    fn clone(&self) -> Self {
15        Self {
16            table: self.table.clone(),
17            _phantom: core::marker::PhantomData,
18        }
19    }
20}
21
22impl<K: Enumable, V: PartialEq, const N: usize> PartialEq for EnumTable<K, V, N> {
23    fn eq(&self, other: &Self) -> bool {
24        self.table.eq(&other.table)
25    }
26}
27
28impl<K: Enumable, V: Eq, const N: usize> Eq for EnumTable<K, V, N> {}
29
30impl<K: Enumable, V: std::hash::Hash, const N: usize> std::hash::Hash for EnumTable<K, V, N> {
31    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
32        self.table.hash(state);
33    }
34}
35
36impl<K: Enumable, V: Default, const N: usize> Default for EnumTable<K, V, N> {
37    fn default() -> Self {
38        EnumTable::new_with_fn(|_| Default::default())
39    }
40}
41
42impl<K: Enumable, V, const N: usize> Index<K> for EnumTable<K, V, N> {
43    type Output = V;
44
45    fn index(&self, index: K) -> &Self::Output {
46        self.get(&index)
47    }
48}
49
50impl<K: Enumable, V, const N: usize> IndexMut<K> for EnumTable<K, V, N> {
51    fn index_mut(&mut self, index: K) -> &mut Self::Output {
52        self.get_mut(&index)
53    }
54}
55
56#[cfg(feature = "serde")]
57impl<K, V, const N: usize> serde::Serialize for EnumTable<K, V, N>
58where
59    K: Enumable + serde::Serialize,
60    V: serde::Serialize,
61{
62    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
63    where
64        S: serde::Serializer,
65    {
66        use serde::ser::SerializeMap;
67        let mut map = serializer.serialize_map(Some(N))?;
68        for (key, value) in self.iter() {
69            map.serialize_entry(key, value)?;
70        }
71        map.end()
72    }
73}
74
75#[cfg(feature = "serde")]
76impl<'de, K, V, const N: usize> serde::Deserialize<'de> for EnumTable<K, V, N>
77where
78    K: Enumable + serde::Deserialize<'de> + core::fmt::Debug,
79    V: serde::Deserialize<'de>,
80{
81    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
82    where
83        D: serde::Deserializer<'de>,
84    {
85        use serde::de::{MapAccess, Visitor};
86        use std::fmt;
87        use std::marker::PhantomData;
88
89        struct EnumTableVisitor<K, V, const N: usize> {
90            _phantom: PhantomData<(K, V)>,
91        }
92
93        impl<'de, K, V, const N: usize> Visitor<'de> for EnumTableVisitor<K, V, N>
94        where
95            K: Enumable + serde::Deserialize<'de> + core::fmt::Debug,
96            V: serde::Deserialize<'de>,
97        {
98            type Value = EnumTable<K, V, N>;
99
100            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
101                formatter.write_str("a map with all enum variants as keys")
102            }
103
104            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
105            where
106                A: MapAccess<'de>,
107            {
108                use crate::EnumTableFromVecError;
109
110                let mut values: Vec<(K, V)> = Vec::with_capacity(N);
111
112                while let Some((key, value)) = map.next_entry::<K, V>()? {
113                    values.push((key, value));
114                }
115
116                match EnumTable::try_from_vec(values) {
117                    Ok(t) => Ok(t),
118                    Err(EnumTableFromVecError::InvalidSize { expected, found }) => {
119                        Err(serde::de::Error::invalid_length(
120                            found,
121                            &format!("expected {expected} entries, found {found}").as_str(),
122                        ))
123                    }
124                    Err(EnumTableFromVecError::MissingVariant(variant)) => {
125                        Err(serde::de::Error::invalid_value(
126                            serde::de::Unexpected::Str(&format!("{variant:?}")),
127                            &"all enum variants must be present",
128                        ))
129                    }
130                }
131            }
132        }
133
134        deserializer.deserialize_map(EnumTableVisitor::<K, V, N> {
135            _phantom: PhantomData,
136        })
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use std::hash::{Hash, Hasher};
143
144    use super::*;
145
146    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Enumable)]
147    #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
148    enum Color {
149        Red,
150        Green,
151        Blue,
152    }
153
154    const TABLES: EnumTable<Color, &'static str, { Color::COUNT }> =
155        crate::et!(Color, &'static str, |color| match color {
156            Color::Red => "Red",
157            Color::Green => "Green",
158            Color::Blue => "Blue",
159        });
160
161    const ANOTHER_TABLES: EnumTable<Color, &'static str, { Color::COUNT }> =
162        crate::et!(Color, &'static str, |color| match color {
163            Color::Red => "Red",
164            Color::Green => "Green",
165            Color::Blue => "Blue",
166        });
167
168    #[test]
169    fn debug_impl() {
170        assert_eq!(
171            format!("{TABLES:?}"),
172            r#"{Red: "Red", Green: "Green", Blue: "Blue"}"#
173        );
174    }
175
176    #[test]
177    fn clone_impl() {
178        let cloned = TABLES.clone();
179        assert_eq!(cloned, TABLES);
180    }
181
182    #[test]
183    fn eq_impl() {
184        assert!(TABLES == ANOTHER_TABLES);
185        assert!(TABLES != EnumTable::new_with_fn(|_| "Unknown"));
186    }
187
188    #[test]
189    fn hash_impl() {
190        let mut hasher = std::collections::hash_map::DefaultHasher::new();
191        TABLES.hash(&mut hasher);
192        let hash1 = hasher.finish();
193
194        let mut hasher2 = std::collections::hash_map::DefaultHasher::new();
195        ANOTHER_TABLES.hash(&mut hasher2);
196        let hash2 = hasher2.finish();
197
198        assert_eq!(hash1, hash2);
199    }
200
201    #[test]
202    fn default_impl() {
203        let default_table: EnumTable<Color, &'static str, { Color::COUNT }> = EnumTable::default();
204        assert_eq!(default_table.get(&Color::Red), &"");
205        assert_eq!(default_table.get(&Color::Green), &"");
206        assert_eq!(default_table.get(&Color::Blue), &"");
207    }
208
209    #[test]
210    fn index_impl() {
211        assert_eq!(TABLES[Color::Red], "Red");
212        assert_eq!(TABLES[Color::Green], "Green");
213        assert_eq!(TABLES[Color::Blue], "Blue");
214
215        let mut mutable_table = TABLES.clone();
216        mutable_table[Color::Red] = "Changed Red";
217        assert_eq!(mutable_table[Color::Red], "Changed Red");
218    }
219
220    #[cfg(feature = "serde")]
221    #[test]
222    fn serde_serialize() {
223        let json = serde_json::to_string(&TABLES).unwrap();
224        assert!(json.contains(r#""Red":"Red""#));
225        assert!(json.contains(r#""Green":"Green""#));
226        assert!(json.contains(r#""Blue":"Blue""#));
227    }
228
229    #[cfg(feature = "serde")]
230    #[test]
231    fn serde_deserialize() {
232        let json = r#"{"Red":"Red","Green":"Green","Blue":"Blue"}"#;
233        let table: EnumTable<Color, &str, { Color::COUNT }> = serde_json::from_str(json).unwrap();
234
235        assert_eq!(table.get(&Color::Red), &"Red");
236        assert_eq!(table.get(&Color::Green), &"Green");
237        assert_eq!(table.get(&Color::Blue), &"Blue");
238    }
239
240    #[cfg(feature = "serde")]
241    #[test]
242    fn serde_roundtrip() {
243        let original = TABLES;
244        let json = serde_json::to_string(&original).unwrap();
245        let deserialized: EnumTable<Color, &str, { Color::COUNT }> =
246            serde_json::from_str(&json).unwrap();
247
248        assert_eq!(original, deserialized);
249    }
250
251    #[cfg(feature = "serde")]
252    #[test]
253    fn serde_missing_variant_error() {
254        // Missing Blue variant
255        let json = r#"{"Red":"Red","Green":"Green"}"#;
256        let result: Result<EnumTable<Color, &str, { Color::COUNT }>, _> =
257            serde_json::from_str(json);
258
259        assert!(result.is_err());
260    }
261}