use serde::de::{Error as DeError, MapAccess, SeqAccess, Visitor};
use serde::ser::{SerializeMap, SerializeSeq};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fmt::{Debug, Display, Formatter};
use std::marker::PhantomData;
use std::str::FromStr;
use std::{
clone::Clone,
collections::btree_map::{Values, ValuesMut},
};
pub trait PrimaryKey {
type PrimaryKeyType;
fn primary_key(&self) -> &Self::PrimaryKeyType;
}
pub trait TableKey: Ord + FromStr + Display + Debug + Clone {
fn parse_key(value: &str) -> Result<Self, String>
where
Self: Sized;
}
impl<T> TableKey for T
where
T: Ord + FromStr + Display + Debug + Clone,
T::Err: Display,
{
fn parse_key(value: &str) -> Result<Self, String> {
T::from_str(value).map_err(|error| error.to_string())
}
}
pub struct Table<V>
where
V: PrimaryKey,
{
inner: BTreeMap<<V as PrimaryKey>::PrimaryKeyType, V>,
}
impl<V> Debug for Table<V>
where
V: PrimaryKey + Debug,
V::PrimaryKeyType: Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Table").field("inner", &self.inner).finish()
}
}
impl<V> Clone for Table<V>
where
V: PrimaryKey + Clone,
V::PrimaryKeyType: Clone,
{
fn clone(&self) -> Self {
Table {
inner: self.inner.clone(),
}
}
}
impl<V> Default for Table<V>
where
V: PrimaryKey,
{
fn default() -> Self {
Table {
inner: BTreeMap::new(),
}
}
}
impl<V> Serialize for Table<V>
where
V: PrimaryKey + Serialize + for<'a> Deserialize<'a>,
V::PrimaryKeyType: TableKey,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if serializer.is_human_readable() {
let mut map = serializer.serialize_map(Some(self.inner.len()))?;
for (k, v) in &self.inner {
map.serialize_entry(&k.to_string(), v)?;
}
map.end()
} else {
let mut seq = serializer.serialize_seq(Some(self.inner.len()))?;
for v in self.inner.values() {
seq.serialize_element(v)?;
}
seq.end()
}
}
}
impl<'de, V> Deserialize<'de> for Table<V>
where
V: PrimaryKey + Serialize + Deserialize<'de>,
V::PrimaryKeyType: TableKey,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
if deserializer.is_human_readable() {
struct MapVisitor<V>(PhantomData<V>);
impl<'de, V> Visitor<'de> for MapVisitor<V>
where
V: PrimaryKey + Serialize + Deserialize<'de>,
V::PrimaryKeyType: TableKey,
{
type Value = Table<V>;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str("a map of stringified primary keys to rows")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut inner = BTreeMap::new();
while let Some((k_str, v)) = map.next_entry::<String, V>()? {
let k = V::PrimaryKeyType::parse_key(&k_str).map_err(|e| {
A::Error::custom(format!(
"failed to parse primary key '{}': {}",
k_str, e
))
})?;
inner.insert(k, v);
}
Ok(Table { inner })
}
}
deserializer.deserialize_map(MapVisitor::<V>(PhantomData))
} else {
struct SeqVisitor<V>(PhantomData<V>);
impl<'de, V> Visitor<'de> for SeqVisitor<V>
where
V: PrimaryKey + Serialize + Deserialize<'de>,
V::PrimaryKeyType: TableKey,
{
type Value = Table<V>;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str("a sequence of table rows")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut inner = BTreeMap::new();
while let Some(v) = seq.next_element::<V>()? {
let k = v.primary_key().clone();
inner.insert(k, v);
}
Ok(Table { inner })
}
}
deserializer.deserialize_seq(SeqVisitor::<V>(PhantomData))
}
}
}
impl<V> Table<V>
where
V: PrimaryKey + Serialize + for<'a> Deserialize<'a>,
V::PrimaryKeyType: TableKey,
{
pub fn add(&mut self, value: V) -> Option<V>
where
V: Clone,
V::PrimaryKeyType: Clone,
{
let key = value.primary_key();
if !self.inner.contains_key(key) {
self.inner.insert(key.clone(), value.clone());
return Some(value);
}
None
}
pub fn get(&self, key: &V::PrimaryKeyType) -> Option<&V> {
self.inner.get(key)
}
pub fn get_mut(&mut self, key: &V::PrimaryKeyType) -> Option<&mut V> {
self.inner.get_mut(key)
}
pub fn edit(&mut self, key: &V::PrimaryKeyType, new_value: V) -> Option<V>
where
V: Clone,
V::PrimaryKeyType: Clone,
{
let new_key = new_value.primary_key();
if (key == new_key || !self.inner.contains_key(new_key)) && self.inner.remove(key).is_some()
{
self.inner.insert(new_key.clone(), new_value.clone());
return Some(new_value);
}
None
}
pub fn delete(&mut self, key: &V::PrimaryKeyType) -> Option<V> {
self.inner.remove(key)
}
pub fn search<F>(&self, predicate: F) -> Vec<&V>
where
F: Fn(&V) -> bool,
{
self.inner.values().filter(|&val| predicate(val)).collect()
}
pub fn search_ordered<F, O>(&self, predicate: F, comparator: O) -> Vec<&V>
where
F: Fn(&V) -> bool,
O: Fn(&&V, &&V) -> std::cmp::Ordering,
{
let mut result = self.search(predicate);
result.sort_by(comparator);
result
}
pub fn values(&self) -> Values<'_, V::PrimaryKeyType, V> {
self.inner.values()
}
pub fn values_mut(&mut self) -> ValuesMut<'_, V::PrimaryKeyType, V> {
self.inner.values_mut()
}
}
#[cfg(test)]
mod test {
use super::{PrimaryKey, Table};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
struct User {
id: usize,
name: String,
age: usize,
}
impl PrimaryKey for User {
type PrimaryKeyType = usize;
fn primary_key(&self) -> &Self::PrimaryKeyType {
&self.id
}
}
#[test]
fn json_roundtrip_as_map() {
let mut table = Table::default();
table.add(User {
id: 0,
name: "".into(),
age: 0,
});
let s = serde_json::to_string(&table).unwrap();
assert_eq!(s, r#"{"0":{"id":0,"name":"","age":0}}"#);
let back: Table<User> = serde_json::from_str(&s).unwrap();
assert!(back.get(&0).is_some());
}
#[test]
#[cfg(feature = "encrypted")]
fn bincode_roundtrip_as_seq() {
use crate::encrypted::bincode_cfg;
let mut table = Table::default();
for i in 0..3 {
table.add(User {
id: i,
name: format!("u{i}"),
age: i,
});
}
let bytes = bincode::serde::encode_to_vec(&table, bincode_cfg()).unwrap();
let (back, _): (Table<User>, usize) =
bincode::serde::decode_from_slice(&bytes, bincode_cfg()).unwrap();
assert_eq!(table.values().count(), back.values().count());
for i in 0..3 {
assert_eq!(table.get(&i).unwrap().name, back.get(&i).unwrap().name);
}
}
}