use async_trait::async_trait;
use std::{
borrow::Cow,
path::{Path, PathBuf},
};
use crate::{
traits::{CacheKey, CacheStrategy, FlushableStrategy, RecoverableStrategy},
CacheCapacity, DiskUtil, Result,
};
const LIMIT_KIND_BYTE_DISK: &str = "Stored bytes on disk";
const LIMIT_KIND_ENTRY_DISK: &str = "Stored entries on disk";
enum LimitExceededKind {
Bytes,
Entries,
}
enum LimitEvaluation {
LimitSatisfied,
LimitExceeded(LimitExceededKind),
}
impl LimitEvaluation {
fn is_satisfied(&self) -> bool {
matches!(self, LimitEvaluation::LimitSatisfied)
}
}
#[derive(Debug)]
pub struct MemoryEntry {
data: Vec<u8>,
byte_len: usize,
}
#[derive(Debug)]
pub struct DiskEntry {
path: PathBuf,
byte_len: usize,
}
#[derive(Debug)]
pub enum Entry {
Memory(MemoryEntry),
Disk(DiskEntry),
}
#[derive(Debug, Default)]
pub struct Limits {
byte_limit: Option<usize>,
entry_limit: Option<usize>,
current_byte_count: usize,
current_entry_count: usize,
}
impl Limits {
pub fn new(byte_limit: Option<usize>, entry_limit: Option<usize>) -> Self {
Self {
byte_limit,
entry_limit,
..Default::default()
}
}
fn evaluate(&self, size: usize) -> LimitEvaluation {
if let Some(byte_limit) = self.byte_limit {
if self.current_byte_count + size > byte_limit {
return LimitEvaluation::LimitExceeded(LimitExceededKind::Bytes);
}
} else if let Some(entries_limit) = self.entry_limit {
if self.current_entry_count + 1 > entries_limit {
return LimitEvaluation::LimitExceeded(LimitExceededKind::Entries);
}
}
LimitEvaluation::LimitSatisfied
}
}
#[derive(Debug)]
pub struct Hybrid {
cache_dir: PathBuf,
memory_limits: Limits,
disk_limits: Limits,
}
impl Default for Hybrid {
fn default() -> Self {
Self {
cache_dir: PathBuf::from("cache"),
memory_limits: Limits::default(),
disk_limits: Limits::default(),
}
}
}
impl Hybrid {
pub fn new<'a>(
cache_dir: impl Into<Cow<'a, Path>>,
memory_limits: Limits,
disk_limits: Limits,
) -> Self {
Self {
cache_dir: cache_dir.into().into_owned(),
memory_limits,
disk_limits,
}
}
}
#[async_trait]
impl CacheStrategy for Hybrid {
type CacheEntry = Entry;
async fn setup(&mut self) -> Result<()> {
DiskUtil::create_dir(&self.cache_dir).await
}
async fn put<'a, K, V>(&mut self, key: &K, value: V) -> Result<Self::CacheEntry>
where
K: CacheKey + Sync + Send,
V: Into<Cow<'a, [u8]>> + Send,
{
let value = value.into();
let byte_len = value.as_ref().len();
let fits_into_memory = self.memory_limits.evaluate(byte_len);
let fits_into_disk = self.disk_limits.evaluate(byte_len);
if fits_into_memory.is_satisfied() {
self.memory_limits.current_byte_count += byte_len;
self.memory_limits.current_entry_count += 1;
Ok(Entry::Memory(MemoryEntry {
data: value.into_owned(),
byte_len,
}))
}
else if fits_into_disk.is_satisfied() {
let path = self.cache_dir.join(key.to_key());
DiskUtil::write(&path, &value).await?;
self.disk_limits.current_byte_count += byte_len;
self.disk_limits.current_entry_count += 1;
Ok(Entry::Disk(DiskEntry { path, byte_len }))
}
else {
use LimitEvaluation::LimitExceeded;
let limit_kind = Cow::Borrowed(match fits_into_disk {
LimitExceeded(LimitExceededKind::Bytes) => LIMIT_KIND_BYTE_DISK,
LimitExceeded(LimitExceededKind::Entries) => LIMIT_KIND_ENTRY_DISK,
_ => unreachable!(),
});
Err(crate::Error::LimitExceeded { limit_kind })
}
}
async fn get<'a>(&self, entry: &'a Self::CacheEntry) -> Result<Cow<'a, [u8]>> {
match entry {
Entry::Memory(entry) => Ok(Cow::Borrowed(&entry.data)),
Entry::Disk(entry) => Ok(Cow::Owned(
DiskUtil::read(&entry.path, Some(entry.byte_len)).await?,
)),
}
}
async fn take(&mut self, entry: Self::CacheEntry) -> Result<Vec<u8>> {
match entry {
Entry::Memory(entry) => {
self.memory_limits.current_byte_count -= entry.byte_len;
self.memory_limits.current_entry_count -= 1;
Ok(entry.data)
}
Entry::Disk(ref entry) => {
let data = DiskUtil::read(&entry.path, Some(entry.byte_len)).await?;
DiskUtil::delete(&entry.path).await?;
self.disk_limits.current_byte_count -= entry.byte_len;
self.disk_limits.current_entry_count -= 1;
Ok(data)
}
}
}
async fn delete(&mut self, entry: Self::CacheEntry) -> Result<()> {
match entry {
Entry::Memory(entry) => {
self.memory_limits.current_byte_count -= entry.byte_len;
self.memory_limits.current_entry_count -= 1;
}
Entry::Disk(entry) => {
DiskUtil::delete(&entry.path).await?;
self.disk_limits.current_byte_count -= entry.byte_len;
self.disk_limits.current_entry_count -= 1;
}
}
Ok(())
}
fn get_cache_capacity(&self) -> Option<CacheCapacity> {
if let (Some(memory_byte_limit), Some(disk_byte_limit)) =
(self.memory_limits.byte_limit, self.disk_limits.byte_limit)
{
Some(CacheCapacity::new(
memory_byte_limit + disk_byte_limit,
self.memory_limits.current_byte_count + self.disk_limits.current_byte_count,
))
} else {
None
}
}
}
#[async_trait]
impl RecoverableStrategy for Hybrid {
async fn recover<K, F>(&mut self, mut recover_key: F) -> Result<Vec<(K, Self::CacheEntry)>>
where
K: Send,
F: Fn(&str) -> Option<K> + Send,
{
let lost_found_dir = self.cache_dir.join("lost+found");
std::fs::create_dir_all(&lost_found_dir)?;
let move_to_lost_found = |source: &Path| {
let Some(file_name) = source.file_name() else {
return;
};
let target_path = lost_found_dir.join(file_name);
_ = std::fs::rename(source, target_path);
};
let mut entries = Vec::new();
for entry in std::fs::read_dir(&self.cache_dir)?.filter_map(|e| e.ok()) {
let path = entry.path();
if path.is_dir() {
continue;
}
let Some(key) = path
.file_name()
.and_then(|p| p.to_str())
.and_then(&mut recover_key)
else {
move_to_lost_found(&path);
continue;
};
let buf = DiskUtil::read(&path, None).await?;
self.disk_limits.current_byte_count += buf.len();
self.disk_limits.current_entry_count += 1;
entries.push((
key,
Entry::Disk(DiskEntry {
path,
byte_len: buf.len(),
}),
));
}
Ok(entries)
}
}
#[async_trait]
impl FlushableStrategy for Hybrid {
async fn flush<K>(
&mut self,
key: &K,
entry: &Self::CacheEntry,
) -> Result<Option<Self::CacheEntry>>
where
K: CacheKey + Sync + Send,
{
let Self::CacheEntry::Memory(entry) = entry else {
return Ok(None);
};
if let LimitEvaluation::LimitExceeded(reason) = self.disk_limits.evaluate(entry.byte_len) {
let limit_kind = Cow::Borrowed(match reason {
LimitExceededKind::Bytes => LIMIT_KIND_BYTE_DISK,
LimitExceededKind::Entries => LIMIT_KIND_ENTRY_DISK,
});
return Err(crate::Error::LimitExceeded { limit_kind });
}
let path = self.cache_dir.join(key.to_key());
DiskUtil::write(&path, &entry.data).await?;
self.disk_limits.current_byte_count += entry.byte_len;
self.disk_limits.current_entry_count += 1;
Ok(Some(Entry::Disk(DiskEntry {
path,
byte_len: entry.byte_len,
})))
}
}
#[cfg(test)]
mod tests {
use std::fs::metadata;
use super::{Hybrid, Limits, LIMIT_KIND_BYTE_DISK, LIMIT_KIND_ENTRY_DISK};
use crate::{async_test, utils::test::TempDir, Cache, Error, NO_COMPRESSION};
async_test! {
async fn test_default_strategy() {
let mut cache = Cache::new(Hybrid::default(), NO_COMPRESSION).await.unwrap();
cache.put("foo", b"foo".to_vec()).await.unwrap();
assert_eq!(cache.strategy().memory_limits.current_byte_count, 3);
assert_eq!(cache.strategy().memory_limits.current_entry_count, 1);
assert_eq!(cache.strategy().disk_limits.current_byte_count, 0);
assert_eq!(cache.strategy().disk_limits.current_entry_count, 0);
cache.put("bar", b"bar".to_vec()).await.unwrap();
assert_eq!(cache.strategy().memory_limits.current_byte_count, 6);
assert_eq!(cache.strategy().memory_limits.current_entry_count, 2);
assert_eq!(cache.strategy().disk_limits.current_byte_count, 0);
assert_eq!(cache.strategy().disk_limits.current_entry_count, 0);
assert_eq!(cache.get("foo").await.unwrap(), b"foo".as_slice());
assert_eq!(cache.get("bar").await.unwrap(), b"bar".as_slice());
assert!(cache.get("baz").await.is_err());
cache.delete("foo").await.unwrap();
assert_eq!(cache.strategy().memory_limits.current_byte_count, 3);
assert_eq!(cache.strategy().memory_limits.current_entry_count, 1);
assert_eq!(cache.strategy().disk_limits.current_byte_count, 0);
assert_eq!(cache.strategy().disk_limits.current_entry_count, 0);
cache.delete("bar").await.unwrap();
assert_eq!(cache.strategy().memory_limits.current_byte_count, 0);
assert_eq!(cache.strategy().memory_limits.current_entry_count, 0);
assert_eq!(cache.strategy().disk_limits.current_byte_count, 0);
assert_eq!(cache.strategy().disk_limits.current_entry_count, 0);
}
async fn test_strategy_with_memory_byte_limit() {
let temp_dir = TempDir::new();
let mut cache = Cache::new(Hybrid::new(
temp_dir.as_ref(),
Limits::new(Some(6), None),
Limits::default(),
), NO_COMPRESSION).await.unwrap();
cache.put("foo", b"foo".to_vec()).await.unwrap();
cache.put("bar", b"bar".to_vec()).await.unwrap();
assert_eq!(cache.get("foo").await.unwrap(), b"foo".as_slice());
assert_eq!(cache.get("bar").await.unwrap(), b"bar".as_slice());
cache.put("baz", b"baz".to_vec()).await.unwrap();
assert!(metadata(temp_dir.as_ref().join("baz")).unwrap().is_file());
}
async fn test_strategy_with_memory_entry_limit() {
let temp_dir = TempDir::new();
let mut cache = Cache::new(Hybrid::new(
temp_dir.as_ref(),
Limits::new(None, Some(2)),
Limits::default(),
), NO_COMPRESSION).await.unwrap();
cache.put("foo", b"foo".to_vec()).await.unwrap();
cache.put("bar", b"bar".to_vec()).await.unwrap();
assert_eq!(cache.get("foo").await.unwrap(), b"foo".as_slice());
assert_eq!(cache.get("bar").await.unwrap(), b"bar".as_slice());
cache.put("baz", b"baz".to_vec()).await.unwrap();
assert!(metadata(temp_dir.as_ref().join("baz")).unwrap().is_file());
}
async fn test_strategy_with_memory_and_disk_byte_limit() {
let temp_dir = TempDir::new();
let mut cache = Cache::new(Hybrid::new(
temp_dir.as_ref(),
Limits::new(Some(6), None),
Limits::new(Some(6), None),
), NO_COMPRESSION).await.unwrap();
cache.put("foo", b"foo".to_vec()).await.unwrap();
cache.put("bar", b"bar".to_vec()).await.unwrap();
assert_eq!(cache.get("foo").await.unwrap(), b"foo".as_slice());
assert_eq!(cache.get("bar").await.unwrap(), b"bar".as_slice());
cache.put("baz", b"baz".to_vec()).await.unwrap();
cache.put("bax", b"bax".to_vec()).await.unwrap();
assert!(metadata(temp_dir.as_ref().join("baz")).unwrap().is_file());
assert!(metadata(temp_dir.as_ref().join("bax")).unwrap().is_file());
match cache.put("quix", b"quix".to_vec()).await {
Err(err) => match err {
Error::LimitExceeded { limit_kind } => {
assert_eq!(limit_kind, LIMIT_KIND_BYTE_DISK);
}
_ => {
panic!("Unexpected error: {:?}", err);
}
},
_ => (),
}
}
async fn test_strategy_with_memory_and_disk_entry_limit() {
let temp_dir = TempDir::new();
let mut cache = Cache::new(Hybrid::new(
temp_dir.as_ref(),
Limits::new(None, Some(2)),
Limits::new(None, Some(2)),
), NO_COMPRESSION).await.unwrap();
cache.put("foo", b"foo".to_vec()).await.unwrap();
cache.put("bar", b"bar".to_vec()).await.unwrap();
assert_eq!(cache.get("foo").await.unwrap(), b"foo".as_slice());
assert_eq!(cache.get("bar").await.unwrap(), b"bar".as_slice());
cache.put("baz", b"baz".to_vec()).await.unwrap();
cache.put("bax", b"bax".to_vec()).await.unwrap();
assert!(metadata(temp_dir.as_ref().join("baz")).unwrap().is_file());
assert!(metadata(temp_dir.as_ref().join("bax")).unwrap().is_file());
match cache.put("quix", b"quix".to_vec()).await {
Err(err) => match err {
Error::LimitExceeded { limit_kind } => {
assert_eq!(limit_kind, LIMIT_KIND_ENTRY_DISK);
}
_ => {
panic!("Unexpected error: {:?}", err);
}
},
_ => (),
}
}
async fn test_recovery() {
let temp_dir = TempDir::new();
{
let mut cache = Cache::new(Hybrid::new(
temp_dir.as_ref(),
Limits::new(None, Some(1)),
Limits::default(),
), NO_COMPRESSION).await.unwrap();
cache.put("foo", b"foo".to_vec()).await.unwrap();
cache.put("bar", b"bar".to_vec()).await.unwrap();
cache.put("baz", b"baz".to_vec()).await.unwrap();
}
{
let mut cache = Cache::new(Hybrid::new(
temp_dir.as_ref(),
Limits::default(),
Limits::default(),
), NO_COMPRESSION).await.unwrap();
let recovered_items = cache
.recover(|k| Some(k.to_string()))
.await
.expect("Failed to recover");
assert_eq!(recovered_items, 2);
assert_eq!(cache.strategy().disk_limits.current_byte_count, 6);
assert_eq!(cache.strategy().disk_limits.current_entry_count, 2);
}
}
async fn test_flush() {
let temp_dir = TempDir::new();
let mut cache = Cache::new(Hybrid::new(
temp_dir.as_ref(),
Limits::default(),
Limits::default(),
), NO_COMPRESSION).await.unwrap();
cache.put("foo", b"foo".as_slice()).await.unwrap();
cache.put("bar", b"bar".as_slice()).await.unwrap();
assert_eq!(cache.strategy().memory_limits.current_byte_count, 6);
assert_eq!(cache.strategy().memory_limits.current_entry_count, 2);
cache.flush().await.unwrap();
assert_eq!(cache.strategy().memory_limits.current_byte_count, 0);
assert_eq!(cache.strategy().memory_limits.current_entry_count, 0);
assert_eq!(cache.strategy().disk_limits.current_byte_count, 6);
assert_eq!(cache.strategy().disk_limits.current_entry_count, 2);
}
}
}