Skip to main content

oxigdal_cache_advanced/
distributed.rs

1//! Distributed cache protocol
2//!
3//! Implements distributed caching with:
4//! - Consistent hashing for key distribution
5//! - Distributed LRU with global coordination
6//! - Cache peer discovery
7//! - Replication for hot keys
8//! - Automatic rebalancing
9
10use crate::CacheStats;
11use crate::error::Result;
12use crate::multi_tier::{CacheKey, CacheValue};
13use async_trait::async_trait;
14use dashmap::DashMap;
15use std::hash::{Hash, Hasher};
16use std::sync::Arc;
17use tokio::sync::RwLock;
18
19/// Hash ring node
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct Node {
22    /// Node identifier
23    pub id: String,
24    /// Node address
25    pub address: String,
26    /// Node weight (for distribution)
27    pub weight: usize,
28}
29
30impl Hash for Node {
31    fn hash<H: Hasher>(&self, state: &mut H) {
32        self.id.hash(state);
33    }
34}
35
36/// Consistent hash ring for key distribution
37pub struct ConsistentHashRing {
38    /// Virtual nodes on the ring
39    ring: Vec<(u64, Node)>,
40    /// Number of virtual nodes per physical node
41    virtual_nodes: usize,
42}
43
44impl ConsistentHashRing {
45    /// Create new hash ring
46    pub fn new(virtual_nodes: usize) -> Self {
47        Self {
48            ring: Vec::new(),
49            virtual_nodes,
50        }
51    }
52
53    /// Add node to the ring
54    pub fn add_node(&mut self, node: Node) {
55        for i in 0..self.virtual_nodes {
56            let virtual_key = format!("{}:{}", node.id, i);
57            let hash = self.hash_key(&virtual_key);
58            self.ring.push((hash, node.clone()));
59        }
60
61        // Sort ring by hash values
62        self.ring.sort_by_key(|(hash, _)| *hash);
63    }
64
65    /// Remove node from the ring
66    pub fn remove_node(&mut self, node_id: &str) {
67        self.ring.retain(|(_, node)| node.id != node_id);
68    }
69
70    /// Get node responsible for a key
71    pub fn get_node(&self, key: &CacheKey) -> Option<&Node> {
72        if self.ring.is_empty() {
73            return None;
74        }
75
76        let hash = self.hash_key(key);
77
78        // Binary search for the first node with hash >= key hash
79        let idx = self.ring.partition_point(|(h, _)| *h < hash);
80
81        // Wrap around if needed
82        let node_idx = if idx < self.ring.len() { idx } else { 0 };
83
84        self.ring.get(node_idx).map(|(_, node)| node)
85    }
86
87    /// Get N nodes for replication
88    pub fn get_nodes(&self, key: &CacheKey, n: usize) -> Vec<&Node> {
89        if self.ring.is_empty() {
90            return Vec::new();
91        }
92
93        let hash = self.hash_key(key);
94        let start_idx = self.ring.partition_point(|(h, _)| *h < hash);
95
96        let mut nodes = Vec::new();
97        let mut seen = std::collections::HashSet::new();
98
99        for i in 0..self.ring.len() {
100            let idx = (start_idx + i) % self.ring.len();
101            let (_, node) = &self.ring[idx];
102
103            if !seen.contains(&node.id) {
104                nodes.push(node);
105                seen.insert(node.id.clone());
106
107                if nodes.len() >= n {
108                    break;
109                }
110            }
111        }
112
113        nodes
114    }
115
116    /// Hash a key
117    fn hash_key(&self, key: &str) -> u64 {
118        use std::collections::hash_map::DefaultHasher;
119
120        let mut hasher = DefaultHasher::new();
121        key.hash(&mut hasher);
122        hasher.finish()
123    }
124
125    /// Get all nodes
126    pub fn nodes(&self) -> Vec<Node> {
127        let mut seen = std::collections::HashSet::new();
128        let mut nodes = Vec::new();
129
130        for (_, node) in &self.ring {
131            if !seen.contains(&node.id) {
132                nodes.push(node.clone());
133                seen.insert(node.id.clone());
134            }
135        }
136
137        nodes
138    }
139
140    /// Get ring size
141    pub fn size(&self) -> usize {
142        self.ring.len()
143    }
144}
145
146/// Distributed cache coordinator
147pub struct DistributedCache {
148    /// Local cache
149    local: Arc<DashMap<CacheKey, CacheValue>>,
150    /// Hash ring for distribution
151    ring: Arc<RwLock<ConsistentHashRing>>,
152    /// Current node info
153    local_node: Node,
154    /// Replication factor
155    replication_factor: usize,
156    /// Hot key threshold (access count)
157    hot_key_threshold: u64,
158    /// Statistics
159    stats: Arc<RwLock<CacheStats>>,
160}
161
162impl DistributedCache {
163    /// Create new distributed cache
164    pub fn new(local_node: Node, replication_factor: usize) -> Self {
165        let mut ring = ConsistentHashRing::new(150); // 150 virtual nodes
166        ring.add_node(local_node.clone());
167
168        Self {
169            local: Arc::new(DashMap::new()),
170            ring: Arc::new(RwLock::new(ring)),
171            local_node,
172            replication_factor,
173            hot_key_threshold: 100,
174            stats: Arc::new(RwLock::new(CacheStats::new())),
175        }
176    }
177
178    /// Add peer node
179    pub async fn add_peer(&self, node: Node) {
180        let mut ring = self.ring.write().await;
181        ring.add_node(node);
182    }
183
184    /// Remove peer node
185    pub async fn remove_peer(&self, node_id: &str) {
186        let mut ring = self.ring.write().await;
187        ring.remove_node(node_id);
188    }
189
190    /// Get value from distributed cache
191    pub async fn get(&self, key: &CacheKey) -> Result<Option<CacheValue>> {
192        let ring = self.ring.read().await;
193
194        // Check if this node is responsible
195        if let Some(node) = ring.get_node(key) {
196            if node.id == self.local_node.id {
197                // Local lookup
198                if let Some(mut value) = self.local.get_mut(key) {
199                    value.record_access();
200
201                    let mut stats = self.stats.write().await;
202                    stats.hits += 1;
203
204                    return Ok(Some(value.clone()));
205                } else {
206                    let mut stats = self.stats.write().await;
207                    stats.misses += 1;
208                    return Ok(None);
209                }
210            } else {
211                // Remote lookup (would use network RPC in production)
212                // For now, return None
213                let mut stats = self.stats.write().await;
214                stats.misses += 1;
215                return Ok(None);
216            }
217        }
218
219        Ok(None)
220    }
221
222    /// Put value into distributed cache
223    pub async fn put(&self, key: CacheKey, value: CacheValue) -> Result<()> {
224        let ring = self.ring.read().await;
225
226        // Get nodes for replication
227        let nodes = ring.get_nodes(&key, self.replication_factor);
228
229        // Check if local node should store this key
230        let should_store_locally = nodes.iter().any(|n| n.id == self.local_node.id);
231
232        if should_store_locally {
233            self.local.insert(key.clone(), value.clone());
234
235            let mut stats = self.stats.write().await;
236            stats.bytes_stored += value.size as u64;
237            stats.item_count += 1;
238        }
239
240        // In production, would replicate to other nodes here
241
242        Ok(())
243    }
244
245    /// Remove value from distributed cache
246    pub async fn remove(&self, key: &CacheKey) -> Result<bool> {
247        let removed = self.local.remove(key);
248
249        if let Some((_, value)) = removed {
250            let mut stats = self.stats.write().await;
251            stats.bytes_stored = stats.bytes_stored.saturating_sub(value.size as u64);
252            stats.item_count = stats.item_count.saturating_sub(1);
253
254            Ok(true)
255        } else {
256            Ok(false)
257        }
258    }
259
260    /// Check if key is hot (frequently accessed)
261    pub fn is_hot_key(&self, key: &CacheKey) -> bool {
262        if let Some(value) = self.local.get(key) {
263            value.access_count >= self.hot_key_threshold
264        } else {
265            false
266        }
267    }
268
269    /// Get statistics
270    pub async fn stats(&self) -> CacheStats {
271        self.stats.read().await.clone()
272    }
273
274    /// Get all peer nodes
275    pub async fn peers(&self) -> Vec<Node> {
276        let ring = self.ring.read().await;
277        ring.nodes()
278    }
279
280    /// Rebalance cache after topology change
281    pub async fn rebalance(&self) -> Result<()> {
282        let ring = self.ring.read().await;
283        let mut keys_to_remove = Vec::new();
284
285        // Check all local keys
286        for entry in self.local.iter() {
287            let key = entry.key();
288            let nodes = ring.get_nodes(key, self.replication_factor);
289
290            // If local node is no longer responsible, mark for removal
291            if !nodes.iter().any(|n| n.id == self.local_node.id) {
292                keys_to_remove.push(key.clone());
293            }
294        }
295
296        drop(ring);
297
298        // Remove keys no longer owned
299        for key in keys_to_remove {
300            self.remove(&key).await?;
301        }
302
303        Ok(())
304    }
305
306    /// Clear local cache
307    pub async fn clear(&self) -> Result<()> {
308        self.local.clear();
309
310        let mut stats = self.stats.write().await;
311        *stats = CacheStats::new();
312
313        Ok(())
314    }
315}
316
317/// Distributed cache metadata
318#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
319pub struct CacheMetadata {
320    /// Version number
321    pub version: u64,
322    /// Owner node ID
323    pub owner: String,
324    /// Replica node IDs
325    pub replicas: Vec<String>,
326    /// Last modified timestamp
327    pub last_modified: chrono::DateTime<chrono::Utc>,
328}
329
330/// Cache operation for synchronization
331#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
332pub enum CacheOperation {
333    /// Put operation
334    Put {
335        /// Key
336        key: CacheKey,
337        /// Value
338        value: Vec<u8>,
339        /// Metadata
340        metadata: CacheMetadata,
341    },
342    /// Delete operation
343    Delete {
344        /// Key
345        key: CacheKey,
346        /// Version
347        version: u64,
348    },
349    /// Invalidate operation
350    Invalidate {
351        /// Key
352        key: CacheKey,
353    },
354}
355
356/// Distributed cache protocol
357#[async_trait]
358pub trait DistributedProtocol: Send + Sync {
359    /// Broadcast operation to peers
360    async fn broadcast(&self, operation: CacheOperation) -> Result<()>;
361
362    /// Handle incoming operation
363    async fn handle_operation(&self, operation: CacheOperation) -> Result<()>;
364
365    /// Sync with peer
366    async fn sync_with_peer(&self, peer_id: &str) -> Result<()>;
367}
368
369/// Peer discovery trait
370#[async_trait]
371pub trait PeerDiscovery: Send + Sync {
372    /// Discover peers
373    async fn discover(&self) -> Result<Vec<Node>>;
374
375    /// Register self
376    async fn register(&self, node: Node) -> Result<()>;
377
378    /// Unregister self
379    async fn unregister(&self, node_id: &str) -> Result<()>;
380
381    /// Health check
382    async fn health_check(&self, node_id: &str) -> Result<bool>;
383}
384
385/// Simple static peer list discovery
386pub struct StaticPeerDiscovery {
387    /// Static peer list
388    peers: Vec<Node>,
389}
390
391impl StaticPeerDiscovery {
392    /// Create new static peer discovery
393    pub fn new(peers: Vec<Node>) -> Self {
394        Self { peers }
395    }
396}
397
398#[async_trait]
399impl PeerDiscovery for StaticPeerDiscovery {
400    async fn discover(&self) -> Result<Vec<Node>> {
401        Ok(self.peers.clone())
402    }
403
404    async fn register(&self, _node: Node) -> Result<()> {
405        // Static list doesn't support registration
406        Ok(())
407    }
408
409    async fn unregister(&self, _node_id: &str) -> Result<()> {
410        // Static list doesn't support unregistration
411        Ok(())
412    }
413
414    async fn health_check(&self, _node_id: &str) -> Result<bool> {
415        // Assume all peers are healthy
416        Ok(true)
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use bytes::Bytes;
424
425    #[test]
426    fn test_consistent_hash_ring() {
427        let mut ring = ConsistentHashRing::new(150);
428
429        let node1 = Node {
430            id: "node1".to_string(),
431            address: "127.0.0.1:8001".to_string(),
432            weight: 1,
433        };
434
435        let node2 = Node {
436            id: "node2".to_string(),
437            address: "127.0.0.1:8002".to_string(),
438            weight: 1,
439        };
440
441        ring.add_node(node1.clone());
442        ring.add_node(node2.clone());
443
444        assert_eq!(ring.size(), 300); // 2 nodes * 150 virtual nodes
445
446        let key = "test_key".to_string();
447        let node = ring.get_node(&key);
448        assert!(node.is_some());
449    }
450
451    #[test]
452    fn test_replication_nodes() {
453        let mut ring = ConsistentHashRing::new(150);
454
455        for i in 0..5 {
456            ring.add_node(Node {
457                id: format!("node{}", i),
458                address: format!("127.0.0.1:800{}", i),
459                weight: 1,
460            });
461        }
462
463        let key = "test_key".to_string();
464        let nodes = ring.get_nodes(&key, 3);
465
466        assert_eq!(nodes.len(), 3);
467
468        // Check that all nodes are unique
469        let unique_ids: std::collections::HashSet<_> = nodes.iter().map(|n| &n.id).collect();
470        assert_eq!(unique_ids.len(), 3);
471    }
472
473    #[tokio::test]
474    async fn test_distributed_cache() {
475        let node = Node {
476            id: "test_node".to_string(),
477            address: "127.0.0.1:8000".to_string(),
478            weight: 1,
479        };
480
481        let cache = DistributedCache::new(node, 2);
482
483        let key = "test_key".to_string();
484        let value = CacheValue::new(
485            Bytes::from("test data"),
486            crate::compression::DataType::Binary,
487        );
488
489        cache
490            .put(key.clone(), value.clone())
491            .await
492            .expect("put failed");
493
494        let retrieved = cache.get(&key).await.expect("get failed");
495        assert!(retrieved.is_some());
496    }
497
498    #[tokio::test]
499    async fn test_cache_rebalance() {
500        let node1 = Node {
501            id: "node1".to_string(),
502            address: "127.0.0.1:8001".to_string(),
503            weight: 1,
504        };
505
506        let cache = DistributedCache::new(node1.clone(), 2);
507
508        // Add some data
509        for i in 0..10 {
510            let key = format!("key{}", i);
511            let value = CacheValue::new(
512                Bytes::from(format!("value{}", i)),
513                crate::compression::DataType::Binary,
514            );
515            cache.put(key, value).await.expect("put failed");
516        }
517
518        // Add a new peer
519        let node2 = Node {
520            id: "node2".to_string(),
521            address: "127.0.0.1:8002".to_string(),
522            weight: 1,
523        };
524        cache.add_peer(node2).await;
525
526        // Rebalance
527        cache.rebalance().await.expect("rebalance failed");
528
529        // Some keys may have been removed due to rebalancing
530        let stats = cache.stats().await;
531        assert!(stats.item_count <= 10);
532    }
533}