use crate::error::{EdgeError, Result};
use bytes::Bytes;
use chrono::{DateTime, Utc};
use lru::LruCache;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CachePolicy {
Lru,
Lfu,
Ttl,
SizeBased,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub max_size: usize,
pub policy: CachePolicy,
pub ttl_secs: Option<u64>,
pub persistent: bool,
pub cache_dir: Option<PathBuf>,
pub max_entries: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_size: crate::DEFAULT_CACHE_SIZE,
policy: CachePolicy::Lru,
ttl_secs: Some(3600), persistent: false,
cache_dir: None,
max_entries: 1000,
}
}
}
impl CacheConfig {
pub fn minimal() -> Self {
Self {
max_size: 1024 * 1024, policy: CachePolicy::Lru,
ttl_secs: Some(1800), persistent: false,
cache_dir: None,
max_entries: 100,
}
}
pub fn offline_first() -> Self {
Self {
max_size: 50 * 1024 * 1024, policy: CachePolicy::Lru,
ttl_secs: None, persistent: true,
cache_dir: Some(PathBuf::from(".oxigdal_cache")),
max_entries: 5000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheEntry {
pub key: String,
pub data: Bytes,
pub created_at: DateTime<Utc>,
pub accessed_at: DateTime<Utc>,
pub access_count: u64,
pub size: usize,
pub expires_at: Option<DateTime<Utc>>,
}
impl CacheEntry {
pub fn new(key: String, data: Bytes) -> Self {
let now = Utc::now();
let size = data.len();
Self {
key,
data,
created_at: now,
accessed_at: now,
access_count: 0,
size,
expires_at: None,
}
}
pub fn with_ttl(key: String, data: Bytes, ttl_secs: u64) -> Self {
let mut entry = Self::new(key, data);
entry.expires_at = Some(Utc::now() + chrono::Duration::seconds(ttl_secs as i64));
entry
}
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
Utc::now() > expires_at
} else {
false
}
}
pub fn mark_accessed(&mut self) {
self.accessed_at = Utc::now();
self.access_count = self.access_count.saturating_add(1);
}
}
pub struct Cache {
config: CacheConfig,
lru_cache: Arc<RwLock<LruCache<String, CacheEntry>>>,
metadata: Arc<RwLock<HashMap<String, CacheMetadata>>>,
current_size: Arc<RwLock<usize>>,
persistent_storage: Option<sled::Db>,
}
#[derive(Debug, Clone)]
struct CacheMetadata {
size: usize,
access_count: u64,
}
impl Cache {
pub fn new(config: CacheConfig) -> Result<Self> {
let max_entries = NonZeroUsize::new(config.max_entries)
.ok_or_else(|| EdgeError::invalid_config("max_entries must be greater than 0"))?;
let lru_cache = Arc::new(RwLock::new(LruCache::new(max_entries)));
let metadata = Arc::new(RwLock::new(HashMap::new()));
let current_size = Arc::new(RwLock::new(0));
let persistent_storage = if config.persistent {
if let Some(cache_dir) = &config.cache_dir {
let db = sled::open(cache_dir).map_err(|e| EdgeError::storage(e.to_string()))?;
Some(db)
} else {
None
}
} else {
None
};
Ok(Self {
config,
lru_cache,
metadata,
current_size,
persistent_storage,
})
}
pub fn get(&self, key: &str) -> Result<Option<Bytes>> {
let mut cache = self.lru_cache.write();
if let Some(entry) = cache.get_mut(key) {
if !entry.is_expired() {
entry.mark_accessed();
return Ok(Some(entry.data.clone()));
} else {
cache.pop(key);
let mut meta = self.metadata.write();
meta.remove(key);
}
}
drop(cache);
if let Some(db) = &self.persistent_storage {
if let Some(value) = db.get(key).map_err(|e| EdgeError::storage(e.to_string()))? {
let entry: CacheEntry = serde_json::from_slice(&value)
.map_err(|e| EdgeError::deserialization(e.to_string()))?;
if !entry.is_expired() {
let mut cache = self.lru_cache.write();
cache.put(key.to_string(), entry.clone());
return Ok(Some(entry.data));
}
}
}
Ok(None)
}
pub fn put(&self, key: String, data: Bytes) -> Result<()> {
let entry_size = data.len();
if entry_size > self.config.max_size {
return Err(EdgeError::cache(format!(
"Entry size {} exceeds max cache size {}",
entry_size, self.config.max_size
)));
}
let entry = if let Some(ttl) = self.config.ttl_secs {
CacheEntry::with_ttl(key.clone(), data, ttl)
} else {
CacheEntry::new(key.clone(), data)
};
self.evict_if_needed(entry_size)?;
let mut cache = self.lru_cache.write();
cache.put(key.clone(), entry.clone());
drop(cache);
let mut meta = self.metadata.write();
meta.insert(
key.clone(),
CacheMetadata {
size: entry_size,
access_count: 0,
},
);
drop(meta);
let mut current_size = self.current_size.write();
*current_size = current_size.saturating_add(entry_size);
drop(current_size);
if let Some(db) = &self.persistent_storage {
let serialized =
serde_json::to_vec(&entry).map_err(|e| EdgeError::serialization(e.to_string()))?;
db.insert(key.as_bytes(), serialized)
.map_err(|e| EdgeError::storage(e.to_string()))?;
}
Ok(())
}
pub fn remove(&self, key: &str) -> Result<Option<Bytes>> {
let mut cache = self.lru_cache.write();
let entry = cache.pop(key);
drop(cache);
if let Some(ref e) = entry {
let mut meta = self.metadata.write();
meta.remove(key);
drop(meta);
let mut current_size = self.current_size.write();
*current_size = current_size.saturating_sub(e.size);
drop(current_size);
if let Some(db) = &self.persistent_storage {
db.remove(key.as_bytes())
.map_err(|e| EdgeError::storage(e.to_string()))?;
}
}
Ok(entry.map(|e| e.data))
}
pub fn clear(&self) -> Result<()> {
let mut cache = self.lru_cache.write();
cache.clear();
drop(cache);
let mut meta = self.metadata.write();
meta.clear();
drop(meta);
let mut current_size = self.current_size.write();
*current_size = 0;
drop(current_size);
if let Some(db) = &self.persistent_storage {
db.clear().map_err(|e| EdgeError::storage(e.to_string()))?;
}
Ok(())
}
pub fn size(&self) -> usize {
*self.current_size.read()
}
pub fn len(&self) -> usize {
self.lru_cache.read().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
fn evict_if_needed(&self, new_entry_size: usize) -> Result<()> {
let current_size = *self.current_size.read();
let target_size = self.config.max_size.saturating_sub(new_entry_size);
if current_size <= target_size {
return Ok(());
}
let mut to_evict = Vec::new();
let mut freed_size = 0;
match self.config.policy {
CachePolicy::Lru => {
let mut cache = self.lru_cache.write();
while freed_size < current_size.saturating_sub(target_size) && !cache.is_empty() {
if let Some((key, entry)) = cache.pop_lru() {
freed_size = freed_size.saturating_add(entry.size);
to_evict.push(key);
}
}
}
CachePolicy::Lfu => {
let meta = self.metadata.read();
let mut entries: Vec<_> = meta.iter().collect();
entries.sort_by_key(|(_, m)| m.access_count);
for (key, metadata) in entries {
if freed_size >= current_size.saturating_sub(target_size) {
break;
}
freed_size = freed_size.saturating_add(metadata.size);
to_evict.push(key.clone());
}
}
CachePolicy::Ttl => {
let cache = self.lru_cache.read();
for (key, entry) in cache.iter() {
if entry.is_expired() {
freed_size = freed_size.saturating_add(entry.size);
to_evict.push(key.clone());
}
}
}
CachePolicy::SizeBased => {
let meta = self.metadata.read();
let mut entries: Vec<_> = meta.iter().collect();
entries.sort_by_key(|(_, m)| std::cmp::Reverse(m.size));
for (key, metadata) in entries {
if freed_size >= current_size.saturating_sub(target_size) {
break;
}
freed_size = freed_size.saturating_add(metadata.size);
to_evict.push(key.clone());
}
}
}
for key in to_evict {
self.remove(&key)?;
}
Ok(())
}
pub fn stats(&self) -> CacheStats {
CacheStats {
entries: self.len(),
size_bytes: self.size(),
max_size_bytes: self.config.max_size,
max_entries: self.config.max_entries,
policy: self.config.policy,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStats {
pub entries: usize,
pub size_bytes: usize,
pub max_size_bytes: usize,
pub max_entries: usize,
pub policy: CachePolicy,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_creation() {
let config = CacheConfig::default();
let cache = Cache::new(config);
assert!(cache.is_ok());
}
#[test]
fn test_cache_put_get() -> Result<()> {
let config = CacheConfig::minimal();
let cache = Cache::new(config)?;
let key = "test_key".to_string();
let data = Bytes::from("test_data");
cache.put(key.clone(), data.clone())?;
let retrieved = cache.get(&key)?;
assert_eq!(retrieved, Some(data));
Ok(())
}
#[test]
fn test_cache_eviction() -> Result<()> {
let mut config = CacheConfig::minimal();
config.max_size = 100;
config.max_entries = 10;
let cache = Cache::new(config)?;
for i in 0..5 {
let key = format!("key_{}", i);
let data = Bytes::from(vec![0u8; 25]);
cache.put(key, data)?;
}
let key = "new_key".to_string();
let data = Bytes::from(vec![0u8; 25]);
cache.put(key.clone(), data.clone())?;
let retrieved = cache.get(&key)?;
assert_eq!(retrieved, Some(data));
Ok(())
}
#[test]
fn test_cache_remove() -> Result<()> {
let config = CacheConfig::minimal();
let cache = Cache::new(config)?;
let key = "test_key".to_string();
let data = Bytes::from("test_data");
cache.put(key.clone(), data.clone())?;
let removed = cache.remove(&key)?;
assert_eq!(removed, Some(data));
assert_eq!(cache.get(&key)?, None);
Ok(())
}
#[test]
fn test_cache_clear() -> Result<()> {
let config = CacheConfig::minimal();
let cache = Cache::new(config)?;
for i in 0..5 {
let key = format!("key_{}", i);
let data = Bytes::from(format!("data_{}", i));
cache.put(key, data)?;
}
assert_eq!(cache.len(), 5);
cache.clear()?;
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
Ok(())
}
#[test]
fn test_entry_expiration() {
let key = "test".to_string();
let data = Bytes::from("data");
let entry = CacheEntry::with_ttl(key, data, 0);
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(entry.is_expired());
}
}