use core::time::Duration;
use core::{fmt::Display, hash::Hash};
use std::{
io::Read,
path::{Path, PathBuf},
};
use alloc::vec::Vec;
use hashbrown::HashMap;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::cache_file::CacheFile;
#[derive(Debug)]
pub struct Cache<K, V> {
in_memory_cache: HashMap<K, V>,
file: CacheFile,
separator: Vec<u8>,
}
#[derive(Default)]
pub struct CacheOption {
separator: Option<Vec<u8>>,
version: Option<String>,
name: Option<String>,
root: Option<PathBuf>,
lock_max_duration: Option<Duration>,
}
#[derive(Debug)]
pub enum CacheError<K: Serialize, V: Serialize> {
#[allow(missing_docs)]
DuplicatedKey {
key: K,
value_previous: V,
value_updated: V,
},
#[allow(missing_docs)]
KeyOutOfSync {
key: K,
value_previous: V,
value_updated: V,
},
}
impl CacheOption {
pub fn separator<S: Into<Vec<u8>>>(mut self, separator: S) -> Self {
self.separator = Some(separator.into());
self
}
pub fn version<V: Into<String>>(mut self, version: V) -> Self {
self.version = Some(version.into());
self
}
pub fn name<R: Into<String>>(mut self, name: R) -> Self {
self.name = Some(name.into());
self
}
pub fn root<R: Into<PathBuf>>(mut self, path: R) -> Self {
self.root = Some(path.into());
self
}
fn resolve(self) -> (Vec<u8>, String, String, PathBuf, Duration) {
let separator = self.separator.unwrap_or_else(|| b"\n".to_vec());
let version = self
.version
.unwrap_or_else(|| std::env!("CARGO_PKG_VERSION").to_string());
let name = self.name.unwrap_or_else(|| "cubecl".to_string());
let duration = self
.lock_max_duration
.unwrap_or_else(|| Duration::from_secs(30));
let root = match self.root {
Some(root) => root,
None => dirs::home_dir()
.expect("An home directory should exist")
.join(".cache"),
};
(separator, name, version, root, duration)
}
}
pub trait CacheKey: Serialize + DeserializeOwned + PartialEq + Eq + Hash + Clone {}
pub trait CacheValue: Serialize + DeserializeOwned + PartialEq + Eq + Clone {}
impl<T: Serialize + DeserializeOwned + PartialEq + Eq + Clone + Hash> CacheKey for T {}
impl<T: Serialize + DeserializeOwned + PartialEq + Eq + Clone> CacheValue for T {}
impl<K: CacheKey, V: CacheValue> Cache<K, V> {
pub fn new<P: AsRef<Path>>(path: P, option: CacheOption) -> Self {
let (separator, name, version, root, lock_max_duration) = option.resolve();
let path = get_persistent_cache_file_path(path, root, name, version);
let mut this = Self {
in_memory_cache: HashMap::new(),
file: CacheFile::new(&path, lock_max_duration),
separator,
};
if let Some(mut reader) = this.file.lock() {
let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.expect("Can read the cache content");
this.sync_content(&buffer, None).ok();
}
this.file.unlock();
this
}
pub fn for_each<F: FnMut(&K, &V)>(&mut self, mut func: F) {
if let Some(mut reader) = self.file.lock() {
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).unwrap();
self.sync_content(&buffer, None).ok();
}
for (key, value) in self.in_memory_cache.iter() {
func(key, value)
}
self.file.unlock();
}
pub fn get(&self, key: &K) -> Option<&V> {
self.in_memory_cache.get(key)
}
pub fn len(&self) -> usize {
self.in_memory_cache.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn insert(&mut self, key: K, value: V) -> Result<(), CacheError<K, V>> {
if let Some(mut reader) = self.file.lock() {
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).unwrap();
if let Err(err) = self.sync_content(&buffer, Some((&key, &value))) {
self.file.unlock();
return Err(err);
}
}
if let Some(existing) = self.in_memory_cache.get(&key) {
if existing != &value {
self.file.unlock();
return Err(CacheError::DuplicatedKey {
key,
value_previous: existing.clone(),
value_updated: value,
});
} else {
self.file.unlock();
return Ok(());
}
}
self.insert_unchecked(key, value);
self.file.unlock();
Ok(())
}
fn sync_content(
&mut self,
bytes: &[u8],
new_insert: Option<(&K, &V)>,
) -> Result<(), CacheError<K, V>> {
let mut start = 0;
let mut result = Ok(());
while let Some(pos) = bytes[start..]
.windows(self.separator.len())
.position(|w| w == self.separator)
{
match serde_json::from_slice::<Entry<K, V>>(&bytes[start..start + pos]) {
Ok(entry) => {
if let Some(insert) = &new_insert
&& result.is_ok()
&& insert.0 == &entry.key
&& insert.1 != &entry.value
{
result = Err(CacheError::KeyOutOfSync {
key: entry.key.clone(),
value_previous: entry.value.clone(),
value_updated: insert.1.clone(),
})
}
self.in_memory_cache.insert(entry.key, entry.value);
}
Err(err) => {
log::warn!(
"Corrupted cache file {}, ignoring entry ({}..{}) : {err}",
self.file,
start,
start + pos,
);
}
};
start += pos + self.separator.len();
}
result
}
fn insert_unchecked(&mut self, key: K, value: V) {
let entry = Entry { key, value };
let mut bytes = serde_json::to_vec(&entry).expect("Can serialize data");
for b in self.separator.iter() {
bytes.push(*b);
}
self.file.write(&bytes);
self.in_memory_cache.insert(entry.key, entry.value);
}
}
impl<K: CacheKey, V: CacheValue> Display for Cache<K, V> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.file)?;
for (key, value) in self.in_memory_cache.iter() {
let key = serde_json::to_string_pretty(key).unwrap();
let value = serde_json::to_string_pretty(value).unwrap();
writeln!(f, " [{key}] => {value}")?;
}
Ok(())
}
}
fn get_persistent_cache_file_path<P: AsRef<Path>>(
path_partial: P,
root: PathBuf,
name: String,
version: String,
) -> PathBuf {
let path_partial: &Path = path_partial.as_ref();
let add_extension = !path_partial.ends_with("json.log");
let mut path = root
.join(sanitize_path_segment(&name))
.join(sanitize_path_segment(&version));
for segment in path_partial.iter() {
if segment == "/" {
continue;
}
path = path.join(sanitize_path_segment(segment.to_str().unwrap()));
}
if add_extension {
path.set_extension("json.log");
}
path
}
#[derive(Serialize, Deserialize)]
struct Entry<K, V> {
key: K,
value: V,
}
impl<K: Serialize, V: Serialize> core::fmt::Debug for Entry<K, V> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let formatted = serde_json::to_string_pretty(self).unwrap();
write!(f, "{formatted}")
}
}
pub(crate) fn sanitize_path_segment(segment: &str) -> String {
sanitize_filename::sanitize_with_options(
segment,
sanitize_filename::Options {
replacement: "_",
..Default::default()
},
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg_attr(miri, ignore)]
fn test_cache_simple() {
let key1 = || "key1".to_string();
let key2 = || "key2".to_string();
let value1 = || "value1".to_string();
let value2 = || "value2".to_string();
let mut cache = Cache::<String, String>::new("test", CacheOption::default());
cache.insert(key1(), value1()).unwrap();
cache.insert(key2(), value2()).unwrap();
let result = cache.insert(key1(), value2());
assert!(
result.is_err(),
"Can't reinsert the same key with a different value."
);
assert_eq!(cache.len(), 2);
let value1_actual = cache.get(&key1()).unwrap();
assert_eq!(value1_actual, &value1());
let value2_actual = cache.get(&key2()).unwrap();
assert_eq!(value2_actual, &value2());
}
}