use super::{Cache, CacheError};
use async_trait::async_trait;
use core::{marker::PhantomData, time::Duration};
use fjall::SingleWriterTxKeyspace;
use serde::{Deserialize, Serialize};
use tokio::time::sleep;
const DELAY: Duration = Duration::from_millis(10);
pub struct FjallCache<T> {
keyspace: SingleWriterTxKeyspace,
phantom: PhantomData<T>,
}
impl<T> FjallCache<T> {
pub fn new(keyspace: SingleWriterTxKeyspace) -> Self {
Self {
keyspace,
phantom: Default::default(),
}
}
}
#[async_trait]
impl<T: Clone + Serialize + for<'a> Deserialize<'a> + Send + Sync> Cache<T> for FjallCache<T> {
async fn get_with<'a>(
&self,
key: String,
future: Box<dyn Future<Output = T> + Send + 'a>,
) -> Result<T, CacheError> {
let placeholder = bitcode::serialize(&None::<T>)?;
let previous = self.keyspace.fetch_update(key.clone(), |previous| {
Some(if let Some(value) = previous {
value.to_vec().into()
} else {
placeholder.clone().into()
})
})?;
if previous.is_none() {
let value = Box::into_pin(future).await;
self.keyspace
.insert(key.clone(), bitcode::serialize(&Some(&value))?)?;
return Ok(value);
}
loop {
if let Some(value) = self.keyspace.get(key.as_bytes())? {
if let Some(value) = bitcode::deserialize::<Option<T>>(&value)? {
return Ok(value);
}
} else {
return self.get_with(key, future).await;
}
sleep(DELAY).await;
}
}
async fn remove(&self, key: &str) -> Result<(), CacheError> {
self.keyspace.remove(key)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::sync::Arc;
use futures::future::join;
use tempfile::TempDir;
use tokio::sync::Mutex;
#[tokio::test]
async fn get_or_set() {
let directory = TempDir::new().unwrap();
let db = fjall::SingleWriterTxDatabase::builder(directory.path())
.open()
.unwrap();
let cache = FjallCache::new(
db.keyspace("foo", || fjall::KeyspaceCreateOptions::default())
.unwrap(),
);
assert_eq!(
cache
.get_with("key".into(), Box::new(async { 42 }))
.await
.unwrap(),
42,
);
assert_eq!(
cache
.get_with("key".into(), Box::new(async { 0 }))
.await
.unwrap(),
42,
);
}
#[tokio::test]
async fn remove_while_set() {
let directory = TempDir::new().unwrap();
let db = fjall::SingleWriterTxDatabase::builder(directory.path())
.open()
.unwrap();
let cache = Arc::new(FjallCache::new(
db.keyspace("foo", || fjall::KeyspaceCreateOptions::default())
.unwrap(),
));
assert_eq!(
cache
.clone()
.get_with(
"key".into(),
Box::new(async move {
cache.remove("key").await.unwrap();
42
}),
)
.await
.unwrap(),
42,
);
}
#[tokio::test]
async fn remove_while_get() {
let directory = TempDir::new().unwrap();
let db = fjall::SingleWriterTxDatabase::builder(directory.path())
.open()
.unwrap();
let cache = FjallCache::new(
db.keyspace("foo", || fjall::KeyspaceCreateOptions::default())
.unwrap(),
);
for _ in 0..10000 {
let mutex = Arc::new(Mutex::new(()));
let mutex1 = mutex.clone();
let lock = mutex1.lock().await;
let future = join(
{
let mutex = mutex.clone();
async {
cache
.get_with(
"key".into(),
Box::new(async move {
let _ = mutex.lock().await;
42
}),
)
.await
.unwrap();
cache.remove("key").await.unwrap()
}
},
async {
cache
.get_with(
"key".into(),
Box::new(async move {
let _ = mutex.lock().await;
42
}),
)
.await
.unwrap();
cache.remove("key").await.unwrap()
},
);
drop(lock);
future.await;
}
}
}