memberlist_plumtree/message/
cache.rs

1//! Message cache for storing messages for Graft requests.
2//!
3//! Uses a combination of HashMap and time-ordered eviction to provide
4//! efficient O(1) lookups while managing memory usage.
5
6use bytes::Bytes;
7use parking_lot::RwLock;
8use std::{
9    collections::HashMap,
10    sync::Arc,
11    time::{Duration, Instant},
12};
13
14use super::MessageId;
15
16/// Entry in the message cache.
17#[derive(Debug, Clone)]
18struct CacheEntry {
19    /// Message payload (Arc for zero-copy sharing).
20    payload: Arc<Bytes>,
21    /// When this entry was inserted.
22    inserted_at: Instant,
23    /// Number of times this message has been accessed.
24    /// Reserved for future use in cache eviction policies.
25    #[allow(dead_code)]
26    access_count: u32,
27}
28
29/// Thread-safe message cache with TTL-based eviction.
30///
31/// Messages are stored for a configurable duration to serve Graft
32/// requests from nodes that missed the initial broadcast.
33#[derive(Debug)]
34pub struct MessageCache {
35    /// Inner cache state protected by RwLock.
36    inner: RwLock<CacheInner>,
37    /// Time-to-live for cache entries.
38    ttl: Duration,
39    /// Maximum number of entries.
40    max_size: usize,
41}
42
43#[derive(Debug)]
44struct CacheInner {
45    /// Map from message ID to cache entry.
46    entries: HashMap<MessageId, CacheEntry>,
47    /// Insertion order for LRU eviction (oldest first).
48    insertion_order: Vec<(MessageId, Instant)>,
49}
50
51impl MessageCache {
52    /// Create a new message cache with the specified TTL and max size.
53    pub fn new(ttl: Duration, max_size: usize) -> Self {
54        Self {
55            inner: RwLock::new(CacheInner {
56                entries: HashMap::with_capacity(max_size.min(1024)),
57                insertion_order: Vec::with_capacity(max_size.min(1024)),
58            }),
59            ttl,
60            max_size,
61        }
62    }
63
64    /// Insert a message into the cache.
65    ///
66    /// If the message already exists, this is a no-op.
67    /// If the cache is full, oldest entries are evicted.
68    pub fn insert(&self, id: MessageId, payload: Bytes) {
69        let now = Instant::now();
70        let payload = Arc::new(payload);
71
72        let mut inner = self.inner.write();
73
74        // Check if already exists
75        if inner.entries.contains_key(&id) {
76            return;
77        }
78
79        // Evict expired entries first
80        self.evict_expired_locked(&mut inner, now);
81
82        // Evict oldest if still over capacity
83        while inner.entries.len() >= self.max_size {
84            if let Some((old_id, _)) = inner.insertion_order.first().cloned() {
85                inner.entries.remove(&old_id);
86                inner.insertion_order.remove(0);
87            } else {
88                break;
89            }
90        }
91
92        // Insert new entry
93        inner.entries.insert(
94            id,
95            CacheEntry {
96                payload,
97                inserted_at: now,
98                access_count: 0,
99            },
100        );
101        inner.insertion_order.push((id, now));
102    }
103
104    /// Get a message from the cache.
105    ///
106    /// Returns `None` if the message is not in the cache or has expired.
107    /// Returns a cloned Arc to avoid holding locks during message sending.
108    pub fn get(&self, id: &MessageId) -> Option<Arc<Bytes>> {
109        let now = Instant::now();
110
111        // Fast path: read lock
112        {
113            let inner = self.inner.read();
114            if let Some(entry) = inner.entries.get(id) {
115                if now.duration_since(entry.inserted_at) <= self.ttl {
116                    return Some(entry.payload.clone());
117                }
118            }
119        }
120
121        None
122    }
123
124    /// Check if a message exists in the cache (without returning it).
125    pub fn contains(&self, id: &MessageId) -> bool {
126        let now = Instant::now();
127        let inner = self.inner.read();
128
129        if let Some(entry) = inner.entries.get(id) {
130            now.duration_since(entry.inserted_at) <= self.ttl
131        } else {
132            false
133        }
134    }
135
136    /// Get multiple messages from the cache.
137    ///
138    /// Returns a map of found message IDs to their payloads.
139    pub fn get_many(&self, ids: &[MessageId]) -> HashMap<MessageId, Arc<Bytes>> {
140        let now = Instant::now();
141        let inner = self.inner.read();
142
143        let mut result = HashMap::with_capacity(ids.len());
144        for id in ids {
145            if let Some(entry) = inner.entries.get(id) {
146                if now.duration_since(entry.inserted_at) <= self.ttl {
147                    result.insert(*id, entry.payload.clone());
148                }
149            }
150        }
151        result
152    }
153
154    /// Remove a message from the cache.
155    pub fn remove(&self, id: &MessageId) -> Option<Arc<Bytes>> {
156        let mut inner = self.inner.write();
157
158        if let Some(entry) = inner.entries.remove(id) {
159            inner.insertion_order.retain(|(i, _)| i != id);
160            Some(entry.payload)
161        } else {
162            None
163        }
164    }
165
166    /// Get the number of entries currently in the cache.
167    pub fn len(&self) -> usize {
168        self.inner.read().entries.len()
169    }
170
171    /// Check if the cache is empty.
172    pub fn is_empty(&self) -> bool {
173        self.inner.read().entries.is_empty()
174    }
175
176    /// Remove all expired entries from the cache.
177    pub fn evict_expired(&self) {
178        let now = Instant::now();
179        let mut inner = self.inner.write();
180        self.evict_expired_locked(&mut inner, now);
181    }
182
183    /// Internal method to evict expired entries while holding the lock.
184    fn evict_expired_locked(&self, inner: &mut CacheInner, now: Instant) {
185        // Find cutoff point in insertion order
186        let cutoff = now - self.ttl;
187        let mut remove_count = 0;
188
189        for (_, inserted_at) in &inner.insertion_order {
190            if *inserted_at < cutoff {
191                remove_count += 1;
192            } else {
193                // insertion_order is sorted by time, so we can stop here
194                break;
195            }
196        }
197
198        if remove_count > 0 {
199            // Remove expired entries
200            let to_remove: Vec<_> = inner.insertion_order.drain(..remove_count).collect();
201            for (id, _) in to_remove {
202                inner.entries.remove(&id);
203            }
204        }
205    }
206
207    /// Clear all entries from the cache.
208    pub fn clear(&self) {
209        let mut inner = self.inner.write();
210        inner.entries.clear();
211        inner.insertion_order.clear();
212    }
213
214    /// Get cache statistics.
215    pub fn stats(&self) -> CacheStats {
216        let inner = self.inner.read();
217        CacheStats {
218            entries: inner.entries.len(),
219            capacity: self.max_size,
220            ttl: self.ttl,
221        }
222    }
223}
224
225/// Statistics about the message cache.
226#[derive(Debug, Clone, Copy)]
227pub struct CacheStats {
228    /// Number of entries currently in the cache.
229    pub entries: usize,
230    /// Maximum capacity of the cache.
231    pub capacity: usize,
232    /// Time-to-live for cache entries.
233    pub ttl: Duration,
234}
235
236impl Default for MessageCache {
237    fn default() -> Self {
238        Self::new(Duration::from_secs(60), 10000)
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_insert_and_get() {
248        let cache = MessageCache::new(Duration::from_secs(60), 100);
249
250        let id = MessageId::new();
251        let payload = Bytes::from_static(b"test payload");
252
253        cache.insert(id, payload.clone());
254
255        let retrieved = cache.get(&id).unwrap();
256        assert_eq!(&**retrieved, &payload);
257    }
258
259    #[test]
260    fn test_contains() {
261        let cache = MessageCache::new(Duration::from_secs(60), 100);
262
263        let id = MessageId::new();
264        assert!(!cache.contains(&id));
265
266        cache.insert(id, Bytes::from_static(b"test"));
267        assert!(cache.contains(&id));
268    }
269
270    #[test]
271    fn test_duplicate_insert() {
272        let cache = MessageCache::new(Duration::from_secs(60), 100);
273
274        let id = MessageId::new();
275        cache.insert(id, Bytes::from_static(b"first"));
276        cache.insert(id, Bytes::from_static(b"second"));
277
278        // Should still have the first payload
279        let retrieved = cache.get(&id).unwrap();
280        assert_eq!(&**retrieved, b"first");
281    }
282
283    #[test]
284    fn test_capacity_eviction() {
285        let cache = MessageCache::new(Duration::from_secs(60), 3);
286
287        let ids: Vec<_> = (0..5).map(|_| MessageId::new()).collect();
288
289        for (i, id) in ids.iter().enumerate() {
290            cache.insert(*id, Bytes::from(format!("payload {}", i)));
291        }
292
293        // Only last 3 should remain
294        assert_eq!(cache.len(), 3);
295        assert!(!cache.contains(&ids[0]));
296        assert!(!cache.contains(&ids[1]));
297        assert!(cache.contains(&ids[2]));
298        assert!(cache.contains(&ids[3]));
299        assert!(cache.contains(&ids[4]));
300    }
301
302    #[test]
303    fn test_ttl_expiration() {
304        let cache = MessageCache::new(Duration::from_millis(50), 100);
305
306        let id = MessageId::new();
307        cache.insert(id, Bytes::from_static(b"test"));
308
309        assert!(cache.contains(&id));
310
311        // Wait for TTL to expire
312        std::thread::sleep(Duration::from_millis(100));
313
314        assert!(!cache.contains(&id));
315    }
316
317    #[test]
318    fn test_remove() {
319        let cache = MessageCache::new(Duration::from_secs(60), 100);
320
321        let id = MessageId::new();
322        cache.insert(id, Bytes::from_static(b"test"));
323
324        let removed = cache.remove(&id);
325        assert!(removed.is_some());
326        assert!(!cache.contains(&id));
327    }
328
329    #[test]
330    fn test_get_many() {
331        let cache = MessageCache::new(Duration::from_secs(60), 100);
332
333        let ids: Vec<_> = (0..5).map(|_| MessageId::new()).collect();
334        for (i, id) in ids.iter().enumerate() {
335            cache.insert(*id, Bytes::from(format!("payload {}", i)));
336        }
337
338        let result = cache.get_many(&ids[1..4]);
339        assert_eq!(result.len(), 3);
340    }
341
342    #[test]
343    fn test_clear() {
344        let cache = MessageCache::new(Duration::from_secs(60), 100);
345
346        for _ in 0..10 {
347            cache.insert(MessageId::new(), Bytes::from_static(b"test"));
348        }
349
350        assert_eq!(cache.len(), 10);
351        cache.clear();
352        assert!(cache.is_empty());
353    }
354}