1use serde::de::{Error as DeError, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeMap, SerializeSeq};
3use serde::{Deserialize, Serialize};
4use std::collections::BTreeMap;
5use std::fmt::{Debug, Display};
6use std::marker::PhantomData;
7use std::str::FromStr;
8use std::{
9 clone::Clone,
10 collections::btree_map::{Values, ValuesMut},
11};
12
13pub trait PrimaryKey {
15 type PrimaryKeyType;
16 fn primary_key(&self) -> &Self::PrimaryKeyType;
17}
18
19#[derive(Default, Debug, Clone)]
23pub struct Table<V>
24where
25 V: PrimaryKey + Serialize,
26 V::PrimaryKeyType: Ord + FromStr + Display + Debug + Clone,
27 <<V as PrimaryKey>::PrimaryKeyType as FromStr>::Err: std::fmt::Display,
28{
29 inner: BTreeMap<<V as PrimaryKey>::PrimaryKeyType, V>,
30}
31
32impl<V> Serialize for Table<V>
33where
34 V: PrimaryKey + Serialize + for<'a> Deserialize<'a>,
35 V::PrimaryKeyType: Ord + FromStr + Display + Debug + Clone,
36 <<V as PrimaryKey>::PrimaryKeyType as FromStr>::Err: std::fmt::Display,
37{
38 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
39 where
40 S: serde::Serializer,
41 {
42 if serializer.is_human_readable() {
43 let mut map = serializer.serialize_map(Some(self.inner.len()))?;
45 for (k, v) in &self.inner {
46 map.serialize_entry(&k.to_string(), v)?;
47 }
48 map.end()
49 } else {
50 let mut seq = serializer.serialize_seq(Some(self.inner.len()))?;
52 for v in self.inner.values() {
53 seq.serialize_element(v)?;
54 }
55 seq.end()
56 }
57 }
58}
59
60impl<'de, V> Deserialize<'de> for Table<V>
61where
62 V: PrimaryKey + Serialize + Deserialize<'de>,
63 V::PrimaryKeyType: Ord + FromStr + Display + Debug + Clone,
64 <<V as PrimaryKey>::PrimaryKeyType as FromStr>::Err: std::fmt::Display,
65{
66 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
67 where
68 D: serde::Deserializer<'de>,
69 {
70 if deserializer.is_human_readable() {
71 struct MapVisitor<V>(PhantomData<V>);
73
74 impl<'de, V> Visitor<'de> for MapVisitor<V>
75 where
76 V: PrimaryKey + Serialize + Deserialize<'de>,
77 V::PrimaryKeyType: Ord + FromStr + Display + Debug + Clone,
78 <<V as PrimaryKey>::PrimaryKeyType as FromStr>::Err: std::fmt::Display,
79 {
80 type Value = Table<V>;
81
82 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
83 f.write_str("a map of stringified primary keys to rows")
84 }
85
86 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
87 where
88 A: MapAccess<'de>,
89 {
90 let mut inner = BTreeMap::new();
91 while let Some((k_str, v)) = map.next_entry::<String, V>()? {
92 let k = V::PrimaryKeyType::from_str(&k_str).map_err(|e| {
93 A::Error::custom(format!(
94 "failed to parse primary key '{}': {}",
95 k_str, e
96 ))
97 })?;
98 inner.insert(k, v);
100 }
101 Ok(Table { inner })
102 }
103 }
104
105 deserializer.deserialize_map(MapVisitor::<V>(PhantomData))
106 } else {
107 struct SeqVisitor<V>(PhantomData<V>);
109
110 impl<'de, V> Visitor<'de> for SeqVisitor<V>
111 where
112 V: PrimaryKey + Serialize + Deserialize<'de>,
113 V::PrimaryKeyType: Ord + FromStr + Display + Debug + Clone,
114 <<V as PrimaryKey>::PrimaryKeyType as FromStr>::Err: std::fmt::Display,
115 {
116 type Value = Table<V>;
117
118 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
119 f.write_str("a sequence of table rows")
120 }
121
122 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
123 where
124 A: SeqAccess<'de>,
125 {
126 let mut inner = BTreeMap::new();
127 while let Some(v) = seq.next_element::<V>()? {
128 let k = v.primary_key().clone();
129 inner.insert(k, v);
130 }
131 Ok(Table { inner })
132 }
133 }
134
135 deserializer.deserialize_seq(SeqVisitor::<V>(PhantomData))
136 }
137 }
138}
139
140impl<V> Table<V>
141where
142 V: PrimaryKey + Serialize + for<'a> Deserialize<'a>,
143 V::PrimaryKeyType: Ord + FromStr + Display + Debug + Clone,
144 <<V as PrimaryKey>::PrimaryKeyType as FromStr>::Err: std::fmt::Display,
145{
146 pub fn add(&mut self, value: V) -> Option<V>
148 where
149 V: Clone,
150 V::PrimaryKeyType: Clone,
151 {
152 let key = value.primary_key();
153 if !self.inner.contains_key(key) {
154 self.inner.insert(key.clone(), value.clone());
155 return Some(value);
156 }
157 None
158 }
159
160 pub fn get(&self, key: &V::PrimaryKeyType) -> Option<&V> {
162 self.inner.get(key)
163 }
164
165 pub fn get_mut(&mut self, key: &V::PrimaryKeyType) -> Option<&mut V> {
167 self.inner.get_mut(key)
168 }
169
170 pub fn edit(&mut self, key: &V::PrimaryKeyType, new_value: V) -> Option<V>
172 where
173 V: Clone,
174 V::PrimaryKeyType: Clone,
175 {
176 let new_key = new_value.primary_key();
177 if (key == new_key || !self.inner.contains_key(new_key)) && self.inner.remove(key).is_some()
178 {
179 self.inner.insert(new_key.clone(), new_value.clone());
180 return Some(new_value);
181 }
182 None
183 }
184
185 pub fn delete(&mut self, key: &V::PrimaryKeyType) -> Option<V> {
187 self.inner.remove(key)
188 }
189
190 pub fn search<F>(&self, predicate: F) -> Vec<&V>
192 where
193 F: Fn(&V) -> bool,
194 {
195 self.inner.values().filter(|&val| predicate(val)).collect()
196 }
197
198 pub fn search_ordered<F, O>(&self, predicate: F, comparator: O) -> Vec<&V>
200 where
201 F: Fn(&V) -> bool,
202 O: Fn(&&V, &&V) -> std::cmp::Ordering,
203 {
204 let mut result = self.search(predicate);
205 result.sort_by(comparator);
206 result
207 }
208
209 pub fn values(&self) -> Values<'_, V::PrimaryKeyType, V> {
211 self.inner.values()
212 }
213
214 pub fn values_mut(&mut self) -> ValuesMut<'_, V::PrimaryKeyType, V> {
216 self.inner.values_mut()
217 }
218}
219
220#[cfg(test)]
221mod test {
222 use super::{PrimaryKey, Table};
223 use serde::{Deserialize, Serialize};
224
225 #[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
226 struct User {
227 id: usize,
228 name: String,
229 age: usize,
230 }
231
232 impl PrimaryKey for User {
233 type PrimaryKeyType = usize;
234 fn primary_key(&self) -> &Self::PrimaryKeyType {
235 &self.id
236 }
237 }
238
239 #[test]
240 fn json_roundtrip_as_map() {
241 let mut table = Table::default();
242 table.add(User {
243 id: 0,
244 name: "".into(),
245 age: 0,
246 });
247 let s = serde_json::to_string(&table).unwrap();
248 assert_eq!(s, r#"{"0":{"id":0,"name":"","age":0}}"#);
249 let back: Table<User> = serde_json::from_str(&s).unwrap();
250 assert!(back.get(&0).is_some());
251 }
252
253 #[test]
254 #[cfg(feature = "encrypted")]
255 fn bincode_roundtrip_as_seq() {
256 let mut table = Table::default();
257 for i in 0..3 {
258 table.add(User {
259 id: i,
260 name: format!("u{i}"),
261 age: i,
262 });
263 }
264 let bytes = bincode::serialize(&table).unwrap();
265 let back: Table<User> = bincode::deserialize(&bytes).unwrap();
266 assert_eq!(table.values().count(), back.values().count());
267 for i in 0..3 {
268 assert_eq!(table.get(&i).unwrap().name, back.get(&i).unwrap().name);
269 }
270 }
271}