use crate::error::Result;
use crate::traits::cache::Cache;
use async_trait::async_trait;
use moka::Expiry;
use moka::future::Cache as MokaCache;
use std::time::{Duration, Instant};
const DEFAULT_TTL: Duration = Duration::from_secs(86400);
#[derive(Clone)]
struct CacheEntry {
value: Vec<u8>,
custom_ttl: Option<Duration>,
}
struct CacheExpiry {
default_ttl: Duration,
}
impl Expiry<String, CacheEntry> for CacheExpiry {
fn expire_after_create(
&self,
_key: &String,
value: &CacheEntry,
_created_at: Instant,
) -> Option<Duration> {
Some(value.custom_ttl.unwrap_or(self.default_ttl))
}
fn expire_after_read(
&self,
_key: &String,
_value: &CacheEntry,
_read_at: Instant,
duration_until_expiry: Option<Duration>,
_last_modified_at: Instant,
) -> Option<Duration> {
duration_until_expiry
}
fn expire_after_update(
&self,
_key: &String,
value: &CacheEntry,
_updated_at: Instant,
_duration_until_expiry: Option<Duration>,
) -> Option<Duration> {
Some(value.custom_ttl.unwrap_or(self.default_ttl))
}
}
#[derive(Clone)]
pub struct InMemoryCache {
inner: MokaCache<String, CacheEntry>,
}
impl InMemoryCache {
pub fn new(max_entries: u64) -> Self {
let expiry = CacheExpiry {
default_ttl: DEFAULT_TTL,
};
let cache = MokaCache::builder()
.max_capacity(max_entries)
.expire_after(expiry)
.build();
Self { inner: cache }
}
pub fn with_ttl(max_entries: u64, default_ttl: Duration) -> Self {
let expiry = CacheExpiry { default_ttl };
let cache = MokaCache::builder()
.max_capacity(max_entries)
.expire_after(expiry)
.build();
Self { inner: cache }
}
pub fn builder() -> InMemoryCacheBuilder {
InMemoryCacheBuilder::new()
}
pub async fn run_pending_tasks(&self) {
self.inner.run_pending_tasks().await;
}
pub fn entry_count(&self) -> u64 {
self.inner.entry_count()
}
pub fn weighted_size(&self) -> u64 {
self.inner.weighted_size()
}
}
pub struct InMemoryCacheBuilder {
max_entries: u64,
default_ttl: Duration,
}
impl InMemoryCacheBuilder {
pub fn new() -> Self {
Self {
max_entries: 10_000,
default_ttl: DEFAULT_TTL,
}
}
pub fn max_entries(mut self, max: u64) -> Self {
self.max_entries = max;
self
}
pub fn time_to_live(mut self, ttl: Duration) -> Self {
self.default_ttl = ttl;
self
}
pub fn build(self) -> InMemoryCache {
let expiry = CacheExpiry {
default_ttl: self.default_ttl,
};
let cache = MokaCache::builder()
.max_capacity(self.max_entries)
.expire_after(expiry)
.build();
InMemoryCache { inner: cache }
}
}
impl Default for InMemoryCacheBuilder {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Cache for InMemoryCache {
async fn get_bytes(&self, key: &str) -> Result<Option<Vec<u8>>> {
Ok(self.inner.get(key).await.map(|entry| entry.value))
}
async fn set_bytes(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
let entry = CacheEntry {
value,
custom_ttl: ttl,
};
self.inner.insert(key.to_string(), entry).await;
Ok(())
}
async fn delete(&self, key: &str) -> Result<()> {
self.inner.remove(key).await;
Ok(())
}
async fn clear(&self) -> Result<()> {
self.inner.invalidate_all();
self.inner.run_pending_tasks().await;
Ok(())
}
fn is_healthy(&self) -> bool {
true }
}
impl Default for InMemoryCache {
fn default() -> Self {
Self::new(10_000)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::cache::CacheExt;
#[tokio::test]
async fn test_get_set() {
let cache = InMemoryCache::new(100);
cache.set("key1", &"value1", None).await.unwrap();
let value: Option<String> = cache.get("key1").await.unwrap();
assert_eq!(value, Some("value1".to_string()));
}
#[tokio::test]
async fn test_ttl_expiration() {
let cache = InMemoryCache::with_ttl(100, Duration::from_millis(50));
cache
.set("key1", &"value1", Some(Duration::from_millis(10)))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
cache.run_pending_tasks().await;
let value: Option<String> = cache.get("key1").await.unwrap();
assert_eq!(value, None);
}
#[tokio::test]
async fn test_delete() {
let cache = InMemoryCache::new(100);
cache.set("key1", &"value1", None).await.unwrap();
cache.delete("key1").await.unwrap();
let value: Option<String> = cache.get("key1").await.unwrap();
assert_eq!(value, None);
}
#[tokio::test]
async fn test_clear() {
let cache = InMemoryCache::new(100);
cache.set("key1", &"value1", None).await.unwrap();
cache.set("key2", &"value2", None).await.unwrap();
cache.clear().await.unwrap();
assert_eq!(cache.get::<String>("key1").await.unwrap(), None);
assert_eq!(cache.get::<String>("key2").await.unwrap(), None);
}
#[tokio::test]
async fn test_bounded_cache_does_not_grow_unbounded() {
let cache = InMemoryCache::new(10);
for i in 0..100 {
cache
.set(&format!("key{}", i), &format!("value{}", i), None)
.await
.unwrap();
}
cache.run_pending_tasks().await;
let size = cache.entry_count();
assert!(
size <= 15, "Cache should be bounded near max_entries, got {}",
size
);
}
#[tokio::test]
async fn test_concurrent_access() {
use std::sync::Arc;
let cache = Arc::new(InMemoryCache::new(1000));
let mut handles = vec![];
for i in 0..10 {
let cache = cache.clone();
handles.push(tokio::spawn(async move {
for j in 0..100 {
let key = format!("key{}_{}", i, j);
cache
.set(&key, &format!("value{}_{}", i, j), None)
.await
.unwrap();
let _: Option<String> = cache.get(&key).await.unwrap();
}
}));
}
for handle in handles {
handle.await.unwrap();
}
cache.set("final", &"value", None).await.unwrap();
let value: Option<String> = cache.get("final").await.unwrap();
assert_eq!(value, Some("value".to_string()));
}
#[tokio::test]
async fn test_builder_pattern() {
let cache = InMemoryCache::builder()
.max_entries(500)
.time_to_live(Duration::from_secs(60))
.build();
cache.set("key", &"value", None).await.unwrap();
let value: Option<String> = cache.get("key").await.unwrap();
assert_eq!(value, Some("value".to_string()));
}
}