use crate::CacheStats;
use crate::error::Result;
use crate::multi_tier::{CacheKey, CacheValue};
use async_trait::async_trait;
use dashmap::DashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Node {
pub id: String,
pub address: String,
pub weight: usize,
}
impl Hash for Node {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
}
}
pub struct ConsistentHashRing {
ring: Vec<(u64, Node)>,
virtual_nodes: usize,
}
impl ConsistentHashRing {
pub fn new(virtual_nodes: usize) -> Self {
Self {
ring: Vec::new(),
virtual_nodes,
}
}
pub fn add_node(&mut self, node: Node) {
for i in 0..self.virtual_nodes {
let virtual_key = format!("{}:{}", node.id, i);
let hash = self.hash_key(&virtual_key);
self.ring.push((hash, node.clone()));
}
self.ring.sort_by_key(|(hash, _)| *hash);
}
pub fn remove_node(&mut self, node_id: &str) {
self.ring.retain(|(_, node)| node.id != node_id);
}
pub fn get_node(&self, key: &CacheKey) -> Option<&Node> {
if self.ring.is_empty() {
return None;
}
let hash = self.hash_key(key);
let idx = self.ring.partition_point(|(h, _)| *h < hash);
let node_idx = if idx < self.ring.len() { idx } else { 0 };
self.ring.get(node_idx).map(|(_, node)| node)
}
pub fn get_nodes(&self, key: &CacheKey, n: usize) -> Vec<&Node> {
if self.ring.is_empty() {
return Vec::new();
}
let hash = self.hash_key(key);
let start_idx = self.ring.partition_point(|(h, _)| *h < hash);
let mut nodes = Vec::new();
let mut seen = std::collections::HashSet::new();
for i in 0..self.ring.len() {
let idx = (start_idx + i) % self.ring.len();
let (_, node) = &self.ring[idx];
if !seen.contains(&node.id) {
nodes.push(node);
seen.insert(node.id.clone());
if nodes.len() >= n {
break;
}
}
}
nodes
}
fn hash_key(&self, key: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
pub fn nodes(&self) -> Vec<Node> {
let mut seen = std::collections::HashSet::new();
let mut nodes = Vec::new();
for (_, node) in &self.ring {
if !seen.contains(&node.id) {
nodes.push(node.clone());
seen.insert(node.id.clone());
}
}
nodes
}
pub fn size(&self) -> usize {
self.ring.len()
}
}
pub struct DistributedCache {
local: Arc<DashMap<CacheKey, CacheValue>>,
ring: Arc<RwLock<ConsistentHashRing>>,
local_node: Node,
replication_factor: usize,
hot_key_threshold: u64,
stats: Arc<RwLock<CacheStats>>,
}
impl DistributedCache {
pub fn new(local_node: Node, replication_factor: usize) -> Self {
let mut ring = ConsistentHashRing::new(150); ring.add_node(local_node.clone());
Self {
local: Arc::new(DashMap::new()),
ring: Arc::new(RwLock::new(ring)),
local_node,
replication_factor,
hot_key_threshold: 100,
stats: Arc::new(RwLock::new(CacheStats::new())),
}
}
pub async fn add_peer(&self, node: Node) {
let mut ring = self.ring.write().await;
ring.add_node(node);
}
pub async fn remove_peer(&self, node_id: &str) {
let mut ring = self.ring.write().await;
ring.remove_node(node_id);
}
pub async fn get(&self, key: &CacheKey) -> Result<Option<CacheValue>> {
let ring = self.ring.read().await;
if let Some(node) = ring.get_node(key) {
if node.id == self.local_node.id {
if let Some(mut value) = self.local.get_mut(key) {
value.record_access();
let mut stats = self.stats.write().await;
stats.hits += 1;
return Ok(Some(value.clone()));
} else {
let mut stats = self.stats.write().await;
stats.misses += 1;
return Ok(None);
}
} else {
let mut stats = self.stats.write().await;
stats.misses += 1;
return Ok(None);
}
}
Ok(None)
}
pub async fn put(&self, key: CacheKey, value: CacheValue) -> Result<()> {
let ring = self.ring.read().await;
let nodes = ring.get_nodes(&key, self.replication_factor);
let should_store_locally = nodes.iter().any(|n| n.id == self.local_node.id);
if should_store_locally {
self.local.insert(key.clone(), value.clone());
let mut stats = self.stats.write().await;
stats.bytes_stored += value.size as u64;
stats.item_count += 1;
}
Ok(())
}
pub async fn remove(&self, key: &CacheKey) -> Result<bool> {
let removed = self.local.remove(key);
if let Some((_, value)) = removed {
let mut stats = self.stats.write().await;
stats.bytes_stored = stats.bytes_stored.saturating_sub(value.size as u64);
stats.item_count = stats.item_count.saturating_sub(1);
Ok(true)
} else {
Ok(false)
}
}
pub fn is_hot_key(&self, key: &CacheKey) -> bool {
if let Some(value) = self.local.get(key) {
value.access_count >= self.hot_key_threshold
} else {
false
}
}
pub async fn stats(&self) -> CacheStats {
self.stats.read().await.clone()
}
pub async fn peers(&self) -> Vec<Node> {
let ring = self.ring.read().await;
ring.nodes()
}
pub async fn rebalance(&self) -> Result<()> {
let ring = self.ring.read().await;
let mut keys_to_remove = Vec::new();
for entry in self.local.iter() {
let key = entry.key();
let nodes = ring.get_nodes(key, self.replication_factor);
if !nodes.iter().any(|n| n.id == self.local_node.id) {
keys_to_remove.push(key.clone());
}
}
drop(ring);
for key in keys_to_remove {
self.remove(&key).await?;
}
Ok(())
}
pub async fn clear(&self) -> Result<()> {
self.local.clear();
let mut stats = self.stats.write().await;
*stats = CacheStats::new();
Ok(())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CacheMetadata {
pub version: u64,
pub owner: String,
pub replicas: Vec<String>,
pub last_modified: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum CacheOperation {
Put {
key: CacheKey,
value: Vec<u8>,
metadata: CacheMetadata,
},
Delete {
key: CacheKey,
version: u64,
},
Invalidate {
key: CacheKey,
},
}
#[async_trait]
pub trait DistributedProtocol: Send + Sync {
async fn broadcast(&self, operation: CacheOperation) -> Result<()>;
async fn handle_operation(&self, operation: CacheOperation) -> Result<()>;
async fn sync_with_peer(&self, peer_id: &str) -> Result<()>;
}
#[async_trait]
pub trait PeerDiscovery: Send + Sync {
async fn discover(&self) -> Result<Vec<Node>>;
async fn register(&self, node: Node) -> Result<()>;
async fn unregister(&self, node_id: &str) -> Result<()>;
async fn health_check(&self, node_id: &str) -> Result<bool>;
}
pub struct StaticPeerDiscovery {
peers: Vec<Node>,
}
impl StaticPeerDiscovery {
pub fn new(peers: Vec<Node>) -> Self {
Self { peers }
}
}
#[async_trait]
impl PeerDiscovery for StaticPeerDiscovery {
async fn discover(&self) -> Result<Vec<Node>> {
Ok(self.peers.clone())
}
async fn register(&self, _node: Node) -> Result<()> {
Ok(())
}
async fn unregister(&self, _node_id: &str) -> Result<()> {
Ok(())
}
async fn health_check(&self, _node_id: &str) -> Result<bool> {
Ok(true)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[test]
fn test_consistent_hash_ring() {
let mut ring = ConsistentHashRing::new(150);
let node1 = Node {
id: "node1".to_string(),
address: "127.0.0.1:8001".to_string(),
weight: 1,
};
let node2 = Node {
id: "node2".to_string(),
address: "127.0.0.1:8002".to_string(),
weight: 1,
};
ring.add_node(node1.clone());
ring.add_node(node2.clone());
assert_eq!(ring.size(), 300);
let key = "test_key".to_string();
let node = ring.get_node(&key);
assert!(node.is_some());
}
#[test]
fn test_replication_nodes() {
let mut ring = ConsistentHashRing::new(150);
for i in 0..5 {
ring.add_node(Node {
id: format!("node{}", i),
address: format!("127.0.0.1:800{}", i),
weight: 1,
});
}
let key = "test_key".to_string();
let nodes = ring.get_nodes(&key, 3);
assert_eq!(nodes.len(), 3);
let unique_ids: std::collections::HashSet<_> = nodes.iter().map(|n| &n.id).collect();
assert_eq!(unique_ids.len(), 3);
}
#[tokio::test]
async fn test_distributed_cache() {
let node = Node {
id: "test_node".to_string(),
address: "127.0.0.1:8000".to_string(),
weight: 1,
};
let cache = DistributedCache::new(node, 2);
let key = "test_key".to_string();
let value = CacheValue::new(
Bytes::from("test data"),
crate::compression::DataType::Binary,
);
cache
.put(key.clone(), value.clone())
.await
.expect("put failed");
let retrieved = cache.get(&key).await.expect("get failed");
assert!(retrieved.is_some());
}
#[tokio::test]
async fn test_cache_rebalance() {
let node1 = Node {
id: "node1".to_string(),
address: "127.0.0.1:8001".to_string(),
weight: 1,
};
let cache = DistributedCache::new(node1.clone(), 2);
for i in 0..10 {
let key = format!("key{}", i);
let value = CacheValue::new(
Bytes::from(format!("value{}", i)),
crate::compression::DataType::Binary,
);
cache.put(key, value).await.expect("put failed");
}
let node2 = Node {
id: "node2".to_string(),
address: "127.0.0.1:8002".to_string(),
weight: 1,
};
cache.add_peer(node2).await;
cache.rebalance().await.expect("rebalance failed");
let stats = cache.stats().await;
assert!(stats.item_count <= 10);
}
}