1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3use tokio::sync::{broadcast, RwLock};
4
5#[derive(Clone, Debug)]
7pub enum RefreshMessage {
8 All,
10 Pattern(String),
12}
13
14#[derive(Clone)]
17pub struct RefreshTrigger {
18 sender: broadcast::Sender<RefreshMessage>,
19}
20
21impl RefreshTrigger {
22 pub fn new() -> Self {
23 let (sender, _) = broadcast::channel(16);
24 Self { sender }
25 }
26
27 pub fn trigger(&self) {
29 let _ = self.sender.send(RefreshMessage::All);
31 }
32
33 pub fn trigger_by_key_match(&self, pattern: &str) {
36 let _ = self.sender.send(RefreshMessage::Pattern(pattern.to_string()));
38 }
39
40 pub fn subscribe(&self) -> broadcast::Receiver<RefreshMessage> {
42 self.sender.subscribe()
43 }
44}
45
46fn matches_pattern(key: &str, pattern: &str) -> bool {
48 if key == pattern {
50 return true;
51 }
52
53 let parts: Vec<&str> = pattern.split('*').collect();
55
56 if parts.len() == 1 {
57 return false;
59 }
60
61 let mut current_pos = 0;
62
63 for (i, part) in parts.iter().enumerate() {
64 if part.is_empty() {
65 continue;
66 }
67
68 if i == 0 {
70 if !key.starts_with(part) {
71 return false;
72 }
73 current_pos = part.len();
74 }
75 else if i == parts.len() - 1 {
77 if !key[current_pos..].ends_with(part) {
78 return false;
79 }
80 }
81 else if let Some(pos) = key[current_pos..].find(part) {
83 current_pos += pos + part.len();
84 } else {
85 return false;
86 }
87 }
88
89 true
90}
91
92#[derive(Clone)]
94pub struct CacheStore {
95 store: Arc<RwLock<HashMap<String, CachedResponse>>>,
96 store_404: Arc<RwLock<HashMap<String, CachedResponse>>>,
98 keys_404: Arc<RwLock<VecDeque<String>>>,
99 cache_404_capacity: usize,
100 refresh_trigger: RefreshTrigger,
101}
102
103#[derive(Clone, Debug)]
104pub struct CachedResponse {
105 pub body: Vec<u8>,
106 pub headers: HashMap<String, String>,
107 pub status: u16,
108}
109
110impl CacheStore {
111 pub fn new(refresh_trigger: RefreshTrigger, cache_404_capacity: usize) -> Self {
112 Self {
113 store: Arc::new(RwLock::new(HashMap::new())),
114 store_404: Arc::new(RwLock::new(HashMap::new())),
115 keys_404: Arc::new(RwLock::new(VecDeque::new())),
116 cache_404_capacity,
117 refresh_trigger,
118 }
119 }
120
121 pub async fn get(&self, key: &str) -> Option<CachedResponse> {
122 let store = self.store.read().await;
123 store.get(key).cloned()
124 }
125
126 pub async fn get_404(&self, key: &str) -> Option<CachedResponse> {
128 let store = self.store_404.read().await;
129 store.get(key).cloned()
130 }
131
132 pub async fn set(&self, key: String, response: CachedResponse) {
133 let mut store = self.store.write().await;
134 store.insert(key, response);
135 }
136
137 pub async fn set_404(&self, key: String, response: CachedResponse) {
139 if self.cache_404_capacity == 0 {
140 return;
142 }
143
144 let mut store = self.store_404.write().await;
145 let mut keys = self.keys_404.write().await;
146
147 if store.contains_key(&key) {
149 if let Some(pos) = keys.iter().position(|k| k == &key) {
151 keys.remove(pos);
152 }
153 }
154
155 store.insert(key.clone(), response);
157 keys.push_back(key.clone());
158
159 while keys.len() > self.cache_404_capacity {
161 if let Some(old_key) = keys.pop_front() {
162 store.remove(&old_key);
163 }
164 }
165 }
166
167 pub async fn clear(&self) {
168 let mut store = self.store.write().await;
169 store.clear();
170 let mut store404 = self.store_404.write().await;
171 store404.clear();
172 let mut keys = self.keys_404.write().await;
173 keys.clear();
174 }
175
176 pub async fn clear_by_pattern(&self, pattern: &str) {
178 let mut store = self.store.write().await;
179 store.retain(|key, _| !matches_pattern(key, pattern));
180
181 let mut store404 = self.store_404.write().await;
182 let mut keys = self.keys_404.write().await;
183 store404.retain(|key, _| !matches_pattern(key, pattern));
185 keys.retain(|k| !matches_pattern(k, pattern));
186 }
187
188 pub fn refresh_trigger(&self) -> &RefreshTrigger {
189 &self.refresh_trigger
190 }
191
192 pub async fn size(&self) -> usize {
194 let store = self.store.read().await;
195 store.len()
196 }
197
198 pub async fn size_404(&self) -> usize {
200 let store = self.store_404.read().await;
201 store.len()
202 }
203}
204
205impl Default for RefreshTrigger {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_matches_pattern_exact() {
217 assert!(matches_pattern("GET:/api/users", "GET:/api/users"));
218 assert!(!matches_pattern("GET:/api/users", "GET:/api/posts"));
219 }
220
221 #[test]
222 fn test_matches_pattern_wildcard() {
223 assert!(matches_pattern("GET:/api/users", "GET:/api/*"));
225 assert!(matches_pattern("GET:/api/users/123", "GET:/api/*"));
226 assert!(!matches_pattern("GET:/v2/users", "GET:/api/*"));
227
228 assert!(matches_pattern("GET:/api/users", "*/users"));
230 assert!(matches_pattern("POST:/v2/users", "*/users"));
231 assert!(!matches_pattern("GET:/api/posts", "*/users"));
232
233 assert!(matches_pattern("GET:/api/v1/users", "GET:/api/*/users"));
235 assert!(matches_pattern("GET:/api/v2/users", "GET:/api/*/users"));
236 assert!(!matches_pattern("GET:/api/v1/posts", "GET:/api/*/users"));
237
238 assert!(matches_pattern("GET:/api/v1/users/123", "GET:*/users/*"));
240 assert!(matches_pattern("POST:/v2/admin/users/456", "*/users/*"));
241 }
242
243 #[test]
244 fn test_matches_pattern_wildcard_only() {
245 assert!(matches_pattern("GET:/api/users", "*"));
246 assert!(matches_pattern("POST:/anything", "*"));
247 }
248
249 #[tokio::test]
250 async fn test_404_cache_set_get_and_eviction() {
251 let trigger = RefreshTrigger::new();
252 let store = CacheStore::new(trigger, 2);
254
255 let resp1 = CachedResponse { body: vec![1], headers: HashMap::new(), status: 404 };
256 let resp2 = CachedResponse { body: vec![2], headers: HashMap::new(), status: 404 };
257 let resp3 = CachedResponse { body: vec![3], headers: HashMap::new(), status: 404 };
258
259 store.set_404("GET:/notfound1".to_string(), resp1.clone()).await;
261 store.set_404("GET:/notfound2".to_string(), resp2.clone()).await;
262
263 assert_eq!(store.size_404().await, 2);
264 assert_eq!(store.get_404("GET:/notfound1").await.unwrap().body, vec![1]);
265
266 store.set_404("GET:/notfound3".to_string(), resp3.clone()).await;
268 assert_eq!(store.size_404().await, 2);
269 assert!(store.get_404("GET:/notfound1").await.is_none());
270 assert_eq!(store.get_404("GET:/notfound2").await.unwrap().body, vec![2]);
271 assert_eq!(store.get_404("GET:/notfound3").await.unwrap().body, vec![3]);
272 }
273
274 #[tokio::test]
275 async fn test_clear_by_pattern_removes_404_entries() {
276 let trigger = RefreshTrigger::new();
277 let store = CacheStore::new(trigger, 10);
278
279 let resp = CachedResponse { body: vec![1], headers: HashMap::new(), status: 404 };
280 store.set_404("GET:/api/notfound".to_string(), resp.clone()).await;
281 store.set_404("GET:/api/another".to_string(), resp.clone()).await;
282 assert_eq!(store.size_404().await, 2);
283
284 store.clear_by_pattern("GET:/api/*").await;
285 assert_eq!(store.size_404().await, 0);
286 }
287}