Skip to main content

ferro_rs/cache/
memory.rs

1//! In-memory cache implementation for testing and fallback
2//!
3//! Provides a thread-safe in-memory cache that mimics Redis behavior.
4//! Supports TTL expiration.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::RwLock;
9use std::time::{Duration, Instant};
10
11use super::store::CacheStore;
12use crate::error::FrameworkError;
13
14/// In-memory cache entry with optional expiration
15#[derive(Clone)]
16struct CacheEntry {
17    value: String,
18    expires_at: Option<Instant>,
19}
20
21impl CacheEntry {
22    fn is_expired(&self) -> bool {
23        self.expires_at.map(|t| Instant::now() > t).unwrap_or(false)
24    }
25}
26
27/// In-memory cache implementation
28///
29/// Thread-safe cache that stores values in memory with optional TTL.
30/// Use this as a fallback when Redis is unavailable, or in tests.
31///
32/// # Example
33///
34/// ```rust,ignore
35/// use ferro_rs::cache::InMemoryCache;
36///
37/// let cache = InMemoryCache::new();
38/// ```
39pub struct InMemoryCache {
40    store: RwLock<HashMap<String, CacheEntry>>,
41    prefix: String,
42}
43
44impl InMemoryCache {
45    /// Create a new empty in-memory cache
46    pub fn new() -> Self {
47        Self {
48            store: RwLock::new(HashMap::new()),
49            prefix: "ferro_cache:".to_string(),
50        }
51    }
52
53    /// Create with a custom prefix
54    pub fn with_prefix(prefix: impl Into<String>) -> Self {
55        Self {
56            store: RwLock::new(HashMap::new()),
57            prefix: prefix.into(),
58        }
59    }
60
61    fn prefixed_key(&self, key: &str) -> String {
62        format!("{}{}", self.prefix, key)
63    }
64}
65
66impl Default for InMemoryCache {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72#[async_trait]
73impl CacheStore for InMemoryCache {
74    async fn get_raw(&self, key: &str) -> Result<Option<String>, FrameworkError> {
75        let key = self.prefixed_key(key);
76
77        let store = self
78            .store
79            .read()
80            .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
81
82        match store.get(&key) {
83            Some(entry) if !entry.is_expired() => Ok(Some(entry.value.clone())),
84            _ => Ok(None),
85        }
86    }
87
88    async fn put_raw(
89        &self,
90        key: &str,
91        value: &str,
92        ttl: Option<Duration>,
93    ) -> Result<(), FrameworkError> {
94        let key = self.prefixed_key(key);
95
96        let entry = CacheEntry {
97            value: value.to_string(),
98            expires_at: ttl.map(|d| Instant::now() + d),
99        };
100
101        let mut store = self
102            .store
103            .write()
104            .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
105
106        store.insert(key, entry);
107        Ok(())
108    }
109
110    async fn has(&self, key: &str) -> Result<bool, FrameworkError> {
111        let key = self.prefixed_key(key);
112
113        let store = self
114            .store
115            .read()
116            .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
117
118        Ok(store.get(&key).map(|e| !e.is_expired()).unwrap_or(false))
119    }
120
121    async fn forget(&self, key: &str) -> Result<bool, FrameworkError> {
122        let key = self.prefixed_key(key);
123
124        let mut store = self
125            .store
126            .write()
127            .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
128
129        Ok(store.remove(&key).is_some())
130    }
131
132    async fn flush(&self) -> Result<(), FrameworkError> {
133        let mut store = self
134            .store
135            .write()
136            .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
137
138        store.clear();
139        Ok(())
140    }
141
142    async fn increment(&self, key: &str, amount: i64) -> Result<i64, FrameworkError> {
143        let key = self.prefixed_key(key);
144
145        let mut store = self
146            .store
147            .write()
148            .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
149
150        let current: i64 = store
151            .get(&key)
152            .filter(|e| !e.is_expired())
153            .and_then(|e| e.value.parse().ok())
154            .unwrap_or(0);
155
156        let new_value = current + amount;
157
158        store.insert(
159            key,
160            CacheEntry {
161                value: new_value.to_string(),
162                expires_at: None,
163            },
164        );
165
166        Ok(new_value)
167    }
168
169    async fn decrement(&self, key: &str, amount: i64) -> Result<i64, FrameworkError> {
170        self.increment(key, -amount).await
171    }
172
173    async fn expire(&self, key: &str, ttl: Duration) -> Result<bool, FrameworkError> {
174        let key = self.prefixed_key(key);
175
176        let mut store = self
177            .store
178            .write()
179            .map_err(|_| FrameworkError::internal("Cache lock poisoned"))?;
180
181        if let Some(entry) = store.get_mut(&key) {
182            entry.expires_at = Some(Instant::now() + ttl);
183            Ok(true)
184        } else {
185            Ok(false)
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[tokio::test]
195    async fn test_expire_sets_ttl() {
196        let cache = InMemoryCache::new();
197
198        // Increment a key so it exists
199        cache.increment("counter", 1).await.unwrap();
200
201        // Set a 1-second TTL
202        let result = cache
203            .expire("counter", Duration::from_secs(1))
204            .await
205            .unwrap();
206        assert!(result, "expire should return true for existing key");
207
208        // Key should still be accessible immediately
209        let val = cache.get_raw("counter").await.unwrap();
210        assert_eq!(val, Some("1".to_string()));
211
212        // Wait for TTL to expire
213        tokio::time::sleep(Duration::from_millis(1100)).await;
214
215        // Key should now be expired
216        let val = cache.get_raw("counter").await.unwrap();
217        assert!(val.is_none(), "key should be expired after TTL");
218
219        // Increment should treat it as new (returns 1)
220        let new_val = cache.increment("counter", 1).await.unwrap();
221        assert_eq!(new_val, 1, "increment on expired key should return 1");
222    }
223
224    #[tokio::test]
225    async fn test_expire_missing_key() {
226        let cache = InMemoryCache::new();
227
228        let result = cache
229            .expire("nonexistent", Duration::from_secs(10))
230            .await
231            .unwrap();
232        assert!(!result, "expire on missing key should return false");
233    }
234
235    #[tokio::test]
236    async fn test_increment_then_expire_preserves_value() {
237        let cache = InMemoryCache::new();
238
239        // Increment to 5
240        for _ in 0..5 {
241            cache.increment("counter", 1).await.unwrap();
242        }
243
244        // Set TTL (long enough not to expire during test)
245        let result = cache
246            .expire("counter", Duration::from_secs(10))
247            .await
248            .unwrap();
249        assert!(result);
250
251        // Increment again should return 6 (not 1)
252        let val = cache.increment("counter", 1).await.unwrap();
253        assert_eq!(val, 6, "expire should not reset the value");
254    }
255}