phantom_frame/
cache.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::{broadcast, RwLock};
4
5/// Enum representing different types of cache refresh triggers
6#[derive(Clone, Debug)]
7pub enum RefreshMessage {
8    /// Refresh all cache entries
9    All,
10    /// Refresh cache entries matching a pattern (supports wildcards)
11    Pattern(String),
12}
13
14/// A trigger that can be cloned and triggered multiple times
15/// Similar to oneshot but reusable
16#[derive(Clone)]
17pub struct RefreshTrigger {
18    sender: broadcast::Sender<RefreshMessage>,
19}
20
21impl RefreshTrigger {
22    pub fn new() -> Self {
23        let (sender, _) = broadcast::channel(16);
24        Self { sender }
25    }
26
27    /// Trigger a full cache refresh (clears all entries)
28    pub fn trigger(&self) {
29        // Ignore errors if there are no receivers
30        let _ = self.sender.send(RefreshMessage::All);
31    }
32
33    /// Trigger a cache refresh for entries matching a pattern
34    /// Supports wildcards: "/api/*", "GET:/api/*", etc.
35    pub fn trigger_by_key_match(&self, pattern: &str) {
36        // Ignore errors if there are no receivers
37        let _ = self.sender.send(RefreshMessage::Pattern(pattern.to_string()));
38    }
39
40    /// Subscribe to refresh events
41    pub fn subscribe(&self) -> broadcast::Receiver<RefreshMessage> {
42        self.sender.subscribe()
43    }
44}
45
46/// Helper function to check if a key matches a pattern with wildcard support
47fn matches_pattern(key: &str, pattern: &str) -> bool {
48    // Handle exact match
49    if key == pattern {
50        return true;
51    }
52
53    // Split pattern by '*' and check if all parts exist in order
54    let parts: Vec<&str> = pattern.split('*').collect();
55    
56    if parts.len() == 1 {
57        // No wildcard, exact match already checked above
58        return false;
59    }
60
61    let mut current_pos = 0;
62    
63    for (i, part) in parts.iter().enumerate() {
64        if part.is_empty() {
65            continue;
66        }
67
68        // First part must match from the beginning
69        if i == 0 {
70            if !key.starts_with(part) {
71                return false;
72            }
73            current_pos = part.len();
74        }
75        // Last part must match to the end
76        else if i == parts.len() - 1 {
77            if !key[current_pos..].ends_with(part) {
78                return false;
79            }
80        }
81        // Middle parts must exist in order
82        else if let Some(pos) = key[current_pos..].find(part) {
83            current_pos += pos + part.len();
84        } else {
85            return false;
86        }
87    }
88
89    true
90}
91
92/// Cache storage for prerendered content
93#[derive(Clone)]
94pub struct CacheStore {
95    store: Arc<RwLock<HashMap<String, CachedResponse>>>,
96    refresh_trigger: RefreshTrigger,
97}
98
99#[derive(Clone, Debug)]
100pub struct CachedResponse {
101    pub body: Vec<u8>,
102    pub headers: HashMap<String, String>,
103    pub status: u16,
104}
105
106impl CacheStore {
107    pub fn new(refresh_trigger: RefreshTrigger) -> Self {
108        Self {
109            store: Arc::new(RwLock::new(HashMap::new())),
110            refresh_trigger,
111        }
112    }
113
114    pub async fn get(&self, key: &str) -> Option<CachedResponse> {
115        let store = self.store.read().await;
116        store.get(key).cloned()
117    }
118
119    pub async fn set(&self, key: String, response: CachedResponse) {
120        let mut store = self.store.write().await;
121        store.insert(key, response);
122    }
123
124    pub async fn clear(&self) {
125        let mut store = self.store.write().await;
126        store.clear();
127    }
128
129    /// Clear cache entries matching a pattern (supports wildcards)
130    pub async fn clear_by_pattern(&self, pattern: &str) {
131        let mut store = self.store.write().await;
132        store.retain(|key, _| !matches_pattern(key, pattern));
133    }
134
135    pub fn refresh_trigger(&self) -> &RefreshTrigger {
136        &self.refresh_trigger
137    }
138
139    /// Get the number of cached items
140    pub async fn size(&self) -> usize {
141        let store = self.store.read().await;
142        store.len()
143    }
144}
145
146impl Default for RefreshTrigger {
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_matches_pattern_exact() {
158        assert!(matches_pattern("GET:/api/users", "GET:/api/users"));
159        assert!(!matches_pattern("GET:/api/users", "GET:/api/posts"));
160    }
161
162    #[test]
163    fn test_matches_pattern_wildcard() {
164        // Wildcard at end
165        assert!(matches_pattern("GET:/api/users", "GET:/api/*"));
166        assert!(matches_pattern("GET:/api/users/123", "GET:/api/*"));
167        assert!(!matches_pattern("GET:/v2/users", "GET:/api/*"));
168
169        // Wildcard at start
170        assert!(matches_pattern("GET:/api/users", "*/users"));
171        assert!(matches_pattern("POST:/v2/users", "*/users"));
172        assert!(!matches_pattern("GET:/api/posts", "*/users"));
173
174        // Wildcard in middle
175        assert!(matches_pattern("GET:/api/v1/users", "GET:/api/*/users"));
176        assert!(matches_pattern("GET:/api/v2/users", "GET:/api/*/users"));
177        assert!(!matches_pattern("GET:/api/v1/posts", "GET:/api/*/users"));
178
179        // Multiple wildcards
180        assert!(matches_pattern("GET:/api/v1/users/123", "GET:*/users/*"));
181        assert!(matches_pattern("POST:/v2/admin/users/456", "*/users/*"));
182    }
183
184    #[test]
185    fn test_matches_pattern_wildcard_only() {
186        assert!(matches_pattern("GET:/api/users", "*"));
187        assert!(matches_pattern("POST:/anything", "*"));
188    }
189}