Skip to main content

heliosdb_proxy/distribcache/tiers/
l3_distributed.rs

1//! L3 Distributed Cache - Cache mesh with <10ms access time
2//!
3//! Features:
4//! - Consistent hashing for key distribution
5//! - Replication for availability
6//! - TCP-based peer-to-peer communication
7//! - Gossip protocol for peer discovery (planned)
8
9use dashmap::DashMap;
10use std::collections::BTreeMap;
11use std::net::SocketAddr;
12use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tokio::net::TcpStream;
15
16use super::{CacheEntry, TierStats};
17use crate::distribcache::QueryFingerprint;
18
19/// Cache protocol message types
20#[derive(Debug, Clone, Copy)]
21#[repr(u8)]
22enum MessageType {
23    Get = 1,
24    GetResponse = 2,
25    Put = 3,
26    PutResponse = 4,
27    Invalidate = 5,
28    Ping = 6,
29    Pong = 7,
30}
31
32impl TryFrom<u8> for MessageType {
33    type Error = ();
34
35    fn try_from(value: u8) -> Result<Self, Self::Error> {
36        match value {
37            1 => Ok(MessageType::Get),
38            2 => Ok(MessageType::GetResponse),
39            3 => Ok(MessageType::Put),
40            4 => Ok(MessageType::PutResponse),
41            5 => Ok(MessageType::Invalidate),
42            6 => Ok(MessageType::Ping),
43            7 => Ok(MessageType::Pong),
44            _ => Err(()),
45        }
46    }
47}
48
49/// Peer identifier
50#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
51pub struct PeerId(pub u64);
52
53impl PeerId {
54    pub fn new(addr: &SocketAddr) -> Self {
55        use std::hash::{Hash, Hasher};
56        use std::collections::hash_map::DefaultHasher;
57
58        let mut hasher = DefaultHasher::new();
59        addr.hash(&mut hasher);
60        Self(hasher.finish())
61    }
62
63    pub fn local() -> Self {
64        Self(0)
65    }
66}
67
68/// Consistent hash ring for key distribution
69struct HashRing {
70    /// Ring nodes (virtual nodes -> peer)
71    ring: BTreeMap<u64, PeerId>,
72    /// Number of virtual nodes per peer
73    virtual_nodes: usize,
74}
75
76impl HashRing {
77    fn new(virtual_nodes: usize) -> Self {
78        Self {
79            ring: BTreeMap::new(),
80            virtual_nodes,
81        }
82    }
83
84    fn add_peer(&mut self, peer: PeerId) {
85        for i in 0..self.virtual_nodes {
86            let hash = Self::hash_peer(peer, i);
87            self.ring.insert(hash, peer);
88        }
89    }
90
91    fn remove_peer(&mut self, peer: PeerId) {
92        self.ring.retain(|_, p| *p != peer);
93    }
94
95    fn get_nodes(&self, key: &[u8], count: u32) -> Vec<PeerId> {
96        if self.ring.is_empty() {
97            return Vec::new();
98        }
99
100        let key_hash = Self::hash_key(key);
101        let mut nodes = Vec::new();
102        let mut seen = std::collections::HashSet::new();
103
104        // Find first node >= key_hash
105        let iter = self.ring.range(key_hash..).chain(self.ring.range(..key_hash));
106
107        for (_, peer) in iter {
108            if !seen.contains(peer) {
109                seen.insert(*peer);
110                nodes.push(*peer);
111                if nodes.len() >= count as usize {
112                    break;
113                }
114            }
115        }
116
117        nodes
118    }
119
120    fn hash_peer(peer: PeerId, vnode: usize) -> u64 {
121        use std::hash::{Hash, Hasher};
122        use std::collections::hash_map::DefaultHasher;
123
124        let mut hasher = DefaultHasher::new();
125        peer.0.hash(&mut hasher);
126        vnode.hash(&mut hasher);
127        hasher.finish()
128    }
129
130    fn hash_key(key: &[u8]) -> u64 {
131        use std::hash::{Hash, Hasher};
132        use std::collections::hash_map::DefaultHasher;
133
134        let mut hasher = DefaultHasher::new();
135        key.hash(&mut hasher);
136        hasher.finish()
137    }
138}
139
140/// Peer connection state
141#[derive(Debug)]
142pub struct PeerConnection {
143    /// Peer address
144    pub addr: SocketAddr,
145    /// Connection healthy
146    pub healthy: bool,
147    /// Last seen timestamp
148    pub last_seen: u64,
149    /// Round-trip time in microseconds
150    pub rtt_us: u64,
151    /// Connection timeout in milliseconds
152    timeout_ms: u64,
153}
154
155impl Clone for PeerConnection {
156    fn clone(&self) -> Self {
157        Self {
158            addr: self.addr,
159            healthy: self.healthy,
160            last_seen: self.last_seen,
161            rtt_us: self.rtt_us,
162            timeout_ms: self.timeout_ms,
163        }
164    }
165}
166
167impl PeerConnection {
168    fn new(addr: SocketAddr) -> Self {
169        Self {
170            addr,
171            healthy: true,
172            last_seen: 0,
173            rtt_us: 0,
174            timeout_ms: 5000, // 5 second timeout
175        }
176    }
177
178    /// Get entry from peer via TCP
179    pub async fn get(&self, fingerprint: &QueryFingerprint) -> Result<CacheEntry, &'static str> {
180        let _start = std::time::Instant::now();
181
182        // Try to connect with timeout
183        let stream = match tokio::time::timeout(
184            std::time::Duration::from_millis(self.timeout_ms),
185            TcpStream::connect(self.addr),
186        )
187        .await
188        {
189            Ok(Ok(s)) => s,
190            Ok(Err(_)) => return Err("Connection failed"),
191            Err(_) => return Err("Connection timeout"),
192        };
193
194        // Build request message
195        let fp_bytes = match bincode::serialize(fingerprint) {
196            Ok(b) => b,
197            Err(_) => return Err("Serialization failed"),
198        };
199
200        // Send GET request
201        let (mut reader, mut writer) = stream.into_split();
202
203        // Message format: [type: u8][length: u32][data: bytes]
204        let mut header = vec![MessageType::Get as u8];
205        header.extend_from_slice(&(fp_bytes.len() as u32).to_le_bytes());
206
207        if writer.write_all(&header).await.is_err() {
208            return Err("Failed to write header");
209        }
210        if writer.write_all(&fp_bytes).await.is_err() {
211            return Err("Failed to write data");
212        }
213
214        // Read response
215        let mut resp_header = [0u8; 5];
216        if reader.read_exact(&mut resp_header).await.is_err() {
217            return Err("Failed to read response header");
218        }
219
220        let _msg_type = MessageType::try_from(resp_header[0]).map_err(|_| "Invalid message type")?;
221        let length = u32::from_le_bytes([resp_header[1], resp_header[2], resp_header[3], resp_header[4]]) as usize;
222
223        if length == 0 {
224            return Err("Entry not found");
225        }
226
227        let mut data = vec![0u8; length];
228        if reader.read_exact(&mut data).await.is_err() {
229            return Err("Failed to read response data");
230        }
231
232        // Deserialize entry
233        let entry: CacheEntry = bincode::deserialize(&data).map_err(|_| "Deserialization failed")?;
234
235        Ok(entry)
236    }
237
238    /// Insert entry to peer via TCP
239    pub async fn insert(&self, fingerprint: QueryFingerprint, entry: CacheEntry) -> Result<(), &'static str> {
240        // Try to connect with timeout
241        let stream = match tokio::time::timeout(
242            std::time::Duration::from_millis(self.timeout_ms),
243            TcpStream::connect(self.addr),
244        )
245        .await
246        {
247            Ok(Ok(s)) => s,
248            Ok(Err(_)) => return Err("Connection failed"),
249            Err(_) => return Err("Connection timeout"),
250        };
251
252        // Serialize fingerprint and entry
253        let fp_bytes = bincode::serialize(&fingerprint).map_err(|_| "FP serialization failed")?;
254        let entry_bytes = bincode::serialize(&entry).map_err(|_| "Entry serialization failed")?;
255
256        // Build message: [type: u8][fp_len: u32][entry_len: u32][fp_data][entry_data]
257        let mut message = Vec::with_capacity(1 + 4 + 4 + fp_bytes.len() + entry_bytes.len());
258        message.push(MessageType::Put as u8);
259        message.extend_from_slice(&(fp_bytes.len() as u32).to_le_bytes());
260        message.extend_from_slice(&(entry_bytes.len() as u32).to_le_bytes());
261        message.extend_from_slice(&fp_bytes);
262        message.extend_from_slice(&entry_bytes);
263
264        let (mut reader, mut writer) = stream.into_split();
265
266        if writer.write_all(&message).await.is_err() {
267            return Err("Failed to write");
268        }
269
270        // Read response (ack)
271        let mut resp_header = [0u8; 5];
272        if reader.read_exact(&mut resp_header).await.is_err() {
273            return Err("Failed to read ack");
274        }
275
276        Ok(())
277    }
278
279    /// Ping peer to check health
280    pub async fn ping(&self) -> bool {
281        let _start = std::time::Instant::now();
282
283        let stream = match tokio::time::timeout(
284            std::time::Duration::from_millis(1000),
285            TcpStream::connect(self.addr),
286        )
287        .await
288        {
289            Ok(Ok(s)) => s,
290            _ => return false,
291        };
292
293        let (mut reader, mut writer) = stream.into_split();
294
295        // Send ping
296        let ping_msg = [MessageType::Ping as u8, 0, 0, 0, 0];
297        if writer.write_all(&ping_msg).await.is_err() {
298            return false;
299        }
300
301        // Wait for pong
302        let mut resp = [0u8; 5];
303        match tokio::time::timeout(
304            std::time::Duration::from_millis(1000),
305            reader.read_exact(&mut resp),
306        )
307        .await
308        {
309            Ok(Ok(_)) => resp[0] == MessageType::Pong as u8,
310            _ => false,
311        }
312    }
313
314    /// Send invalidation message to peer
315    pub async fn invalidate(&self, fingerprint: &QueryFingerprint) -> Result<(), &'static str> {
316        let stream = match tokio::time::timeout(
317            std::time::Duration::from_millis(self.timeout_ms),
318            TcpStream::connect(self.addr),
319        )
320        .await
321        {
322            Ok(Ok(s)) => s,
323            Ok(Err(_)) => return Err("Connection failed"),
324            Err(_) => return Err("Connection timeout"),
325        };
326
327        let fp_bytes = bincode::serialize(fingerprint).map_err(|_| "Serialization failed")?;
328
329        let mut message = vec![MessageType::Invalidate as u8];
330        message.extend_from_slice(&(fp_bytes.len() as u32).to_le_bytes());
331        message.extend_from_slice(&fp_bytes);
332
333        let (_, mut writer) = stream.into_split();
334        writer.write_all(&message).await.map_err(|_| "Write failed")?;
335
336        Ok(())
337    }
338}
339
340/// L3 Distributed Cache - Cache mesh with consistent hashing
341pub struct DistributedCache {
342    /// Local peer ID
343    local_peer_id: PeerId,
344
345    /// Consistent hash ring
346    hash_ring: std::sync::RwLock<HashRing>,
347
348    /// Peer connections
349    peers: DashMap<PeerId, PeerConnection>,
350
351    /// Local cache for owned keys
352    local: DashMap<u64, CacheEntry>,
353
354    /// Replication factor
355    replication_factor: u32,
356
357    /// Statistics
358    hits: AtomicU64,
359    misses: AtomicU64,
360    remote_hits: AtomicU64,
361    replication_lag_ms: AtomicU64,
362    healthy_peers: AtomicU32,
363}
364
365impl DistributedCache {
366    /// Create a new distributed cache
367    pub fn new(replication_factor: u32, peer_addrs: Vec<SocketAddr>) -> Self {
368        let local_peer_id = PeerId::local();
369
370        let mut hash_ring = HashRing::new(100); // 100 virtual nodes per peer
371        hash_ring.add_peer(local_peer_id);
372
373        let peers = DashMap::new();
374        for addr in &peer_addrs {
375            let peer_id = PeerId::new(addr);
376            hash_ring.add_peer(peer_id);
377            peers.insert(peer_id, PeerConnection::new(*addr));
378        }
379
380        Self {
381            local_peer_id,
382            hash_ring: std::sync::RwLock::new(hash_ring),
383            peers,
384            local: DashMap::new(),
385            replication_factor,
386            hits: AtomicU64::new(0),
387            misses: AtomicU64::new(0),
388            remote_hits: AtomicU64::new(0),
389            replication_lag_ms: AtomicU64::new(0),
390            healthy_peers: AtomicU32::new(peer_addrs.len() as u32),
391        }
392    }
393
394    /// Get an entry from the distributed cache
395    pub async fn get(&self, fingerprint: &QueryFingerprint) -> Option<CacheEntry> {
396        let key = self.fingerprint_to_hash(fingerprint);
397        let key_bytes = key.to_le_bytes();
398
399        // Determine owners
400        let owners = {
401            let ring = self.hash_ring.read().ok()?;
402            ring.get_nodes(&key_bytes, self.replication_factor)
403        };
404
405        // Check local first if we own it
406        if owners.contains(&self.local_peer_id) {
407            if let Some(entry) = self.local.get(&key) {
408                if !entry.is_expired() {
409                    self.hits.fetch_add(1, Ordering::Relaxed);
410                    return Some(entry.clone());
411                } else {
412                    drop(entry);
413                    self.local.remove(&key);
414                }
415            }
416        }
417
418        // Query remote peers
419        for owner in owners {
420            if owner == self.local_peer_id {
421                continue;
422            }
423
424            if let Some(peer) = self.peers.get(&owner) {
425                if peer.healthy {
426                    if let Ok(entry) = peer.get(fingerprint).await {
427                        // Cache locally
428                        self.local.insert(key, entry.clone());
429                        self.remote_hits.fetch_add(1, Ordering::Relaxed);
430                        self.hits.fetch_add(1, Ordering::Relaxed);
431                        return Some(entry);
432                    }
433                }
434            }
435        }
436
437        self.misses.fetch_add(1, Ordering::Relaxed);
438        None
439    }
440
441    /// Insert an entry into the distributed cache
442    pub async fn insert(&self, fingerprint: QueryFingerprint, entry: CacheEntry) {
443        let key = self.fingerprint_to_hash(&fingerprint);
444        let key_bytes = key.to_le_bytes();
445
446        // Determine owners
447        let owners = {
448            let ring = self.hash_ring.read().unwrap();
449            ring.get_nodes(&key_bytes, self.replication_factor)
450        };
451
452        // Insert locally if we own it
453        if owners.contains(&self.local_peer_id) {
454            self.local.insert(key, entry.clone());
455        }
456
457        // Replicate to other owners (fire and forget for now)
458        for owner in owners {
459            if owner == self.local_peer_id {
460                continue;
461            }
462
463            if let Some(peer) = self.peers.get(&owner) {
464                if peer.healthy {
465                    let fp = fingerprint.clone();
466                    let e = entry.clone();
467                    let _ = peer.insert(fp, e).await;
468                }
469            }
470        }
471    }
472
473    /// Add a peer to the cache mesh
474    pub fn add_peer(&self, addr: SocketAddr) {
475        let peer_id = PeerId::new(&addr);
476
477        if let Ok(mut ring) = self.hash_ring.write() {
478            ring.add_peer(peer_id);
479        }
480
481        self.peers.insert(peer_id, PeerConnection::new(addr));
482        self.healthy_peers.fetch_add(1, Ordering::Relaxed);
483    }
484
485    /// Remove a peer from the cache mesh
486    pub fn remove_peer(&self, addr: &SocketAddr) {
487        let peer_id = PeerId::new(addr);
488
489        if let Ok(mut ring) = self.hash_ring.write() {
490            ring.remove_peer(peer_id);
491        }
492
493        if self.peers.remove(&peer_id).is_some() {
494            self.healthy_peers.fetch_sub(1, Ordering::Relaxed);
495        }
496    }
497
498    /// Mark peer as unhealthy
499    pub fn mark_unhealthy(&self, addr: &SocketAddr) {
500        let peer_id = PeerId::new(addr);
501
502        if let Some(mut peer) = self.peers.get_mut(&peer_id) {
503            if peer.healthy {
504                peer.healthy = false;
505                self.healthy_peers.fetch_sub(1, Ordering::Relaxed);
506            }
507        }
508    }
509
510    /// Mark peer as healthy
511    pub fn mark_healthy(&self, addr: &SocketAddr) {
512        let peer_id = PeerId::new(addr);
513
514        if let Some(mut peer) = self.peers.get_mut(&peer_id) {
515            if !peer.healthy {
516                peer.healthy = true;
517                self.healthy_peers.fetch_add(1, Ordering::Relaxed);
518            }
519        }
520    }
521
522    /// Invalidate an entry across the mesh
523    pub async fn invalidate(&self, fingerprint: &QueryFingerprint) {
524        let key = self.fingerprint_to_hash(fingerprint);
525
526        // Remove locally
527        self.local.remove(&key);
528
529        // Broadcast invalidation to all healthy peers
530        for peer_ref in self.peers.iter() {
531            let peer = peer_ref.value();
532            if peer.healthy {
533                // Fire and forget - don't wait for ack
534                let fp = fingerprint.clone();
535                let peer_clone = peer.clone();
536                tokio::spawn(async move {
537                    let _ = peer_clone.invalidate(&fp).await;
538                });
539            }
540        }
541    }
542
543    /// Convert fingerprint to hash key
544    fn fingerprint_to_hash(&self, fingerprint: &QueryFingerprint) -> u64 {
545        use std::hash::{Hash, Hasher};
546        use std::collections::hash_map::DefaultHasher;
547
548        let mut hasher = DefaultHasher::new();
549        fingerprint.template.hash(&mut hasher);
550        if let Some(param) = fingerprint.param_hash {
551            param.hash(&mut hasher);
552        }
553        hasher.finish()
554    }
555
556    /// Get cache statistics
557    pub fn stats(&self) -> TierStats {
558        let local_size: usize = self.local.iter()
559            .map(|e| e.value().size())
560            .sum();
561
562        TierStats {
563            size_bytes: local_size as u64,
564            max_size_bytes: 0, // Distributed, no single max
565            entry_count: self.local.len() as u64,
566            hits: self.hits.load(Ordering::Relaxed),
567            misses: self.misses.load(Ordering::Relaxed),
568            evictions: 0,
569            compression_ratio: None,
570            peer_count: Some(self.peers.len() as u32 + 1), // +1 for local
571            healthy_peers: Some(self.healthy_peers.load(Ordering::Relaxed) + 1),
572        }
573    }
574
575    /// Get peer addresses
576    pub fn peer_addrs(&self) -> Vec<SocketAddr> {
577        self.peers.iter()
578            .map(|p| p.value().addr)
579            .collect()
580    }
581
582    /// Copy valid entries to another cache (for branch merging)
583    pub fn copy_valid_entries_to(&self, target: &DistributedCache) {
584        for entry in self.local.iter() {
585            if !entry.value().is_expired() {
586                target.local.insert(*entry.key(), entry.value().clone());
587            }
588        }
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595    use std::time::Duration;
596
597    #[test]
598    fn test_hash_ring_distribution() {
599        let mut ring = HashRing::new(10);
600
601        let peer1 = PeerId(1);
602        let peer2 = PeerId(2);
603        let peer3 = PeerId(3);
604
605        ring.add_peer(peer1);
606        ring.add_peer(peer2);
607        ring.add_peer(peer3);
608
609        // Test key distribution
610        let key1 = b"test-key-1";
611        let key2 = b"test-key-2";
612        let key3 = b"test-key-3";
613
614        let nodes1 = ring.get_nodes(key1, 2);
615        let nodes2 = ring.get_nodes(key2, 2);
616        let nodes3 = ring.get_nodes(key3, 2);
617
618        // Each should return 2 nodes
619        assert_eq!(nodes1.len(), 2);
620        assert_eq!(nodes2.len(), 2);
621        assert_eq!(nodes3.len(), 2);
622    }
623
624    #[test]
625    fn test_hash_ring_replication() {
626        let mut ring = HashRing::new(10);
627
628        let peer1 = PeerId(1);
629        let peer2 = PeerId(2);
630
631        ring.add_peer(peer1);
632        ring.add_peer(peer2);
633
634        let key = b"replicated-key";
635        let nodes = ring.get_nodes(key, 2);
636
637        // Should return both peers
638        assert_eq!(nodes.len(), 2);
639        assert!(nodes.contains(&peer1));
640        assert!(nodes.contains(&peer2));
641    }
642
643    #[tokio::test]
644    async fn test_distributed_cache_local_insert_get() {
645        let cache = DistributedCache::new(1, Vec::new());
646
647        let fp = QueryFingerprint::from_query("SELECT * FROM users");
648        let entry = CacheEntry::new(vec![1, 2, 3], vec!["users".to_string()], 1)
649            .with_ttl(Duration::from_secs(300));
650
651        cache.insert(fp.clone(), entry).await;
652
653        let result = cache.get(&fp).await;
654        assert!(result.is_some());
655        assert_eq!(result.unwrap().data, vec![1, 2, 3]);
656    }
657
658    #[test]
659    fn test_distributed_cache_peer_management() {
660        let cache = DistributedCache::new(2, Vec::new());
661
662        let addr1: SocketAddr = "127.0.0.1:9100".parse().unwrap();
663        let addr2: SocketAddr = "127.0.0.1:9101".parse().unwrap();
664
665        cache.add_peer(addr1);
666        cache.add_peer(addr2);
667
668        assert_eq!(cache.stats().peer_count, Some(3)); // 2 remote + 1 local
669
670        cache.mark_unhealthy(&addr1);
671        assert_eq!(cache.stats().healthy_peers, Some(2)); // 1 remote + 1 local
672
673        cache.remove_peer(&addr1);
674        assert_eq!(cache.stats().peer_count, Some(2)); // 1 remote + 1 local
675    }
676
677    #[tokio::test]
678    async fn test_distributed_cache_stats() {
679        let cache = DistributedCache::new(1, Vec::new());
680
681        let fp1 = QueryFingerprint::from_query("SELECT * FROM users");
682        let fp2 = QueryFingerprint::from_query("SELECT * FROM orders");
683
684        cache.insert(
685            fp1.clone(),
686            CacheEntry::new(vec![1], vec![], 1).with_ttl(Duration::from_secs(300)),
687        ).await;
688
689        cache.get(&fp1).await; // Hit
690        cache.get(&fp2).await; // Miss
691
692        let stats = cache.stats();
693        assert_eq!(stats.hits, 1);
694        assert_eq!(stats.misses, 1);
695        assert_eq!(stats.entry_count, 1);
696    }
697}