use super::EvictionManager;
use crate::key::CompactCacheKey;
use async_trait::async_trait;
use log::{info, warn};
use pingora_error::{BError, ErrorType::*, OrErr, Result};
use pingora_lru::Lru;
use rand::Rng;
use serde::de::SeqAccess;
use serde::{Deserialize, Serialize};
use std::fs::{rename, File};
use std::hash::{Hash, Hasher};
use std::io::prelude::*;
use std::path::Path;
use std::time::SystemTime;
pub struct Manager<const N: usize>(Lru<CompactCacheKey, N>);
#[derive(Debug, Serialize, Deserialize)]
struct SerdeHelperNode(CompactCacheKey, usize);
impl<const N: usize> Manager<N> {
pub fn with_capacity(limit: usize, capacity: usize) -> Self {
Manager(Lru::with_capacity(limit, capacity))
}
pub fn with_capacity_and_watermark(
limit: usize,
capacity: usize,
watermark: Option<usize>,
) -> Self {
Manager(Lru::with_capacity_and_watermark(limit, capacity, watermark))
}
pub fn shards(&self) -> usize {
self.0.shards()
}
pub fn shard_weight(&self, shard: usize) -> usize {
self.0.shard_weight(shard)
}
pub fn shard_len(&self, shard: usize) -> usize {
self.0.shard_len(shard)
}
pub fn get_shard_for_key(&self, key: &CompactCacheKey) -> usize {
(u64key(key) % N as u64) as usize
}
pub fn serialize_shard(&self, shard: usize) -> Result<Vec<u8>> {
use rmp_serde::encode::Serializer;
use serde::ser::SerializeSeq;
use serde::ser::Serializer as _;
assert!(shard < N);
let mut nodes = Vec::with_capacity(self.0.shard_len(shard));
self.0.iter_for_each(shard, |(node, size)| {
nodes.push(SerdeHelperNode(node.clone(), size));
});
let mut ser = Serializer::new(vec![]);
let mut seq = ser
.serialize_seq(Some(self.0.shard_len(shard)))
.or_err(InternalError, "fail to serialize node")?;
for node in nodes {
seq.serialize_element(&node).unwrap(); }
seq.end().or_err(InternalError, "when serializing LRU")?;
Ok(ser.into_inner())
}
pub fn deserialize_shard(&self, buf: &[u8]) -> Result<()> {
use rmp_serde::decode::Deserializer;
use serde::de::Deserializer as _;
let mut de = Deserializer::new(buf);
let visitor = InsertToManager { lru: self };
de.deserialize_seq(visitor)
.or_err(InternalError, "when deserializing LRU")?;
Ok(())
}
pub fn peek_weight(&self, item: &CompactCacheKey) -> Option<usize> {
let key = u64key(item);
self.0.peek_weight(key)
}
}
struct InsertToManager<'a, const N: usize> {
lru: &'a Manager<N>,
}
impl<'de, const N: usize> serde::de::Visitor<'de> for InsertToManager<'_, N> {
type Value = ();
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("array of lru nodes")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
while let Some(node) = seq.next_element::<SerdeHelperNode>()? {
let key = u64key(&node.0);
self.lru.0.insert_tail(key, node.0, node.1); }
Ok(())
}
}
#[inline]
fn u64key(key: &CompactCacheKey) -> u64 {
let mut hasher = ahash::AHasher::default();
key.hash(&mut hasher);
hasher.finish()
}
const FILE_NAME: &str = "lru.data";
#[inline]
fn err_str_path(s: &str, path: &Path) -> String {
format!("{s} {}", path.display())
}
#[async_trait]
impl<const N: usize> EvictionManager for Manager<N> {
fn total_size(&self) -> usize {
self.0.weight()
}
fn total_items(&self) -> usize {
self.0.len()
}
fn evicted_size(&self) -> usize {
self.0.evicted_weight()
}
fn evicted_items(&self) -> usize {
self.0.evicted_len()
}
fn admit(
&self,
item: CompactCacheKey,
size: usize,
_fresh_until: SystemTime,
) -> Vec<CompactCacheKey> {
let key = u64key(&item);
self.0.admit(key, item, size);
self.0
.evict_to_limit()
.into_iter()
.map(|(key, _weight)| key)
.collect()
}
fn increment_weight(
&self,
item: &CompactCacheKey,
delta: usize,
max_weight: Option<usize>,
) -> Vec<CompactCacheKey> {
let key = u64key(item);
self.0.increment_weight(key, delta, max_weight);
self.0
.evict_to_limit()
.into_iter()
.map(|(key, _weight)| key)
.collect()
}
fn remove(&self, item: &CompactCacheKey) {
let key = u64key(item);
self.0.remove(key);
}
fn access(&self, item: &CompactCacheKey, size: usize, _fresh_until: SystemTime) -> bool {
let key = u64key(item);
if !self.0.promote(key) {
self.0.admit(key, item.clone(), size);
false
} else {
true
}
}
fn peek(&self, item: &CompactCacheKey) -> bool {
let key = u64key(item);
self.0.peek(key)
}
async fn save(&self, dir_path: &str) -> Result<()> {
let dir_path_str = dir_path.to_owned();
tokio::task::spawn_blocking(move || {
let dir_path = Path::new(&dir_path_str);
std::fs::create_dir_all(dir_path)
.or_err_with(InternalError, || err_str_path("fail to create", dir_path))
})
.await
.or_err(InternalError, "async blocking IO failure")??;
for i in 0..N {
let data = self.serialize_shard(i)?;
let dir_path = dir_path.to_owned();
tokio::task::spawn_blocking(move || {
let dir_path = Path::new(&dir_path);
let final_path = dir_path.join(format!("{}.{i}", FILE_NAME));
let random_suffix: u32 = rand::thread_rng().gen();
let temp_path =
dir_path.join(format!("{}.{i}.{:08x}.tmp", FILE_NAME, random_suffix));
let mut file = File::create(&temp_path)
.or_err_with(InternalError, || err_str_path("fail to create", &temp_path))?;
file.write_all(&data).or_err_with(InternalError, || {
err_str_path("fail to write to", &temp_path)
})?;
file.flush().or_err_with(InternalError, || {
err_str_path("fail to flush temp file", &temp_path)
})?;
rename(&temp_path, &final_path).or_err_with(InternalError, || {
format!(
"Failed to rename file from {} to {}",
temp_path.display(),
final_path.display(),
)
})
})
.await
.or_err(InternalError, "async blocking IO failure")??;
}
Ok(())
}
async fn load(&self, dir_path: &str) -> Result<()> {
let mut loaded_shards = 0;
for i in 0..N {
let dir_path = dir_path.to_owned();
let data = tokio::task::spawn_blocking(move || {
let file_path = Path::new(&dir_path).join(format!("{}.{i}", FILE_NAME));
let mut file = File::open(&file_path)
.or_err_with(InternalError, || err_str_path("fail to open", &file_path))?;
let mut buffer = Vec::with_capacity(8192);
file.read_to_end(&mut buffer)
.or_err_with(InternalError, || {
err_str_path("fail to read from", &file_path)
})?;
Ok::<Vec<u8>, BError>(buffer)
})
.await
.or_err(InternalError, "async blocking IO failure")??;
if let Err(e) = self.deserialize_shard(&data) {
warn!("Failed to deserialize shard {}: {}. Skipping shard.", i, e);
continue; }
loaded_shards += 1;
}
if loaded_shards < N {
warn!(
"Only loaded {}/{} shards. Cache may be incomplete.",
loaded_shards, N
)
} else {
info!("Successfully loaded {}/{} shards.", loaded_shards, N)
}
cleanup_temp_files(dir_path);
Ok(())
}
}
fn cleanup_temp_files(dir_path: &str) {
let dir_path = Path::new(dir_path).to_owned();
tokio::task::spawn_blocking({
move || {
if !dir_path.exists() {
return;
}
let entries = match std::fs::read_dir(&dir_path) {
Ok(entries) => entries,
Err(e) => {
warn!("Failed to read directory {}: {e}", dir_path.display());
return;
}
};
let mut cleaned_count = 0;
let mut error_count = 0;
for entry in entries {
let entry = match entry {
Ok(entry) => entry,
Err(e) => {
warn!(
"Failed to read directory entry in {}: {e}",
dir_path.display()
);
error_count += 1;
continue;
}
};
let file_name = entry.file_name();
let file_name_str = file_name.to_string_lossy();
if file_name_str.starts_with(FILE_NAME) && file_name_str.ends_with(".tmp") {
match std::fs::remove_file(entry.path()) {
Ok(()) => {
info!("Cleaned up orphaned temp file: {}", entry.path().display());
cleaned_count += 1;
}
Err(e) => {
warn!("Failed to remove temp file {}: {e}", entry.path().display());
error_count += 1;
}
}
}
}
if cleaned_count > 0 || error_count > 0 {
info!(
"Temp file cleanup completed. Removed: {cleaned_count}, Errors: {error_count}"
);
}
}
});
}
#[cfg(test)]
mod test {
use super::*;
use crate::CacheKey;
#[test]
fn test_admission() {
let lru = Manager::<1>::with_capacity(4, 10);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4, 2, until);
assert_eq!(v.len(), 2);
assert_eq!(v[0], key1);
assert_eq!(v[1], key2);
}
#[test]
fn test_access() {
let lru = Manager::<1>::with_capacity(4, 10);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
lru.access(&key1, 1, until);
assert_eq!(v.len(), 0);
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4, 2, until);
assert_eq!(v.len(), 1);
assert_eq!(v[0], key2);
}
#[test]
fn test_remove() {
let lru = Manager::<1>::with_capacity(4, 10);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
lru.remove(&key1);
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4, 2, until);
assert_eq!(v.len(), 1);
assert_eq!(v[0], key2);
}
#[test]
fn test_access_add() {
let lru = Manager::<1>::with_capacity(4, 10);
let until = SystemTime::now();
let key1 = CacheKey::new("", "a", "1").to_compact();
lru.access(&key1, 1, until);
let key2 = CacheKey::new("", "b", "1").to_compact();
lru.access(&key2, 2, until);
let key3 = CacheKey::new("", "c", "1").to_compact();
lru.access(&key3, 2, until);
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4, 2, until);
assert_eq!(v.len(), 2);
assert_eq!(v[0], key1);
assert_eq!(v[1], key2);
}
#[test]
fn test_admit_update() {
let lru = Manager::<1>::with_capacity(4, 10);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
let v = lru.admit(key2, 1, until);
assert_eq!(v.len(), 0);
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru.admit(key4.clone(), 1, until);
assert_eq!(v.len(), 0);
let v = lru.admit(key4, 2, until);
assert_eq!(v.len(), 1);
assert_eq!(v[0], key1);
}
#[test]
fn test_peek() {
let lru = Manager::<1>::with_capacity(4, 10);
let until = SystemTime::now();
let key1 = CacheKey::new("", "a", "1").to_compact();
lru.access(&key1, 1, until);
let key2 = CacheKey::new("", "b", "1").to_compact();
lru.access(&key2, 2, until);
assert!(lru.peek(&key1));
assert!(lru.peek(&key2));
}
#[test]
fn test_serde() {
let lru = Manager::<1>::with_capacity(4, 10);
let key1 = CacheKey::new("", "a", "1").to_compact();
let until = SystemTime::now(); let v = lru.admit(key1.clone(), 1, until);
assert_eq!(v.len(), 0);
let key2 = CacheKey::new("", "b", "1").to_compact();
let v = lru.admit(key2.clone(), 2, until);
assert_eq!(v.len(), 0);
let key3 = CacheKey::new("", "c", "1").to_compact();
let v = lru.admit(key3, 1, until);
assert_eq!(v.len(), 0);
lru.access(&key1, 1, until);
assert_eq!(v.len(), 0);
let ser = lru.serialize_shard(0).unwrap();
let lru2 = Manager::<1>::with_capacity(4, 10);
lru2.deserialize_shard(&ser).unwrap();
let key4 = CacheKey::new("", "d", "1").to_compact();
let v = lru2.admit(key4, 2, until);
assert_eq!(v.len(), 1);
assert_eq!(v[0], key2);
}
#[tokio::test]
async fn test_save_to_disk() {
let until = SystemTime::now(); let lru = Manager::<2>::with_capacity(10, 10);
lru.admit(CacheKey::new("", "a", "1").to_compact(), 1, until);
lru.admit(CacheKey::new("", "b", "1").to_compact(), 2, until);
lru.admit(CacheKey::new("", "c", "1").to_compact(), 1, until);
lru.admit(CacheKey::new("", "d", "1").to_compact(), 1, until);
lru.admit(CacheKey::new("", "e", "1").to_compact(), 2, until);
lru.admit(CacheKey::new("", "f", "1").to_compact(), 1, until);
lru.save("/tmp/test_lru_save").await.unwrap();
let lru2 = Manager::<2>::with_capacity(4, 10);
lru2.load("/tmp/test_lru_save").await.unwrap();
let ser0 = lru.serialize_shard(0).unwrap();
let ser1 = lru.serialize_shard(1).unwrap();
assert_eq!(ser0, lru2.serialize_shard(0).unwrap());
assert_eq!(ser1, lru2.serialize_shard(1).unwrap());
}
#[tokio::test]
async fn test_temp_file_cleanup() {
let test_dir = "/tmp/test_lru_cleanup";
let dir_path = Path::new(test_dir);
std::fs::create_dir_all(dir_path).unwrap();
let temp_files = [
"lru.data.0.12345678.tmp",
"lru.data.1.abcdef00.tmp",
"other_file.tmp", "lru.data.2", ];
for file in temp_files {
let file_path = dir_path.join(file);
std::fs::write(&file_path, b"test").unwrap();
}
cleanup_temp_files(test_dir);
tokio::time::sleep(core::time::Duration::from_secs(1)).await;
assert!(!dir_path.join("lru.data.0.12345678.tmp").exists());
assert!(!dir_path.join("lru.data.1.abcdef00.tmp").exists());
assert!(dir_path.join("other_file.tmp").exists()); assert!(dir_path.join("lru.data.2").exists());
std::fs::remove_dir_all(dir_path).unwrap();
}
}