use super::{Cache, CacheError};
use async_trait::async_trait;
use core::{marker::PhantomData, time::Duration};
use log::trace;
use serde::{Deserialize, Serialize};
use sled::Tree;
use tokio::time::sleep;
const DELAY: Duration = Duration::from_millis(10);
pub struct SledCache<T> {
tree: Tree,
phantom: PhantomData<T>,
}
impl<T> SledCache<T> {
pub fn new(tree: Tree) -> Self {
Self {
tree,
phantom: Default::default(),
}
}
}
#[async_trait]
impl<T: Clone + Serialize + for<'a> Deserialize<'a> + Send + Sync> Cache<T> for SledCache<T> {
async fn get_with<'a>(
&self,
key: String,
future: Box<dyn Future<Output = T> + Send + 'a>,
) -> Result<T, CacheError> {
trace!("getting cache at {key}");
if self
.tree
.compare_and_swap::<_, Vec<u8>, Vec<u8>>(
&key,
None,
Some(bitcode::serialize(&Option::<T>::None)?),
)?
.is_ok()
{
trace!("awaiting future for cache at {key}");
let value = Box::into_pin(future).await;
trace!("setting cache at {key}");
self.tree
.insert(key.clone(), bitcode::serialize(&Some(&value))?)?;
trace!("set cache at {key}");
return Ok(value);
}
trace!("waiting for cache at {key}");
loop {
if let Some(value) = self.tree.get(&key)? {
if let Some(value) = bitcode::deserialize::<Option<T>>(&value)? {
trace!("waited for cache at {key}");
return Ok(value);
}
} else {
return self.get_with(key, future).await;
}
sleep(DELAY).await;
}
}
async fn remove(&self, key: &str) -> Result<(), CacheError> {
trace!("removing cache entry at {key}");
self.tree.remove(key)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::sync::Arc;
use futures::future::join;
use tempfile::TempDir;
use tokio::sync::Mutex;
#[tokio::test]
async fn get_or_set() {
let file = TempDir::new().unwrap();
let cache = SledCache::new(sled::open(file.path()).unwrap().open_tree("foo").unwrap());
assert_eq!(
cache
.get_with("key".into(), Box::new(async { 42 }))
.await
.unwrap(),
42,
);
assert_eq!(
cache
.get_with("key".into(), Box::new(async { 0 }))
.await
.unwrap(),
42,
);
}
#[tokio::test]
async fn remove_while_set() {
let file = TempDir::new().unwrap();
let cache = Arc::new(SledCache::new(
sled::open(file.path()).unwrap().open_tree("foo").unwrap(),
));
assert_eq!(
cache
.clone()
.get_with(
"key".into(),
Box::new(async move {
cache.remove("key").await.unwrap();
42
})
)
.await
.unwrap(),
42,
);
}
#[tokio::test]
async fn remove_while_get() {
let file = TempDir::new().unwrap();
let cache = SledCache::new(sled::open(file.path()).unwrap().open_tree("foo").unwrap());
for _ in 0..10000 {
let mutex = Arc::new(Mutex::new(()));
let mutex1 = mutex.clone();
let lock = mutex1.lock().await;
let future = join(
{
let mutex = mutex.clone();
async {
cache
.get_with(
"key".into(),
Box::new(async move {
let _ = mutex.lock().await;
42
}),
)
.await
.unwrap();
cache.remove("key").await.unwrap()
}
},
async {
cache
.get_with(
"key".into(),
Box::new(async move {
let _ = mutex.lock().await;
42
}),
)
.await
.unwrap();
cache.remove("key").await.unwrap()
},
);
drop(lock);
future.await;
}
}
}