1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::{broadcast, RwLock};
4
5#[derive(Clone, Debug)]
7pub enum RefreshMessage {
8 All,
10 Pattern(String),
12}
13
14#[derive(Clone)]
17pub struct RefreshTrigger {
18 sender: broadcast::Sender<RefreshMessage>,
19}
20
21impl RefreshTrigger {
22 pub fn new() -> Self {
23 let (sender, _) = broadcast::channel(16);
24 Self { sender }
25 }
26
27 pub fn trigger(&self) {
29 let _ = self.sender.send(RefreshMessage::All);
31 }
32
33 pub fn trigger_by_key_match(&self, pattern: &str) {
36 let _ = self.sender.send(RefreshMessage::Pattern(pattern.to_string()));
38 }
39
40 pub fn subscribe(&self) -> broadcast::Receiver<RefreshMessage> {
42 self.sender.subscribe()
43 }
44}
45
46fn matches_pattern(key: &str, pattern: &str) -> bool {
48 if key == pattern {
50 return true;
51 }
52
53 let parts: Vec<&str> = pattern.split('*').collect();
55
56 if parts.len() == 1 {
57 return false;
59 }
60
61 let mut current_pos = 0;
62
63 for (i, part) in parts.iter().enumerate() {
64 if part.is_empty() {
65 continue;
66 }
67
68 if i == 0 {
70 if !key.starts_with(part) {
71 return false;
72 }
73 current_pos = part.len();
74 }
75 else if i == parts.len() - 1 {
77 if !key[current_pos..].ends_with(part) {
78 return false;
79 }
80 }
81 else {
83 if let Some(pos) = key[current_pos..].find(part) {
84 current_pos += pos + part.len();
85 } else {
86 return false;
87 }
88 }
89 }
90
91 true
92}
93
94#[derive(Clone)]
96pub struct CacheStore {
97 store: Arc<RwLock<HashMap<String, CachedResponse>>>,
98 refresh_trigger: RefreshTrigger,
99}
100
101#[derive(Clone, Debug)]
102pub struct CachedResponse {
103 pub body: Vec<u8>,
104 pub headers: HashMap<String, String>,
105 pub status: u16,
106}
107
108impl CacheStore {
109 pub fn new(refresh_trigger: RefreshTrigger) -> Self {
110 Self {
111 store: Arc::new(RwLock::new(HashMap::new())),
112 refresh_trigger,
113 }
114 }
115
116 pub async fn get(&self, key: &str) -> Option<CachedResponse> {
117 let store = self.store.read().await;
118 store.get(key).cloned()
119 }
120
121 pub async fn set(&self, key: String, response: CachedResponse) {
122 let mut store = self.store.write().await;
123 store.insert(key, response);
124 }
125
126 pub async fn clear(&self) {
127 let mut store = self.store.write().await;
128 store.clear();
129 }
130
131 pub async fn clear_by_pattern(&self, pattern: &str) {
133 let mut store = self.store.write().await;
134 store.retain(|key, _| !matches_pattern(key, pattern));
135 }
136
137 pub fn refresh_trigger(&self) -> &RefreshTrigger {
138 &self.refresh_trigger
139 }
140
141 pub async fn size(&self) -> usize {
143 let store = self.store.read().await;
144 store.len()
145 }
146}
147
148impl Default for RefreshTrigger {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn test_matches_pattern_exact() {
160 assert!(matches_pattern("GET:/api/users", "GET:/api/users"));
161 assert!(!matches_pattern("GET:/api/users", "GET:/api/posts"));
162 }
163
164 #[test]
165 fn test_matches_pattern_wildcard() {
166 assert!(matches_pattern("GET:/api/users", "GET:/api/*"));
168 assert!(matches_pattern("GET:/api/users/123", "GET:/api/*"));
169 assert!(!matches_pattern("GET:/v2/users", "GET:/api/*"));
170
171 assert!(matches_pattern("GET:/api/users", "*/users"));
173 assert!(matches_pattern("POST:/v2/users", "*/users"));
174 assert!(!matches_pattern("GET:/api/posts", "*/users"));
175
176 assert!(matches_pattern("GET:/api/v1/users", "GET:/api/*/users"));
178 assert!(matches_pattern("GET:/api/v2/users", "GET:/api/*/users"));
179 assert!(!matches_pattern("GET:/api/v1/posts", "GET:/api/*/users"));
180
181 assert!(matches_pattern("GET:/api/v1/users/123", "GET:*/users/*"));
183 assert!(matches_pattern("POST:/v2/admin/users/456", "*/users/*"));
184 }
185
186 #[test]
187 fn test_matches_pattern_wildcard_only() {
188 assert!(matches_pattern("GET:/api/users", "*"));
189 assert!(matches_pattern("POST:/anything", "*"));
190 }
191}