use async_trait::async_trait;
use std::{
borrow::Cow,
path::{Path, PathBuf},
};
use crate::{
traits::{CacheKey, CacheStrategy, RecoverableStrategy},
CacheCapacity, DiskUtil, Result,
};
const LIMIT_KIND_BYTE: &str = "Stored bytes";
const LIMIT_KIND_ENTRY: &str = "Stored entries";
#[derive(Debug)]
pub struct Entry {
path: PathBuf,
byte_len: usize,
}
#[derive(Debug)]
pub struct Disk {
cache_dir: PathBuf,
byte_limit: Option<usize>,
entry_limit: Option<usize>,
current_byte_count: usize,
current_entry_count: usize,
}
impl Disk {
pub fn new<'a>(
cache_dir: impl Into<Cow<'a, Path>>,
byte_limit: Option<usize>,
entry_limit: Option<usize>,
) -> Self {
Self {
cache_dir: cache_dir.into().into_owned(),
byte_limit,
entry_limit,
..Default::default()
}
}
}
impl Default for Disk {
fn default() -> Self {
Self {
cache_dir: PathBuf::from("cache"),
byte_limit: None,
entry_limit: None,
current_byte_count: 0,
current_entry_count: 0,
}
}
}
#[async_trait]
impl CacheStrategy for Disk {
type CacheEntry = Entry;
async fn setup(&mut self) -> Result<()> {
DiskUtil::create_dir(&self.cache_dir).await
}
async fn put<'a, K, V>(&mut self, key: &K, value: V) -> Result<Self::CacheEntry>
where
K: CacheKey + Sync + Send,
V: Into<Cow<'a, [u8]>> + Send,
{
let value = value.into();
let byte_len = value.as_ref().len();
if let Some(byte_limit) = self.byte_limit {
if self.current_byte_count + byte_len > byte_limit {
return Err(crate::Error::LimitExceeded {
limit_kind: LIMIT_KIND_BYTE.into(),
});
}
}
if let Some(entry_limit) = self.entry_limit {
if self.current_entry_count + 1 > entry_limit {
return Err(crate::Error::LimitExceeded {
limit_kind: LIMIT_KIND_ENTRY.into(),
});
}
}
let path = self.cache_dir.join(key.to_key());
DiskUtil::write(&path, value.as_ref()).await?;
self.current_byte_count += byte_len;
self.current_entry_count += 1;
Ok(Entry { path, byte_len })
}
async fn get<'a>(&self, entry: &'a Self::CacheEntry) -> Result<Cow<'a, [u8]>> {
DiskUtil::read(&entry.path, Some(entry.byte_len))
.await
.map(Cow::Owned)
}
async fn take(&mut self, entry: Self::CacheEntry) -> Result<Vec<u8>> {
let data = DiskUtil::read(&entry.path, Some(entry.byte_len)).await?;
self.delete(entry).await?;
Ok(data)
}
async fn delete(&mut self, entry: Self::CacheEntry) -> Result<()> {
DiskUtil::delete(&entry.path).await?;
self.current_byte_count -= entry.byte_len;
self.current_entry_count -= 1;
Ok(())
}
fn get_cache_capacity(&self) -> Option<CacheCapacity> {
self.byte_limit
.map(|byte_limit| CacheCapacity::new(byte_limit, self.current_byte_count))
}
}
#[async_trait]
impl RecoverableStrategy for Disk {
async fn recover<K, F>(&mut self, mut recover_key: F) -> Result<Vec<(K, Self::CacheEntry)>>
where
K: Send,
F: Fn(&str) -> Option<K> + Send,
{
let lost_found_dir = self.cache_dir.join("lost+found");
std::fs::create_dir_all(&lost_found_dir)?;
let move_to_lost_found = |source: &Path| {
let Some(file_name) = source.file_name() else {
return;
};
let target_path = lost_found_dir.join(file_name);
_ = std::fs::rename(source, target_path);
};
let mut entries = Vec::new();
for entry in std::fs::read_dir(&self.cache_dir)?.filter_map(|e| e.ok()) {
let path = entry.path();
if path.is_dir() {
continue;
}
let Some(key) = path
.file_name()
.and_then(|p| p.to_str())
.and_then(&mut recover_key)
else {
move_to_lost_found(&path);
continue;
};
let buf = DiskUtil::read(&path, None).await?;
self.current_byte_count += buf.len();
self.current_entry_count += 1;
entries.push((
key,
Entry {
path,
byte_len: buf.len(),
},
));
}
Ok(entries)
}
}
#[cfg(test)]
mod tests {
use super::{Disk, LIMIT_KIND_BYTE, LIMIT_KIND_ENTRY};
use crate::{async_test, utils::test::TempDir, Cache, Error, NO_COMPRESSION};
async_test! {
async fn test_default() {
let temp_dir = TempDir::new();
let mut cache = Cache::new(Disk::new(temp_dir.as_ref(), None, None), NO_COMPRESSION).await.unwrap();
cache.put("foo", b"foo".to_vec()).await.unwrap();
assert_eq!(cache.strategy().current_byte_count, 3);
assert_eq!(cache.strategy().current_entry_count, 1);
cache.put("bar", b"bar".to_vec()).await.unwrap();
assert_eq!(cache.strategy().current_byte_count, 6);
assert_eq!(cache.strategy().current_entry_count, 2);
assert_eq!(cache.get("foo").await.unwrap(), b"foo".as_slice());
assert_eq!(cache.get("bar").await.unwrap(), b"bar".as_slice());
assert!(cache.get("baz").await.is_err());
cache.delete("foo").await.unwrap();
assert_eq!(cache.strategy().current_byte_count, 3);
assert_eq!(cache.strategy().current_entry_count, 1);
cache.delete("bar").await.unwrap();
assert_eq!(cache.strategy().current_byte_count, 0);
assert_eq!(cache.strategy().current_entry_count, 0);
}
async fn test_strategy_with_byte_limit() {
let temp_dir = TempDir::new();
let mut cache = Cache::new(Disk::new(temp_dir.as_ref(), Some(6), None), NO_COMPRESSION).await.unwrap();
let foo_data = b"foo".to_vec();
let bar_data = b"bar".to_vec();
let baz_data = b"baz".to_vec();
assert_eq!(foo_data.len(), 3);
assert_eq!(bar_data.len(), 3);
assert_eq!(baz_data.len(), 3);
cache.put("foo", foo_data.clone()).await.unwrap();
cache.put("bar", bar_data.clone()).await.unwrap();
assert_eq!(cache.get("foo").await.unwrap(), foo_data.as_slice());
assert_eq!(cache.get("bar").await.unwrap(), bar_data.as_slice());
match cache.put("baz", baz_data).await {
Err(err) => match err {
Error::LimitExceeded { limit_kind } => {
assert_eq!(limit_kind, LIMIT_KIND_BYTE);
}
_ => panic!("Unexpected error: {:?}", err),
},
_ => (),
}
}
async fn test_strategy_with_entry_limit() {
let temp_dir = TempDir::new();
let mut cache = Cache::new(Disk::new(temp_dir.as_ref(), None, Some(3)), NO_COMPRESSION).await.unwrap();
cache.put("foo", b"foo".to_vec()).await.unwrap();
cache.put("bar", b"bar".to_vec()).await.unwrap();
assert_eq!(cache.get("foo").await.unwrap(), b"foo".as_slice());
assert_eq!(cache.get("bar").await.unwrap(), b"bar".as_slice());
match cache.put("baz", b"baz".to_vec()).await {
Err(err) => match err {
Error::LimitExceeded { limit_kind } => {
assert_eq!(limit_kind, LIMIT_KIND_ENTRY);
}
_ => panic!("Unexpected error: {:?}", err),
},
_ => (),
}
}
async fn test_recovery() {
let temp_dir = TempDir::new();
{
let mut cache = Cache::new(Disk::new(temp_dir.as_ref(), None, None), NO_COMPRESSION).await.unwrap();
cache.put("foo", b"foo".to_vec()).await.unwrap();
cache.put("bar", b"bar".to_vec()).await.unwrap();
}
{
let mut cache = Cache::new(Disk::new(temp_dir.as_ref(), None, None), NO_COMPRESSION).await.unwrap();
let recovered_items = cache
.recover(|k| Some(k.to_string()))
.await
.expect("Failed to recover");
assert_eq!(recovered_items, 2);
assert_eq!(cache.strategy().current_byte_count, 6);
assert_eq!(cache.strategy().current_entry_count, 2);
}
}
}
}