use serde_json::Value;
use sha2::{Digest, Sha256};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
pub trait Cache: Send + Sync {
fn get(&self, key: &str) -> Option<CacheEntry>;
fn set(&self, key: &str, entry: CacheEntry);
fn delete(&self, key: &str);
}
#[derive(Debug, Clone)]
pub struct CacheEntry {
pub value: Value,
pub expires_at: u64,
pub cache_control: CacheControlDirectives,
}
#[derive(Debug, Clone, Default)]
pub struct CacheControlDirectives {
pub no_store: bool,
pub no_cache: bool,
pub private: bool,
pub max_age: Option<u64>,
pub stale_while_revalidate: Option<u64>,
}
pub fn parse_cache_control(header: Option<&str>) -> CacheControlDirectives {
let mut directives = CacheControlDirectives::default();
let header = match header {
Some(h) => h,
None => return directives,
};
for part in header.split(',') {
let part = part.trim().to_lowercase();
if part == "no-store" {
directives.no_store = true;
} else if part == "no-cache" {
directives.no_cache = true;
} else if part == "private" {
directives.private = true;
} else if let Some(value) = part.strip_prefix("max-age=") {
if let Ok(v) = value.parse() {
directives.max_age = Some(v);
}
} else if let Some(value) = part.strip_prefix("stale-while-revalidate=") {
if let Ok(v) = value.parse() {
directives.stale_while_revalidate = Some(v);
}
}
}
directives
}
pub fn create_cache_entry(value: Value, cache_control_header: Option<&str>) -> Option<CacheEntry> {
let cache_control = parse_cache_control(cache_control_header);
if cache_control.no_store {
return None;
}
let max_age = cache_control.max_age?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
Some(CacheEntry {
value,
expires_at: now + max_age,
cache_control,
})
}
pub fn generate_cache_key(method: &str, url: &str, auth_hash: Option<&str>) -> String {
let mut key = format!("{}:{}", method.to_uppercase(), url);
if let Some(hash) = auth_hash {
key.push(':');
key.push_str(hash);
}
key
}
pub fn hash_string(s: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(s.as_bytes());
let result = hasher.finalize();
hex::encode(&result[..8])
}
pub struct MemoryCache {
store: Arc<RwLock<HashMap<String, CacheEntry>>>,
order: Arc<RwLock<VecDeque<String>>>,
max_entries: usize,
}
impl MemoryCache {
pub fn new(max_entries: usize) -> Self {
Self {
store: Arc::new(RwLock::new(HashMap::with_capacity(max_entries))),
order: Arc::new(RwLock::new(VecDeque::with_capacity(max_entries))),
max_entries,
}
}
pub fn size(&self) -> usize {
self.store.read().unwrap().len()
}
pub fn clear(&self) {
let mut store = self.store.write().unwrap();
let mut order = self.order.write().unwrap();
store.clear();
order.clear();
}
}
impl Cache for MemoryCache {
fn get(&self, key: &str) -> Option<CacheEntry> {
let store = self.store.read().unwrap();
let entry = store.get(key)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
if entry.expires_at < now {
if let Some(swr) = entry.cache_control.stale_while_revalidate {
let stale_deadline = entry.expires_at + swr;
if now < stale_deadline {
return Some(entry.clone());
}
}
return None;
}
Some(entry.clone())
}
fn set(&self, key: &str, entry: CacheEntry) {
if entry.cache_control.no_store {
return;
}
let mut store = self.store.write().unwrap();
let mut order = self.order.write().unwrap();
while store.len() >= self.max_entries {
if let Some(oldest) = order.pop_front() {
store.remove(&oldest);
} else {
break;
}
}
if !store.contains_key(key) {
order.push_back(key.to_string());
}
store.insert(key.to_string(), entry);
}
fn delete(&self, key: &str) {
let mut store = self.store.write().unwrap();
let mut order = self.order.write().unwrap();
store.remove(key);
order.retain(|k| k != key);
}
}
impl Default for MemoryCache {
fn default() -> Self {
Self::new(100)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_parse_cache_control() {
let d = parse_cache_control(None);
assert!(!d.no_store);
assert!(d.max_age.is_none());
let d = parse_cache_control(Some("no-store"));
assert!(d.no_store);
let d = parse_cache_control(Some("max-age=3600"));
assert_eq!(d.max_age, Some(3600));
let d = parse_cache_control(Some("private, max-age=300, stale-while-revalidate=60"));
assert!(d.private);
assert_eq!(d.max_age, Some(300));
assert_eq!(d.stale_while_revalidate, Some(60));
}
#[test]
fn test_create_cache_entry() {
assert!(create_cache_entry(json!({}), Some("no-store")).is_none());
assert!(create_cache_entry(json!({}), Some("private")).is_none());
let entry = create_cache_entry(json!({"test": true}), Some("max-age=3600"));
assert!(entry.is_some());
let entry = entry.unwrap();
assert_eq!(entry.value, json!({"test": true}));
}
#[test]
fn test_memory_cache() {
let cache = MemoryCache::new(2);
let entry = create_cache_entry(json!("v1"), Some("max-age=3600")).unwrap();
cache.set("k1", entry);
assert!(cache.get("k1").is_some());
assert!(cache.get("k2").is_none());
cache.delete("k1");
assert!(cache.get("k1").is_none());
}
#[test]
fn test_hash_string() {
let h1 = hash_string("test");
let h2 = hash_string("test");
assert_eq!(h1, h2);
let h3 = hash_string("other");
assert_ne!(h1, h3);
}
}