1use std::ops::{Index, IndexMut};
2
3use crate::{EnumTable, Enumable};
4
5impl<K: Enumable + core::fmt::Debug, V: core::fmt::Debug, const N: usize> core::fmt::Debug
6 for EnumTable<K, V, N>
7{
8 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
9 f.debug_map().entries(self.iter()).finish()
10 }
11}
12
13impl<K: Enumable, V: Clone, const N: usize> Clone for EnumTable<K, V, N> {
14 fn clone(&self) -> Self {
15 Self {
16 table: self.table.clone(),
17 _phantom: core::marker::PhantomData,
18 }
19 }
20}
21
22impl<K: Enumable, V: PartialEq, const N: usize> PartialEq for EnumTable<K, V, N> {
23 fn eq(&self, other: &Self) -> bool {
24 self.table.eq(&other.table)
25 }
26}
27
28impl<K: Enumable, V: Eq, const N: usize> Eq for EnumTable<K, V, N> {}
29
30impl<K: Enumable, V: std::hash::Hash, const N: usize> std::hash::Hash for EnumTable<K, V, N> {
31 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
32 self.table.hash(state);
33 }
34}
35
36impl<K: Enumable, V: Default, const N: usize> Default for EnumTable<K, V, N> {
37 fn default() -> Self {
38 EnumTable::new_with_fn(|_| Default::default())
39 }
40}
41
42impl<K: Enumable, V, const N: usize> Index<K> for EnumTable<K, V, N> {
43 type Output = V;
44
45 fn index(&self, index: K) -> &Self::Output {
46 self.get(&index)
47 }
48}
49
50impl<K: Enumable, V, const N: usize> IndexMut<K> for EnumTable<K, V, N> {
51 fn index_mut(&mut self, index: K) -> &mut Self::Output {
52 self.get_mut(&index)
53 }
54}
55
56#[cfg(feature = "serde")]
57impl<K, V, const N: usize> serde::Serialize for EnumTable<K, V, N>
58where
59 K: Enumable + serde::Serialize,
60 V: serde::Serialize,
61{
62 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
63 where
64 S: serde::Serializer,
65 {
66 use serde::ser::SerializeMap;
67 let mut map = serializer.serialize_map(Some(N))?;
68 for (key, value) in self.iter() {
69 map.serialize_entry(key, value)?;
70 }
71 map.end()
72 }
73}
74
75#[cfg(feature = "serde")]
76impl<'de, K, V, const N: usize> serde::Deserialize<'de> for EnumTable<K, V, N>
77where
78 K: Enumable + serde::Deserialize<'de> + core::fmt::Debug,
79 V: serde::Deserialize<'de>,
80{
81 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
82 where
83 D: serde::Deserializer<'de>,
84 {
85 use serde::de::{MapAccess, Visitor};
86 use std::fmt;
87 use std::marker::PhantomData;
88
89 struct EnumTableVisitor<K, V, const N: usize> {
90 _phantom: PhantomData<(K, V)>,
91 }
92
93 impl<'de, K, V, const N: usize> Visitor<'de> for EnumTableVisitor<K, V, N>
94 where
95 K: Enumable + serde::Deserialize<'de> + core::fmt::Debug,
96 V: serde::Deserialize<'de>,
97 {
98 type Value = EnumTable<K, V, N>;
99
100 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
101 formatter.write_str("a map with all enum variants as keys")
102 }
103
104 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
105 where
106 A: MapAccess<'de>,
107 {
108 use crate::EnumTableFromVecError;
109
110 let mut values: Vec<(K, V)> = Vec::with_capacity(N);
111
112 while let Some((key, value)) = map.next_entry::<K, V>()? {
113 values.push((key, value));
114 }
115
116 match EnumTable::try_from_vec(values) {
117 Ok(t) => Ok(t),
118 Err(EnumTableFromVecError::InvalidSize { expected, found }) => {
119 Err(serde::de::Error::invalid_length(
120 found,
121 &format!("expected {expected} entries, found {found}").as_str(),
122 ))
123 }
124 Err(EnumTableFromVecError::MissingVariant(variant)) => {
125 Err(serde::de::Error::invalid_value(
126 serde::de::Unexpected::Str(&format!("{variant:?}")),
127 &"all enum variants must be present",
128 ))
129 }
130 }
131 }
132 }
133
134 deserializer.deserialize_map(EnumTableVisitor::<K, V, N> {
135 _phantom: PhantomData,
136 })
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use std::hash::{Hash, Hasher};
143
144 use super::*;
145
146 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Enumable)]
147 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
148 enum Color {
149 Red,
150 Green,
151 Blue,
152 }
153
154 const TABLES: EnumTable<Color, &'static str, { Color::COUNT }> =
155 crate::et!(Color, &'static str, |color| match color {
156 Color::Red => "Red",
157 Color::Green => "Green",
158 Color::Blue => "Blue",
159 });
160
161 const ANOTHER_TABLES: EnumTable<Color, &'static str, { Color::COUNT }> =
162 crate::et!(Color, &'static str, |color| match color {
163 Color::Red => "Red",
164 Color::Green => "Green",
165 Color::Blue => "Blue",
166 });
167
168 #[test]
169 fn debug_impl() {
170 assert_eq!(
171 format!("{TABLES:?}"),
172 r#"{Red: "Red", Green: "Green", Blue: "Blue"}"#
173 );
174 }
175
176 #[test]
177 fn clone_impl() {
178 let cloned = TABLES.clone();
179 assert_eq!(cloned, TABLES);
180 }
181
182 #[test]
183 fn eq_impl() {
184 assert!(TABLES == ANOTHER_TABLES);
185 assert!(TABLES != EnumTable::new_with_fn(|_| "Unknown"));
186 }
187
188 #[test]
189 fn hash_impl() {
190 let mut hasher = std::collections::hash_map::DefaultHasher::new();
191 TABLES.hash(&mut hasher);
192 let hash1 = hasher.finish();
193
194 let mut hasher2 = std::collections::hash_map::DefaultHasher::new();
195 ANOTHER_TABLES.hash(&mut hasher2);
196 let hash2 = hasher2.finish();
197
198 assert_eq!(hash1, hash2);
199 }
200
201 #[test]
202 fn default_impl() {
203 let default_table: EnumTable<Color, &'static str, { Color::COUNT }> = EnumTable::default();
204 assert_eq!(default_table.get(&Color::Red), &"");
205 assert_eq!(default_table.get(&Color::Green), &"");
206 assert_eq!(default_table.get(&Color::Blue), &"");
207 }
208
209 #[test]
210 fn index_impl() {
211 assert_eq!(TABLES[Color::Red], "Red");
212 assert_eq!(TABLES[Color::Green], "Green");
213 assert_eq!(TABLES[Color::Blue], "Blue");
214
215 let mut mutable_table = TABLES.clone();
216 mutable_table[Color::Red] = "Changed Red";
217 assert_eq!(mutable_table[Color::Red], "Changed Red");
218 }
219
220 #[cfg(feature = "serde")]
221 #[test]
222 fn serde_serialize() {
223 let json = serde_json::to_string(&TABLES).unwrap();
224 assert!(json.contains(r#""Red":"Red""#));
225 assert!(json.contains(r#""Green":"Green""#));
226 assert!(json.contains(r#""Blue":"Blue""#));
227 }
228
229 #[cfg(feature = "serde")]
230 #[test]
231 fn serde_deserialize() {
232 let json = r#"{"Red":"Red","Green":"Green","Blue":"Blue"}"#;
233 let table: EnumTable<Color, &str, { Color::COUNT }> = serde_json::from_str(json).unwrap();
234
235 assert_eq!(table.get(&Color::Red), &"Red");
236 assert_eq!(table.get(&Color::Green), &"Green");
237 assert_eq!(table.get(&Color::Blue), &"Blue");
238 }
239
240 #[cfg(feature = "serde")]
241 #[test]
242 fn serde_roundtrip() {
243 let original = TABLES;
244 let json = serde_json::to_string(&original).unwrap();
245 let deserialized: EnumTable<Color, &str, { Color::COUNT }> =
246 serde_json::from_str(&json).unwrap();
247
248 assert_eq!(original, deserialized);
249 }
250
251 #[cfg(feature = "serde")]
252 #[test]
253 fn serde_missing_variant_error() {
254 let json = r#"{"Red":"Red","Green":"Green"}"#;
256 let result: Result<EnumTable<Color, &str, { Color::COUNT }>, _> =
257 serde_json::from_str(json);
258
259 assert!(result.is_err());
260 }
261}