#![allow(clippy::expect_used)]
use crate::IOKash;
use directories::BaseDirs;
use instant::Duration;
use serde::Serialize;
use serde::de::DeserializeOwned;
use sled::Db;
use std::marker::PhantomData;
use std::path::Path;
use std::{path::PathBuf, time::SystemTime};
use thiserror::Error;
pub struct DiskCacheBuilder<K, V> {
seconds: Option<u64>,
sync_to_disk_on_cache_change: bool,
dir: Option<PathBuf>,
cache_name: String,
connection_config: Option<sled::Config>,
_phantom: PhantomData<(K, V)>,
}
#[derive(Error, Debug)]
pub enum DiskCacheBuildError {
#[error("Storage connection error")]
ConnectionError(#[from] sled::Error),
#[error("Connection string not specified or invalid in env var {env_key:?}: {error:?}")]
MissingPath {
env_key: String,
error: std::env::VarError,
},
}
static DISK_FILE_PREFIX: &str = "kash_disk_cache";
const DISK_FILE_VERSION: u64 = 1;
impl<K, V> DiskCacheBuilder<K, V>
where
K: ToString,
V: Serialize + DeserializeOwned,
{
#[must_use]
pub fn new(cache_name: &str) -> Self {
Self {
seconds: None,
sync_to_disk_on_cache_change: false,
dir: None,
cache_name: cache_name.to_string(),
connection_config: None,
_phantom: Default::default(),
}
}
#[must_use]
pub fn set_ttl(mut self, seconds: u64) -> Self {
self.seconds = Some(seconds);
self
}
#[must_use]
pub fn set_disk_directory<P: AsRef<Path>>(mut self, dir: P) -> Self {
self.dir = Some(dir.as_ref().into());
self
}
#[must_use]
pub fn set_sync_to_disk_on_cache_change(mut self, sync_to_disk_on_cache_change: bool) -> Self {
self.sync_to_disk_on_cache_change = sync_to_disk_on_cache_change;
self
}
#[must_use]
pub fn set_connection_config(mut self, config: sled::Config) -> Self {
self.connection_config = Some(config);
self
}
fn default_disk_dir() -> PathBuf {
BaseDirs::new().map_or_else(
|| std::env::current_dir().expect("disk cache unable to determine current directory"),
|base_dirs| {
let exe_name = std::env::current_exe()
.ok()
.and_then(|path| {
path.file_name()
.and_then(|os_str| os_str.to_str().map(|s| format!("{s}_")))
})
.unwrap_or_default();
let dir_prefix = format!("{exe_name}{DISK_FILE_PREFIX}");
base_dirs.cache_dir().join(dir_prefix)
},
)
}
pub fn build(self) -> Result<DiskCache<K, V>, DiskCacheBuildError> {
let dir = self.dir.unwrap_or_else(|| Self::default_disk_dir());
let path = dir.join(format!("{}_v{}", self.cache_name, DISK_FILE_VERSION));
let connection = match self.connection_config {
Some(config) => config.path(path.clone()).open()?,
None => sled::open(path.clone())?,
};
Ok(DiskCache {
seconds: self.seconds,
sync_to_disk_on_cache_change: self.sync_to_disk_on_cache_change,
version: DISK_FILE_VERSION,
path,
connection,
_phantom: self._phantom,
})
}
}
pub struct DiskCache<K, V> {
pub(super) seconds: Option<u64>,
sync_to_disk_on_cache_change: bool,
#[allow(unused)]
version: u64,
#[allow(unused)]
path: PathBuf,
connection: Db,
_phantom: PhantomData<(K, V)>,
}
impl<K, V> DiskCache<K, V>
where
K: ToString,
V: Serialize + DeserializeOwned,
{
#[allow(clippy::new_ret_no_self)]
#[must_use]
pub fn new(cache_name: &str) -> DiskCacheBuilder<K, V> {
DiskCacheBuilder::new(cache_name)
}
pub fn remove_expired_entries(&self) -> Result<(), DiskCacheError> {
let now = SystemTime::now();
for (key, value) in self.connection.iter().flatten() {
if let Ok(kash) = rmp_serde::from_slice::<KashDiskValue<V>>(&value) {
if let Some(lifetime_seconds) = self.seconds {
if now.duration_since(kash.created_at).unwrap_or_default()
>= Duration::from_secs(lifetime_seconds)
{
self.connection.remove(key)?;
}
}
}
}
if self.sync_to_disk_on_cache_change {
self.connection.flush()?;
}
Ok(())
}
#[must_use]
pub fn connection(&self) -> &Db {
&self.connection
}
pub fn connection_mut(&mut self) -> &mut Db {
&mut self.connection
}
fn check_expiration(&self, kash: KashDiskValue<V>) -> Option<V> {
if let Some(ttl) = self.seconds {
(SystemTime::now()
.duration_since(kash.created_at)
.unwrap_or_default()
< Duration::from_secs(ttl))
.then_some(kash.value)
} else {
Some(kash.value)
}
}
}
#[derive(Error, Debug)]
pub enum DiskCacheError {
#[error("Storage error")]
StorageError(#[from] sled::Error),
#[error("Error deserializing cached value")]
CacheDeserializationError(#[from] rmp_serde::decode::Error),
#[error("Error serializing cached value")]
CacheSerializationError(#[from] rmp_serde::encode::Error),
}
#[derive(serde::Serialize, serde::Deserialize)]
struct KashDiskValue<V> {
pub(crate) value: V,
pub(crate) created_at: SystemTime,
}
impl<V> KashDiskValue<V> {
fn new(value: V) -> Self {
Self {
value,
created_at: SystemTime::now(),
}
}
}
impl<K, V> IOKash<K, V> for DiskCache<K, V>
where
K: ToString,
V: Serialize + DeserializeOwned,
{
type Error = DiskCacheError;
fn get(&self, k: &K) -> Result<Option<V>, DiskCacheError> {
let key = k.to_string();
let seconds = self.seconds;
let update = |old: Option<&[u8]>| -> Option<Vec<u8>> {
let old = old?;
let Some(seconds) = seconds else {
return Some(old.to_vec());
};
let Ok(kash) = rmp_serde::from_slice::<KashDiskValue<V>>(old) else {
return None;
};
if SystemTime::now()
.duration_since(kash.created_at)
.unwrap_or_default()
< Duration::from_secs(seconds)
{
rmp_serde::to_vec(&kash).ok()
} else {
None
}
};
if let Some(data) = self.connection.update_and_fetch(key, update)? {
let kash = rmp_serde::from_slice::<KashDiskValue<V>>(&data)?;
Ok(Some(kash.value))
} else {
Ok(None)
}
}
fn set(&self, k: K, v: V) -> Result<Option<V>, DiskCacheError> {
let key = k.to_string();
let value = rmp_serde::to_vec(&KashDiskValue::new(v))?;
let result = if let Some(data) = self.connection.insert(key, value)? {
let kash = rmp_serde::from_slice::<KashDiskValue<V>>(&data)?;
self.check_expiration(kash)
} else {
None
};
if self.sync_to_disk_on_cache_change {
self.connection.flush()?;
}
Ok(result)
}
fn remove(&self, k: &K) -> Result<Option<V>, DiskCacheError> {
let key = k.to_string();
let result = if let Some(data) = self.connection.remove(key)? {
let kash = rmp_serde::from_slice::<KashDiskValue<V>>(&data)?;
self.check_expiration(kash)
} else {
None
};
if self.sync_to_disk_on_cache_change {
self.connection.flush()?;
}
Ok(result)
}
fn clear(&self) -> Result<(), DiskCacheError> {
self.connection.clear()?;
if self.sync_to_disk_on_cache_change {
self.connection.flush()?;
}
Ok(())
}
fn ttl(&self) -> Option<u64> {
self.seconds
}
fn set_ttl(&mut self, seconds: u64) -> Option<u64> {
let old = self.seconds;
self.seconds = Some(seconds);
old
}
fn unset_ttl(&mut self) -> Option<u64> {
self.seconds.take()
}
}
#[allow(clippy::unwrap_used, non_snake_case)]
#[cfg(test)]
mod test_DiskCache {
use googletest::{
assert_that,
matchers::{anything, eq, none, ok, some},
};
use std::thread::sleep;
use std::time::Duration;
use tempfile::TempDir;
use super::*;
macro_rules! temp_dir {
() => {
TempDir::new().expect("Error creating temp dir")
};
(no_exist) => {{
let tmp_dir = TempDir::new().expect("Error creating temp dir");
std::fs::remove_dir_all(tmp_dir.path()).expect("error emptying the tmp dir");
tmp_dir
}};
}
fn now_millis() -> u128 {
SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis()
}
const TEST_KEY: u32 = 1;
const TEST_VAL: u32 = 100;
const TEST_KEY_1: u32 = 2;
const TEST_VAL_1: u32 = 200;
const LIFE_SPAN_2_SECS: u64 = 2;
const LIFE_SPAN_1_SEC: u64 = 1;
#[googletest::test]
fn cache_get_after_cache_remove_returns_none() {
let tmp_dir = temp_dir!();
let cache: DiskCache<u32, u32> = DiskCache::new("test-cache")
.set_disk_directory(tmp_dir.path())
.build()
.unwrap();
let kash = cache.get(&TEST_KEY).unwrap();
assert_that!(
kash,
none(),
"Getting a non-existent key-value should return None"
);
let kash = cache.set(TEST_KEY, TEST_VAL).unwrap();
assert_that!(kash, none(), "Setting a new key-value should return None");
let kash = cache.set(TEST_KEY, TEST_VAL_1).unwrap();
assert_that!(
kash,
some(eq(TEST_VAL)),
"Setting an existing key-value should return the old value"
);
let kash = cache.get(&TEST_KEY).unwrap();
assert_that!(
kash,
some(eq(TEST_VAL_1)),
"Getting an existing key-value should return the value"
);
let kash = cache.remove(&TEST_KEY).unwrap();
assert_that!(
kash,
some(eq(TEST_VAL_1)),
"Removing an existing key-value should return the value"
);
let kash = cache.get(&TEST_KEY).unwrap();
assert_that!(kash, none(), "Getting a removed key should return None");
drop(cache);
}
#[googletest::test]
fn cache_get_after_cache_clear_returns_none() {
let tmp_dir = temp_dir!();
let cache: DiskCache<u32, u32> = DiskCache::new("test-cache")
.set_disk_directory(tmp_dir.path())
.build()
.unwrap();
let kash = cache.get(&TEST_KEY).unwrap();
assert_that!(
kash,
none(),
"Getting a non-existent key-value should return None"
);
let kash = cache.set(TEST_KEY, TEST_VAL).unwrap();
assert_that!(kash, none(), "Setting a new key-value should return None");
let kash = cache.set(TEST_KEY, TEST_VAL_1).unwrap();
assert_that!(
kash,
some(eq(TEST_VAL)),
"Setting an existing key-value should return the old value"
);
let kash = cache.get(&TEST_KEY).unwrap();
assert_that!(
kash,
some(eq(TEST_VAL_1)),
"Getting an existing key-value should return the value"
);
cache.clear().unwrap();
let kash = cache.get(&TEST_KEY).unwrap();
assert_that!(kash, none(), "Getting a cleared key should return None");
drop(cache);
}
#[googletest::test]
fn values_expire_when_lifespan_elapses_returning_none() {
let tmp_dir = temp_dir!();
let cache: DiskCache<u32, u32> = DiskCache::new("test-cache")
.set_disk_directory(tmp_dir.path())
.set_ttl(LIFE_SPAN_2_SECS)
.build()
.unwrap();
assert_that!(
cache.get(&TEST_KEY),
ok(none()),
"Getting a non-existent key-value should return None"
);
assert_that!(
cache.set(TEST_KEY, 100),
ok(none()),
"Setting a new key-value should return None"
);
assert_that!(
cache.get(&TEST_KEY),
ok(some(anything())),
"Getting an existing key-value before it expires should return the value"
);
sleep(Duration::from_secs(LIFE_SPAN_2_SECS));
sleep(Duration::from_micros(500)); assert_that!(
cache.get(&TEST_KEY),
ok(none()),
"Getting an expired key-value should return None"
);
}
#[googletest::test]
fn set_lifespan_to_a_different_lifespan_is_respected() {
let tmp_dir = temp_dir!();
let mut cache: DiskCache<u32, u32> = DiskCache::new("test-cache")
.set_disk_directory(tmp_dir.path())
.set_ttl(LIFE_SPAN_2_SECS)
.build()
.unwrap();
assert_that!(
cache.get(&TEST_KEY),
ok(none()),
"Getting a non-existent key-value should return None"
);
assert_that!(
cache.set(TEST_KEY, TEST_VAL),
ok(none()),
"Setting a new key-value should return None"
);
sleep(Duration::from_secs(LIFE_SPAN_2_SECS));
sleep(Duration::from_micros(500)); assert_that!(
cache.get(&TEST_KEY),
ok(none()),
"Getting an expired key-value should return None"
);
let old_from_setting_lifespan = cache
.set_ttl(LIFE_SPAN_1_SEC)
.expect("error setting new ttl");
assert_that!(
old_from_setting_lifespan,
eq(LIFE_SPAN_2_SECS),
"Setting ttl should return the old ttl"
);
assert_that!(
cache.set(TEST_KEY, TEST_VAL),
ok(none()),
"Setting a previously expired key-value should return None"
);
assert_that!(
cache.get(&TEST_KEY),
ok(some(eq(&TEST_VAL))),
"Getting a new set (previously expired) key-value should return the value"
);
sleep(Duration::from_secs(LIFE_SPAN_1_SEC));
sleep(Duration::from_micros(500)); assert_that!(
cache.get(&TEST_KEY),
ok(none()),
"Getting an expired key-value should return None"
);
cache.set_ttl(10).expect("error setting ttl");
assert_that!(
cache.set(TEST_KEY, TEST_VAL),
ok(none()),
"Setting a previously expired key-value should return None"
);
assert_that!(
cache.set(TEST_KEY_1, TEST_VAL),
ok(none()),
"Setting a new, separate, key-value should return None"
);
assert_that!(
cache.get(&TEST_KEY),
ok(some(eq(&TEST_VAL))),
"Getting a new set (previously expired) key-value should return the value"
);
assert_that!(
cache.get(&TEST_KEY),
ok(some(eq(&TEST_VAL))),
"Getting the same value again should return the value"
);
}
#[googletest::test]
fn does_not_break_when_constructed_using_default_disk_directory() {
let cache: DiskCache<u32, u32> =
DiskCache::new(&format!("{}:disk-cache-test-default-dir", now_millis()))
.build()
.unwrap();
let kash = cache.get(&TEST_KEY).unwrap();
assert_that!(
kash,
none(),
"Getting a non-existent key-value should return None"
);
let kash = cache.set(TEST_KEY, TEST_VAL).unwrap();
assert_that!(kash, none(), "Setting a new key-value should return None");
let kash = cache.set(TEST_KEY, TEST_VAL_1).unwrap();
assert_that!(
kash,
some(eq(TEST_VAL)),
"Setting an existing key-value should return the old value"
);
std::fs::remove_dir_all(cache.path).expect("error in clean up removing the cache dir");
}
mod set_sync_to_disk_on_cache_change {
mod when_no_auto_flushing {
use super::super::*;
fn check_on_recovered_cache(
set_sync_to_disk_on_cache_change: bool,
run_on_original_cache: fn(&DiskCache<u32, u32>) -> (),
run_on_recovered_cache: fn(&DiskCache<u32, u32>) -> (),
) {
const CACHE_NAME: &str = "test-cache";
let original_cache_tmp_dir = temp_dir!();
let copied_cache_tmp_dir = temp_dir!(no_exist);
let cache: DiskCache<u32, u32> = DiskCache::new(CACHE_NAME)
.set_disk_directory(original_cache_tmp_dir.path())
.set_sync_to_disk_on_cache_change(set_sync_to_disk_on_cache_change) .set_connection_config(sled::Config::new().flush_every_ms(None))
.build()
.unwrap();
cache
.connection
.flush()
.expect("error flushing cache before any cache setting");
run_on_original_cache(&cache);
let recovered_cache = clone_cache_to_new_location_no_flushing(
CACHE_NAME,
&cache,
copied_cache_tmp_dir.path(),
);
assert_that!(recovered_cache.connection.was_recovered(), eq(true));
run_on_recovered_cache(&recovered_cache);
}
mod changes_persist_after_recovery_when_set_to_true {
use super::*;
#[googletest::test]
fn for_cache_set() {
check_on_recovered_cache(
false,
|cache| {
cache
.set(TEST_KEY, TEST_VAL)
.expect("error setting cache in assemble stage");
},
|recovered_cache| {
assert_that!(
recovered_cache.get(&TEST_KEY),
ok(none()),
"set_sync_to_disk_on_cache_change is false, and there is no auto-flushing, so the cache should not have persisted"
);
},
);
}
#[googletest::test]
fn for_cache_remove() {
check_on_recovered_cache(
false,
|cache| {
cache
.set(TEST_KEY, TEST_VAL)
.expect("error setting cache in assemble stage");
cache.connection.flush().expect("error flushing cache");
cache
.remove(&TEST_KEY)
.expect("error removing cache in assemble stage");
},
|recovered_cache| {
assert_that!(
recovered_cache.get(&TEST_KEY),
ok(some(eq(&TEST_VAL))),
"set_sync_to_disk_on_cache_change is false, and there is no auto-flushing, so the cache_remove should not have persisted"
);
},
);
}
#[googletest::test]
fn for_cache_clear() {
check_on_recovered_cache(
false,
|cache| {
cache
.set(TEST_KEY, TEST_VAL)
.expect("error setting cache in assemble stage");
cache.connection.flush().expect("error flushing cache");
cache
.clear()
.expect("error clearing cache in assemble stage");
},
|recovered_cache| {
assert_that!(
recovered_cache.get(&TEST_KEY),
ok(some(eq(&TEST_VAL))),
"set_sync_to_disk_on_cache_change is false, and there is no auto-flushing, so the cache_clear should not have persisted"
);
},
);
}
}
mod changes_do_not_persist_after_recovery_when_set_to_false {
use super::*;
#[googletest::test]
fn for_cache_set() {
check_on_recovered_cache(
true,
|cache| {
cache
.set(TEST_KEY, TEST_VAL)
.expect("error setting cache in assemble stage");
},
|recovered_cache| {
assert_that!(
recovered_cache.get(&TEST_KEY),
ok(some(eq(&TEST_VAL))),
"Getting a set key should return the value"
);
},
);
}
#[googletest::test]
fn for_cache_remove() {
check_on_recovered_cache(
true,
|cache| {
cache
.set(TEST_KEY, TEST_VAL)
.expect("error setting cache in assemble stage");
cache
.remove(&TEST_KEY)
.expect("error removing cache in assemble stage");
},
|recovered_cache| {
assert_that!(
recovered_cache.get(&TEST_KEY),
ok(none()),
"Getting a removed key should return None"
);
},
);
}
#[googletest::test]
fn for_cache_clear() {
check_on_recovered_cache(
true,
|cache| {
cache
.set(TEST_KEY, TEST_VAL)
.expect("error setting cache in assemble stage");
cache
.clear()
.expect("error clearing cache in assemble stage");
},
|recovered_cache| {
assert_that!(
recovered_cache.get(&TEST_KEY),
ok(none()),
"Getting a cleared key should return None"
);
},
);
}
}
fn clone_cache_to_new_location_no_flushing(
cache_name: &str,
cache: &DiskCache<u32, u32>,
new_location: &Path,
) -> DiskCache<u32, u32> {
copy_dir::copy_dir(cache.path.parent().unwrap(), new_location)
.expect("error copying cache files to new location");
DiskCache::new(cache_name)
.set_disk_directory(new_location)
.build()
.expect("error building cache from copied files")
}
}
}
}