use crate::IOCached;
use directories::BaseDirs;
use serde::de::DeserializeOwned;
use serde::Serialize;
use sled::Db;
use std::marker::PhantomData;
use std::path::Path;
use std::{path::PathBuf, time::SystemTime};
use web_time::Duration;
pub struct DiskCacheBuilder<K, V> {
seconds: Option<u64>,
refresh: bool,
sync_to_disk_on_cache_change: bool,
disk_dir: Option<PathBuf>,
cache_name: String,
connection_config: Option<sled::Config>,
_phantom: PhantomData<(K, V)>,
}
use thiserror::Error;
#[derive(Error, Debug)]
pub enum DiskCacheBuildError {
#[error("Storage connection error")]
ConnectionError(#[from] sled::Error),
#[error("Connection string not specified or invalid in env var {env_key:?}: {error:?}")]
MissingDiskPath {
env_key: String,
error: std::env::VarError,
},
}
static DISK_FILE_PREFIX: &str = "cached_disk_cache";
const DISK_FILE_VERSION: u64 = 1;
impl<K, V> DiskCacheBuilder<K, V>
where
K: ToString,
V: Serialize + DeserializeOwned,
{
pub fn new<S: AsRef<str>>(cache_name: S) -> DiskCacheBuilder<K, V> {
Self {
seconds: None,
refresh: false,
sync_to_disk_on_cache_change: false,
disk_dir: None,
cache_name: cache_name.as_ref().to_string(),
connection_config: None,
_phantom: Default::default(),
}
}
pub fn set_lifespan(mut self, seconds: u64) -> Self {
self.seconds = Some(seconds);
self
}
pub fn set_refresh(mut self, refresh: bool) -> Self {
self.refresh = refresh;
self
}
pub fn set_disk_directory<P: AsRef<Path>>(mut self, dir: P) -> Self {
self.disk_dir = Some(dir.as_ref().into());
self
}
pub fn set_sync_to_disk_on_cache_change(mut self, sync_to_disk_on_cache_change: bool) -> Self {
self.sync_to_disk_on_cache_change = sync_to_disk_on_cache_change;
self
}
pub fn set_connection_config(mut self, config: sled::Config) -> Self {
self.connection_config = Some(config);
self
}
fn default_disk_dir() -> PathBuf {
BaseDirs::new()
.map(|base_dirs| {
let exe_name = std::env::current_exe()
.ok()
.and_then(|path| {
path.file_name()
.and_then(|os_str| os_str.to_str().map(|s| format!("{}_", s)))
})
.unwrap_or_default();
let dir_prefix = format!("{}{}", exe_name, DISK_FILE_PREFIX);
base_dirs.cache_dir().join(dir_prefix)
})
.unwrap_or_else(|| {
std::env::current_dir().expect("disk cache unable to determine current directory")
})
}
pub fn build(self) -> Result<DiskCache<K, V>, DiskCacheBuildError> {
let disk_dir = self.disk_dir.unwrap_or_else(|| Self::default_disk_dir());
let disk_path = disk_dir.join(format!("{}_v{}", self.cache_name, DISK_FILE_VERSION));
let connection = match self.connection_config {
Some(config) => config.path(disk_path.clone()).open()?,
None => sled::open(disk_path.clone())?,
};
Ok(DiskCache {
seconds: self.seconds,
refresh: self.refresh,
sync_to_disk_on_cache_change: self.sync_to_disk_on_cache_change,
version: DISK_FILE_VERSION,
disk_path,
connection,
_phantom: self._phantom,
})
}
}
pub struct DiskCache<K, V> {
pub(super) seconds: Option<u64>,
pub(super) refresh: bool,
sync_to_disk_on_cache_change: bool,
#[allow(unused)]
version: u64,
#[allow(unused)]
disk_path: PathBuf,
connection: Db,
_phantom: PhantomData<(K, V)>,
}
impl<K, V> DiskCache<K, V>
where
K: ToString,
V: Serialize + DeserializeOwned,
{
#[allow(clippy::new_ret_no_self)]
pub fn new(cache_name: &str) -> DiskCacheBuilder<K, V> {
DiskCacheBuilder::new(cache_name)
}
pub fn remove_expired_entries(&self) -> Result<(), DiskCacheError> {
let now = SystemTime::now();
for (key, value) in self.connection.iter().flatten() {
if let Ok(cached) = rmp_serde::from_slice::<CachedDiskValue<V>>(&value) {
if let Some(lifetime_seconds) = self.seconds {
if now
.duration_since(cached.created_at)
.unwrap_or(Duration::from_secs(0))
>= Duration::from_secs(lifetime_seconds)
{
self.connection.remove(key)?;
}
}
}
}
if self.sync_to_disk_on_cache_change {
self.connection.flush()?;
}
Ok(())
}
pub fn connection(&self) -> &Db {
&self.connection
}
pub fn connection_mut(&mut self) -> &mut Db {
&mut self.connection
}
}
#[derive(Error, Debug)]
pub enum DiskCacheError {
#[error("Storage error")]
StorageError(#[from] sled::Error),
#[error("Error deserializing cached value")]
CacheDeserializationError(#[from] rmp_serde::decode::Error),
#[error("Error serializing cached value")]
CacheSerializationError(#[from] rmp_serde::encode::Error),
}
#[derive(serde::Serialize, serde::Deserialize)]
struct CachedDiskValue<V> {
pub(crate) value: V,
pub(crate) created_at: SystemTime,
pub(crate) version: u64,
}
impl<V> CachedDiskValue<V> {
fn new(value: V) -> Self {
Self {
value,
created_at: SystemTime::now(),
version: 1,
}
}
fn refresh_created_at(&mut self) {
self.created_at = SystemTime::now();
}
}
impl<K, V> IOCached<K, V> for DiskCache<K, V>
where
K: ToString,
V: Serialize + DeserializeOwned,
{
type Error = DiskCacheError;
fn cache_get(&self, key: &K) -> Result<Option<V>, DiskCacheError> {
let key = key.to_string();
let seconds = self.seconds;
let refresh = self.refresh;
let mut cache_updated = false;
let update = |old: Option<&[u8]>| -> Option<Vec<u8>> {
let old = old?;
if seconds.is_none() {
return Some(old.to_vec());
}
let seconds = seconds.unwrap();
let mut cached = match rmp_serde::from_slice::<CachedDiskValue<V>>(old) {
Ok(cached) => cached,
Err(_) => {
return None;
}
};
if SystemTime::now()
.duration_since(cached.created_at)
.unwrap_or(Duration::from_secs(0))
< Duration::from_secs(seconds)
{
if refresh {
cached.refresh_created_at();
cache_updated = true;
}
let cache_val =
rmp_serde::to_vec(&cached).expect("error serializing cached disk value");
Some(cache_val)
} else {
None
}
};
let result = if let Some(data) = self.connection.update_and_fetch(key, update)? {
let cached = rmp_serde::from_slice::<CachedDiskValue<V>>(&data)?;
Ok(Some(cached.value))
} else {
Ok(None)
};
if cache_updated && self.sync_to_disk_on_cache_change {
self.connection.flush()?;
}
result
}
fn cache_set(&self, key: K, value: V) -> Result<Option<V>, DiskCacheError> {
let key = key.to_string();
let value = rmp_serde::to_vec(&CachedDiskValue::new(value))?;
let result = if let Some(data) = self.connection.insert(key, value)? {
let cached = rmp_serde::from_slice::<CachedDiskValue<V>>(&data)?;
if let Some(lifetime_seconds) = self.seconds {
if SystemTime::now()
.duration_since(cached.created_at)
.unwrap_or(Duration::from_secs(0))
< Duration::from_secs(lifetime_seconds)
{
Ok(Some(cached.value))
} else {
Ok(None)
}
} else {
Ok(Some(cached.value))
}
} else {
Ok(None)
};
if self.sync_to_disk_on_cache_change {
self.connection.flush()?;
}
result
}
fn cache_remove(&self, key: &K) -> Result<Option<V>, DiskCacheError> {
let key = key.to_string();
let result = if let Some(data) = self.connection.remove(key)? {
let cached = rmp_serde::from_slice::<CachedDiskValue<V>>(&data)?;
if let Some(lifetime_seconds) = self.seconds {
if SystemTime::now()
.duration_since(cached.created_at)
.unwrap_or(Duration::from_secs(0))
< Duration::from_secs(lifetime_seconds)
{
Ok(Some(cached.value))
} else {
Ok(None)
}
} else {
Ok(Some(cached.value))
}
} else {
Ok(None)
};
if self.sync_to_disk_on_cache_change {
self.connection.flush()?;
}
result
}
fn cache_lifespan(&self) -> Option<u64> {
self.seconds
}
fn cache_set_lifespan(&mut self, seconds: u64) -> Option<u64> {
let old = self.seconds;
self.seconds = Some(seconds);
old
}
fn cache_set_refresh(&mut self, refresh: bool) -> bool {
let old = self.refresh;
self.refresh = refresh;
old
}
fn cache_unset_lifespan(&mut self) -> Option<u64> {
self.seconds.take()
}
}
#[cfg(test)]
#[allow(non_snake_case)]
mod test_DiskCache {
use googletest::{
assert_that,
matchers::{anything, eq, none, ok, some},
GoogleTestSupport as _,
};
use std::thread::sleep;
use std::time::Duration;
use tempfile::TempDir;
use super::*;
macro_rules! temp_dir {
() => {
TempDir::new().expect("Error creating temp dir")
};
(no_exist) => {{
let tmp_dir = TempDir::new().expect("Error creating temp dir");
std::fs::remove_dir_all(tmp_dir.path()).expect("error emptying the tmp dir");
tmp_dir
}};
}
fn now_millis() -> u128 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis()
}
const TEST_KEY: u32 = 1;
const TEST_VAL: u32 = 100;
const TEST_KEY_1: u32 = 2;
const TEST_VAL_1: u32 = 200;
const LIFE_SPAN_2_SECS: u64 = 2;
const LIFE_SPAN_1_SEC: u64 = 1;
#[googletest::test]
fn cache_get_after_cache_remove_returns_none() {
let tmp_dir = temp_dir!();
let cache: DiskCache<u32, u32> = DiskCache::new("test-cache")
.set_disk_directory(tmp_dir.path())
.build()
.unwrap();
let cached = cache.cache_get(&TEST_KEY).unwrap();
assert_that!(
cached,
none(),
"Getting a non-existent key-value should return None"
);
let cached = cache.cache_set(TEST_KEY, TEST_VAL).unwrap();
assert_that!(cached, none(), "Setting a new key-value should return None");
let cached = cache.cache_set(TEST_KEY, TEST_VAL_1).unwrap();
assert_that!(
cached,
some(eq(TEST_VAL)),
"Setting an existing key-value should return the old value"
);
let cached = cache.cache_get(&TEST_KEY).unwrap();
assert_that!(
cached,
some(eq(TEST_VAL_1)),
"Getting an existing key-value should return the value"
);
let cached = cache.cache_remove(&TEST_KEY).unwrap();
assert_that!(
cached,
some(eq(TEST_VAL_1)),
"Removing an existing key-value should return the value"
);
let cached = cache.cache_get(&TEST_KEY).unwrap();
assert_that!(cached, none(), "Getting a removed key should return None");
drop(cache);
}
#[googletest::test]
fn values_expire_when_lifespan_elapses_returning_none() {
let tmp_dir = temp_dir!();
let cache: DiskCache<u32, u32> = DiskCache::new("test-cache")
.set_disk_directory(tmp_dir.path())
.set_lifespan(LIFE_SPAN_2_SECS)
.build()
.unwrap();
assert_that!(
cache.cache_get(&TEST_KEY),
ok(none()),
"Getting a non-existent key-value should return None"
);
assert_that!(
cache.cache_set(TEST_KEY, 100),
ok(none()),
"Setting a new key-value should return None"
);
assert_that!(
cache.cache_get(&TEST_KEY),
ok(some(anything())),
"Getting an existing key-value before it expires should return the value"
);
sleep(Duration::from_secs(LIFE_SPAN_2_SECS));
sleep(Duration::from_micros(500)); assert_that!(
cache.cache_get(&TEST_KEY),
ok(none()),
"Getting an expired key-value should return None"
);
}
#[googletest::test]
fn set_lifespan_to_a_different_lifespan_is_respected() {
let tmp_dir = temp_dir!();
let mut cache: DiskCache<u32, u32> = DiskCache::new("test-cache")
.set_disk_directory(tmp_dir.path())
.set_lifespan(LIFE_SPAN_2_SECS)
.build()
.unwrap();
assert_that!(
cache.cache_get(&TEST_KEY),
ok(none()),
"Getting a non-existent key-value should return None"
);
assert_that!(
cache.cache_set(TEST_KEY, TEST_VAL),
ok(none()),
"Setting a new key-value should return None"
);
sleep(Duration::from_secs(LIFE_SPAN_2_SECS));
sleep(Duration::from_micros(500)); assert_that!(
cache.cache_get(&TEST_KEY),
ok(none()),
"Getting an expired key-value should return None"
);
let old_from_setting_lifespan = cache
.cache_set_lifespan(LIFE_SPAN_1_SEC)
.expect("error setting new lifespan");
assert_that!(
old_from_setting_lifespan,
eq(LIFE_SPAN_2_SECS),
"Setting lifespan should return the old lifespan"
);
assert_that!(
cache.cache_set(TEST_KEY, TEST_VAL),
ok(none()),
"Setting a previously expired key-value should return None"
);
assert_that!(
cache.cache_get(&TEST_KEY),
ok(some(eq(TEST_VAL))),
"Getting a newly set (previously expired) key-value should return the value"
);
sleep(Duration::from_secs(LIFE_SPAN_1_SEC));
sleep(Duration::from_micros(500)); assert_that!(
cache.cache_get(&TEST_KEY),
ok(none()),
"Getting an expired key-value should return None"
);
cache
.cache_set_lifespan(10)
.expect("error setting lifespan");
assert_that!(
cache.cache_set(TEST_KEY, TEST_VAL),
ok(none()),
"Setting a previously expired key-value should return None"
);
assert_that!(
cache.cache_set(TEST_KEY_1, TEST_VAL),
ok(none()),
"Setting a new, separate, key-value should return None"
);
assert_that!(
cache.cache_get(&TEST_KEY),
ok(some(eq(TEST_VAL))),
"Getting a newly set (previously expired) key-value should return the value"
);
assert_that!(
cache.cache_get(&TEST_KEY),
ok(some(eq(TEST_VAL))),
"Getting the same value again should return the value"
);
}
#[googletest::test]
fn refreshing_on_cache_get_delays_cache_expiry() {
const LIFE_SPAN: u64 = 2;
const HALF_LIFE_SPAN: u64 = 1;
let tmp_dir = temp_dir!();
let cache: DiskCache<u32, u32> = DiskCache::new("test-cache")
.set_disk_directory(tmp_dir.path())
.set_lifespan(LIFE_SPAN)
.set_refresh(true) .build()
.unwrap();
assert_that!(cache.cache_set(TEST_KEY, TEST_VAL), ok(none()));
sleep(Duration::from_secs(HALF_LIFE_SPAN));
assert_that!(
cache.cache_get(&TEST_KEY),
ok(some(eq(TEST_VAL))),
"Getting a value before expiry should return the value"
);
sleep(Duration::from_secs(HALF_LIFE_SPAN));
assert_that!(
cache.cache_get(&TEST_KEY),
ok(some(eq(TEST_VAL))),
"Getting a value after the initial expiry should return the value as we have refreshed"
);
sleep(Duration::from_secs(LIFE_SPAN));
assert_that!(
cache.cache_get(&TEST_KEY),
ok(none()),
"Getting a value after the refreshed expiry should return None"
);
drop(cache);
}
#[googletest::test]
fn does_not_break_when_constructed_using_default_disk_directory() {
let cache: DiskCache<u32, u32> =
DiskCache::new(&format!("{}:disk-cache-test-default-dir", now_millis()))
.build()
.unwrap();
let cached = cache.cache_get(&TEST_KEY).unwrap();
assert_that!(
cached,
none(),
"Getting a non-existent key-value should return None"
);
let cached = cache.cache_set(TEST_KEY, TEST_VAL).unwrap();
assert_that!(cached, none(), "Setting a new key-value should return None");
let cached = cache.cache_set(TEST_KEY, TEST_VAL_1).unwrap();
assert_that!(
cached,
some(eq(TEST_VAL)),
"Setting an existing key-value should return the old value"
);
std::fs::remove_dir_all(cache.disk_path).expect("error in clean up removeing the cache dir")
}
mod set_sync_to_disk_on_cache_change {
mod when_no_auto_flushing {
use super::super::*;
fn check_on_recovered_cache(
set_sync_to_disk_on_cache_change: bool,
run_on_original_cache: fn(&DiskCache<u32, u32>) -> (),
run_on_recovered_cache: fn(&DiskCache<u32, u32>) -> (),
) {
let original_cache_tmp_dir = temp_dir!();
let copied_cache_tmp_dir = temp_dir!(no_exist);
const CACHE_NAME: &str = "test-cache";
let cache: DiskCache<u32, u32> = DiskCache::new(CACHE_NAME)
.set_disk_directory(original_cache_tmp_dir.path())
.set_sync_to_disk_on_cache_change(set_sync_to_disk_on_cache_change) .set_connection_config(sled::Config::new().flush_every_ms(None))
.build()
.unwrap();
cache
.connection
.flush()
.expect("error flushing cache before any cache setting");
run_on_original_cache(&cache);
let recovered_cache = clone_cache_to_new_location_no_flushing(
CACHE_NAME,
&cache,
copied_cache_tmp_dir.path(),
);
assert_that!(recovered_cache.connection.was_recovered(), eq(true));
run_on_recovered_cache(&recovered_cache);
}
mod changes_persist_after_recovery_when_set_to_true {
use super::*;
#[googletest::test]
fn for_cache_set() {
check_on_recovered_cache(
false,
|cache| {
cache
.cache_set(TEST_KEY, TEST_VAL)
.expect("error setting cache in assemble stage");
},
|recovered_cache| {
assert_that!(
recovered_cache.cache_get(&TEST_KEY),
ok(none()),
"set_sync_to_disk_on_cache_change is false, and there is no auto-flushing, so the cache should not have persisted"
);
},
)
}
#[googletest::test]
fn for_cache_remove() {
check_on_recovered_cache(
false,
|cache| {
cache
.cache_set(TEST_KEY, TEST_VAL)
.expect("error setting cache in assemble stage");
cache.connection.flush().expect("error flushing cache");
cache
.cache_remove(&TEST_KEY)
.expect("error removing cache in assemble stage");
},
|recovered_cache| {
assert_that!(
recovered_cache.cache_get(&TEST_KEY),
ok(some(eq(TEST_VAL))),
"set_sync_to_disk_on_cache_change is false, and there is no auto-flushing, so the cache_remove should not have persisted"
);
},
)
}
#[ignore = "Not implemented"]
#[googletest::test]
fn for_cache_get_when_refreshing() {
todo!("Test not implemented.")
}
}
mod changes_do_not_persist_after_recovery_when_set_to_false {
use super::*;
#[googletest::test]
fn for_cache_set() {
check_on_recovered_cache(
true,
|cache| {
cache
.cache_set(TEST_KEY, TEST_VAL)
.expect("error setting cache in assemble stage");
},
|recovered_cache| {
assert_that!(
recovered_cache.cache_get(&TEST_KEY),
ok(some(eq(TEST_VAL))),
"Getting a set key should return the value"
);
},
)
}
#[googletest::test]
fn for_cache_remove() {
check_on_recovered_cache(
true,
|cache| {
cache
.cache_set(TEST_KEY, TEST_VAL)
.expect("error setting cache in assemble stage");
cache
.cache_remove(&TEST_KEY)
.expect("error removing cache in assemble stage");
},
|recovered_cache| {
assert_that!(
recovered_cache.cache_get(&TEST_KEY),
ok(none()),
"Getting a removed key should return None"
);
},
)
}
#[ignore = "Not implemented"]
#[googletest::test]
fn for_cache_get_when_refreshing() {
todo!("Test not implemented.")
}
}
fn clone_cache_to_new_location_no_flushing(
cache_name: &str,
cache: &DiskCache<u32, u32>,
new_location: &Path,
) -> DiskCache<u32, u32> {
copy_dir::copy_dir(cache.disk_path.parent().unwrap(), new_location)
.expect("error copying cache files to new location");
DiskCache::new(cache_name)
.set_disk_directory(new_location)
.build()
.expect("error building cache from copied files")
}
}
}
}