1use crate::CacheStats;
11use crate::error::Result;
12use crate::multi_tier::{CacheKey, CacheValue};
13use async_trait::async_trait;
14use dashmap::DashMap;
15use std::hash::{Hash, Hasher};
16use std::sync::Arc;
17use tokio::sync::RwLock;
18
19#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct Node {
22 pub id: String,
24 pub address: String,
26 pub weight: usize,
28}
29
30impl Hash for Node {
31 fn hash<H: Hasher>(&self, state: &mut H) {
32 self.id.hash(state);
33 }
34}
35
36pub struct ConsistentHashRing {
38 ring: Vec<(u64, Node)>,
40 virtual_nodes: usize,
42}
43
44impl ConsistentHashRing {
45 pub fn new(virtual_nodes: usize) -> Self {
47 Self {
48 ring: Vec::new(),
49 virtual_nodes,
50 }
51 }
52
53 pub fn add_node(&mut self, node: Node) {
55 for i in 0..self.virtual_nodes {
56 let virtual_key = format!("{}:{}", node.id, i);
57 let hash = self.hash_key(&virtual_key);
58 self.ring.push((hash, node.clone()));
59 }
60
61 self.ring.sort_by_key(|(hash, _)| *hash);
63 }
64
65 pub fn remove_node(&mut self, node_id: &str) {
67 self.ring.retain(|(_, node)| node.id != node_id);
68 }
69
70 pub fn get_node(&self, key: &CacheKey) -> Option<&Node> {
72 if self.ring.is_empty() {
73 return None;
74 }
75
76 let hash = self.hash_key(key);
77
78 let idx = self.ring.partition_point(|(h, _)| *h < hash);
80
81 let node_idx = if idx < self.ring.len() { idx } else { 0 };
83
84 self.ring.get(node_idx).map(|(_, node)| node)
85 }
86
87 pub fn get_nodes(&self, key: &CacheKey, n: usize) -> Vec<&Node> {
89 if self.ring.is_empty() {
90 return Vec::new();
91 }
92
93 let hash = self.hash_key(key);
94 let start_idx = self.ring.partition_point(|(h, _)| *h < hash);
95
96 let mut nodes = Vec::new();
97 let mut seen = std::collections::HashSet::new();
98
99 for i in 0..self.ring.len() {
100 let idx = (start_idx + i) % self.ring.len();
101 let (_, node) = &self.ring[idx];
102
103 if !seen.contains(&node.id) {
104 nodes.push(node);
105 seen.insert(node.id.clone());
106
107 if nodes.len() >= n {
108 break;
109 }
110 }
111 }
112
113 nodes
114 }
115
116 fn hash_key(&self, key: &str) -> u64 {
118 use std::collections::hash_map::DefaultHasher;
119
120 let mut hasher = DefaultHasher::new();
121 key.hash(&mut hasher);
122 hasher.finish()
123 }
124
125 pub fn nodes(&self) -> Vec<Node> {
127 let mut seen = std::collections::HashSet::new();
128 let mut nodes = Vec::new();
129
130 for (_, node) in &self.ring {
131 if !seen.contains(&node.id) {
132 nodes.push(node.clone());
133 seen.insert(node.id.clone());
134 }
135 }
136
137 nodes
138 }
139
140 pub fn size(&self) -> usize {
142 self.ring.len()
143 }
144}
145
146pub struct DistributedCache {
148 local: Arc<DashMap<CacheKey, CacheValue>>,
150 ring: Arc<RwLock<ConsistentHashRing>>,
152 local_node: Node,
154 replication_factor: usize,
156 hot_key_threshold: u64,
158 stats: Arc<RwLock<CacheStats>>,
160}
161
162impl DistributedCache {
163 pub fn new(local_node: Node, replication_factor: usize) -> Self {
165 let mut ring = ConsistentHashRing::new(150); ring.add_node(local_node.clone());
167
168 Self {
169 local: Arc::new(DashMap::new()),
170 ring: Arc::new(RwLock::new(ring)),
171 local_node,
172 replication_factor,
173 hot_key_threshold: 100,
174 stats: Arc::new(RwLock::new(CacheStats::new())),
175 }
176 }
177
178 pub async fn add_peer(&self, node: Node) {
180 let mut ring = self.ring.write().await;
181 ring.add_node(node);
182 }
183
184 pub async fn remove_peer(&self, node_id: &str) {
186 let mut ring = self.ring.write().await;
187 ring.remove_node(node_id);
188 }
189
190 pub async fn get(&self, key: &CacheKey) -> Result<Option<CacheValue>> {
192 let ring = self.ring.read().await;
193
194 if let Some(node) = ring.get_node(key) {
196 if node.id == self.local_node.id {
197 if let Some(mut value) = self.local.get_mut(key) {
199 value.record_access();
200
201 let mut stats = self.stats.write().await;
202 stats.hits += 1;
203
204 return Ok(Some(value.clone()));
205 } else {
206 let mut stats = self.stats.write().await;
207 stats.misses += 1;
208 return Ok(None);
209 }
210 } else {
211 let mut stats = self.stats.write().await;
214 stats.misses += 1;
215 return Ok(None);
216 }
217 }
218
219 Ok(None)
220 }
221
222 pub async fn put(&self, key: CacheKey, value: CacheValue) -> Result<()> {
224 let ring = self.ring.read().await;
225
226 let nodes = ring.get_nodes(&key, self.replication_factor);
228
229 let should_store_locally = nodes.iter().any(|n| n.id == self.local_node.id);
231
232 if should_store_locally {
233 self.local.insert(key.clone(), value.clone());
234
235 let mut stats = self.stats.write().await;
236 stats.bytes_stored += value.size as u64;
237 stats.item_count += 1;
238 }
239
240 Ok(())
243 }
244
245 pub async fn remove(&self, key: &CacheKey) -> Result<bool> {
247 let removed = self.local.remove(key);
248
249 if let Some((_, value)) = removed {
250 let mut stats = self.stats.write().await;
251 stats.bytes_stored = stats.bytes_stored.saturating_sub(value.size as u64);
252 stats.item_count = stats.item_count.saturating_sub(1);
253
254 Ok(true)
255 } else {
256 Ok(false)
257 }
258 }
259
260 pub fn is_hot_key(&self, key: &CacheKey) -> bool {
262 if let Some(value) = self.local.get(key) {
263 value.access_count >= self.hot_key_threshold
264 } else {
265 false
266 }
267 }
268
269 pub async fn stats(&self) -> CacheStats {
271 self.stats.read().await.clone()
272 }
273
274 pub async fn peers(&self) -> Vec<Node> {
276 let ring = self.ring.read().await;
277 ring.nodes()
278 }
279
280 pub async fn rebalance(&self) -> Result<()> {
282 let ring = self.ring.read().await;
283 let mut keys_to_remove = Vec::new();
284
285 for entry in self.local.iter() {
287 let key = entry.key();
288 let nodes = ring.get_nodes(key, self.replication_factor);
289
290 if !nodes.iter().any(|n| n.id == self.local_node.id) {
292 keys_to_remove.push(key.clone());
293 }
294 }
295
296 drop(ring);
297
298 for key in keys_to_remove {
300 self.remove(&key).await?;
301 }
302
303 Ok(())
304 }
305
306 pub async fn clear(&self) -> Result<()> {
308 self.local.clear();
309
310 let mut stats = self.stats.write().await;
311 *stats = CacheStats::new();
312
313 Ok(())
314 }
315}
316
317#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
319pub struct CacheMetadata {
320 pub version: u64,
322 pub owner: String,
324 pub replicas: Vec<String>,
326 pub last_modified: chrono::DateTime<chrono::Utc>,
328}
329
330#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
332pub enum CacheOperation {
333 Put {
335 key: CacheKey,
337 value: Vec<u8>,
339 metadata: CacheMetadata,
341 },
342 Delete {
344 key: CacheKey,
346 version: u64,
348 },
349 Invalidate {
351 key: CacheKey,
353 },
354}
355
356#[async_trait]
358pub trait DistributedProtocol: Send + Sync {
359 async fn broadcast(&self, operation: CacheOperation) -> Result<()>;
361
362 async fn handle_operation(&self, operation: CacheOperation) -> Result<()>;
364
365 async fn sync_with_peer(&self, peer_id: &str) -> Result<()>;
367}
368
369#[async_trait]
371pub trait PeerDiscovery: Send + Sync {
372 async fn discover(&self) -> Result<Vec<Node>>;
374
375 async fn register(&self, node: Node) -> Result<()>;
377
378 async fn unregister(&self, node_id: &str) -> Result<()>;
380
381 async fn health_check(&self, node_id: &str) -> Result<bool>;
383}
384
385pub struct StaticPeerDiscovery {
387 peers: Vec<Node>,
389}
390
391impl StaticPeerDiscovery {
392 pub fn new(peers: Vec<Node>) -> Self {
394 Self { peers }
395 }
396}
397
398#[async_trait]
399impl PeerDiscovery for StaticPeerDiscovery {
400 async fn discover(&self) -> Result<Vec<Node>> {
401 Ok(self.peers.clone())
402 }
403
404 async fn register(&self, _node: Node) -> Result<()> {
405 Ok(())
407 }
408
409 async fn unregister(&self, _node_id: &str) -> Result<()> {
410 Ok(())
412 }
413
414 async fn health_check(&self, _node_id: &str) -> Result<bool> {
415 Ok(true)
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use bytes::Bytes;
424
425 #[test]
426 fn test_consistent_hash_ring() {
427 let mut ring = ConsistentHashRing::new(150);
428
429 let node1 = Node {
430 id: "node1".to_string(),
431 address: "127.0.0.1:8001".to_string(),
432 weight: 1,
433 };
434
435 let node2 = Node {
436 id: "node2".to_string(),
437 address: "127.0.0.1:8002".to_string(),
438 weight: 1,
439 };
440
441 ring.add_node(node1.clone());
442 ring.add_node(node2.clone());
443
444 assert_eq!(ring.size(), 300); let key = "test_key".to_string();
447 let node = ring.get_node(&key);
448 assert!(node.is_some());
449 }
450
451 #[test]
452 fn test_replication_nodes() {
453 let mut ring = ConsistentHashRing::new(150);
454
455 for i in 0..5 {
456 ring.add_node(Node {
457 id: format!("node{}", i),
458 address: format!("127.0.0.1:800{}", i),
459 weight: 1,
460 });
461 }
462
463 let key = "test_key".to_string();
464 let nodes = ring.get_nodes(&key, 3);
465
466 assert_eq!(nodes.len(), 3);
467
468 let unique_ids: std::collections::HashSet<_> = nodes.iter().map(|n| &n.id).collect();
470 assert_eq!(unique_ids.len(), 3);
471 }
472
473 #[tokio::test]
474 async fn test_distributed_cache() {
475 let node = Node {
476 id: "test_node".to_string(),
477 address: "127.0.0.1:8000".to_string(),
478 weight: 1,
479 };
480
481 let cache = DistributedCache::new(node, 2);
482
483 let key = "test_key".to_string();
484 let value = CacheValue::new(
485 Bytes::from("test data"),
486 crate::compression::DataType::Binary,
487 );
488
489 cache
490 .put(key.clone(), value.clone())
491 .await
492 .expect("put failed");
493
494 let retrieved = cache.get(&key).await.expect("get failed");
495 assert!(retrieved.is_some());
496 }
497
498 #[tokio::test]
499 async fn test_cache_rebalance() {
500 let node1 = Node {
501 id: "node1".to_string(),
502 address: "127.0.0.1:8001".to_string(),
503 weight: 1,
504 };
505
506 let cache = DistributedCache::new(node1.clone(), 2);
507
508 for i in 0..10 {
510 let key = format!("key{}", i);
511 let value = CacheValue::new(
512 Bytes::from(format!("value{}", i)),
513 crate::compression::DataType::Binary,
514 );
515 cache.put(key, value).await.expect("put failed");
516 }
517
518 let node2 = Node {
520 id: "node2".to_string(),
521 address: "127.0.0.1:8002".to_string(),
522 weight: 1,
523 };
524 cache.add_peer(node2).await;
525
526 cache.rebalance().await.expect("rebalance failed");
528
529 let stats = cache.stats().await;
531 assert!(stats.item_count <= 10);
532 }
533}