Skip to main content

oxigdal_cache_advanced/coherency/
protocol.rs

1//! Cache coherency protocols for distributed caching
2//!
3//! Implements various coherency protocols:
4//! - MSI protocol (Modified, Shared, Invalid)
5//! - MESI protocol (Modified, Exclusive, Shared, Invalid)
6//! - Directory-based coherency for large clusters
7//! - Invalidation batching for performance
8
9use crate::error::Result;
10use crate::multi_tier::CacheKey;
11use std::collections::{HashMap, HashSet};
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15/// Cache line state in MSI protocol
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum MSIState {
18    /// Modified - cache has exclusive ownership and has been modified
19    Modified,
20    /// Shared - cache has a valid copy, may be shared with others
21    Shared,
22    /// Invalid - cache line is not valid
23    Invalid,
24}
25
26/// Cache line state in MESI protocol
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum MESIState {
29    /// Modified - exclusive ownership, modified
30    Modified,
31    /// Exclusive - exclusive ownership, not modified
32    Exclusive,
33    /// Shared - valid copy, may be shared
34    Shared,
35    /// Invalid - not valid
36    Invalid,
37}
38
39/// Coherency message types
40#[derive(Debug, Clone)]
41pub enum CoherencyMessage {
42    /// Read request
43    Read(CacheKey),
44    /// Write request
45    Write(CacheKey),
46    /// Invalidate request
47    Invalidate(CacheKey),
48    /// Invalidate acknowledgment
49    InvalidateAck(CacheKey),
50    /// Write-back notification
51    WriteBack(CacheKey),
52    /// Shared response
53    Shared(CacheKey),
54}
55
56/// MSI coherency protocol implementation
57pub struct MSIProtocol {
58    /// Cache line states
59    states: Arc<RwLock<HashMap<CacheKey, MSIState>>>,
60    /// Node ID
61    #[allow(dead_code)]
62    node_id: String,
63    /// Other nodes in the system
64    peer_nodes: Arc<RwLock<HashSet<String>>>,
65    /// Pending invalidations
66    pending_invalidations: Arc<RwLock<HashMap<CacheKey, HashSet<String>>>>,
67}
68
69impl MSIProtocol {
70    /// Create new MSI protocol instance
71    pub fn new(node_id: String) -> Self {
72        Self {
73            states: Arc::new(RwLock::new(HashMap::new())),
74            node_id,
75            peer_nodes: Arc::new(RwLock::new(HashSet::new())),
76            pending_invalidations: Arc::new(RwLock::new(HashMap::new())),
77        }
78    }
79
80    /// Add peer node
81    pub async fn add_peer(&self, peer_id: String) {
82        self.peer_nodes.write().await.insert(peer_id);
83    }
84
85    /// Remove peer node
86    pub async fn remove_peer(&self, peer_id: &str) {
87        self.peer_nodes.write().await.remove(peer_id);
88    }
89
90    /// Get current state of a cache line
91    pub async fn get_state(&self, key: &CacheKey) -> MSIState {
92        self.states
93            .read()
94            .await
95            .get(key)
96            .copied()
97            .unwrap_or(MSIState::Invalid)
98    }
99
100    /// Handle read request
101    pub async fn handle_read(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
102        let state = self.get_state(key).await;
103        let mut messages = Vec::new();
104
105        match state {
106            MSIState::Modified | MSIState::Shared => {
107                // Already have valid copy, no action needed
108                Ok(messages)
109            }
110            MSIState::Invalid => {
111                // Request from other nodes
112                messages.push(CoherencyMessage::Read(key.clone()));
113
114                // Transition to Shared state
115                self.states
116                    .write()
117                    .await
118                    .insert(key.clone(), MSIState::Shared);
119
120                Ok(messages)
121            }
122        }
123    }
124
125    /// Handle write request
126    pub async fn handle_write(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
127        let state = self.get_state(key).await;
128        let mut messages = Vec::new();
129
130        match state {
131            MSIState::Modified => {
132                // Already have exclusive access
133                Ok(messages)
134            }
135            MSIState::Shared => {
136                // Need to invalidate all other copies
137                let peers = self.peer_nodes.read().await;
138                for _peer in peers.iter() {
139                    messages.push(CoherencyMessage::Invalidate(key.clone()));
140                }
141
142                // Track pending invalidations
143                self.pending_invalidations
144                    .write()
145                    .await
146                    .insert(key.clone(), peers.clone());
147
148                // Transition to Modified state
149                self.states
150                    .write()
151                    .await
152                    .insert(key.clone(), MSIState::Modified);
153
154                Ok(messages)
155            }
156            MSIState::Invalid => {
157                // Request exclusive access
158                let peers = self.peer_nodes.read().await;
159                for _peer in peers.iter() {
160                    messages.push(CoherencyMessage::Invalidate(key.clone()));
161                }
162
163                self.pending_invalidations
164                    .write()
165                    .await
166                    .insert(key.clone(), peers.clone());
167
168                self.states
169                    .write()
170                    .await
171                    .insert(key.clone(), MSIState::Modified);
172
173                Ok(messages)
174            }
175        }
176    }
177
178    /// Handle invalidation request from remote node
179    pub async fn handle_remote_invalidate(&self, key: &CacheKey) -> Result<CoherencyMessage> {
180        let state = self.get_state(key).await;
181
182        match state {
183            MSIState::Modified => {
184                // Need to write back modified data
185                self.states
186                    .write()
187                    .await
188                    .insert(key.clone(), MSIState::Invalid);
189                Ok(CoherencyMessage::WriteBack(key.clone()))
190            }
191            MSIState::Shared => {
192                // Just invalidate
193                self.states
194                    .write()
195                    .await
196                    .insert(key.clone(), MSIState::Invalid);
197                Ok(CoherencyMessage::InvalidateAck(key.clone()))
198            }
199            MSIState::Invalid => {
200                // Already invalid
201                Ok(CoherencyMessage::InvalidateAck(key.clone()))
202            }
203        }
204    }
205
206    /// Handle invalidation acknowledgment
207    pub async fn handle_invalidate_ack(&self, key: &CacheKey, from_node: &str) {
208        let mut pending = self.pending_invalidations.write().await;
209        if let Some(waiting) = pending.get_mut(key) {
210            waiting.remove(from_node);
211            if waiting.is_empty() {
212                pending.remove(key);
213            }
214        }
215    }
216
217    /// Check if invalidations are complete
218    pub async fn invalidations_complete(&self, key: &CacheKey) -> bool {
219        let pending = self.pending_invalidations.read().await;
220        !pending.contains_key(key)
221    }
222
223    /// Evict cache line
224    pub async fn evict(&self, key: &CacheKey) -> Result<Option<CoherencyMessage>> {
225        let state = self.get_state(key).await;
226
227        match state {
228            MSIState::Modified => {
229                // Write back modified data
230                self.states.write().await.remove(key);
231                Ok(Some(CoherencyMessage::WriteBack(key.clone())))
232            }
233            MSIState::Shared | MSIState::Invalid => {
234                // No write-back needed
235                self.states.write().await.remove(key);
236                Ok(None)
237            }
238        }
239    }
240}
241
242/// MESI coherency protocol implementation
243pub struct MESIProtocol {
244    /// Cache line states
245    states: Arc<RwLock<HashMap<CacheKey, MESIState>>>,
246    /// Node ID
247    #[allow(dead_code)]
248    node_id: String,
249    /// Peer nodes
250    peer_nodes: Arc<RwLock<HashSet<String>>>,
251    /// Pending invalidations
252    pending_invalidations: Arc<RwLock<HashMap<CacheKey, HashSet<String>>>>,
253}
254
255impl MESIProtocol {
256    /// Create new MESI protocol instance
257    pub fn new(node_id: String) -> Self {
258        Self {
259            states: Arc::new(RwLock::new(HashMap::new())),
260            node_id,
261            peer_nodes: Arc::new(RwLock::new(HashSet::new())),
262            pending_invalidations: Arc::new(RwLock::new(HashMap::new())),
263        }
264    }
265
266    /// Add peer node
267    pub async fn add_peer(&self, peer_id: String) {
268        self.peer_nodes.write().await.insert(peer_id);
269    }
270
271    /// Get current state
272    pub async fn get_state(&self, key: &CacheKey) -> MESIState {
273        self.states
274            .read()
275            .await
276            .get(key)
277            .copied()
278            .unwrap_or(MESIState::Invalid)
279    }
280
281    /// Handle read request
282    pub async fn handle_read(
283        &self,
284        key: &CacheKey,
285        has_other_copy: bool,
286    ) -> Result<Vec<CoherencyMessage>> {
287        let state = self.get_state(key).await;
288        let mut messages = Vec::new();
289
290        match state {
291            MESIState::Modified | MESIState::Exclusive | MESIState::Shared => {
292                // Already have valid copy
293                Ok(messages)
294            }
295            MESIState::Invalid => {
296                messages.push(CoherencyMessage::Read(key.clone()));
297
298                // Transition based on whether other copies exist
299                let new_state = if has_other_copy {
300                    MESIState::Shared
301                } else {
302                    MESIState::Exclusive
303                };
304
305                self.states.write().await.insert(key.clone(), new_state);
306                Ok(messages)
307            }
308        }
309    }
310
311    /// Handle write request
312    pub async fn handle_write(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
313        let state = self.get_state(key).await;
314        let mut messages = Vec::new();
315
316        match state {
317            MESIState::Modified => {
318                // Already have exclusive modified access
319                Ok(messages)
320            }
321            MESIState::Exclusive => {
322                // Upgrade to Modified
323                self.states
324                    .write()
325                    .await
326                    .insert(key.clone(), MESIState::Modified);
327                Ok(messages)
328            }
329            MESIState::Shared | MESIState::Invalid => {
330                // Invalidate all other copies
331                let peers = self.peer_nodes.read().await;
332                for _peer in peers.iter() {
333                    messages.push(CoherencyMessage::Invalidate(key.clone()));
334                }
335
336                self.pending_invalidations
337                    .write()
338                    .await
339                    .insert(key.clone(), peers.clone());
340
341                self.states
342                    .write()
343                    .await
344                    .insert(key.clone(), MESIState::Modified);
345
346                Ok(messages)
347            }
348        }
349    }
350
351    /// Handle remote read request
352    pub async fn handle_remote_read(&self, key: &CacheKey) -> Result<CoherencyMessage> {
353        let state = self.get_state(key).await;
354
355        match state {
356            MESIState::Modified => {
357                // Downgrade to Shared and provide data
358                self.states
359                    .write()
360                    .await
361                    .insert(key.clone(), MESIState::Shared);
362                Ok(CoherencyMessage::Shared(key.clone()))
363            }
364            MESIState::Exclusive => {
365                // Downgrade to Shared
366                self.states
367                    .write()
368                    .await
369                    .insert(key.clone(), MESIState::Shared);
370                Ok(CoherencyMessage::Shared(key.clone()))
371            }
372            MESIState::Shared => {
373                // Already shared
374                Ok(CoherencyMessage::Shared(key.clone()))
375            }
376            MESIState::Invalid => {
377                // No valid copy
378                Ok(CoherencyMessage::InvalidateAck(key.clone()))
379            }
380        }
381    }
382
383    /// Evict cache line
384    pub async fn evict(&self, key: &CacheKey) -> Result<Option<CoherencyMessage>> {
385        let state = self.get_state(key).await;
386
387        match state {
388            MESIState::Modified => {
389                self.states.write().await.remove(key);
390                Ok(Some(CoherencyMessage::WriteBack(key.clone())))
391            }
392            _ => {
393                self.states.write().await.remove(key);
394                Ok(None)
395            }
396        }
397    }
398}
399
400/// Directory-based coherency for large-scale systems
401pub struct DirectoryCoherency {
402    /// Directory entries (key -> set of nodes with copies)
403    directory: Arc<RwLock<HashMap<CacheKey, HashSet<String>>>>,
404    /// Modified state tracking (key -> node with modified copy)
405    modified_by: Arc<RwLock<HashMap<CacheKey, String>>>,
406    /// Local node ID
407    node_id: String,
408}
409
410impl DirectoryCoherency {
411    /// Create new directory coherency
412    pub fn new(node_id: String) -> Self {
413        Self {
414            directory: Arc::new(RwLock::new(HashMap::new())),
415            modified_by: Arc::new(RwLock::new(HashMap::new())),
416            node_id,
417        }
418    }
419
420    /// Handle read request
421    pub async fn handle_read(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
422        let mut dir = self.directory.write().await;
423        let modified = self.modified_by.read().await;
424
425        let mut messages = Vec::new();
426
427        if let Some(_modifier) = modified.get(key) {
428            // Request data from modifier
429            messages.push(CoherencyMessage::Read(key.clone()));
430        }
431
432        // Add this node to sharers
433        dir.entry(key.clone())
434            .or_insert_with(HashSet::new)
435            .insert(self.node_id.clone());
436
437        Ok(messages)
438    }
439
440    /// Handle write request
441    pub async fn handle_write(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
442        let mut dir = self.directory.write().await;
443        let mut modified = self.modified_by.write().await;
444
445        let mut messages = Vec::new();
446
447        // Invalidate all sharers
448        if let Some(sharers) = dir.get(key) {
449            for sharer in sharers.iter() {
450                if sharer != &self.node_id {
451                    messages.push(CoherencyMessage::Invalidate(key.clone()));
452                }
453            }
454        }
455
456        // Mark as modified by this node
457        modified.insert(key.clone(), self.node_id.clone());
458
459        // Clear sharers
460        dir.insert(key.clone(), {
461            let mut set = HashSet::new();
462            set.insert(self.node_id.clone());
463            set
464        });
465
466        Ok(messages)
467    }
468
469    /// Handle invalidation acknowledgment
470    pub async fn handle_invalidate_ack(&self, key: &CacheKey, from_node: &str) {
471        let mut dir = self.directory.write().await;
472        if let Some(sharers) = dir.get_mut(key) {
473            sharers.remove(from_node);
474        }
475    }
476
477    /// Get nodes with copies
478    pub async fn get_sharers(&self, key: &CacheKey) -> HashSet<String> {
479        self.directory
480            .read()
481            .await
482            .get(key)
483            .cloned()
484            .unwrap_or_default()
485    }
486}
487
488/// Batched invalidation for performance
489pub struct InvalidationBatcher {
490    /// Pending invalidations
491    pending: Arc<RwLock<HashMap<String, HashSet<CacheKey>>>>,
492    /// Batch size threshold
493    batch_size: usize,
494}
495
496impl InvalidationBatcher {
497    /// Create new invalidation batcher
498    pub fn new(batch_size: usize) -> Self {
499        Self {
500            pending: Arc::new(RwLock::new(HashMap::new())),
501            batch_size,
502        }
503    }
504
505    /// Add invalidation to batch
506    pub async fn add_invalidation(&self, node: String, key: CacheKey) -> Option<Vec<CacheKey>> {
507        let mut pending = self.pending.write().await;
508        let keys = pending.entry(node.clone()).or_insert_with(HashSet::new);
509
510        keys.insert(key);
511
512        // Flush if batch size reached
513        if keys.len() >= self.batch_size {
514            let batch: Vec<CacheKey> = keys.iter().cloned().collect();
515            keys.clear();
516            Some(batch)
517        } else {
518            None
519        }
520    }
521
522    /// Flush all pending invalidations
523    pub async fn flush(&self) -> HashMap<String, Vec<CacheKey>> {
524        let mut pending = self.pending.write().await;
525        let result: HashMap<String, Vec<CacheKey>> = pending
526            .iter()
527            .map(|(node, keys)| (node.clone(), keys.iter().cloned().collect()))
528            .collect();
529
530        pending.clear();
531        result
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    #[tokio::test]
540    async fn test_msi_protocol() {
541        let protocol = MSIProtocol::new("node1".to_string());
542        protocol.add_peer("node2".to_string()).await;
543
544        let key = "test_key".to_string();
545
546        // Read should transition to Shared
547        let messages = protocol.handle_read(&key).await.unwrap_or_default();
548        assert_eq!(messages.len(), 1);
549        assert_eq!(protocol.get_state(&key).await, MSIState::Shared);
550
551        // Write should send invalidations
552        let messages = protocol.handle_write(&key).await.unwrap_or_default();
553        assert!(!messages.is_empty());
554        assert_eq!(protocol.get_state(&key).await, MSIState::Modified);
555    }
556
557    #[tokio::test]
558    async fn test_mesi_protocol() {
559        let protocol = MESIProtocol::new("node1".to_string());
560        protocol.add_peer("node2".to_string()).await;
561
562        let key = "test_key".to_string();
563
564        // Read without other copies should be Exclusive
565        let _messages = protocol.handle_read(&key, false).await.unwrap_or_default();
566        assert_eq!(protocol.get_state(&key).await, MESIState::Exclusive);
567
568        // Write should upgrade to Modified
569        let _messages = protocol.handle_write(&key).await.unwrap_or_default();
570        assert_eq!(protocol.get_state(&key).await, MESIState::Modified);
571    }
572
573    #[tokio::test]
574    async fn test_directory_coherency() {
575        let dir = DirectoryCoherency::new("node1".to_string());
576        let key = "test_key".to_string();
577
578        let _messages = dir.handle_read(&key).await.unwrap_or_default();
579        let sharers = dir.get_sharers(&key).await;
580        assert!(sharers.contains("node1"));
581
582        let messages = dir.handle_write(&key).await.unwrap_or_default();
583        assert!(messages.is_empty()); // No other sharers yet
584    }
585
586    #[tokio::test]
587    async fn test_invalidation_batcher() {
588        let batcher = InvalidationBatcher::new(3);
589
590        // Add invalidations
591        let result = batcher
592            .add_invalidation("node1".to_string(), "key1".to_string())
593            .await;
594        assert!(result.is_none());
595
596        let result = batcher
597            .add_invalidation("node1".to_string(), "key2".to_string())
598            .await;
599        assert!(result.is_none());
600
601        // This should trigger flush
602        let result = batcher
603            .add_invalidation("node1".to_string(), "key3".to_string())
604            .await;
605        assert!(result.is_some());
606        let batch = result.unwrap_or_default();
607        assert_eq!(batch.len(), 3);
608    }
609}