use anyhow::Result;
use async_trait::async_trait;
use moka::sync::Cache as MokaCache;
use crate::utils::error::OpenCratesError;
use moka::future::Cache;
#[cfg(feature = "redis")]
use deadpool_redis::Pool as RedisPool;
#[cfg(feature = "redis")]
use redis::{AsyncCommands, Client as RedisClient};
use serde::{Deserialize, Serialize};
use tracing::Level;
use std::collections::HashMap;
use std::fmt;
use std::hash::Hash;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use tokio::sync::RwLock;
pub trait CacheKey: Clone + Eq + Hash + Send + Sync + fmt::Debug + 'static {}
impl<T> CacheKey for T where T: Clone + Eq + Hash + Send + Sync + fmt::Debug + 'static {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub entries: usize,
pub sets: u64,
pub deletes: u64,
pub evictions: u64,
pub memory_usage: usize,
pub uptime: Duration,
}
impl Default for CacheStats {
fn default() -> Self {
Self::new()
}
}
impl CacheStats {
#[must_use]
pub fn new() -> Self {
Self {
hits: 0,
misses: 0,
entries: 0,
sets: 0,
deletes: 0,
evictions: 0,
memory_usage: 0,
uptime: Duration::new(0, 0),
}
}
pub fn calculate_hit_rate(&mut self) -> f64 {
let total = self.hits + self.misses;
if total > 0 {
self.hits as f64 / total as f64
} else {
0.0
}
}
}
#[async_trait]
pub trait CacheBackend<K, V>: Send + Sync + fmt::Debug
where
K: CacheKey + fmt::Debug,
V: Clone + Send + Sync + fmt::Debug + 'static,
{
async fn get(&self, key: &K) -> Result<Option<V>>;
async fn set(&self, key: K, value: V, ttl: Option<Duration>) -> Result<()>;
async fn delete(&self, key: &K) -> Result<bool>;
async fn exists(&self, key: &K) -> Result<bool>;
async fn clear(&self) -> Result<()>;
async fn keys(&self) -> Result<Vec<K>>;
async fn size(&self) -> Result<usize>;
async fn stats(&self) -> Result<CacheStats>;
async fn expire(&self, key: &K, ttl: Duration) -> Result<bool>;
async fn ttl(&self, key: &K) -> Result<Option<Duration>>;
}
#[derive(Debug)]
pub struct MemoryCache<K, V>
where
K: CacheKey + Eq + Hash + Send + Sync + fmt::Debug + 'static,
V: Clone + Send + Sync + fmt::Debug + 'static,
{
cache: Arc<MokaCache<K, V>>,
stats: Arc<RwLock<CacheStats>>,
created_at: Instant,
}
impl<K, V> MemoryCache<K, V>
where
K: CacheKey,
V: Clone + Send + Sync + fmt::Debug + 'static,
{
#[must_use]
pub fn new(max_capacity: usize, _max_memory: usize, default_ttl: Duration) -> Self {
Self {
cache: Arc::new(
MokaCache::builder()
.max_capacity(max_capacity as u64)
.time_to_live(default_ttl)
.build(),
),
stats: Arc::new(RwLock::new(CacheStats::new())),
created_at: Instant::now(),
}
}
#[must_use]
pub fn with_ttl(max_capacity: u64, ttl: Duration) -> Self {
Self {
cache: Arc::new(
MokaCache::builder()
.max_capacity(max_capacity)
.time_to_live(ttl)
.build(),
),
stats: Arc::new(RwLock::new(CacheStats::new())),
created_at: Instant::now(),
}
}
}
#[async_trait]
impl<K, V> CacheBackend<K, V> for MemoryCache<K, V>
where
K: CacheKey,
V: Clone + Send + Sync + fmt::Debug + 'static,
{
async fn get(&self, key: &K) -> Result<Option<V>> {
let result = self.cache.get(key);
let mut stats = self.stats.write().await;
if result.is_some() {
stats.hits += 1;
} else {
stats.misses += 1;
}
Ok(result)
}
async fn set(&self, key: K, value: V, _ttl: Option<Duration>) -> Result<()> {
self.cache.insert(key, value);
let mut stats = self.stats.write().await;
stats.sets += 1;
Ok(())
}
async fn delete(&self, key: &K) -> Result<bool> {
let deleted = self.cache.remove(key).is_some();
if deleted {
let mut stats = self.stats.write().await;
stats.deletes += 1;
}
Ok(deleted)
}
async fn exists(&self, key: &K) -> Result<bool> {
Ok(self.cache.contains_key(key))
}
async fn clear(&self) -> Result<()> {
self.cache.invalidate_all();
Ok(())
}
async fn keys(&self) -> Result<Vec<K>> {
Ok(Vec::new())
}
async fn size(&self) -> Result<usize> {
Ok(self.cache.entry_count() as usize)
}
async fn stats(&self) -> Result<CacheStats> {
let mut stats = self.stats.read().await.clone();
stats.entries = self.cache.entry_count() as usize;
stats.uptime = self.created_at.elapsed();
Ok(stats)
}
async fn expire(&self, _key: &K, _ttl: Duration) -> Result<bool> {
Ok(true)
}
async fn ttl(&self, _key: &K) -> Result<Option<Duration>> {
Ok(None)
}
}
#[cfg(feature = "redis")]
pub struct RedisCache<K, V> {
pool: RedisPool,
prefix: String,
_phantom: std::marker::PhantomData<(K, V)>,
}
#[cfg(feature = "redis")]
impl<K, V> fmt::Debug for RedisCache<K, V> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RedisCache")
.field("prefix", &self.prefix)
.field("pool", &"RedisPool") .finish()
}
}
#[cfg(feature = "redis")]
impl<K, V> RedisCache<K, V>
where
K: CacheKey + redis::ToRedisArgs + redis::FromRedisValue,
V: Clone + Send + Sync + fmt::Debug + Serialize + for<'de> Deserialize<'de> + 'static,
{
pub async fn new(redis_url: &str) -> Result<Self> {
let _client = RedisClient::open(redis_url)
.map_err(|e| anyhow::anyhow!("Failed to create Redis client: {}", e))?;
let manager = deadpool_redis::Manager::new(redis_url)
.map_err(|e| anyhow::anyhow!("Failed to create Redis manager: {}", e))?;
let pool = deadpool_redis::Pool::builder(manager)
.build()
.map_err(|e| anyhow::anyhow!("Failed to create Redis pool: {}", e))?;
Ok(Self {
pool,
prefix: "opencrates".to_string(),
_phantom: std::marker::PhantomData,
})
}
async fn get_connection(&self) -> Result<deadpool_redis::Connection> {
self.pool
.get()
.await
.map_err(|e| anyhow::anyhow!("Failed to get Redis connection: {}", e))
}
fn make_key(&self, key: &K) -> String {
format!("{}:{:?}", self.prefix, key)
}
}
#[cfg(feature = "redis")]
#[async_trait]
impl<K, V> CacheBackend<K, V> for RedisCache<K, V>
where
K: CacheKey + redis::ToRedisArgs + redis::FromRedisValue,
V: Clone + Send + Sync + fmt::Debug + Serialize + for<'de> Deserialize<'de> + 'static,
{
async fn get(&self, key: &K) -> Result<Option<V>> {
let redis_key = self.make_key(key);
let mut conn = self.get_connection().await?;
let result: Option<String> = conn
.get(&redis_key)
.await
.map_err(|e| anyhow::anyhow!("Redis get error: {}", e))?;
match result {
Some(json_str) => {
let value: V = serde_json::from_str(&json_str)
.map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?;
Ok(Some(value))
}
None => Ok(None),
}
}
async fn set(&self, key: K, value: V, ttl: Option<Duration>) -> Result<()> {
let redis_key = self.make_key(&key);
let json_str = serde_json::to_string(&value)
.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))?;
let mut conn = self.get_connection().await?;
if let Some(ttl) = ttl {
redis::cmd("SETEX")
.arg(&redis_key)
.arg(ttl.as_secs())
.arg(&json_str)
.query_async::<_, ()>(&mut *conn)
.await
.map_err(|e| anyhow::anyhow!("Redis setex error: {}", e))?;
} else {
redis::cmd("SET")
.arg(&redis_key)
.arg(&json_str)
.query_async::<_, ()>(&mut *conn)
.await
.map_err(|e| anyhow::anyhow!("Redis set error: {}", e))?;
}
Ok(())
}
async fn delete(&self, key: &K) -> Result<bool> {
let redis_key = self.make_key(key);
let mut conn = self.get_connection().await?;
let result: i32 = redis::cmd("DEL")
.arg(&redis_key)
.query_async(&mut *conn)
.await
.map_err(|e| anyhow::anyhow!("Redis del error: {}", e))?;
Ok(result > 0)
}
async fn exists(&self, key: &K) -> Result<bool> {
let redis_key = self.make_key(key);
let mut conn = self.get_connection().await?;
let result: bool = redis::cmd("EXISTS")
.arg(&redis_key)
.query_async(&mut *conn)
.await
.map_err(|e| anyhow::anyhow!("Redis exists error: {}", e))?;
Ok(result)
}
async fn clear(&self) -> Result<()> {
let mut conn = self.get_connection().await?;
redis::cmd("FLUSHDB")
.query_async::<_, ()>(&mut *conn)
.await
.map_err(|e| anyhow::anyhow!("Redis flushdb error: {}", e))?;
Ok(())
}
async fn keys(&self) -> Result<Vec<K>> {
Ok(Vec::new())
}
async fn size(&self) -> Result<usize> {
let mut conn = self.get_connection().await?;
let size: usize = redis::cmd("DBSIZE")
.query_async(&mut *conn)
.await
.map_err(|e| anyhow::anyhow!("Redis dbsize error: {}", e))?;
Ok(size)
}
async fn stats(&self) -> Result<CacheStats> {
let mut conn = self.get_connection().await?;
let info: String = redis::cmd("INFO")
.arg("stats")
.query_async(&mut *conn)
.await
.map_err(|e| anyhow::anyhow!("Redis info error: {}", e))?;
let mut stats = CacheStats::new();
for line in info.lines() {
if let Some((key, value)) = line.split_once(':') {
match key {
"keyspace_hits" => {
stats.hits = value.parse().unwrap_or(0);
}
"keyspace_misses" => {
stats.misses = value.parse().unwrap_or(0);
}
"used_memory" => {
stats.memory_usage = value.parse().unwrap_or(0);
}
_ => {}
}
}
}
stats.entries = self.size().await?;
Ok(stats)
}
async fn expire(&self, key: &K, ttl: Duration) -> Result<bool> {
let redis_key = self.make_key(key);
let mut conn = self.get_connection().await?;
let result: bool = redis::cmd("EXPIRE")
.arg(&redis_key)
.arg(ttl.as_secs())
.query_async(&mut *conn)
.await
.map_err(|e| anyhow::anyhow!("Redis expire error: {}", e))?;
Ok(result)
}
async fn ttl(&self, key: &K) -> Result<Option<Duration>> {
let redis_key = self.make_key(key);
let mut conn = self.get_connection().await?;
let ttl_secs: i64 = redis::cmd("TTL")
.arg(&redis_key)
.query_async(&mut *conn)
.await
.map_err(|e| anyhow::anyhow!("Redis ttl error: {}", e))?;
if ttl_secs > 0 {
Ok(Some(Duration::from_secs(ttl_secs as u64)))
} else {
Ok(None)
}
}
}
#[derive(Debug)]
pub struct LayeredCache<K, V> {
l1: Arc<dyn CacheBackend<K, V>>,
l2: Option<Arc<dyn CacheBackend<K, V>>>,
stats: Arc<RwLock<CacheStats>>,
}
impl<K, V> LayeredCache<K, V>
where
K: CacheKey,
V: Clone + Send + Sync + fmt::Debug + 'static,
{
pub fn new(l1: Arc<dyn CacheBackend<K, V>>, l2: Option<Arc<dyn CacheBackend<K, V>>>) -> Self {
Self {
l1,
l2,
stats: Arc::new(RwLock::new(CacheStats::new())),
}
}
pub fn memory_only(memory_cache: Arc<dyn CacheBackend<K, V>>) -> Self {
Self::new(memory_cache, None)
}
pub fn with_redis(
memory_cache: Arc<dyn CacheBackend<K, V>>,
redis_cache: Arc<dyn CacheBackend<K, V>>,
) -> Self {
Self::new(memory_cache, Some(redis_cache))
}
}
#[async_trait]
impl<K, V> CacheBackend<K, V> for LayeredCache<K, V>
where
K: CacheKey,
V: Clone + Send + Sync + fmt::Debug + 'static,
{
async fn get(&self, key: &K) -> Result<Option<V>> {
if let Some(value) = self.l1.get(key).await? {
let mut stats = self.stats.write().await;
stats.hits += 1;
return Ok(Some(value));
}
if let Some(l2) = &self.l2 {
if let Some(value) = l2.get(key).await? {
let _ = self.l1.set(key.clone(), value.clone(), None).await;
let mut stats = self.stats.write().await;
stats.hits += 1;
return Ok(Some(value));
}
}
let mut stats = self.stats.write().await;
stats.misses += 1;
Ok(None)
}
async fn set(&self, key: K, value: V, ttl: Option<Duration>) -> Result<()> {
self.l1.set(key.clone(), value.clone(), ttl).await?;
if let Some(l2) = &self.l2 {
l2.set(key, value, ttl).await?;
}
Ok(())
}
async fn delete(&self, key: &K) -> Result<bool> {
let mut deleted = self.l1.delete(key).await?;
if let Some(l2) = &self.l2 {
deleted = l2.delete(key).await? || deleted;
}
Ok(deleted)
}
async fn exists(&self, key: &K) -> Result<bool> {
if self.l1.exists(key).await? {
return Ok(true);
}
if let Some(l2) = &self.l2 {
return l2.exists(key).await;
}
Ok(false)
}
async fn clear(&self) -> Result<()> {
self.l1.clear().await?;
if let Some(l2) = &self.l2 {
l2.clear().await?;
}
Ok(())
}
async fn keys(&self) -> Result<Vec<K>> {
self.l1.keys().await
}
async fn size(&self) -> Result<usize> {
let l1_size = self.l1.size().await?;
if let Some(l2) = &self.l2 {
let l2_size = l2.size().await?;
Ok(l1_size + l2_size)
} else {
Ok(l1_size)
}
}
async fn stats(&self) -> Result<CacheStats> {
let mut combined_stats = self.stats.read().await.clone();
let l1_stats = self.l1.stats().await?;
combined_stats.hits += l1_stats.hits;
combined_stats.misses += l1_stats.misses;
combined_stats.entries += l1_stats.entries;
combined_stats.memory_usage += l1_stats.memory_usage;
if let Some(l2) = &self.l2 {
let l2_stats = l2.stats().await?;
combined_stats.hits += l2_stats.hits;
combined_stats.misses += l2_stats.misses;
combined_stats.entries += l2_stats.entries;
combined_stats.memory_usage += l2_stats.memory_usage;
}
combined_stats.calculate_hit_rate();
Ok(combined_stats)
}
async fn expire(&self, key: &K, ttl: Duration) -> Result<bool> {
let mut expired = self.l1.expire(key, ttl).await?;
if let Some(l2) = &self.l2 {
expired = l2.expire(key, ttl).await? || expired;
}
Ok(expired)
}
async fn ttl(&self, key: &K) -> Result<Option<Duration>> {
if let Some(ttl) = self.l1.ttl(key).await? {
return Ok(Some(ttl));
}
if let Some(l2) = &self.l2 {
return l2.ttl(key).await;
}
Ok(None)
}
}
#[derive(Debug, Clone)]
pub struct CacheManager {
memory_caches: HashMap<String, Arc<dyn CacheBackend<String, String>>>,
redis_caches: HashMap<String, Arc<dyn CacheBackend<String, String>>>,
layered_caches: HashMap<String, Arc<dyn CacheBackend<String, String>>>,
}
impl CacheManager {
#[must_use]
pub fn new() -> Self {
Self {
memory_caches: HashMap::new(),
redis_caches: HashMap::new(),
layered_caches: HashMap::new(),
}
}
pub fn create_memory_cache(
&mut self,
name: String,
max_capacity: usize,
max_memory: usize,
default_ttl: Duration,
) -> Arc<dyn CacheBackend<String, String>> {
let cache = Arc::new(MemoryCache::new(max_capacity, max_memory, default_ttl));
let cache_backend: Arc<dyn CacheBackend<String, String>> = cache;
self.memory_caches.insert(name, cache_backend.clone());
cache_backend
}
#[cfg(feature = "redis")]
pub async fn create_redis_cache(
&mut self,
name: String,
redis_url: &str,
) -> Result<Arc<dyn CacheBackend<String, String>>> {
let cache = Arc::new(RedisCache::new(redis_url).await?);
let cache_backend: Arc<dyn CacheBackend<String, String>> = cache;
self.redis_caches.insert(name, cache_backend.clone());
Ok(cache_backend)
}
pub fn create_layered_cache(
&mut self,
name: String,
memory_cache: Arc<dyn CacheBackend<String, String>>,
redis_cache: Option<Arc<dyn CacheBackend<String, String>>>,
) -> Arc<dyn CacheBackend<String, String>> {
let cache = Arc::new(LayeredCache::new(memory_cache, redis_cache));
let cache_backend: Arc<dyn CacheBackend<String, String>> = cache;
self.layered_caches.insert(name, cache_backend.clone());
cache_backend
}
#[must_use]
pub fn get_cache(&self, name: &str) -> Option<Arc<dyn CacheBackend<String, String>>> {
self.layered_caches
.get(name)
.or_else(|| self.redis_caches.get(name))
.or_else(|| self.memory_caches.get(name))
.cloned()
}
}
impl Default for CacheManager {
fn default() -> Self {
Self::new()
}
}
pub mod config {
use super::{fmt, Arc, CacheBackend, CacheKey, Duration, MemoryCache, Result};
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub default_ttl: Option<Duration>,
pub cleanup_interval: Duration,
pub enable_metrics: bool,
pub enable_compression: bool,
pub compression_threshold: usize,
pub eviction_policy: EvictionPolicy,
pub write_policy: WritePolicy,
pub consistency_level: ConsistencyLevel,
}
#[derive(Debug, Clone)]
pub enum EvictionPolicy {
LRU,
LFU,
FIFO,
Random,
TTL,
}
#[derive(Debug, Clone)]
pub enum WritePolicy {
WriteThrough,
WriteBehind {
batch_size: usize,
flush_interval: Duration,
},
WriteAround,
}
#[derive(Debug, Clone)]
pub enum ConsistencyLevel {
Eventual,
Strong,
Session,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 10000,
default_ttl: Some(Duration::from_secs(3600)), cleanup_interval: Duration::from_secs(300), enable_metrics: true,
enable_compression: false,
compression_threshold: 1024, eviction_policy: EvictionPolicy::LRU,
write_policy: WritePolicy::WriteThrough,
consistency_level: ConsistencyLevel::Eventual,
}
}
}
impl CacheConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_max_entries(mut self, max_entries: usize) -> Self {
self.max_entries = max_entries;
self
}
#[must_use]
pub fn with_default_ttl(mut self, ttl: Duration) -> Self {
self.default_ttl = Some(ttl);
self
}
#[must_use]
pub fn with_no_ttl(mut self) -> Self {
self.default_ttl = None;
self
}
#[must_use]
pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
self.cleanup_interval = interval;
self
}
#[must_use]
pub fn with_metrics(mut self, enabled: bool) -> Self {
self.enable_metrics = enabled;
self
}
#[must_use]
pub fn with_compression(mut self, enabled: bool, threshold: usize) -> Self {
self.enable_compression = enabled;
self.compression_threshold = threshold;
self
}
#[must_use]
pub fn with_eviction_policy(mut self, policy: EvictionPolicy) -> Self {
self.eviction_policy = policy;
self
}
#[must_use]
pub fn with_write_policy(mut self, policy: WritePolicy) -> Self {
self.write_policy = policy;
self
}
#[must_use]
pub fn with_consistency_level(mut self, level: ConsistencyLevel) -> Self {
self.consistency_level = level;
self
}
pub fn validate(&self) -> Result<()> {
if self.max_entries == 0 {
return Err(anyhow::anyhow!("max_entries must be greater than 0"));
}
if self.cleanup_interval.is_zero() {
return Err(anyhow::anyhow!("cleanup_interval must be greater than 0"));
}
if self.enable_compression && self.compression_threshold == 0 {
return Err(anyhow::anyhow!(
"compression_threshold must be greater than 0 when compression is enabled"
));
}
Ok(())
}
}
#[derive(Debug)]
pub struct CacheBuilder {
config: CacheConfig,
}
impl CacheBuilder {
#[must_use]
pub fn new() -> Self {
Self {
config: CacheConfig::default(),
}
}
#[must_use]
pub fn with_config(mut self, config: CacheConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn max_entries(mut self, max_entries: usize) -> Self {
self.config.max_entries = max_entries;
self
}
#[must_use]
pub fn default_ttl(mut self, ttl: Duration) -> Self {
self.config.default_ttl = Some(ttl);
self
}
#[must_use]
pub fn enable_metrics(mut self) -> Self {
self.config.enable_metrics = true;
self
}
#[must_use]
pub fn enable_compression(mut self, threshold: usize) -> Self {
self.config.enable_compression = true;
self.config.compression_threshold = threshold;
self
}
pub async fn build<K, V>(self) -> Result<Arc<dyn CacheBackend<K, V>>>
where
K: CacheKey,
V: Clone + Send + Sync + fmt::Debug + 'static,
{
self.config.validate()?;
let cache: Arc<dyn CacheBackend<K, V>> = match self.config.eviction_policy {
EvictionPolicy::LRU => Arc::new(MemoryCache::new(
self.config.max_entries,
0,
self.config.default_ttl.unwrap(),
)),
EvictionPolicy::LFU => Arc::new(MemoryCache::new(
self.config.max_entries,
0,
self.config.default_ttl.unwrap(),
)), EvictionPolicy::FIFO => Arc::new(MemoryCache::new(
self.config.max_entries,
0,
self.config.default_ttl.unwrap(),
)), EvictionPolicy::Random => Arc::new(MemoryCache::new(
self.config.max_entries,
0,
self.config.default_ttl.unwrap(),
)), EvictionPolicy::TTL => Arc::new(MemoryCache::new(
self.config.max_entries,
0,
self.config.default_ttl.unwrap(),
)), };
Ok(cache)
}
}
impl Default for CacheBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod config_tests {
use super::*;
#[test]
fn test_cache_config_builder() {
let config = CacheConfig::new()
.with_max_entries(5000)
.with_default_ttl(Duration::from_secs(1800))
.with_metrics(true)
.with_compression(true, 512)
.with_eviction_policy(EvictionPolicy::LFU);
assert_eq!(config.max_entries, 5000);
assert_eq!(config.default_ttl, Some(Duration::from_secs(1800)));
assert!(config.enable_metrics);
assert!(config.enable_compression);
assert_eq!(config.compression_threshold, 512);
assert!(matches!(config.eviction_policy, EvictionPolicy::LFU));
}
#[test]
fn test_config_validation() {
let valid_config = CacheConfig::default();
assert!(valid_config.validate().is_ok());
let invalid_config = CacheConfig::default().with_max_entries(0);
assert!(invalid_config.validate().is_err());
let invalid_compression = CacheConfig::default().with_compression(true, 0);
assert!(invalid_compression.validate().is_err());
}
#[tokio::test]
async fn test_cache_builder() {
let cache = CacheBuilder::new()
.max_entries(1000)
.default_ttl(Duration::from_secs(600))
.enable_metrics()
.build::<String, String>()
.await
.unwrap();
cache
.set("test_key".to_string(), "test_value".to_string(), None)
.await
.unwrap();
let result = cache.get(&"test_key".to_string()).await.unwrap();
assert_eq!(result, Some("test_value".to_string()));
}
}
}
pub mod middleware {
use super::{
async_trait, fmt, AsyncCommands, CacheBackend, CacheKey, CacheStats, Duration, Level,
Result, SystemTime,
};
use std::sync::Arc;
#[async_trait]
pub trait CacheMiddleware<K, V>: Send + Sync + fmt::Debug {
async fn before_get(&self, key: &K) -> Result<Option<V>>;
async fn after_get(&self, key: &K, value: &Option<V>) -> Result<()>;
async fn before_set(&self, key: &K, value: &V) -> Result<()>;
async fn after_set(&self, key: &K, value: &V) -> Result<()>;
async fn before_delete(&self, key: &K) -> Result<()>;
async fn after_delete(&self, key: &K, deleted: bool) -> Result<()>;
}
#[derive(Debug)]
pub struct LoggingMiddleware {
log_level: Level,
}
impl LoggingMiddleware {
#[must_use]
pub fn new(log_level: Level) -> Self {
Self { log_level }
}
}
#[async_trait]
impl<K, V> CacheMiddleware<K, V> for LoggingMiddleware
where
K: CacheKey + fmt::Display,
V: fmt::Debug + Sync,
{
async fn before_get(&self, key: &K) -> Result<Option<V>> {
match self.log_level {
Level::ERROR => tracing::error!("Cache GET: {}", key),
Level::WARN => tracing::warn!("Cache GET: {}", key),
Level::INFO => tracing::info!("Cache GET: {}", key),
Level::DEBUG => tracing::debug!("Cache GET: {}", key),
Level::TRACE => tracing::trace!("Cache GET: {}", key),
}
Ok(None)
}
async fn after_get(&self, key: &K, value: &Option<V>) -> Result<()> {
match value {
Some(_) => match self.log_level {
Level::ERROR => tracing::error!("Cache HIT: {}", key),
Level::WARN => tracing::warn!("Cache HIT: {}", key),
Level::INFO => tracing::info!("Cache HIT: {}", key),
Level::DEBUG => tracing::debug!("Cache HIT: {}", key),
Level::TRACE => tracing::trace!("Cache HIT: {}", key),
},
None => match self.log_level {
Level::ERROR => tracing::error!("Cache MISS: {}", key),
Level::WARN => tracing::warn!("Cache MISS: {}", key),
Level::INFO => tracing::info!("Cache MISS: {}", key),
Level::DEBUG => tracing::debug!("Cache MISS: {}", key),
Level::TRACE => tracing::trace!("Cache MISS: {}", key),
},
}
Ok(())
}
async fn before_set(&self, key: &K, _value: &V) -> Result<()> {
match self.log_level {
Level::ERROR => tracing::error!("Cache SET: {}", key),
Level::WARN => tracing::warn!("Cache SET: {}", key),
Level::INFO => tracing::info!("Cache SET: {}", key),
Level::DEBUG => tracing::debug!("Cache SET: {}", key),
Level::TRACE => tracing::trace!("Cache SET: {}", key),
}
Ok(())
}
async fn after_set(&self, key: &K, _value: &V) -> Result<()> {
match self.log_level {
Level::ERROR => tracing::error!("Cache SET completed: {}", key),
Level::WARN => tracing::warn!("Cache SET completed: {}", key),
Level::INFO => tracing::info!("Cache SET completed: {}", key),
Level::DEBUG => tracing::debug!("Cache SET completed: {}", key),
Level::TRACE => tracing::trace!("Cache SET completed: {}", key),
}
Ok(())
}
async fn before_delete(&self, key: &K) -> Result<()> {
match self.log_level {
Level::ERROR => tracing::error!("Cache DELETE: {}", key),
Level::WARN => tracing::warn!("Cache DELETE: {}", key),
Level::INFO => tracing::info!("Cache DELETE: {}", key),
Level::DEBUG => tracing::debug!("Cache DELETE: {}", key),
Level::TRACE => tracing::trace!("Cache DELETE: {}", key),
}
Ok(())
}
async fn after_delete(&self, key: &K, deleted: bool) -> Result<()> {
let status = if deleted { "completed" } else { "failed" };
match self.log_level {
Level::ERROR => {
tracing::error!("Cache DELETE {}: {} (deleted: {})", status, key, deleted);
}
Level::WARN => {
tracing::warn!("Cache DELETE {}: {} (deleted: {})", status, key, deleted);
}
Level::INFO => {
tracing::info!("Cache DELETE {}: {} (deleted: {})", status, key, deleted);
}
Level::DEBUG => {
tracing::debug!("Cache DELETE {}: {} (deleted: {})", status, key, deleted);
}
Level::TRACE => {
tracing::trace!("Cache DELETE {}: {} (deleted: {})", status, key, deleted);
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct MetricsMiddleware {
metrics: Arc<crate::utils::metrics::OpenCratesMetrics>,
}
impl MetricsMiddleware {
#[must_use]
pub fn new(metrics: Arc<crate::utils::metrics::OpenCratesMetrics>) -> Self {
Self { metrics }
}
}
#[async_trait]
impl<K, V> CacheMiddleware<K, V> for MetricsMiddleware
where
K: CacheKey,
V: fmt::Debug + Sync,
{
async fn before_get(&self, _key: &K) -> Result<Option<V>> {
Ok(None)
}
async fn after_get(&self, _key: &K, value: &Option<V>) -> Result<()> {
match value {
Some(_) => self.metrics.record_cache_hit().await.unwrap_or(()),
None => self.metrics.record_cache_miss().await.unwrap_or(()),
}
Ok(())
}
async fn before_set(&self, _key: &K, _value: &V) -> Result<()> {
Ok(())
}
async fn after_set(&self, _key: &K, _value: &V) -> Result<()> {
self.metrics.record_cache_hit().await.unwrap_or(()); Ok(())
}
async fn before_delete(&self, _key: &K) -> Result<()> {
Ok(())
}
async fn after_delete(&self, _key: &K, deleted: bool) -> Result<()> {
if deleted {
self.metrics.record_cache_hit().await.unwrap_or(());
} else {
self.metrics.record_cache_miss().await.unwrap_or(());
}
Ok(())
}
}
#[derive(Debug)]
pub struct MiddlewareCache<K, V> {
cache: Arc<dyn CacheBackend<K, V>>,
middlewares: Vec<Arc<dyn CacheMiddleware<K, V>>>,
}
impl<K, V> MiddlewareCache<K, V>
where
K: CacheKey,
V: Clone + Send + Sync + fmt::Debug + 'static,
{
pub fn new(cache: Arc<dyn CacheBackend<K, V>>) -> Self {
Self {
cache,
middlewares: Vec::new(),
}
}
pub fn add_middleware(mut self, middleware: Arc<dyn CacheMiddleware<K, V>>) -> Self {
self.middlewares.push(middleware);
self
}
async fn run_before_get(&self, key: &K) -> Result<Option<V>> {
for middleware in &self.middlewares {
if let Some(value) = middleware.before_get(key).await? {
return Ok(Some(value));
}
}
Ok(None)
}
async fn run_after_get(&self, key: &K, value: &Option<V>) -> Result<()> {
for middleware in &self.middlewares {
middleware.after_get(key, value).await?;
}
Ok(())
}
async fn run_before_set(&self, key: &K, value: &V) -> Result<()> {
for middleware in &self.middlewares {
middleware.before_set(key, value).await?;
}
Ok(())
}
async fn run_after_set(&self, key: &K, value: &V) -> Result<()> {
for middleware in &self.middlewares {
middleware.after_set(key, value).await?;
}
Ok(())
}
async fn run_before_delete(&self, key: &K) -> Result<()> {
for middleware in &self.middlewares {
middleware.before_delete(key).await?;
}
Ok(())
}
async fn run_after_delete(&self, key: &K, deleted: bool) -> Result<()> {
for middleware in &self.middlewares {
middleware.after_delete(key, deleted).await?;
}
Ok(())
}
}
#[async_trait]
impl<K, V> CacheBackend<K, V> for MiddlewareCache<K, V>
where
K: CacheKey,
V: Clone + Send + Sync + fmt::Debug + 'static,
{
async fn get(&self, key: &K) -> Result<Option<V>> {
if let Some(value) = self.run_before_get(key).await? {
self.run_after_get(key, &Some(value.clone())).await?;
return Ok(Some(value));
}
let result = self.cache.get(key).await?;
self.run_after_get(key, &result).await?;
Ok(result)
}
async fn set(&self, key: K, value: V, ttl: Option<Duration>) -> Result<()> {
self.run_before_set(&key, &value).await?;
self.cache.set(key.clone(), value.clone(), ttl).await?;
self.run_after_set(&key, &value).await?;
Ok(())
}
async fn delete(&self, key: &K) -> Result<bool> {
self.run_before_delete(key).await?;
let deleted = self.cache.delete(key).await?;
self.run_after_delete(key, deleted).await?;
Ok(deleted)
}
async fn exists(&self, key: &K) -> Result<bool> {
self.cache.exists(key).await
}
async fn clear(&self) -> Result<()> {
self.cache.clear().await
}
async fn size(&self) -> Result<usize> {
self.cache.size().await
}
async fn keys(&self) -> Result<Vec<K>> {
self.cache.keys().await
}
async fn stats(&self) -> Result<CacheStats> {
self.cache.stats().await
}
async fn expire(&self, key: &K, ttl: Duration) -> Result<bool> {
self.cache.expire(key, ttl).await
}
async fn ttl(&self, key: &K) -> Result<Option<Duration>> {
self.cache.ttl(key).await
}
}
#[derive(Debug)]
pub struct RateLimitMiddleware {
max_requests_per_second: u64,
window_start: Arc<std::sync::Mutex<SystemTime>>,
request_count: Arc<std::sync::Mutex<u64>>,
}
impl RateLimitMiddleware {
#[must_use]
pub fn new(max_requests_per_second: u64) -> Self {
Self {
max_requests_per_second,
window_start: Arc::new(std::sync::Mutex::new(SystemTime::now())),
request_count: Arc::new(std::sync::Mutex::new(0)),
}
}
fn check_rate_limit(&self) -> Result<()> {
let now = SystemTime::now();
let mut window_start = self.window_start.lock().unwrap();
let mut count = self.request_count.lock().unwrap();
if now.duration_since(*window_start).unwrap_or_default() >= Duration::from_secs(1) {
*window_start = now;
*count = 0;
}
if *count >= self.max_requests_per_second {
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
*count += 1;
Ok(())
}
}
#[async_trait]
impl<K, V> CacheMiddleware<K, V> for RateLimitMiddleware
where
K: CacheKey,
V: fmt::Debug,
{
async fn before_get(&self, _key: &K) -> Result<Option<V>> {
self.check_rate_limit()?;
Ok(None)
}
async fn after_get(&self, _key: &K, _value: &Option<V>) -> Result<()> {
Ok(())
}
async fn before_set(&self, _key: &K, _value: &V) -> Result<()> {
self.check_rate_limit()
}
async fn after_set(&self, _key: &K, _value: &V) -> Result<()> {
Ok(())
}
async fn before_delete(&self, _key: &K) -> Result<()> {
self.check_rate_limit()
}
async fn after_delete(&self, _key: &K, _deleted: bool) -> Result<()> {
Ok(())
}
}
pub struct ValidationMiddleware<K, V> {
key_validator: Arc<dyn Fn(&K) -> Result<()> + Send + Sync>,
value_validator: Arc<dyn Fn(&V) -> Result<()> + Send + Sync>,
}
impl<K, V> fmt::Debug for ValidationMiddleware<K, V> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ValidationMiddleware")
.field("key_validator", &"Fn(&K) -> Result<()>") .field("value_validator", &"Fn(&V) -> Result<()>") .finish()
}
}
impl<K, V> ValidationMiddleware<K, V> {
pub fn new<KF, VF>(key_validator: KF, value_validator: VF) -> Self
where
KF: Fn(&K) -> Result<()> + Send + Sync + 'static,
VF: Fn(&V) -> Result<()> + Send + Sync + 'static,
{
Self {
key_validator: Arc::new(key_validator),
value_validator: Arc::new(value_validator),
}
}
}
#[async_trait]
impl<K, V> CacheMiddleware<K, V> for ValidationMiddleware<K, V>
where
K: CacheKey,
V: fmt::Debug + Sync,
{
async fn before_get(&self, key: &K) -> Result<Option<V>> {
(self.key_validator)(key)?;
Ok(None)
}
async fn after_get(&self, _key: &K, _value: &Option<V>) -> Result<()> {
Ok(())
}
async fn before_set(&self, key: &K, value: &V) -> Result<()> {
(self.key_validator)(key)?;
(self.value_validator)(value)?;
Ok(())
}
async fn after_set(&self, _key: &K, _value: &V) -> Result<()> {
Ok(())
}
async fn before_delete(&self, key: &K) -> Result<()> {
(self.key_validator)(key)?;
Ok(())
}
async fn after_delete(&self, _key: &K, _deleted: bool) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod middleware_tests {
use super::*;
use crate::utils::cache::MemoryCache;
use crate::utils::metrics::OpenCratesMetrics;
#[tokio::test]
async fn test_logging_middleware() {
let cache = Arc::new(MemoryCache::<String, String>::new(
100,
0,
Duration::from_secs(60),
));
let logging_middleware = Arc::new(LoggingMiddleware::new(Level::INFO));
let mut _middleware_cache =
MiddlewareCache::new(cache).add_middleware(logging_middleware);
}
#[tokio::test]
async fn test_metrics_middleware() {
let cache = Arc::new(MemoryCache::<String, String>::new(
100,
0,
Duration::from_secs(60),
));
let metrics: Arc<OpenCratesMetrics> = Arc::new(
crate::utils::metrics::OpenCratesMetrics::new()
.await
.unwrap(),
);
let metrics_middleware = Arc::new(MetricsMiddleware::new(metrics.clone()));
let middleware_cache = MiddlewareCache::new(cache).add_middleware(metrics_middleware);
let _ = middleware_cache.get(&"miss".to_string()).await;
let cache_misses = metrics.cache_misses.get().await;
assert_eq!(
cache_misses, 1,
"Cache misses should be incremented on get miss"
);
middleware_cache
.set("key".to_string(), "value".to_string(), None)
.await
.unwrap();
let cache_hits = metrics.cache_hits.get().await;
assert_eq!(cache_hits, 1, "Cache hits should be incremented on set");
let _ = middleware_cache.get(&"key".to_string()).await;
let cache_hits = metrics.cache_hits.get().await;
assert_eq!(cache_hits, 2, "Cache hits should be incremented on get hit");
}
#[tokio::test]
async fn test_rate_limit_middleware() {
let cache = Arc::new(MemoryCache::new(100, 0, Duration::from_secs(60)));
let rate_limit_middleware = Arc::new(RateLimitMiddleware::new(2));
let middleware_cache =
MiddlewareCache::new(cache).add_middleware(rate_limit_middleware);
middleware_cache
.set("key1".to_string(), "value1".to_string(), None)
.await
.unwrap();
middleware_cache
.set("key2".to_string(), "value2".to_string(), None)
.await
.unwrap();
let result = middleware_cache
.set("key3".to_string(), "value3".to_string(), None)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Rate limit exceeded"));
}
#[tokio::test]
async fn test_validation_middleware() {
let cache = Arc::new(MemoryCache::new(100, 0, Duration::from_secs(60)));
let validation_middleware = Arc::new(ValidationMiddleware::new(
|key: &String| {
if key.is_empty() {
Err(anyhow::anyhow!("Key cannot be empty"))
} else {
Ok(())
}
},
|value: &String| {
if value.len() > 100 {
Err(anyhow::anyhow!("Value too long"))
} else {
Ok(())
}
},
));
let middleware_cache =
MiddlewareCache::new(cache).add_middleware(validation_middleware);
middleware_cache
.set("valid_key".to_string(), "valid_value".to_string(), None)
.await
.unwrap();
let result = middleware_cache
.set(String::new(), "value".to_string(), None)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Key cannot be empty"));
let long_value = "a".repeat(101);
let result = middleware_cache
.set("key".to_string(), long_value, None)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Value too long"));
}
#[tokio::test]
async fn test_multiple_middlewares() {
let cache = Arc::new(MemoryCache::<String, String>::new(
100,
0,
Duration::from_secs(60),
));
let logging_middleware = Arc::new(LoggingMiddleware::new(Level::INFO));
let metrics: Arc<OpenCratesMetrics> = Arc::new(
crate::utils::metrics::OpenCratesMetrics::new()
.await
.unwrap(),
);
let metrics_middleware = Arc::new(MetricsMiddleware::new(metrics.clone()));
let middleware_cache = MiddlewareCache::new(cache)
.add_middleware(logging_middleware)
.add_middleware(metrics_middleware);
middleware_cache.get(&"miss".to_string()).await.unwrap();
middleware_cache
.set("hit".to_string(), "value".to_string(), None)
.await
.unwrap();
middleware_cache.get(&"hit".to_string()).await.unwrap();
let misses = metrics.cache_misses.get().await;
let hits = metrics.cache_hits.get().await;
assert_eq!(misses, 1, "Should have 1 cache miss");
assert_eq!(hits, 2, "Should have 2 cache hits (from set and get)");
}
}
}
pub use config::*;
pub use middleware::*;
pub mod convenience {
use super::{
config::CacheConfig, middleware::MetricsMiddleware, middleware::ValidationMiddleware,
CacheBackend, CacheKey, MiddlewareCache,
};
use crate::utils::metrics::OpenCratesMetrics;
use anyhow::Result;
use std::{fmt, sync::Arc, time::Duration};
pub async fn create_memory_cache<K, V>(max_entries: usize) -> Arc<dyn CacheBackend<K, V>>
where
K: CacheKey,
V: Clone + Send + Sync + fmt::Debug + 'static,
{
Arc::new(super::MemoryCache::new(
max_entries,
0,
Duration::from_secs(3600),
))
}
pub async fn create_monitored_cache<K, V>(
max_entries: usize,
) -> (
Arc<MiddlewareCache<K, V>>,
Arc<crate::utils::metrics::OpenCratesMetrics>,
)
where
K: CacheKey + fmt::Display,
V: Clone + Send + Sync + fmt::Debug + 'static,
{
let cache = Arc::new(super::MemoryCache::new(
max_entries,
0,
Duration::from_secs(3600),
));
let metrics: Arc<OpenCratesMetrics> = Arc::new(
crate::utils::metrics::OpenCratesMetrics::new()
.await
.unwrap(),
);
let middleware = Arc::new(MetricsMiddleware::new(metrics.clone()));
let middleware_cache = Arc::new(MiddlewareCache::new(cache).add_middleware(middleware));
(middleware_cache, metrics)
}
pub async fn create_production_cache<K, V>(
config: CacheConfig,
) -> Result<(
Arc<MiddlewareCache<K, V>>,
Arc<crate::utils::metrics::OpenCratesMetrics>,
)>
where
K: CacheKey + fmt::Display,
V: Clone + Send + Sync + fmt::Debug + 'static,
{
let memory_cache = Arc::new(super::MemoryCache::new(
config.max_entries,
0,
config.default_ttl.unwrap_or(Duration::from_secs(3600)),
));
let metrics: Arc<OpenCratesMetrics> = Arc::new(OpenCratesMetrics::new().await?);
let metrics_middleware = Arc::new(MetricsMiddleware::new(metrics.clone()));
let validation_middleware = Arc::new(ValidationMiddleware::new(
|key: &K| {
if key.to_string().is_empty() {
Err(anyhow::anyhow!("Cache key cannot be empty"))
} else {
Ok(())
}
},
|_value: &V| Ok(()),
));
let middleware_cache = Arc::new(
MiddlewareCache::new(memory_cache)
.add_middleware(metrics_middleware)
.add_middleware(validation_middleware),
);
Ok((middleware_cache, metrics))
}
#[cfg(test)]
mod convenience_tests {
use super::*;
#[tokio::test]
async fn test_create_memory_cache() {
let cache = create_memory_cache::<String, String>(100).await;
cache
.set("key".to_string(), "value".to_string(), None)
.await
.unwrap();
let value = cache.get(&"key".to_string()).await.unwrap();
assert_eq!(value, Some("value".to_string()));
}
#[tokio::test]
async fn test_create_monitored_cache() {
let (cache, metrics) = create_monitored_cache::<String, String>(100).await;
cache.get(&"miss".to_string()).await.unwrap();
let misses = metrics.cache_misses.get().await;
assert_eq!(misses, 1, "Should increment misses on get miss");
}
#[tokio::test]
async fn test_create_production_cache() {
let (cache, metrics) =
create_production_cache::<String, String>(CacheConfig::default())
.await
.unwrap();
cache.get(&"miss".to_string()).await.unwrap();
let misses = metrics.cache_misses.get().await;
assert_eq!(
misses, 1,
"Should increment misses on get miss for production cache"
);
}
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_full_cache_lifecycle() {
let config = CacheConfig::new()
.with_max_entries(100)
.with_default_ttl(Duration::from_millis(100))
.with_cleanup_interval(Duration::from_millis(50))
.with_metrics(true);
let (cache, _metrics) = convenience::create_production_cache::<String, String>(config)
.await
.unwrap();
cache
.set("key1".to_string(), "value1".to_string(), None)
.await
.unwrap();
cache
.set(
"key2".to_string(),
"value2".to_string(),
Some(Duration::from_millis(50)),
)
.await
.unwrap();
assert_eq!(
cache.get(&"key1".to_string()).await.unwrap(),
Some("value1".to_string())
);
assert_eq!(
cache.get(&"key2".to_string()).await.unwrap(),
Some("value2".to_string())
);
}
}