memberlist_plumtree/message/
cache.rs1use bytes::Bytes;
7use parking_lot::RwLock;
8use std::{
9 collections::HashMap,
10 sync::Arc,
11 time::{Duration, Instant},
12};
13
14use super::MessageId;
15
16#[derive(Debug, Clone)]
18struct CacheEntry {
19 payload: Arc<Bytes>,
21 inserted_at: Instant,
23 #[allow(dead_code)]
26 access_count: u32,
27}
28
29#[derive(Debug)]
34pub struct MessageCache {
35 inner: RwLock<CacheInner>,
37 ttl: Duration,
39 max_size: usize,
41}
42
43#[derive(Debug)]
44struct CacheInner {
45 entries: HashMap<MessageId, CacheEntry>,
47 insertion_order: Vec<(MessageId, Instant)>,
49}
50
51impl MessageCache {
52 pub fn new(ttl: Duration, max_size: usize) -> Self {
54 Self {
55 inner: RwLock::new(CacheInner {
56 entries: HashMap::with_capacity(max_size.min(1024)),
57 insertion_order: Vec::with_capacity(max_size.min(1024)),
58 }),
59 ttl,
60 max_size,
61 }
62 }
63
64 pub fn insert(&self, id: MessageId, payload: Bytes) {
69 let now = Instant::now();
70 let payload = Arc::new(payload);
71
72 let mut inner = self.inner.write();
73
74 if inner.entries.contains_key(&id) {
76 return;
77 }
78
79 self.evict_expired_locked(&mut inner, now);
81
82 while inner.entries.len() >= self.max_size {
84 if let Some((old_id, _)) = inner.insertion_order.first().cloned() {
85 inner.entries.remove(&old_id);
86 inner.insertion_order.remove(0);
87 } else {
88 break;
89 }
90 }
91
92 inner.entries.insert(
94 id,
95 CacheEntry {
96 payload,
97 inserted_at: now,
98 access_count: 0,
99 },
100 );
101 inner.insertion_order.push((id, now));
102 }
103
104 pub fn get(&self, id: &MessageId) -> Option<Arc<Bytes>> {
109 let now = Instant::now();
110
111 {
113 let inner = self.inner.read();
114 if let Some(entry) = inner.entries.get(id) {
115 if now.duration_since(entry.inserted_at) <= self.ttl {
116 return Some(entry.payload.clone());
117 }
118 }
119 }
120
121 None
122 }
123
124 pub fn contains(&self, id: &MessageId) -> bool {
126 let now = Instant::now();
127 let inner = self.inner.read();
128
129 if let Some(entry) = inner.entries.get(id) {
130 now.duration_since(entry.inserted_at) <= self.ttl
131 } else {
132 false
133 }
134 }
135
136 pub fn get_many(&self, ids: &[MessageId]) -> HashMap<MessageId, Arc<Bytes>> {
140 let now = Instant::now();
141 let inner = self.inner.read();
142
143 let mut result = HashMap::with_capacity(ids.len());
144 for id in ids {
145 if let Some(entry) = inner.entries.get(id) {
146 if now.duration_since(entry.inserted_at) <= self.ttl {
147 result.insert(*id, entry.payload.clone());
148 }
149 }
150 }
151 result
152 }
153
154 pub fn remove(&self, id: &MessageId) -> Option<Arc<Bytes>> {
156 let mut inner = self.inner.write();
157
158 if let Some(entry) = inner.entries.remove(id) {
159 inner.insertion_order.retain(|(i, _)| i != id);
160 Some(entry.payload)
161 } else {
162 None
163 }
164 }
165
166 pub fn len(&self) -> usize {
168 self.inner.read().entries.len()
169 }
170
171 pub fn is_empty(&self) -> bool {
173 self.inner.read().entries.is_empty()
174 }
175
176 pub fn evict_expired(&self) {
178 let now = Instant::now();
179 let mut inner = self.inner.write();
180 self.evict_expired_locked(&mut inner, now);
181 }
182
183 fn evict_expired_locked(&self, inner: &mut CacheInner, now: Instant) {
185 let cutoff = now - self.ttl;
187 let mut remove_count = 0;
188
189 for (_, inserted_at) in &inner.insertion_order {
190 if *inserted_at < cutoff {
191 remove_count += 1;
192 } else {
193 break;
195 }
196 }
197
198 if remove_count > 0 {
199 let to_remove: Vec<_> = inner.insertion_order.drain(..remove_count).collect();
201 for (id, _) in to_remove {
202 inner.entries.remove(&id);
203 }
204 }
205 }
206
207 pub fn clear(&self) {
209 let mut inner = self.inner.write();
210 inner.entries.clear();
211 inner.insertion_order.clear();
212 }
213
214 pub fn stats(&self) -> CacheStats {
216 let inner = self.inner.read();
217 CacheStats {
218 entries: inner.entries.len(),
219 capacity: self.max_size,
220 ttl: self.ttl,
221 }
222 }
223}
224
225#[derive(Debug, Clone, Copy)]
227pub struct CacheStats {
228 pub entries: usize,
230 pub capacity: usize,
232 pub ttl: Duration,
234}
235
236impl Default for MessageCache {
237 fn default() -> Self {
238 Self::new(Duration::from_secs(60), 10000)
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_insert_and_get() {
248 let cache = MessageCache::new(Duration::from_secs(60), 100);
249
250 let id = MessageId::new();
251 let payload = Bytes::from_static(b"test payload");
252
253 cache.insert(id, payload.clone());
254
255 let retrieved = cache.get(&id).unwrap();
256 assert_eq!(&**retrieved, &payload);
257 }
258
259 #[test]
260 fn test_contains() {
261 let cache = MessageCache::new(Duration::from_secs(60), 100);
262
263 let id = MessageId::new();
264 assert!(!cache.contains(&id));
265
266 cache.insert(id, Bytes::from_static(b"test"));
267 assert!(cache.contains(&id));
268 }
269
270 #[test]
271 fn test_duplicate_insert() {
272 let cache = MessageCache::new(Duration::from_secs(60), 100);
273
274 let id = MessageId::new();
275 cache.insert(id, Bytes::from_static(b"first"));
276 cache.insert(id, Bytes::from_static(b"second"));
277
278 let retrieved = cache.get(&id).unwrap();
280 assert_eq!(&**retrieved, b"first");
281 }
282
283 #[test]
284 fn test_capacity_eviction() {
285 let cache = MessageCache::new(Duration::from_secs(60), 3);
286
287 let ids: Vec<_> = (0..5).map(|_| MessageId::new()).collect();
288
289 for (i, id) in ids.iter().enumerate() {
290 cache.insert(*id, Bytes::from(format!("payload {}", i)));
291 }
292
293 assert_eq!(cache.len(), 3);
295 assert!(!cache.contains(&ids[0]));
296 assert!(!cache.contains(&ids[1]));
297 assert!(cache.contains(&ids[2]));
298 assert!(cache.contains(&ids[3]));
299 assert!(cache.contains(&ids[4]));
300 }
301
302 #[test]
303 fn test_ttl_expiration() {
304 let cache = MessageCache::new(Duration::from_millis(50), 100);
305
306 let id = MessageId::new();
307 cache.insert(id, Bytes::from_static(b"test"));
308
309 assert!(cache.contains(&id));
310
311 std::thread::sleep(Duration::from_millis(100));
313
314 assert!(!cache.contains(&id));
315 }
316
317 #[test]
318 fn test_remove() {
319 let cache = MessageCache::new(Duration::from_secs(60), 100);
320
321 let id = MessageId::new();
322 cache.insert(id, Bytes::from_static(b"test"));
323
324 let removed = cache.remove(&id);
325 assert!(removed.is_some());
326 assert!(!cache.contains(&id));
327 }
328
329 #[test]
330 fn test_get_many() {
331 let cache = MessageCache::new(Duration::from_secs(60), 100);
332
333 let ids: Vec<_> = (0..5).map(|_| MessageId::new()).collect();
334 for (i, id) in ids.iter().enumerate() {
335 cache.insert(*id, Bytes::from(format!("payload {}", i)));
336 }
337
338 let result = cache.get_many(&ids[1..4]);
339 assert_eq!(result.len(), 3);
340 }
341
342 #[test]
343 fn test_clear() {
344 let cache = MessageCache::new(Duration::from_secs(60), 100);
345
346 for _ in 0..10 {
347 cache.insert(MessageId::new(), Bytes::from_static(b"test"));
348 }
349
350 assert_eq!(cache.len(), 10);
351 cache.clear();
352 assert!(cache.is_empty());
353 }
354}