oxigdal_cache_advanced/coherency/
protocol.rs1use crate::error::Result;
10use crate::multi_tier::CacheKey;
11use std::collections::{HashMap, HashSet};
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum MSIState {
18 Modified,
20 Shared,
22 Invalid,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum MESIState {
29 Modified,
31 Exclusive,
33 Shared,
35 Invalid,
37}
38
39#[derive(Debug, Clone)]
41pub enum CoherencyMessage {
42 Read(CacheKey),
44 Write(CacheKey),
46 Invalidate(CacheKey),
48 InvalidateAck(CacheKey),
50 WriteBack(CacheKey),
52 Shared(CacheKey),
54}
55
56pub struct MSIProtocol {
58 states: Arc<RwLock<HashMap<CacheKey, MSIState>>>,
60 #[allow(dead_code)]
62 node_id: String,
63 peer_nodes: Arc<RwLock<HashSet<String>>>,
65 pending_invalidations: Arc<RwLock<HashMap<CacheKey, HashSet<String>>>>,
67}
68
69impl MSIProtocol {
70 pub fn new(node_id: String) -> Self {
72 Self {
73 states: Arc::new(RwLock::new(HashMap::new())),
74 node_id,
75 peer_nodes: Arc::new(RwLock::new(HashSet::new())),
76 pending_invalidations: Arc::new(RwLock::new(HashMap::new())),
77 }
78 }
79
80 pub async fn add_peer(&self, peer_id: String) {
82 self.peer_nodes.write().await.insert(peer_id);
83 }
84
85 pub async fn remove_peer(&self, peer_id: &str) {
87 self.peer_nodes.write().await.remove(peer_id);
88 }
89
90 pub async fn get_state(&self, key: &CacheKey) -> MSIState {
92 self.states
93 .read()
94 .await
95 .get(key)
96 .copied()
97 .unwrap_or(MSIState::Invalid)
98 }
99
100 pub async fn handle_read(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
102 let state = self.get_state(key).await;
103 let mut messages = Vec::new();
104
105 match state {
106 MSIState::Modified | MSIState::Shared => {
107 Ok(messages)
109 }
110 MSIState::Invalid => {
111 messages.push(CoherencyMessage::Read(key.clone()));
113
114 self.states
116 .write()
117 .await
118 .insert(key.clone(), MSIState::Shared);
119
120 Ok(messages)
121 }
122 }
123 }
124
125 pub async fn handle_write(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
127 let state = self.get_state(key).await;
128 let mut messages = Vec::new();
129
130 match state {
131 MSIState::Modified => {
132 Ok(messages)
134 }
135 MSIState::Shared => {
136 let peers = self.peer_nodes.read().await;
138 for _peer in peers.iter() {
139 messages.push(CoherencyMessage::Invalidate(key.clone()));
140 }
141
142 self.pending_invalidations
144 .write()
145 .await
146 .insert(key.clone(), peers.clone());
147
148 self.states
150 .write()
151 .await
152 .insert(key.clone(), MSIState::Modified);
153
154 Ok(messages)
155 }
156 MSIState::Invalid => {
157 let peers = self.peer_nodes.read().await;
159 for _peer in peers.iter() {
160 messages.push(CoherencyMessage::Invalidate(key.clone()));
161 }
162
163 self.pending_invalidations
164 .write()
165 .await
166 .insert(key.clone(), peers.clone());
167
168 self.states
169 .write()
170 .await
171 .insert(key.clone(), MSIState::Modified);
172
173 Ok(messages)
174 }
175 }
176 }
177
178 pub async fn handle_remote_invalidate(&self, key: &CacheKey) -> Result<CoherencyMessage> {
180 let state = self.get_state(key).await;
181
182 match state {
183 MSIState::Modified => {
184 self.states
186 .write()
187 .await
188 .insert(key.clone(), MSIState::Invalid);
189 Ok(CoherencyMessage::WriteBack(key.clone()))
190 }
191 MSIState::Shared => {
192 self.states
194 .write()
195 .await
196 .insert(key.clone(), MSIState::Invalid);
197 Ok(CoherencyMessage::InvalidateAck(key.clone()))
198 }
199 MSIState::Invalid => {
200 Ok(CoherencyMessage::InvalidateAck(key.clone()))
202 }
203 }
204 }
205
206 pub async fn handle_invalidate_ack(&self, key: &CacheKey, from_node: &str) {
208 let mut pending = self.pending_invalidations.write().await;
209 if let Some(waiting) = pending.get_mut(key) {
210 waiting.remove(from_node);
211 if waiting.is_empty() {
212 pending.remove(key);
213 }
214 }
215 }
216
217 pub async fn invalidations_complete(&self, key: &CacheKey) -> bool {
219 let pending = self.pending_invalidations.read().await;
220 !pending.contains_key(key)
221 }
222
223 pub async fn evict(&self, key: &CacheKey) -> Result<Option<CoherencyMessage>> {
225 let state = self.get_state(key).await;
226
227 match state {
228 MSIState::Modified => {
229 self.states.write().await.remove(key);
231 Ok(Some(CoherencyMessage::WriteBack(key.clone())))
232 }
233 MSIState::Shared | MSIState::Invalid => {
234 self.states.write().await.remove(key);
236 Ok(None)
237 }
238 }
239 }
240}
241
242pub struct MESIProtocol {
244 states: Arc<RwLock<HashMap<CacheKey, MESIState>>>,
246 #[allow(dead_code)]
248 node_id: String,
249 peer_nodes: Arc<RwLock<HashSet<String>>>,
251 pending_invalidations: Arc<RwLock<HashMap<CacheKey, HashSet<String>>>>,
253}
254
255impl MESIProtocol {
256 pub fn new(node_id: String) -> Self {
258 Self {
259 states: Arc::new(RwLock::new(HashMap::new())),
260 node_id,
261 peer_nodes: Arc::new(RwLock::new(HashSet::new())),
262 pending_invalidations: Arc::new(RwLock::new(HashMap::new())),
263 }
264 }
265
266 pub async fn add_peer(&self, peer_id: String) {
268 self.peer_nodes.write().await.insert(peer_id);
269 }
270
271 pub async fn get_state(&self, key: &CacheKey) -> MESIState {
273 self.states
274 .read()
275 .await
276 .get(key)
277 .copied()
278 .unwrap_or(MESIState::Invalid)
279 }
280
281 pub async fn handle_read(
283 &self,
284 key: &CacheKey,
285 has_other_copy: bool,
286 ) -> Result<Vec<CoherencyMessage>> {
287 let state = self.get_state(key).await;
288 let mut messages = Vec::new();
289
290 match state {
291 MESIState::Modified | MESIState::Exclusive | MESIState::Shared => {
292 Ok(messages)
294 }
295 MESIState::Invalid => {
296 messages.push(CoherencyMessage::Read(key.clone()));
297
298 let new_state = if has_other_copy {
300 MESIState::Shared
301 } else {
302 MESIState::Exclusive
303 };
304
305 self.states.write().await.insert(key.clone(), new_state);
306 Ok(messages)
307 }
308 }
309 }
310
311 pub async fn handle_write(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
313 let state = self.get_state(key).await;
314 let mut messages = Vec::new();
315
316 match state {
317 MESIState::Modified => {
318 Ok(messages)
320 }
321 MESIState::Exclusive => {
322 self.states
324 .write()
325 .await
326 .insert(key.clone(), MESIState::Modified);
327 Ok(messages)
328 }
329 MESIState::Shared | MESIState::Invalid => {
330 let peers = self.peer_nodes.read().await;
332 for _peer in peers.iter() {
333 messages.push(CoherencyMessage::Invalidate(key.clone()));
334 }
335
336 self.pending_invalidations
337 .write()
338 .await
339 .insert(key.clone(), peers.clone());
340
341 self.states
342 .write()
343 .await
344 .insert(key.clone(), MESIState::Modified);
345
346 Ok(messages)
347 }
348 }
349 }
350
351 pub async fn handle_remote_read(&self, key: &CacheKey) -> Result<CoherencyMessage> {
353 let state = self.get_state(key).await;
354
355 match state {
356 MESIState::Modified => {
357 self.states
359 .write()
360 .await
361 .insert(key.clone(), MESIState::Shared);
362 Ok(CoherencyMessage::Shared(key.clone()))
363 }
364 MESIState::Exclusive => {
365 self.states
367 .write()
368 .await
369 .insert(key.clone(), MESIState::Shared);
370 Ok(CoherencyMessage::Shared(key.clone()))
371 }
372 MESIState::Shared => {
373 Ok(CoherencyMessage::Shared(key.clone()))
375 }
376 MESIState::Invalid => {
377 Ok(CoherencyMessage::InvalidateAck(key.clone()))
379 }
380 }
381 }
382
383 pub async fn evict(&self, key: &CacheKey) -> Result<Option<CoherencyMessage>> {
385 let state = self.get_state(key).await;
386
387 match state {
388 MESIState::Modified => {
389 self.states.write().await.remove(key);
390 Ok(Some(CoherencyMessage::WriteBack(key.clone())))
391 }
392 _ => {
393 self.states.write().await.remove(key);
394 Ok(None)
395 }
396 }
397 }
398}
399
400pub struct DirectoryCoherency {
402 directory: Arc<RwLock<HashMap<CacheKey, HashSet<String>>>>,
404 modified_by: Arc<RwLock<HashMap<CacheKey, String>>>,
406 node_id: String,
408}
409
410impl DirectoryCoherency {
411 pub fn new(node_id: String) -> Self {
413 Self {
414 directory: Arc::new(RwLock::new(HashMap::new())),
415 modified_by: Arc::new(RwLock::new(HashMap::new())),
416 node_id,
417 }
418 }
419
420 pub async fn handle_read(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
422 let mut dir = self.directory.write().await;
423 let modified = self.modified_by.read().await;
424
425 let mut messages = Vec::new();
426
427 if let Some(_modifier) = modified.get(key) {
428 messages.push(CoherencyMessage::Read(key.clone()));
430 }
431
432 dir.entry(key.clone())
434 .or_insert_with(HashSet::new)
435 .insert(self.node_id.clone());
436
437 Ok(messages)
438 }
439
440 pub async fn handle_write(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
442 let mut dir = self.directory.write().await;
443 let mut modified = self.modified_by.write().await;
444
445 let mut messages = Vec::new();
446
447 if let Some(sharers) = dir.get(key) {
449 for sharer in sharers.iter() {
450 if sharer != &self.node_id {
451 messages.push(CoherencyMessage::Invalidate(key.clone()));
452 }
453 }
454 }
455
456 modified.insert(key.clone(), self.node_id.clone());
458
459 dir.insert(key.clone(), {
461 let mut set = HashSet::new();
462 set.insert(self.node_id.clone());
463 set
464 });
465
466 Ok(messages)
467 }
468
469 pub async fn handle_invalidate_ack(&self, key: &CacheKey, from_node: &str) {
471 let mut dir = self.directory.write().await;
472 if let Some(sharers) = dir.get_mut(key) {
473 sharers.remove(from_node);
474 }
475 }
476
477 pub async fn get_sharers(&self, key: &CacheKey) -> HashSet<String> {
479 self.directory
480 .read()
481 .await
482 .get(key)
483 .cloned()
484 .unwrap_or_default()
485 }
486}
487
488pub struct InvalidationBatcher {
490 pending: Arc<RwLock<HashMap<String, HashSet<CacheKey>>>>,
492 batch_size: usize,
494}
495
496impl InvalidationBatcher {
497 pub fn new(batch_size: usize) -> Self {
499 Self {
500 pending: Arc::new(RwLock::new(HashMap::new())),
501 batch_size,
502 }
503 }
504
505 pub async fn add_invalidation(&self, node: String, key: CacheKey) -> Option<Vec<CacheKey>> {
507 let mut pending = self.pending.write().await;
508 let keys = pending.entry(node.clone()).or_insert_with(HashSet::new);
509
510 keys.insert(key);
511
512 if keys.len() >= self.batch_size {
514 let batch: Vec<CacheKey> = keys.iter().cloned().collect();
515 keys.clear();
516 Some(batch)
517 } else {
518 None
519 }
520 }
521
522 pub async fn flush(&self) -> HashMap<String, Vec<CacheKey>> {
524 let mut pending = self.pending.write().await;
525 let result: HashMap<String, Vec<CacheKey>> = pending
526 .iter()
527 .map(|(node, keys)| (node.clone(), keys.iter().cloned().collect()))
528 .collect();
529
530 pending.clear();
531 result
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538
539 #[tokio::test]
540 async fn test_msi_protocol() {
541 let protocol = MSIProtocol::new("node1".to_string());
542 protocol.add_peer("node2".to_string()).await;
543
544 let key = "test_key".to_string();
545
546 let messages = protocol.handle_read(&key).await.unwrap_or_default();
548 assert_eq!(messages.len(), 1);
549 assert_eq!(protocol.get_state(&key).await, MSIState::Shared);
550
551 let messages = protocol.handle_write(&key).await.unwrap_or_default();
553 assert!(!messages.is_empty());
554 assert_eq!(protocol.get_state(&key).await, MSIState::Modified);
555 }
556
557 #[tokio::test]
558 async fn test_mesi_protocol() {
559 let protocol = MESIProtocol::new("node1".to_string());
560 protocol.add_peer("node2".to_string()).await;
561
562 let key = "test_key".to_string();
563
564 let _messages = protocol.handle_read(&key, false).await.unwrap_or_default();
566 assert_eq!(protocol.get_state(&key).await, MESIState::Exclusive);
567
568 let _messages = protocol.handle_write(&key).await.unwrap_or_default();
570 assert_eq!(protocol.get_state(&key).await, MESIState::Modified);
571 }
572
573 #[tokio::test]
574 async fn test_directory_coherency() {
575 let dir = DirectoryCoherency::new("node1".to_string());
576 let key = "test_key".to_string();
577
578 let _messages = dir.handle_read(&key).await.unwrap_or_default();
579 let sharers = dir.get_sharers(&key).await;
580 assert!(sharers.contains("node1"));
581
582 let messages = dir.handle_write(&key).await.unwrap_or_default();
583 assert!(messages.is_empty()); }
585
586 #[tokio::test]
587 async fn test_invalidation_batcher() {
588 let batcher = InvalidationBatcher::new(3);
589
590 let result = batcher
592 .add_invalidation("node1".to_string(), "key1".to_string())
593 .await;
594 assert!(result.is_none());
595
596 let result = batcher
597 .add_invalidation("node1".to_string(), "key2".to_string())
598 .await;
599 assert!(result.is_none());
600
601 let result = batcher
603 .add_invalidation("node1".to_string(), "key3".to_string())
604 .await;
605 assert!(result.is_some());
606 let batch = result.unwrap_or_default();
607 assert_eq!(batch.len(), 3);
608 }
609}