use std::collections::HashMap;
use std::fs;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
pub trait CacheBackend: Send + Sync {
fn get(&self, key: &str) -> Option<Vec<u8>>;
fn set(&self, key: &str, value: &[u8], ttl: Duration);
fn remove(&self, key: &str);
fn clear(&self);
}
#[derive(Debug)]
pub struct MemoryCache {
data: std::sync::RwLock<HashMap<String, (Vec<u8>, Instant)>>,
}
impl MemoryCache {
pub fn new() -> Self {
Self {
data: std::sync::RwLock::new(HashMap::new()),
}
}
pub fn len(&self) -> usize {
self.data.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.data.read().unwrap().is_empty()
}
}
impl Default for MemoryCache {
fn default() -> Self {
Self::new()
}
}
impl CacheBackend for MemoryCache {
fn get(&self, key: &str) -> Option<Vec<u8>> {
let data = self.data.read().unwrap();
if let Some((value, expiry)) = data.get(key) {
if Instant::now() < *expiry {
return Some(value.clone());
} else {
return None;
}
}
None
}
fn set(&self, key: &str, value: &[u8], ttl: Duration) {
let mut data = self.data.write().unwrap();
let expiry = Instant::now() + ttl;
data.insert(key.to_string(), (value.to_vec(), expiry));
}
fn remove(&self, key: &str) {
let mut data = self.data.write().unwrap();
data.remove(key);
}
fn clear(&self) {
let mut data = self.data.write().unwrap();
data.clear();
}
}
#[derive(Debug, Clone)]
pub struct FileCache {
dir: PathBuf,
}
impl FileCache {
pub fn new<P: AsRef<Path>>(dir: P) -> Self {
let dir = dir.as_ref().to_path_buf();
fs::create_dir_all(&dir).ok();
Self { dir }
}
fn cache_path(&self, key: &str) -> PathBuf {
let safe_key = key.replace('/', "_").replace('\\', "_");
self.dir.join(format!("{}.cache", safe_key))
}
pub fn cleanup_expired(&self) -> io::Result<usize> {
let mut removed = 0;
for entry in fs::read_dir(&self.dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("cache") {
continue;
}
if let Ok(Some(expiry)) = self.read_expiry(&path) {
if Instant::now() >= expiry {
fs::remove_file(&path)?;
removed += 1;
}
}
}
Ok(removed)
}
fn read_expiry(&self, path: &Path) -> io::Result<Option<Instant>> {
let data = fs::read(path)?;
if data.len() < 8 {
return Ok(None);
}
let ttl_micros = u64::from_le_bytes(data[0..8].try_into().unwrap());
let stored_ttl = Duration::from_micros(ttl_micros);
let metadata = fs::metadata(path)?;
let modified = metadata.modified()?;
let file_age = modified.elapsed().unwrap_or_default();
if file_age > stored_ttl {
Ok(Some(Instant::now())) } else {
Ok(None) }
}
}
impl CacheBackend for FileCache {
fn get(&self, key: &str) -> Option<Vec<u8>> {
let path = self.cache_path(key);
let data = fs::read(&path).ok()?;
if data.len() < 8 {
return None;
}
let ttl_micros = u64::from_le_bytes(data[0..8].try_into().unwrap());
let stored_ttl = Duration::from_micros(ttl_micros);
let metadata = fs::metadata(&path).ok()?;
let modified = metadata.modified().ok()?;
let file_age = modified.elapsed().ok()?;
if file_age > stored_ttl {
fs::remove_file(&path).ok();
return None;
}
Some(data[8..].to_vec())
}
fn set(&self, key: &str, value: &[u8], ttl: Duration) {
let path = self.cache_path(key);
if let Ok(mut file) = fs::File::create(&path) {
let ttl_micros = ttl.as_micros() as u64;
file.write_all(&ttl_micros.to_le_bytes()).ok();
file.write_all(value).ok();
}
}
fn remove(&self, key: &str) {
let path = self.cache_path(key);
fs::remove_file(&path).ok();
}
fn clear(&self) {
fs::remove_dir_all(&self.dir).ok();
fs::create_dir_all(&self.dir).ok();
}
}
pub struct Cache<B: CacheBackend> {
backend: B,
ttl: Duration,
}
impl<B: CacheBackend> Cache<B> {
pub fn new(backend: B, ttl: Duration) -> Self {
Self { backend, ttl }
}
pub fn get(&self, key: &str) -> Option<Vec<u8>> {
self.backend.get(key)
}
pub fn set(&self, key: &str, value: &[u8]) {
self.backend.set(key, value, self.ttl);
}
pub fn set_with_ttl(&self, key: &str, value: &[u8], ttl: Duration) {
self.backend.set(key, value, ttl);
}
pub fn remove(&self, key: &str) {
self.backend.remove(key);
}
pub fn clear(&self) {
self.backend.clear();
}
pub fn get_or_fetch<F, E>(&self, key: &str, fetch: F) -> Result<Vec<u8>, E>
where
F: FnOnce() -> Result<Vec<u8>, E>,
{
if let Some(data) = self.get(key) {
return Ok(data);
}
let data = fetch()?;
self.set(key, &data);
Ok(data)
}
pub fn ttl(&self) -> Duration {
self.ttl
}
pub fn backend(&self) -> &B {
&self.backend
}
}
impl Cache<MemoryCache> {
pub fn memory(ttl: Duration) -> Self {
Self::new(MemoryCache::new(), ttl)
}
}
impl Cache<FileCache> {
pub fn file<P: AsRef<Path>>(dir: P, ttl: Duration) -> Self {
Self::new(FileCache::new(dir), ttl)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_memory_cache_basic() {
let cache = MemoryCache::new();
cache.set("key1", b"data1", Duration::from_secs(60));
let result = cache.get("key1");
assert_eq!(result, Some(b"data1".to_vec()));
let result = cache.get("key2");
assert_eq!(result, None);
}
#[test]
fn test_memory_cache_expiry() {
let cache = MemoryCache::new();
cache.set("key1", b"data1", Duration::from_millis(10));
let result = cache.get("key1");
assert_eq!(result, Some(b"data1".to_vec()));
thread::sleep(Duration::from_millis(20));
let result = cache.get("key1");
assert_eq!(result, None);
}
#[test]
fn test_memory_cache_remove() {
let cache = MemoryCache::new();
cache.set("key1", b"data1", Duration::from_secs(60));
assert!(cache.get("key1").is_some());
cache.remove("key1");
assert!(cache.get("key1").is_none());
}
#[test]
fn test_memory_cache_clear() {
let cache = MemoryCache::new();
cache.set("key1", b"data1", Duration::from_secs(60));
cache.set("key2", b"data2", Duration::from_secs(60));
assert_eq!(cache.len(), 2);
cache.clear();
assert_eq!(cache.len(), 0);
}
#[test]
fn test_cache_manager() {
let cache = Cache::memory(Duration::from_secs(60));
cache.set("key1", b"data1");
let result = cache.get("key1");
assert_eq!(result, Some(b"data1".to_vec()));
}
#[test]
fn test_get_or_fetch() {
let cache = Cache::memory(Duration::from_secs(60));
let mut call_count = 0;
let result1: Result<Vec<u8>, ()> = cache.get_or_fetch("key1", || {
call_count += 1;
Ok(vec![1, 2, 3])
});
assert_eq!(result1.unwrap(), vec![1, 2, 3]);
assert_eq!(call_count, 1);
let result2: Result<Vec<u8>, ()> = cache.get_or_fetch("key1", || {
call_count += 1;
Ok(vec![1, 2, 3])
});
assert_eq!(result2.unwrap(), vec![1, 2, 3]);
assert_eq!(call_count, 1); }
#[test]
fn test_file_cache() {
let temp_dir = std::env::temp_dir().join("rustdx_test_cache");
let cache = FileCache::new(&temp_dir);
cache.set("key1", b"data1", Duration::from_secs(60));
let result = cache.get("key1");
assert_eq!(result, Some(b"data1".to_vec()));
cache.clear();
}
#[test]
fn test_file_cache_expiry() {
let temp_dir = std::env::temp_dir().join("rustdx_test_cache_expiry");
let cache = FileCache::new(&temp_dir);
cache.set("key1", b"data1", Duration::from_millis(10));
let result = cache.get("key1");
assert_eq!(result, Some(b"data1".to_vec()));
thread::sleep(Duration::from_millis(20));
let result = cache.get("key1");
assert_eq!(result, None);
cache.clear();
}
}