1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::{broadcast, RwLock};
4
5#[derive(Clone, Debug)]
7pub enum RefreshMessage {
8 All,
10 Pattern(String),
12}
13
14#[derive(Clone)]
17pub struct RefreshTrigger {
18 sender: broadcast::Sender<RefreshMessage>,
19}
20
21impl RefreshTrigger {
22 pub fn new() -> Self {
23 let (sender, _) = broadcast::channel(16);
24 Self { sender }
25 }
26
27 pub fn trigger(&self) {
29 let _ = self.sender.send(RefreshMessage::All);
31 }
32
33 pub fn trigger_by_key_match(&self, pattern: &str) {
36 let _ = self.sender.send(RefreshMessage::Pattern(pattern.to_string()));
38 }
39
40 pub fn subscribe(&self) -> broadcast::Receiver<RefreshMessage> {
42 self.sender.subscribe()
43 }
44}
45
46fn matches_pattern(key: &str, pattern: &str) -> bool {
48 if key == pattern {
50 return true;
51 }
52
53 let parts: Vec<&str> = pattern.split('*').collect();
55
56 if parts.len() == 1 {
57 return false;
59 }
60
61 let mut current_pos = 0;
62
63 for (i, part) in parts.iter().enumerate() {
64 if part.is_empty() {
65 continue;
66 }
67
68 if i == 0 {
70 if !key.starts_with(part) {
71 return false;
72 }
73 current_pos = part.len();
74 }
75 else if i == parts.len() - 1 {
77 if !key[current_pos..].ends_with(part) {
78 return false;
79 }
80 }
81 else if let Some(pos) = key[current_pos..].find(part) {
83 current_pos += pos + part.len();
84 } else {
85 return false;
86 }
87 }
88
89 true
90}
91
92#[derive(Clone)]
94pub struct CacheStore {
95 store: Arc<RwLock<HashMap<String, CachedResponse>>>,
96 refresh_trigger: RefreshTrigger,
97}
98
99#[derive(Clone, Debug)]
100pub struct CachedResponse {
101 pub body: Vec<u8>,
102 pub headers: HashMap<String, String>,
103 pub status: u16,
104}
105
106impl CacheStore {
107 pub fn new(refresh_trigger: RefreshTrigger) -> Self {
108 Self {
109 store: Arc::new(RwLock::new(HashMap::new())),
110 refresh_trigger,
111 }
112 }
113
114 pub async fn get(&self, key: &str) -> Option<CachedResponse> {
115 let store = self.store.read().await;
116 store.get(key).cloned()
117 }
118
119 pub async fn set(&self, key: String, response: CachedResponse) {
120 let mut store = self.store.write().await;
121 store.insert(key, response);
122 }
123
124 pub async fn clear(&self) {
125 let mut store = self.store.write().await;
126 store.clear();
127 }
128
129 pub async fn clear_by_pattern(&self, pattern: &str) {
131 let mut store = self.store.write().await;
132 store.retain(|key, _| !matches_pattern(key, pattern));
133 }
134
135 pub fn refresh_trigger(&self) -> &RefreshTrigger {
136 &self.refresh_trigger
137 }
138
139 pub async fn size(&self) -> usize {
141 let store = self.store.read().await;
142 store.len()
143 }
144}
145
146impl Default for RefreshTrigger {
147 fn default() -> Self {
148 Self::new()
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_matches_pattern_exact() {
158 assert!(matches_pattern("GET:/api/users", "GET:/api/users"));
159 assert!(!matches_pattern("GET:/api/users", "GET:/api/posts"));
160 }
161
162 #[test]
163 fn test_matches_pattern_wildcard() {
164 assert!(matches_pattern("GET:/api/users", "GET:/api/*"));
166 assert!(matches_pattern("GET:/api/users/123", "GET:/api/*"));
167 assert!(!matches_pattern("GET:/v2/users", "GET:/api/*"));
168
169 assert!(matches_pattern("GET:/api/users", "*/users"));
171 assert!(matches_pattern("POST:/v2/users", "*/users"));
172 assert!(!matches_pattern("GET:/api/posts", "*/users"));
173
174 assert!(matches_pattern("GET:/api/v1/users", "GET:/api/*/users"));
176 assert!(matches_pattern("GET:/api/v2/users", "GET:/api/*/users"));
177 assert!(!matches_pattern("GET:/api/v1/posts", "GET:/api/*/users"));
178
179 assert!(matches_pattern("GET:/api/v1/users/123", "GET:*/users/*"));
181 assert!(matches_pattern("POST:/v2/admin/users/456", "*/users/*"));
182 }
183
184 #[test]
185 fn test_matches_pattern_wildcard_only() {
186 assert!(matches_pattern("GET:/api/users", "*"));
187 assert!(matches_pattern("POST:/anything", "*"));
188 }
189}