use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use crate::database::{
DatabaseEntry, DbKey, TransactionError, deserialize_from_ivec, sled_get_all_keys_raw,
sled_get_batch_raw, sled_get_raw,
};
pub struct DbCache<T: DatabaseEntry> {
cache: HashMap<DbKey, Arc<T>>,
}
impl<T: DatabaseEntry> Default for DbCache<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: DatabaseEntry> DbCache<T> {
pub fn new() -> DbCache<T> {
DbCache {
cache: HashMap::new(),
}
}
pub fn fill<'a, I>(&mut self, with: I)
where
I: IntoIterator<Item = Arc<T>>,
{
self.cache.extend(with.into_iter().map(|t| (*t.id(), t)));
}
pub fn clear(&mut self) {
self.cache.clear();
}
pub fn cache(&mut self, item: &Arc<T>) {
self.cache.insert(*item.id(), item.clone());
}
pub fn remove(&mut self, item: T::Id) {
self.cache.remove(&*item);
}
pub fn get_from(&mut self, id: DbKey, db: &T::Db) -> Result<Option<Arc<T>>, TransactionError> {
if let Some(arc) = self.cache.get(&id).cloned() {
return Ok(Some(arc));
}
let Some(value) = sled_get_raw(&T::__tree(db), &id)? else {
return Ok(None);
};
let value: Arc<T> = Arc::new(deserialize_from_ivec(value)?);
self.cache.insert(id, value.clone());
Ok(Some(value))
}
fn get_all_from(&mut self, db: &T::Db) -> Result<Vec<Arc<T>>, TransactionError> {
self.get_batch_from(sled_get_all_keys_raw(&T::__tree(db))?, db)
}
fn get_batch_from<I>(&mut self, ids: I, db: &T::Db) -> Result<Vec<Arc<T>>, TransactionError>
where
I: IntoIterator<Item = DbKey>,
{
let ids: Vec<DbKey> = ids.into_iter().collect();
let mut cached: Vec<Option<Arc<T>>> =
ids.iter().map(|id| self.cache.get(id).cloned()).collect();
let misses: Vec<(usize, DbKey)> = cached
.iter()
.enumerate()
.filter_map(|(i, r)| if r.is_none() { Some((i, ids[i])) } else { None })
.collect();
if !misses.is_empty() {
let mut loaded = sled_get_batch_raw(&T::__tree(db), misses.iter().map(|(_, id)| *id))?
.into_iter()
.map(deserialize_from_ivec::<T>);
for result in cached.iter_mut().filter(|r| r.is_none()) {
let next = Arc::new(loaded.next().unwrap()?);
self.cache.insert(*next.id(), next.clone());
*result = Some(next)
}
}
Ok(cached.into_iter().map(|r| r.unwrap()).collect())
}
}
pub trait Cacheable: DatabaseEntry {
fn cache() -> &'static Mutex<DbCache<Self>>;
fn cache_get_from(id: Self::Id, db: &Self::Db) -> Result<Option<Arc<Self>>, TransactionError> {
Self::cache().lock().unwrap().get_from(*id, db)
}
fn cache_get_batch_from<I, A>(id: I, db: &Self::Db) -> Result<Vec<Arc<Self>>, TransactionError>
where
I: IntoIterator<Item = A>,
A: std::borrow::Borrow<Self::Id>,
{
Self::cache()
.lock()
.unwrap()
.get_batch_from(id.into_iter().map(|a| **a.borrow()), db)
}
fn cache_get_all_from(db: &Self::Db) -> Result<Vec<Arc<Self>>, TransactionError> {
Self::cache().lock().unwrap().get_all_from(db)
}
}