use crate::error::Result;
use crate::multi_tier::CacheKey;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MSIState {
Modified,
Shared,
Invalid,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MESIState {
Modified,
Exclusive,
Shared,
Invalid,
}
#[derive(Debug, Clone)]
pub enum CoherencyMessage {
Read(CacheKey),
Write(CacheKey),
Invalidate(CacheKey),
InvalidateAck(CacheKey),
WriteBack(CacheKey),
Shared(CacheKey),
}
pub struct MSIProtocol {
states: Arc<RwLock<HashMap<CacheKey, MSIState>>>,
#[allow(dead_code)]
node_id: String,
peer_nodes: Arc<RwLock<HashSet<String>>>,
pending_invalidations: Arc<RwLock<HashMap<CacheKey, HashSet<String>>>>,
}
impl MSIProtocol {
pub fn new(node_id: String) -> Self {
Self {
states: Arc::new(RwLock::new(HashMap::new())),
node_id,
peer_nodes: Arc::new(RwLock::new(HashSet::new())),
pending_invalidations: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_peer(&self, peer_id: String) {
self.peer_nodes.write().await.insert(peer_id);
}
pub async fn remove_peer(&self, peer_id: &str) {
self.peer_nodes.write().await.remove(peer_id);
}
pub async fn get_state(&self, key: &CacheKey) -> MSIState {
self.states
.read()
.await
.get(key)
.copied()
.unwrap_or(MSIState::Invalid)
}
pub async fn handle_read(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
let state = self.get_state(key).await;
let mut messages = Vec::new();
match state {
MSIState::Modified | MSIState::Shared => {
Ok(messages)
}
MSIState::Invalid => {
messages.push(CoherencyMessage::Read(key.clone()));
self.states
.write()
.await
.insert(key.clone(), MSIState::Shared);
Ok(messages)
}
}
}
pub async fn handle_write(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
let state = self.get_state(key).await;
let mut messages = Vec::new();
match state {
MSIState::Modified => {
Ok(messages)
}
MSIState::Shared => {
let peers = self.peer_nodes.read().await;
for _peer in peers.iter() {
messages.push(CoherencyMessage::Invalidate(key.clone()));
}
self.pending_invalidations
.write()
.await
.insert(key.clone(), peers.clone());
self.states
.write()
.await
.insert(key.clone(), MSIState::Modified);
Ok(messages)
}
MSIState::Invalid => {
let peers = self.peer_nodes.read().await;
for _peer in peers.iter() {
messages.push(CoherencyMessage::Invalidate(key.clone()));
}
self.pending_invalidations
.write()
.await
.insert(key.clone(), peers.clone());
self.states
.write()
.await
.insert(key.clone(), MSIState::Modified);
Ok(messages)
}
}
}
pub async fn handle_remote_invalidate(&self, key: &CacheKey) -> Result<CoherencyMessage> {
let state = self.get_state(key).await;
match state {
MSIState::Modified => {
self.states
.write()
.await
.insert(key.clone(), MSIState::Invalid);
Ok(CoherencyMessage::WriteBack(key.clone()))
}
MSIState::Shared => {
self.states
.write()
.await
.insert(key.clone(), MSIState::Invalid);
Ok(CoherencyMessage::InvalidateAck(key.clone()))
}
MSIState::Invalid => {
Ok(CoherencyMessage::InvalidateAck(key.clone()))
}
}
}
pub async fn handle_invalidate_ack(&self, key: &CacheKey, from_node: &str) {
let mut pending = self.pending_invalidations.write().await;
if let Some(waiting) = pending.get_mut(key) {
waiting.remove(from_node);
if waiting.is_empty() {
pending.remove(key);
}
}
}
pub async fn invalidations_complete(&self, key: &CacheKey) -> bool {
let pending = self.pending_invalidations.read().await;
!pending.contains_key(key)
}
pub async fn evict(&self, key: &CacheKey) -> Result<Option<CoherencyMessage>> {
let state = self.get_state(key).await;
match state {
MSIState::Modified => {
self.states.write().await.remove(key);
Ok(Some(CoherencyMessage::WriteBack(key.clone())))
}
MSIState::Shared | MSIState::Invalid => {
self.states.write().await.remove(key);
Ok(None)
}
}
}
}
pub struct MESIProtocol {
states: Arc<RwLock<HashMap<CacheKey, MESIState>>>,
#[allow(dead_code)]
node_id: String,
peer_nodes: Arc<RwLock<HashSet<String>>>,
pending_invalidations: Arc<RwLock<HashMap<CacheKey, HashSet<String>>>>,
}
impl MESIProtocol {
pub fn new(node_id: String) -> Self {
Self {
states: Arc::new(RwLock::new(HashMap::new())),
node_id,
peer_nodes: Arc::new(RwLock::new(HashSet::new())),
pending_invalidations: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_peer(&self, peer_id: String) {
self.peer_nodes.write().await.insert(peer_id);
}
pub async fn get_state(&self, key: &CacheKey) -> MESIState {
self.states
.read()
.await
.get(key)
.copied()
.unwrap_or(MESIState::Invalid)
}
pub async fn handle_read(
&self,
key: &CacheKey,
has_other_copy: bool,
) -> Result<Vec<CoherencyMessage>> {
let state = self.get_state(key).await;
let mut messages = Vec::new();
match state {
MESIState::Modified | MESIState::Exclusive | MESIState::Shared => {
Ok(messages)
}
MESIState::Invalid => {
messages.push(CoherencyMessage::Read(key.clone()));
let new_state = if has_other_copy {
MESIState::Shared
} else {
MESIState::Exclusive
};
self.states.write().await.insert(key.clone(), new_state);
Ok(messages)
}
}
}
pub async fn handle_write(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
let state = self.get_state(key).await;
let mut messages = Vec::new();
match state {
MESIState::Modified => {
Ok(messages)
}
MESIState::Exclusive => {
self.states
.write()
.await
.insert(key.clone(), MESIState::Modified);
Ok(messages)
}
MESIState::Shared | MESIState::Invalid => {
let peers = self.peer_nodes.read().await;
for _peer in peers.iter() {
messages.push(CoherencyMessage::Invalidate(key.clone()));
}
self.pending_invalidations
.write()
.await
.insert(key.clone(), peers.clone());
self.states
.write()
.await
.insert(key.clone(), MESIState::Modified);
Ok(messages)
}
}
}
pub async fn handle_remote_read(&self, key: &CacheKey) -> Result<CoherencyMessage> {
let state = self.get_state(key).await;
match state {
MESIState::Modified => {
self.states
.write()
.await
.insert(key.clone(), MESIState::Shared);
Ok(CoherencyMessage::Shared(key.clone()))
}
MESIState::Exclusive => {
self.states
.write()
.await
.insert(key.clone(), MESIState::Shared);
Ok(CoherencyMessage::Shared(key.clone()))
}
MESIState::Shared => {
Ok(CoherencyMessage::Shared(key.clone()))
}
MESIState::Invalid => {
Ok(CoherencyMessage::InvalidateAck(key.clone()))
}
}
}
pub async fn evict(&self, key: &CacheKey) -> Result<Option<CoherencyMessage>> {
let state = self.get_state(key).await;
match state {
MESIState::Modified => {
self.states.write().await.remove(key);
Ok(Some(CoherencyMessage::WriteBack(key.clone())))
}
_ => {
self.states.write().await.remove(key);
Ok(None)
}
}
}
}
pub struct DirectoryCoherency {
directory: Arc<RwLock<HashMap<CacheKey, HashSet<String>>>>,
modified_by: Arc<RwLock<HashMap<CacheKey, String>>>,
node_id: String,
}
impl DirectoryCoherency {
pub fn new(node_id: String) -> Self {
Self {
directory: Arc::new(RwLock::new(HashMap::new())),
modified_by: Arc::new(RwLock::new(HashMap::new())),
node_id,
}
}
pub async fn handle_read(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
let mut dir = self.directory.write().await;
let modified = self.modified_by.read().await;
let mut messages = Vec::new();
if let Some(_modifier) = modified.get(key) {
messages.push(CoherencyMessage::Read(key.clone()));
}
dir.entry(key.clone())
.or_insert_with(HashSet::new)
.insert(self.node_id.clone());
Ok(messages)
}
pub async fn handle_write(&self, key: &CacheKey) -> Result<Vec<CoherencyMessage>> {
let mut dir = self.directory.write().await;
let mut modified = self.modified_by.write().await;
let mut messages = Vec::new();
if let Some(sharers) = dir.get(key) {
for sharer in sharers.iter() {
if sharer != &self.node_id {
messages.push(CoherencyMessage::Invalidate(key.clone()));
}
}
}
modified.insert(key.clone(), self.node_id.clone());
dir.insert(key.clone(), {
let mut set = HashSet::new();
set.insert(self.node_id.clone());
set
});
Ok(messages)
}
pub async fn handle_invalidate_ack(&self, key: &CacheKey, from_node: &str) {
let mut dir = self.directory.write().await;
if let Some(sharers) = dir.get_mut(key) {
sharers.remove(from_node);
}
}
pub async fn get_sharers(&self, key: &CacheKey) -> HashSet<String> {
self.directory
.read()
.await
.get(key)
.cloned()
.unwrap_or_default()
}
}
pub struct InvalidationBatcher {
pending: Arc<RwLock<HashMap<String, HashSet<CacheKey>>>>,
batch_size: usize,
}
impl InvalidationBatcher {
pub fn new(batch_size: usize) -> Self {
Self {
pending: Arc::new(RwLock::new(HashMap::new())),
batch_size,
}
}
pub async fn add_invalidation(&self, node: String, key: CacheKey) -> Option<Vec<CacheKey>> {
let mut pending = self.pending.write().await;
let keys = pending.entry(node.clone()).or_insert_with(HashSet::new);
keys.insert(key);
if keys.len() >= self.batch_size {
let batch: Vec<CacheKey> = keys.iter().cloned().collect();
keys.clear();
Some(batch)
} else {
None
}
}
pub async fn flush(&self) -> HashMap<String, Vec<CacheKey>> {
let mut pending = self.pending.write().await;
let result: HashMap<String, Vec<CacheKey>> = pending
.iter()
.map(|(node, keys)| (node.clone(), keys.iter().cloned().collect()))
.collect();
pending.clear();
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_msi_protocol() {
let protocol = MSIProtocol::new("node1".to_string());
protocol.add_peer("node2".to_string()).await;
let key = "test_key".to_string();
let messages = protocol.handle_read(&key).await.unwrap_or_default();
assert_eq!(messages.len(), 1);
assert_eq!(protocol.get_state(&key).await, MSIState::Shared);
let messages = protocol.handle_write(&key).await.unwrap_or_default();
assert!(!messages.is_empty());
assert_eq!(protocol.get_state(&key).await, MSIState::Modified);
}
#[tokio::test]
async fn test_mesi_protocol() {
let protocol = MESIProtocol::new("node1".to_string());
protocol.add_peer("node2".to_string()).await;
let key = "test_key".to_string();
let _messages = protocol.handle_read(&key, false).await.unwrap_or_default();
assert_eq!(protocol.get_state(&key).await, MESIState::Exclusive);
let _messages = protocol.handle_write(&key).await.unwrap_or_default();
assert_eq!(protocol.get_state(&key).await, MESIState::Modified);
}
#[tokio::test]
async fn test_directory_coherency() {
let dir = DirectoryCoherency::new("node1".to_string());
let key = "test_key".to_string();
let _messages = dir.handle_read(&key).await.unwrap_or_default();
let sharers = dir.get_sharers(&key).await;
assert!(sharers.contains("node1"));
let messages = dir.handle_write(&key).await.unwrap_or_default();
assert!(messages.is_empty()); }
#[tokio::test]
async fn test_invalidation_batcher() {
let batcher = InvalidationBatcher::new(3);
let result = batcher
.add_invalidation("node1".to_string(), "key1".to_string())
.await;
assert!(result.is_none());
let result = batcher
.add_invalidation("node1".to_string(), "key2".to_string())
.await;
assert!(result.is_none());
let result = batcher
.add_invalidation("node1".to_string(), "key3".to_string())
.await;
assert!(result.is_some());
let batch = result.unwrap_or_default();
assert_eq!(batch.len(), 3);
}
}