use super::new_backend::CacheBackend;
use crate::error::{CacheError, Result};
use async_trait::async_trait;
use moka::future::Cache;
use std::collections::HashMap;
use std::time::Duration;
#[derive(Clone)]
pub struct MemoryBackend {
cache: Cache<String, Vec<u8>>,
}
impl MemoryBackend {
pub fn new() -> Self {
Self::builder().build()
}
pub fn builder() -> MemoryBackendBuilder {
MemoryBackendBuilder::default()
}
}
impl Default for MemoryBackend {
fn default() -> Self {
Self::new()
}
}
pub struct MemoryBackendBuilder {
capacity: u64,
ttl: Option<Duration>,
time_to_idle: Option<Duration>,
}
impl Default for MemoryBackendBuilder {
fn default() -> Self {
Self {
capacity: 10_000, ttl: None,
time_to_idle: None,
}
}
}
impl MemoryBackendBuilder {
pub fn capacity(mut self, capacity: u64) -> Self {
self.capacity = capacity;
self
}
pub fn ttl(mut self, ttl: Duration) -> Self {
self.ttl = Some(ttl);
self
}
pub fn time_to_idle(mut self, ttl: Duration) -> Self {
self.time_to_idle = Some(ttl);
self
}
pub fn build(self) -> MemoryBackend {
let mut builder = Cache::builder().max_capacity(self.capacity);
if let Some(ttl) = self.ttl {
builder = builder.time_to_live(ttl);
}
if let Some(tti) = self.time_to_idle {
builder = builder.time_to_idle(tti);
}
MemoryBackend {
cache: builder.build(),
}
}
}
#[async_trait]
impl CacheBackend for MemoryBackend {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
Ok(self.cache.get(key).await)
}
async fn set(&self, key: &str, value: Vec<u8>, _ttl: Option<Duration>) -> Result<()> {
self.cache.insert(key.to_string(), value).await;
Ok(())
}
async fn delete(&self, key: &str) -> Result<()> {
self.cache.invalidate(key).await;
Ok(())
}
async fn exists(&self, key: &str) -> Result<bool> {
Ok(self.cache.contains_key(key))
}
async fn clear(&self) -> Result<()> {
self.cache.invalidate_all();
Ok(())
}
async fn close(&self) -> Result<()> {
self.cache.invalidate_all();
Ok(())
}
async fn ttl(&self, key: &str) -> Result<Option<Duration>> {
if self.cache.contains_key(key) {
Ok(None)
} else {
Err(CacheError::NotFound(key.to_string()))
}
}
async fn expire(&self, _key: &str, _ttl: Duration) -> Result<bool> {
Ok(false)
}
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
async fn stats(&self) -> Result<HashMap<String, String>> {
let mut stats = HashMap::new();
stats.insert("type".to_string(), "memory".to_string());
stats.insert(
"entry_count".to_string(),
self.cache.entry_count().to_string(),
);
stats.insert(
"weighted_size".to_string(),
self.cache.weighted_size().to_string(),
);
Ok(stats)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_memory_backend_basic() {
let backend = MemoryBackend::new();
backend.set("key1", b"value1".to_vec(), None).await.unwrap();
let value = backend.get("key1").await.unwrap();
assert_eq!(value, Some(b"value1".to_vec()));
assert!(backend.exists("key1").await.unwrap());
assert!(!backend.exists("key2").await.unwrap());
backend.delete("key1").await.unwrap();
assert!(!backend.exists("key1").await.unwrap());
}
#[tokio::test]
async fn test_memory_backend_ttl() {
let backend = MemoryBackend::builder()
.ttl(Duration::from_millis(100))
.build();
backend.set("key1", b"value1".to_vec(), None).await.unwrap();
assert!(backend.exists("key1").await.unwrap());
tokio::time::sleep(Duration::from_millis(150)).await;
assert!(!backend.exists("key1").await.unwrap());
}
#[tokio::test]
async fn test_memory_backend_capacity() {
let backend = MemoryBackend::builder().capacity(2).build();
backend.set("key1", b"value1".to_vec(), None).await.unwrap();
backend.set("key2", b"value2".to_vec(), None).await.unwrap();
backend.set("key3", b"value3".to_vec(), None).await.unwrap();
let count = backend.stats().await.unwrap();
let entry_count: u64 = count.get("entry_count").unwrap().parse().unwrap_or(0);
assert!(entry_count <= 2);
}
#[tokio::test]
async fn test_memory_backend_clear() {
let backend = MemoryBackend::new();
backend.set("key1", b"value1".to_vec(), None).await.unwrap();
backend.set("key2", b"value2".to_vec(), None).await.unwrap();
backend.clear().await.unwrap();
assert!(!backend.exists("key1").await.unwrap());
assert!(!backend.exists("key2").await.unwrap());
}
#[tokio::test]
async fn test_memory_backend_stats() {
let backend = MemoryBackend::new();
backend.set("key1", b"value1".to_vec(), None).await.unwrap();
let stats = backend.stats().await.unwrap();
assert_eq!(stats.get("type"), Some(&"memory".to_string()));
assert!(stats.contains_key("entry_count"));
}
}