1use std::collections::hash_map::RandomState;
6use std::hash::{BuildHasher, Hash, Hasher};
7use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
8use std::time::{Duration, Instant};
9
10use dashmap::DashMap;
11
12#[derive(Debug)]
14pub struct CacheEntry {
15 pub data: Vec<u8>,
17 pub created_at: Instant,
19 pub ttl: Duration,
21 pub access_count: AtomicU64,
23 pub last_access: AtomicU64,
25}
26
27impl CacheEntry {
28 pub fn new(data: Vec<u8>, ttl: Duration) -> Self {
30 let now = Instant::now().duration_since(Instant::now()).as_secs();
31 Self {
32 data,
33 created_at: Instant::now(),
34 ttl,
35 access_count: AtomicU64::new(0),
36 last_access: AtomicU64::new(now),
37 }
38 }
39
40 pub fn is_expired(&self) -> bool {
42 self.created_at.elapsed() > self.ttl
43 }
44
45 pub fn record_access(&self) {
47 self.access_count.fetch_add(1, Ordering::Relaxed);
48 let now = Instant::now().duration_since(Instant::now()).as_secs();
49 self.last_access.store(now, Ordering::Relaxed);
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct CacheConfig {
56 pub default_ttl: Duration,
58 pub max_entries: usize,
60 pub max_memory_bytes: usize,
62}
63
64impl Default for CacheConfig {
65 fn default() -> Self {
66 Self {
67 default_ttl: Duration::from_secs(60),
68 max_entries: 10000,
69 max_memory_bytes: 256 * 1024 * 1024, }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct CacheStats {
77 pub hits: u64,
78 pub misses: u64,
79 pub evictions: u64,
80 pub entries: usize,
81 pub memory_bytes: usize,
82 pub hit_rate: f64,
83}
84
85#[derive(Debug)]
87pub struct RequestCache {
88 entries: DashMap<u64, CacheEntry>,
90 config: CacheConfig,
92 hits: AtomicU64,
94 misses: AtomicU64,
96 evictions: AtomicU64,
98 memory_bytes: AtomicUsize,
100}
101
102impl RequestCache {
103 pub fn new(config: CacheConfig) -> Self {
105 Self {
106 entries: DashMap::new(),
107 config,
108 hits: AtomicU64::new(0),
109 misses: AtomicU64::new(0),
110 evictions: AtomicU64::new(0),
111 memory_bytes: AtomicUsize::new(0),
112 }
113 }
114
115 pub fn generate_key(method: &str, args: &[u8]) -> u64 {
117 let hasher = RandomState::new().build_hasher();
118 let mut hasher = hasher;
119 method.hash(&mut hasher);
120 args.hash(&mut hasher);
121 hasher.finish()
122 }
123
124 pub fn get(&self, key: u64) -> Option<Vec<u8>> {
126 let entry = self.entries.get(&key)?;
127
128 if entry.is_expired() {
129 drop(entry);
130 self.remove(key);
131 return None;
132 }
133
134 entry.record_access();
135 self.hits.fetch_add(1, Ordering::Relaxed);
136 Some(entry.data.clone())
137 }
138
139 pub fn set(&self, key: u64, data: Vec<u8>, ttl: Option<Duration>) {
141 if self.entries.len() >= self.config.max_entries {
143 self.evict_one();
144 }
145
146 let entry_size = data.len();
147 if self.memory_bytes.load(Ordering::Relaxed) + entry_size > self.config.max_memory_bytes {
148 self.evict_until_fits(entry_size);
149 }
150
151 let ttl = ttl.unwrap_or(self.config.default_ttl);
152 let entry = CacheEntry::new(data, ttl);
153 let entry_size = entry.data.len();
154
155 self.entries.insert(key, entry);
156 self.memory_bytes.fetch_add(entry_size, Ordering::Relaxed);
157 }
158
159 pub fn remove(&self, key: u64) -> Option<CacheEntry> {
161 if let Some((_, entry)) = self.entries.remove(&key) {
162 self.memory_bytes.fetch_sub(entry.data.len(), Ordering::Relaxed);
163 self.evictions.fetch_add(1, Ordering::Relaxed);
164 Some(entry)
165 } else {
166 None
167 }
168 }
169
170 pub fn clear(&self) {
172 self.entries.clear();
173 self.memory_bytes.store(0, Ordering::Relaxed);
174 }
175
176 fn evict_one(&self) {
178 let mut oldest: Option<(u64, u64)> = None;
179
180 for entry in self.entries.iter() {
181 let last_access = entry.value().last_access.load(Ordering::Relaxed);
182 if oldest.is_none() || last_access < oldest.unwrap().1 {
183 oldest = Some((*entry.key(), last_access));
184 }
185 }
186
187 if let Some((key, _)) = oldest {
188 self.remove(key);
189 }
190 }
191
192 fn evict_until_fits(&self, size: usize) {
194 while self.memory_bytes.load(Ordering::Relaxed) + size > self.config.max_memory_bytes {
195 self.evict_one();
196 if self.entries.is_empty() {
197 break;
198 }
199 }
200 }
201
202 pub fn stats(&self) -> CacheStats {
204 let hits = self.hits.load(Ordering::Relaxed);
205 let misses = self.misses.load(Ordering::Relaxed);
206 let total = hits + misses;
207
208 CacheStats {
209 hits,
210 misses,
211 evictions: self.evictions.load(Ordering::Relaxed),
212 entries: self.entries.len(),
213 memory_bytes: self.memory_bytes.load(Ordering::Relaxed),
214 hit_rate: if total > 0 { hits as f64 / total as f64 } else { 0.0 },
215 }
216 }
217
218 pub fn record_miss(&self) {
220 self.misses.fetch_add(1, Ordering::Relaxed);
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_cache_set_get() {
230 let cache = RequestCache::new(CacheConfig::default());
231 let key = RequestCache::generate_key("method", b"args");
232
233 cache.set(key, b"response".to_vec(), None);
234
235 let result = cache.get(key);
236 assert_eq!(result, Some(b"response".to_vec()));
237 }
238
239 #[test]
240 fn test_cache_miss() {
241 let cache = RequestCache::new(CacheConfig::default());
242
243 let result = cache.get(12345);
244 assert_eq!(result, None);
245
246 cache.record_miss();
249
250 let stats = cache.stats();
251 assert_eq!(stats.misses, 1);
252 }
253
254 #[test]
255 fn test_cache_expiration() {
256 let config = CacheConfig {
257 default_ttl: Duration::from_millis(100),
258 ..CacheConfig::default()
259 };
260 let cache = RequestCache::new(config);
261 let key = RequestCache::generate_key("method", b"args");
262
263 cache.set(key, b"response".to_vec(), None);
264
265 std::thread::sleep(Duration::from_millis(150));
267
268 let result = cache.get(key);
269 assert_eq!(result, None);
270 }
271
272 #[test]
273 fn test_cache_clear() {
274 let cache = RequestCache::new(CacheConfig::default());
275
276 for i in 0..10 {
277 cache.set(i, vec![i as u8], None);
278 }
279
280 assert_eq!(cache.stats().entries, 10);
281
282 cache.clear();
283
284 assert_eq!(cache.stats().entries, 0);
285 }
286}