phantom_frame/
cache.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use tokio::sync::{broadcast, RwLock};
4
5/// Enum representing different types of cache refresh triggers
6#[derive(Clone, Debug)]
7pub enum RefreshMessage {
8    /// Refresh all cache entries
9    All,
10    /// Refresh cache entries matching a pattern (supports wildcards)
11    Pattern(String),
12}
13
14/// A trigger that can be cloned and triggered multiple times
15/// Similar to oneshot but reusable
16#[derive(Clone)]
17pub struct RefreshTrigger {
18    sender: broadcast::Sender<RefreshMessage>,
19}
20
21impl RefreshTrigger {
22    pub fn new() -> Self {
23        let (sender, _) = broadcast::channel(16);
24        Self { sender }
25    }
26
27    /// Trigger a full cache refresh (clears all entries)
28    pub fn trigger(&self) {
29        // Ignore errors if there are no receivers
30        let _ = self.sender.send(RefreshMessage::All);
31    }
32
33    /// Trigger a cache refresh for entries matching a pattern
34    /// Supports wildcards: "/api/*", "GET:/api/*", etc.
35    pub fn trigger_by_key_match(&self, pattern: &str) {
36        // Ignore errors if there are no receivers
37        let _ = self.sender.send(RefreshMessage::Pattern(pattern.to_string()));
38    }
39
40    /// Subscribe to refresh events
41    pub fn subscribe(&self) -> broadcast::Receiver<RefreshMessage> {
42        self.sender.subscribe()
43    }
44}
45
46/// Helper function to check if a key matches a pattern with wildcard support
47fn matches_pattern(key: &str, pattern: &str) -> bool {
48    // Handle exact match
49    if key == pattern {
50        return true;
51    }
52
53    // Split pattern by '*' and check if all parts exist in order
54    let parts: Vec<&str> = pattern.split('*').collect();
55    
56    if parts.len() == 1 {
57        // No wildcard, exact match already checked above
58        return false;
59    }
60
61    let mut current_pos = 0;
62    
63    for (i, part) in parts.iter().enumerate() {
64        if part.is_empty() {
65            continue;
66        }
67
68        // First part must match from the beginning
69        if i == 0 {
70            if !key.starts_with(part) {
71                return false;
72            }
73            current_pos = part.len();
74        }
75        // Last part must match to the end
76        else if i == parts.len() - 1 {
77            if !key[current_pos..].ends_with(part) {
78                return false;
79            }
80        }
81        // Middle parts must exist in order
82        else if let Some(pos) = key[current_pos..].find(part) {
83            current_pos += pos + part.len();
84        } else {
85            return false;
86        }
87    }
88
89    true
90}
91
92/// Cache storage for prerendered content
93#[derive(Clone)]
94pub struct CacheStore {
95    store: Arc<RwLock<HashMap<String, CachedResponse>>>,
96    // 404-specific store with bounded capacity and FIFO eviction
97    store_404: Arc<RwLock<HashMap<String, CachedResponse>>>,
98    keys_404: Arc<RwLock<VecDeque<String>>>,
99    cache_404_capacity: usize,
100    refresh_trigger: RefreshTrigger,
101}
102
103#[derive(Clone, Debug)]
104pub struct CachedResponse {
105    pub body: Vec<u8>,
106    pub headers: HashMap<String, String>,
107    pub status: u16,
108}
109
110impl CacheStore {
111    pub fn new(refresh_trigger: RefreshTrigger, cache_404_capacity: usize) -> Self {
112        Self {
113            store: Arc::new(RwLock::new(HashMap::new())),
114            store_404: Arc::new(RwLock::new(HashMap::new())),
115            keys_404: Arc::new(RwLock::new(VecDeque::new())),
116            cache_404_capacity,
117            refresh_trigger,
118        }
119    }
120
121    pub async fn get(&self, key: &str) -> Option<CachedResponse> {
122        let store = self.store.read().await;
123        store.get(key).cloned()
124    }
125
126    /// Get a 404 cached response (if present)
127    pub async fn get_404(&self, key: &str) -> Option<CachedResponse> {
128        let store = self.store_404.read().await;
129        store.get(key).cloned()
130    }
131
132    pub async fn set(&self, key: String, response: CachedResponse) {
133        let mut store = self.store.write().await;
134        store.insert(key, response);
135    }
136
137    /// Set a 404 cached response. Bounded by `cache_404_capacity` and evict the oldest entries when limit reached.
138    pub async fn set_404(&self, key: String, response: CachedResponse) {
139        if self.cache_404_capacity == 0 {
140            // 404 caching disabled
141            return;
142        }
143
144        let mut store = self.store_404.write().await;
145        let mut keys = self.keys_404.write().await;
146
147        // If key already exists, remove it from its position in keys and re-add to the back
148        if store.contains_key(&key) {
149            // remove the key from keys deque (linear scan; acceptable for bounded deque)
150            if let Some(pos) = keys.iter().position(|k| k == &key) {
151                keys.remove(pos);
152            }
153        }
154
155        // Insert into store and push back in keys
156        store.insert(key.clone(), response);
157        keys.push_back(key.clone());
158
159        // Evict oldest items if capacity exceeded
160        while keys.len() > self.cache_404_capacity {
161            if let Some(old_key) = keys.pop_front() {
162                store.remove(&old_key);
163            }
164        }
165    }
166
167    pub async fn clear(&self) {
168        let mut store = self.store.write().await;
169        store.clear();
170        let mut store404 = self.store_404.write().await;
171        store404.clear();
172        let mut keys = self.keys_404.write().await;
173        keys.clear();
174    }
175
176    /// Clear cache entries matching a pattern (supports wildcards)
177    pub async fn clear_by_pattern(&self, pattern: &str) {
178        let mut store = self.store.write().await;
179        store.retain(|key, _| !matches_pattern(key, pattern));
180
181        let mut store404 = self.store_404.write().await;
182        let mut keys = self.keys_404.write().await;
183        // Remove matching from store_404 and keys
184        store404.retain(|key, _| !matches_pattern(key, pattern));
185        keys.retain(|k| !matches_pattern(k, pattern));
186    }
187
188    pub fn refresh_trigger(&self) -> &RefreshTrigger {
189        &self.refresh_trigger
190    }
191
192    /// Get the number of cached items
193    pub async fn size(&self) -> usize {
194        let store = self.store.read().await;
195        store.len()
196    }
197
198    /// Size of 404 cache
199    pub async fn size_404(&self) -> usize {
200        let store = self.store_404.read().await;
201        store.len()
202    }
203}
204
205impl Default for RefreshTrigger {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_matches_pattern_exact() {
217        assert!(matches_pattern("GET:/api/users", "GET:/api/users"));
218        assert!(!matches_pattern("GET:/api/users", "GET:/api/posts"));
219    }
220
221    #[test]
222    fn test_matches_pattern_wildcard() {
223        // Wildcard at end
224        assert!(matches_pattern("GET:/api/users", "GET:/api/*"));
225        assert!(matches_pattern("GET:/api/users/123", "GET:/api/*"));
226        assert!(!matches_pattern("GET:/v2/users", "GET:/api/*"));
227
228        // Wildcard at start
229        assert!(matches_pattern("GET:/api/users", "*/users"));
230        assert!(matches_pattern("POST:/v2/users", "*/users"));
231        assert!(!matches_pattern("GET:/api/posts", "*/users"));
232
233        // Wildcard in middle
234        assert!(matches_pattern("GET:/api/v1/users", "GET:/api/*/users"));
235        assert!(matches_pattern("GET:/api/v2/users", "GET:/api/*/users"));
236        assert!(!matches_pattern("GET:/api/v1/posts", "GET:/api/*/users"));
237
238        // Multiple wildcards
239        assert!(matches_pattern("GET:/api/v1/users/123", "GET:*/users/*"));
240        assert!(matches_pattern("POST:/v2/admin/users/456", "*/users/*"));
241    }
242
243    #[test]
244    fn test_matches_pattern_wildcard_only() {
245        assert!(matches_pattern("GET:/api/users", "*"));
246        assert!(matches_pattern("POST:/anything", "*"));
247    }
248
249    #[tokio::test]
250    async fn test_404_cache_set_get_and_eviction() {
251        let trigger = RefreshTrigger::new();
252        // capacity 2 for quicker eviction
253        let store = CacheStore::new(trigger, 2);
254
255        let resp1 = CachedResponse { body: vec![1], headers: HashMap::new(), status: 404 };
256        let resp2 = CachedResponse { body: vec![2], headers: HashMap::new(), status: 404 };
257        let resp3 = CachedResponse { body: vec![3], headers: HashMap::new(), status: 404 };
258
259        // Set two 404 entries
260        store.set_404("GET:/notfound1".to_string(), resp1.clone()).await;
261        store.set_404("GET:/notfound2".to_string(), resp2.clone()).await;
262
263        assert_eq!(store.size_404().await, 2);
264        assert_eq!(store.get_404("GET:/notfound1").await.unwrap().body, vec![1]);
265
266        // Add third entry - should evict oldest (notfound1)
267        store.set_404("GET:/notfound3".to_string(), resp3.clone()).await;
268        assert_eq!(store.size_404().await, 2);
269        assert!(store.get_404("GET:/notfound1").await.is_none());
270        assert_eq!(store.get_404("GET:/notfound2").await.unwrap().body, vec![2]);
271        assert_eq!(store.get_404("GET:/notfound3").await.unwrap().body, vec![3]);
272    }
273
274    #[tokio::test]
275    async fn test_clear_by_pattern_removes_404_entries() {
276        let trigger = RefreshTrigger::new();
277        let store = CacheStore::new(trigger, 10);
278
279        let resp = CachedResponse { body: vec![1], headers: HashMap::new(), status: 404 };
280        store.set_404("GET:/api/notfound".to_string(), resp.clone()).await;
281        store.set_404("GET:/api/another".to_string(), resp.clone()).await;
282        assert_eq!(store.size_404().await, 2);
283
284        store.clear_by_pattern("GET:/api/*").await;
285        assert_eq!(store.size_404().await, 0);
286    }
287}