1use serde::de::{Error as DeError, MapAccess, SeqAccess, Visitor};
2use serde::ser::{SerializeMap, SerializeSeq};
3use serde::{Deserialize, Serialize};
4use std::collections::BTreeMap;
5use std::fmt::{Debug, Display};
6use std::marker::PhantomData;
7use std::str::FromStr;
8use std::{
9 clone::Clone,
10 collections::btree_map::{Values, ValuesMut},
11};
12
13pub trait PrimaryKey {
15 type PrimaryKeyType;
16 fn primary_key(&self) -> &Self::PrimaryKeyType;
17}
18
19#[derive(Default, Debug, Clone)]
44pub struct Table<V>
45where
46 V: PrimaryKey + Serialize,
47 V::PrimaryKeyType: Ord + FromStr + Display + Debug + Clone,
48 <<V as PrimaryKey>::PrimaryKeyType as FromStr>::Err: std::fmt::Display,
49{
50 inner: BTreeMap<<V as PrimaryKey>::PrimaryKeyType, V>,
51}
52
53impl<V> Serialize for Table<V>
54where
55 V: PrimaryKey + Serialize + for<'a> Deserialize<'a>,
56 V::PrimaryKeyType: Ord + FromStr + Display + Debug + Clone,
57 <<V as PrimaryKey>::PrimaryKeyType as FromStr>::Err: std::fmt::Display,
58{
59 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
60 where
61 S: serde::Serializer,
62 {
63 if serializer.is_human_readable() {
64 let mut map = serializer.serialize_map(Some(self.inner.len()))?;
66 for (k, v) in &self.inner {
67 map.serialize_entry(&k.to_string(), v)?;
68 }
69 map.end()
70 } else {
71 let mut seq = serializer.serialize_seq(Some(self.inner.len()))?;
73 for v in self.inner.values() {
74 seq.serialize_element(v)?;
75 }
76 seq.end()
77 }
78 }
79}
80
81impl<'de, V> Deserialize<'de> for Table<V>
82where
83 V: PrimaryKey + Serialize + Deserialize<'de>,
84 V::PrimaryKeyType: Ord + FromStr + Display + Debug + Clone,
85 <<V as PrimaryKey>::PrimaryKeyType as FromStr>::Err: std::fmt::Display,
86{
87 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
88 where
89 D: serde::Deserializer<'de>,
90 {
91 if deserializer.is_human_readable() {
92 struct MapVisitor<V>(PhantomData<V>);
94
95 impl<'de, V> Visitor<'de> for MapVisitor<V>
96 where
97 V: PrimaryKey + Serialize + Deserialize<'de>,
98 V::PrimaryKeyType: Ord + FromStr + Display + Debug + Clone,
99 <<V as PrimaryKey>::PrimaryKeyType as FromStr>::Err: std::fmt::Display,
100 {
101 type Value = Table<V>;
102
103 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
104 f.write_str("a map of stringified primary keys to rows")
105 }
106
107 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
108 where
109 A: MapAccess<'de>,
110 {
111 let mut inner = BTreeMap::new();
112 while let Some((k_str, v)) = map.next_entry::<String, V>()? {
113 let k = V::PrimaryKeyType::from_str(&k_str).map_err(|e| {
114 A::Error::custom(format!(
115 "failed to parse primary key '{}': {}",
116 k_str, e
117 ))
118 })?;
119 inner.insert(k, v);
121 }
122 Ok(Table { inner })
123 }
124 }
125
126 deserializer.deserialize_map(MapVisitor::<V>(PhantomData))
127 } else {
128 struct SeqVisitor<V>(PhantomData<V>);
130
131 impl<'de, V> Visitor<'de> for SeqVisitor<V>
132 where
133 V: PrimaryKey + Serialize + Deserialize<'de>,
134 V::PrimaryKeyType: Ord + FromStr + Display + Debug + Clone,
135 <<V as PrimaryKey>::PrimaryKeyType as FromStr>::Err: std::fmt::Display,
136 {
137 type Value = Table<V>;
138
139 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
140 f.write_str("a sequence of table rows")
141 }
142
143 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
144 where
145 A: SeqAccess<'de>,
146 {
147 let mut inner = BTreeMap::new();
148 while let Some(v) = seq.next_element::<V>()? {
149 let k = v.primary_key().clone();
150 inner.insert(k, v);
151 }
152 Ok(Table { inner })
153 }
154 }
155
156 deserializer.deserialize_seq(SeqVisitor::<V>(PhantomData))
157 }
158 }
159}
160
161impl<V> Table<V>
162where
163 V: PrimaryKey + Serialize + for<'a> Deserialize<'a>,
164 V::PrimaryKeyType: Ord + FromStr + Display + Debug + Clone,
165 <<V as PrimaryKey>::PrimaryKeyType as FromStr>::Err: std::fmt::Display,
166{
167 pub fn add(&mut self, value: V) -> Option<V>
169 where
170 V: Clone,
171 V::PrimaryKeyType: Clone,
172 {
173 let key = value.primary_key();
174 if !self.inner.contains_key(key) {
175 self.inner.insert(key.clone(), value.clone());
176 return Some(value);
177 }
178 None
179 }
180
181 pub fn get(&self, key: &V::PrimaryKeyType) -> Option<&V> {
183 self.inner.get(key)
184 }
185
186 pub fn get_mut(&mut self, key: &V::PrimaryKeyType) -> Option<&mut V> {
188 self.inner.get_mut(key)
189 }
190
191 pub fn edit(&mut self, key: &V::PrimaryKeyType, new_value: V) -> Option<V>
193 where
194 V: Clone,
195 V::PrimaryKeyType: Clone,
196 {
197 let new_key = new_value.primary_key();
198 if (key == new_key || !self.inner.contains_key(new_key)) && self.inner.remove(key).is_some()
199 {
200 self.inner.insert(new_key.clone(), new_value.clone());
201 return Some(new_value);
202 }
203 None
204 }
205
206 pub fn delete(&mut self, key: &V::PrimaryKeyType) -> Option<V> {
208 self.inner.remove(key)
209 }
210
211 pub fn search<F>(&self, predicate: F) -> Vec<&V>
213 where
214 F: Fn(&V) -> bool,
215 {
216 self.inner.values().filter(|&val| predicate(val)).collect()
217 }
218
219 pub fn search_ordered<F, O>(&self, predicate: F, comparator: O) -> Vec<&V>
221 where
222 F: Fn(&V) -> bool,
223 O: Fn(&&V, &&V) -> std::cmp::Ordering,
224 {
225 let mut result = self.search(predicate);
226 result.sort_by(comparator);
227 result
228 }
229
230 pub fn values(&self) -> Values<'_, V::PrimaryKeyType, V> {
232 self.inner.values()
233 }
234
235 pub fn values_mut(&mut self) -> ValuesMut<'_, V::PrimaryKeyType, V> {
237 self.inner.values_mut()
238 }
239}
240
241#[cfg(test)]
242mod test {
243 use super::{PrimaryKey, Table};
244 use serde::{Deserialize, Serialize};
245
246 #[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
247 struct User {
248 id: usize,
249 name: String,
250 age: usize,
251 }
252
253 impl PrimaryKey for User {
254 type PrimaryKeyType = usize;
255 fn primary_key(&self) -> &Self::PrimaryKeyType {
256 &self.id
257 }
258 }
259
260 #[test]
261 fn json_roundtrip_as_map() {
262 let mut table = Table::default();
263 table.add(User {
264 id: 0,
265 name: "".into(),
266 age: 0,
267 });
268 let s = serde_json::to_string(&table).unwrap();
269 assert_eq!(s, r#"{"0":{"id":0,"name":"","age":0}}"#);
270 let back: Table<User> = serde_json::from_str(&s).unwrap();
271 assert!(back.get(&0).is_some());
272 }
273
274 #[test]
275 #[cfg(feature = "encrypted")]
276 fn bincode_roundtrip_as_seq() {
277 use crate::encrypted::bincode_cfg;
278
279 let mut table = Table::default();
280 for i in 0..3 {
281 table.add(User {
282 id: i,
283 name: format!("u{i}"),
284 age: i,
285 });
286 }
287 let bytes = bincode::serde::encode_to_vec(&table, bincode_cfg()).unwrap();
288 let (back, _): (Table<User>, usize) =
289 bincode::serde::decode_from_slice(&bytes, bincode_cfg()).unwrap();
290 assert_eq!(table.values().count(), back.values().count());
291 for i in 0..3 {
292 assert_eq!(table.get(&i).unwrap().name, back.get(&i).unwrap().name);
293 }
294 }
295}