phantom_frame/
cache.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::{broadcast, RwLock};
4
5/// Enum representing different types of cache refresh triggers
6#[derive(Clone, Debug)]
7pub enum RefreshMessage {
8    /// Refresh all cache entries
9    All,
10    /// Refresh cache entries matching a pattern (supports wildcards)
11    Pattern(String),
12}
13
14/// A trigger that can be cloned and triggered multiple times
15/// Similar to oneshot but reusable
16#[derive(Clone)]
17pub struct RefreshTrigger {
18    sender: broadcast::Sender<RefreshMessage>,
19}
20
21impl RefreshTrigger {
22    pub fn new() -> Self {
23        let (sender, _) = broadcast::channel(16);
24        Self { sender }
25    }
26
27    /// Trigger a full cache refresh (clears all entries)
28    pub fn trigger(&self) {
29        // Ignore errors if there are no receivers
30        let _ = self.sender.send(RefreshMessage::All);
31    }
32
33    /// Trigger a cache refresh for entries matching a pattern
34    /// Supports wildcards: "/api/*", "GET:/api/*", etc.
35    pub fn trigger_by_key_match(&self, pattern: &str) {
36        // Ignore errors if there are no receivers
37        let _ = self.sender.send(RefreshMessage::Pattern(pattern.to_string()));
38    }
39
40    /// Subscribe to refresh events
41    pub fn subscribe(&self) -> broadcast::Receiver<RefreshMessage> {
42        self.sender.subscribe()
43    }
44}
45
46/// Helper function to check if a key matches a pattern with wildcard support
47fn matches_pattern(key: &str, pattern: &str) -> bool {
48    // Handle exact match
49    if key == pattern {
50        return true;
51    }
52
53    // Split pattern by '*' and check if all parts exist in order
54    let parts: Vec<&str> = pattern.split('*').collect();
55    
56    if parts.len() == 1 {
57        // No wildcard, exact match already checked above
58        return false;
59    }
60
61    let mut current_pos = 0;
62    
63    for (i, part) in parts.iter().enumerate() {
64        if part.is_empty() {
65            continue;
66        }
67
68        // First part must match from the beginning
69        if i == 0 {
70            if !key.starts_with(part) {
71                return false;
72            }
73            current_pos = part.len();
74        }
75        // Last part must match to the end
76        else if i == parts.len() - 1 {
77            if !key[current_pos..].ends_with(part) {
78                return false;
79            }
80        }
81        // Middle parts must exist in order
82        else {
83            if let Some(pos) = key[current_pos..].find(part) {
84                current_pos += pos + part.len();
85            } else {
86                return false;
87            }
88        }
89    }
90
91    true
92}
93
94/// Cache storage for prerendered content
95#[derive(Clone)]
96pub struct CacheStore {
97    store: Arc<RwLock<HashMap<String, CachedResponse>>>,
98    refresh_trigger: RefreshTrigger,
99}
100
101#[derive(Clone, Debug)]
102pub struct CachedResponse {
103    pub body: Vec<u8>,
104    pub headers: HashMap<String, String>,
105    pub status: u16,
106}
107
108impl CacheStore {
109    pub fn new(refresh_trigger: RefreshTrigger) -> Self {
110        Self {
111            store: Arc::new(RwLock::new(HashMap::new())),
112            refresh_trigger,
113        }
114    }
115
116    pub async fn get(&self, key: &str) -> Option<CachedResponse> {
117        let store = self.store.read().await;
118        store.get(key).cloned()
119    }
120
121    pub async fn set(&self, key: String, response: CachedResponse) {
122        let mut store = self.store.write().await;
123        store.insert(key, response);
124    }
125
126    pub async fn clear(&self) {
127        let mut store = self.store.write().await;
128        store.clear();
129    }
130
131    /// Clear cache entries matching a pattern (supports wildcards)
132    pub async fn clear_by_pattern(&self, pattern: &str) {
133        let mut store = self.store.write().await;
134        store.retain(|key, _| !matches_pattern(key, pattern));
135    }
136
137    pub fn refresh_trigger(&self) -> &RefreshTrigger {
138        &self.refresh_trigger
139    }
140
141    /// Get the number of cached items
142    pub async fn size(&self) -> usize {
143        let store = self.store.read().await;
144        store.len()
145    }
146}
147
148impl Default for RefreshTrigger {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn test_matches_pattern_exact() {
160        assert!(matches_pattern("GET:/api/users", "GET:/api/users"));
161        assert!(!matches_pattern("GET:/api/users", "GET:/api/posts"));
162    }
163
164    #[test]
165    fn test_matches_pattern_wildcard() {
166        // Wildcard at end
167        assert!(matches_pattern("GET:/api/users", "GET:/api/*"));
168        assert!(matches_pattern("GET:/api/users/123", "GET:/api/*"));
169        assert!(!matches_pattern("GET:/v2/users", "GET:/api/*"));
170
171        // Wildcard at start
172        assert!(matches_pattern("GET:/api/users", "*/users"));
173        assert!(matches_pattern("POST:/v2/users", "*/users"));
174        assert!(!matches_pattern("GET:/api/posts", "*/users"));
175
176        // Wildcard in middle
177        assert!(matches_pattern("GET:/api/v1/users", "GET:/api/*/users"));
178        assert!(matches_pattern("GET:/api/v2/users", "GET:/api/*/users"));
179        assert!(!matches_pattern("GET:/api/v1/posts", "GET:/api/*/users"));
180
181        // Multiple wildcards
182        assert!(matches_pattern("GET:/api/v1/users/123", "GET:*/users/*"));
183        assert!(matches_pattern("POST:/v2/admin/users/456", "*/users/*"));
184    }
185
186    #[test]
187    fn test_matches_pattern_wildcard_only() {
188        assert!(matches_pattern("GET:/api/users", "*"));
189        assert!(matches_pattern("POST:/anything", "*"));
190    }
191}