use crate::error::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::RwLock;
use uuid::Uuid;
type LocalCacheData = Arc<RwLock<HashMap<String, (Vec<u8>, SystemTime)>>>;
#[derive(Debug)]
pub struct DistributedCacheCoordinator {
node_id: String,
nodes: Arc<RwLock<HashMap<String, CacheNode>>>,
replication_factor: u32,
consistency_level: ConsistencyLevel,
#[allow(dead_code)]
sync_interval: Duration,
}
#[derive(Debug, Clone)]
pub struct CacheNode {
pub id: String,
pub address: String,
pub status: NodeStatus,
pub last_heartbeat: SystemTime,
pub cache_size_mb: usize,
pub load_factor: f32,
}
#[derive(Debug, Clone, PartialEq)]
pub enum NodeStatus {
Online,
Offline,
Degraded,
Synchronizing,
}
#[derive(Debug, Clone)]
pub enum ConsistencyLevel {
Eventual,
Strong,
Quorum,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedCacheEntry<T> {
pub key: String,
pub value: T,
pub version: u64,
pub created_at: SystemTime,
pub updated_at: SystemTime,
pub replicas: Vec<String>, pub checksum: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SyncEvent {
NodeJoined(String),
NodeLeft(String),
EntryUpdated(String, u64), EntryDeleted(String),
FullSync,
}
#[derive(Debug, Clone)]
pub struct RedistributionTask {
pub entry_key: String,
pub source_node: String,
pub target_node: String,
pub priority: RedistributionPriority,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RedistributionPriority {
Critical,
High,
Normal,
Low,
}
impl DistributedCacheCoordinator {
pub fn new(
replication_factor: u32,
consistency_level: ConsistencyLevel,
sync_interval: Duration,
) -> Self {
Self {
node_id: Uuid::new_v4().to_string(),
nodes: Arc::new(RwLock::new(HashMap::new())),
replication_factor,
consistency_level,
sync_interval,
}
}
pub async fn add_node(&self, node: CacheNode) -> Result<()> {
let mut nodes = self.nodes.write().await;
nodes.insert(node.id.clone(), node);
Ok(())
}
pub async fn remove_node(&self, node_id: &str) -> Result<()> {
let mut nodes = self.nodes.write().await;
if let Some(removed_node) = nodes.remove(node_id) {
let active_nodes: Vec<CacheNode> = nodes
.values()
.filter(|node| node.status == NodeStatus::Online)
.cloned()
.collect();
drop(nodes);
self.redistribute_node_data(node_id, &removed_node, &active_nodes)
.await?;
tracing::info!(
"Successfully removed node {} and redistributed its data to {} active nodes",
node_id,
active_nodes.len()
);
} else {
tracing::warn!("Attempted to remove non-existent node: {}", node_id);
}
Ok(())
}
pub fn get_node_id(&self) -> &str {
&self.node_id
}
pub async fn get_active_nodes(&self) -> Vec<CacheNode> {
let nodes = self.nodes.read().await;
nodes
.values()
.filter(|node| node.status == NodeStatus::Online)
.cloned()
.collect()
}
pub async fn select_replica_nodes(&self, exclude_node: Option<&str>) -> Vec<String> {
let nodes = self.get_active_nodes().await;
let mut candidates: Vec<_> = nodes
.into_iter()
.filter(|node| Some(node.id.as_str()) != exclude_node)
.collect();
candidates.sort_by(|a, b| {
a.load_factor
.partial_cmp(&b.load_factor)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates
.into_iter()
.take(self.replication_factor as usize)
.map(|node| node.id)
.collect()
}
pub fn calculate_hash(key: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
pub async fn find_primary_node(&self, key: &str) -> Option<String> {
let nodes = self.get_active_nodes().await;
if nodes.is_empty() {
return None;
}
let hash = Self::calculate_hash(key);
let node_index = (hash as usize) % nodes.len();
Some(nodes[node_index].id.clone())
}
pub async fn sync_entry<T>(&self, entry: &DistributedCacheEntry<T>) -> Result<()>
where
T: Clone + Serialize + for<'de> Deserialize<'de>,
{
match self.consistency_level {
ConsistencyLevel::Eventual => {
self.sync_entry_async(entry).await
}
ConsistencyLevel::Strong => {
self.sync_entry_strong(entry).await
}
ConsistencyLevel::Quorum => {
self.sync_entry_quorum(entry).await
}
}
}
async fn sync_entry_async<T>(&self, _entry: &DistributedCacheEntry<T>) -> Result<()>
where
T: Clone + Serialize + for<'de> Deserialize<'de>,
{
Ok(())
}
async fn sync_entry_strong<T>(&self, _entry: &DistributedCacheEntry<T>) -> Result<()>
where
T: Clone + Serialize + for<'de> Deserialize<'de>,
{
Ok(())
}
async fn sync_entry_quorum<T>(&self, _entry: &DistributedCacheEntry<T>) -> Result<()>
where
T: Clone + Serialize + for<'de> Deserialize<'de>,
{
Ok(())
}
pub async fn handle_heartbeat(&self, node_id: &str) -> Result<()> {
let mut nodes = self.nodes.write().await;
if let Some(node) = nodes.get_mut(node_id) {
node.last_heartbeat = SystemTime::now();
if node.status != NodeStatus::Online {
node.status = NodeStatus::Online;
}
}
Ok(())
}
pub async fn check_node_health(&self) -> Result<Vec<String>> {
let mut failed_nodes = Vec::new();
let mut nodes = self.nodes.write().await;
let now = SystemTime::now();
for (node_id, node) in nodes.iter_mut() {
if let Ok(elapsed) = now.duration_since(node.last_heartbeat) {
if elapsed > Duration::from_secs(60) && node.status == NodeStatus::Online {
node.status = NodeStatus::Offline;
failed_nodes.push(node_id.clone());
}
}
}
Ok(failed_nodes)
}
async fn redistribute_node_data(
&self,
removed_node_id: &str,
_removed_node: &CacheNode,
active_nodes: &[CacheNode],
) -> Result<()> {
if active_nodes.is_empty() {
tracing::warn!("No active nodes available for data redistribution");
return Ok(());
}
tracing::info!(
"Starting data redistribution from removed node {} to {} active nodes",
removed_node_id,
active_nodes.len()
);
let redistribution_tasks = self
.plan_redistribution(removed_node_id, active_nodes)
.await?;
for task in redistribution_tasks {
match self.execute_redistribution_task(task).await {
Ok(_) => {
tracing::debug!("Successfully completed redistribution task");
}
Err(e) => {
tracing::error!("Failed to execute redistribution task: {}", e);
}
}
}
tracing::info!(
"Data redistribution from node {} completed",
removed_node_id
);
Ok(())
}
async fn plan_redistribution(
&self,
removed_node_id: &str,
active_nodes: &[CacheNode],
) -> Result<Vec<RedistributionTask>> {
let mut tasks = Vec::new();
let simulated_entries = self
.get_entries_requiring_redistribution(removed_node_id)
.await;
for entry_key in simulated_entries {
let target_nodes = self
.select_redistribution_targets(&entry_key, active_nodes)
.await;
for target_node_id in target_nodes {
tasks.push(RedistributionTask {
entry_key: entry_key.clone(),
source_node: removed_node_id.to_string(),
target_node: target_node_id,
priority: RedistributionPriority::Normal,
});
}
}
tracing::debug!("Planned {} redistribution tasks", tasks.len());
Ok(tasks)
}
async fn execute_redistribution_task(&self, task: RedistributionTask) -> Result<()> {
tracing::debug!(
"Executing redistribution: {} from {} to {}",
task.entry_key,
task.source_node,
task.target_node
);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Ok(())
}
async fn get_entries_requiring_redistribution(&self, _removed_node_id: &str) -> Vec<String> {
vec![
"user_session_123".to_string(),
"model_weights_abc".to_string(),
"audio_cache_456".to_string(),
]
}
async fn select_redistribution_targets(
&self,
entry_key: &str,
active_nodes: &[CacheNode],
) -> Vec<String> {
if active_nodes.is_empty() {
return Vec::new();
}
let current_replicas = self.count_existing_replicas(entry_key).await;
let needed_replicas = (self.replication_factor as usize).saturating_sub(current_replicas);
if needed_replicas == 0 {
return Vec::new();
}
let mut sorted_nodes = active_nodes.to_vec();
sorted_nodes.sort_by(|a, b| {
a.load_factor
.partial_cmp(&b.load_factor)
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted_nodes
.into_iter()
.take(needed_replicas.min(active_nodes.len()))
.map(|node| node.id)
.collect()
}
async fn count_existing_replicas(&self, _entry_key: &str) -> usize {
1 }
pub async fn get_cluster_stats(&self) -> DistributedCacheStats {
let nodes = self.nodes.read().await;
let total_nodes = nodes.len();
let online_nodes = nodes
.values()
.filter(|n| n.status == NodeStatus::Online)
.count();
let total_cache_mb = nodes.values().map(|n| n.cache_size_mb).sum();
let average_load = if !nodes.is_empty() {
nodes.values().map(|n| n.load_factor).sum::<f32>() / nodes.len() as f32
} else {
0.0
};
DistributedCacheStats {
total_nodes,
online_nodes,
offline_nodes: total_nodes - online_nodes,
total_cache_size_mb: total_cache_mb,
average_load_factor: average_load,
replication_factor: self.replication_factor,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedCacheStats {
pub total_nodes: usize,
pub online_nodes: usize,
pub offline_nodes: usize,
pub total_cache_size_mb: usize,
pub average_load_factor: f32,
pub replication_factor: u32,
}
#[async_trait::async_trait]
pub trait DistributedCacheClient: Send + Sync {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>>;
async fn put(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> Result<()>;
async fn delete(&self, key: &str) -> Result<bool>;
async fn exists(&self, key: &str) -> Result<bool>;
async fn stats(&self) -> Result<DistributedCacheStats>;
}
pub struct InMemoryDistributedCache {
coordinator: Arc<DistributedCacheCoordinator>,
local_cache: LocalCacheData,
}
impl InMemoryDistributedCache {
pub fn new(coordinator: Arc<DistributedCacheCoordinator>) -> Self {
Self {
coordinator,
local_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[async_trait::async_trait]
impl DistributedCacheClient for InMemoryDistributedCache {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
let cache = self.local_cache.read().await;
Ok(cache.get(key).map(|(value, _)| value.clone()))
}
async fn put(&self, key: &str, value: Vec<u8>, _ttl: Option<Duration>) -> Result<()> {
let mut cache = self.local_cache.write().await;
cache.insert(key.to_string(), (value, SystemTime::now()));
Ok(())
}
async fn delete(&self, key: &str) -> Result<bool> {
let mut cache = self.local_cache.write().await;
Ok(cache.remove(key).is_some())
}
async fn exists(&self, key: &str) -> Result<bool> {
let cache = self.local_cache.read().await;
Ok(cache.contains_key(key))
}
async fn stats(&self) -> Result<DistributedCacheStats> {
Ok(self.coordinator.get_cluster_stats().await)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_distributed_cache_coordinator() {
let coordinator =
DistributedCacheCoordinator::new(3, ConsistencyLevel::Quorum, Duration::from_secs(30));
assert!(!coordinator.get_node_id().is_empty());
assert!(coordinator.get_active_nodes().await.is_empty());
}
#[tokio::test]
async fn test_add_remove_nodes() {
let coordinator = DistributedCacheCoordinator::new(
2,
ConsistencyLevel::Eventual,
Duration::from_secs(30),
);
let node = CacheNode {
id: "node1".to_string(),
address: "127.0.0.1:8080".to_string(),
status: NodeStatus::Online,
last_heartbeat: SystemTime::now(),
cache_size_mb: 1024,
load_factor: 0.5,
};
coordinator.add_node(node).await.unwrap();
assert_eq!(coordinator.get_active_nodes().await.len(), 1);
coordinator.remove_node("node1").await.unwrap();
assert_eq!(coordinator.get_active_nodes().await.len(), 0);
}
#[tokio::test]
async fn test_consistent_hashing() {
let coordinator =
DistributedCacheCoordinator::new(2, ConsistencyLevel::Quorum, Duration::from_secs(30));
for i in 1..=3 {
let node = CacheNode {
id: format!("node{i}"),
address: format!("127.0.0.1:808{i}"),
status: NodeStatus::Online,
last_heartbeat: SystemTime::now(),
cache_size_mb: 1024,
load_factor: 0.3,
};
coordinator.add_node(node).await.unwrap();
}
let key = "test_key";
let primary = coordinator.find_primary_node(key).await;
assert!(primary.is_some());
let primary2 = coordinator.find_primary_node(key).await;
assert_eq!(primary, primary2);
}
#[tokio::test]
async fn test_in_memory_distributed_cache() {
let coordinator = Arc::new(DistributedCacheCoordinator::new(
2,
ConsistencyLevel::Eventual,
Duration::from_secs(30),
));
let cache = InMemoryDistributedCache::new(coordinator);
assert!(!cache.exists("key1").await.unwrap());
cache.put("key1", b"value1".to_vec(), None).await.unwrap();
assert!(cache.exists("key1").await.unwrap());
let value = cache.get("key1").await.unwrap();
assert_eq!(value, Some(b"value1".to_vec()));
assert!(cache.delete("key1").await.unwrap());
assert!(!cache.exists("key1").await.unwrap());
}
#[test]
fn test_hash_consistency() {
let key = "test_key";
let hash1 = DistributedCacheCoordinator::calculate_hash(key);
let hash2 = DistributedCacheCoordinator::calculate_hash(key);
assert_eq!(hash1, hash2);
let hash3 = DistributedCacheCoordinator::calculate_hash("different_key");
assert_ne!(hash1, hash3);
}
}