1use dashmap::DashMap;
10use std::collections::BTreeMap;
11use std::net::SocketAddr;
12use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tokio::net::TcpStream;
15
16use super::{CacheEntry, TierStats};
17use crate::distribcache::QueryFingerprint;
18
19#[derive(Debug, Clone, Copy)]
21#[repr(u8)]
22enum MessageType {
23 Get = 1,
24 GetResponse = 2,
25 Put = 3,
26 PutResponse = 4,
27 Invalidate = 5,
28 Ping = 6,
29 Pong = 7,
30}
31
32impl TryFrom<u8> for MessageType {
33 type Error = ();
34
35 fn try_from(value: u8) -> Result<Self, Self::Error> {
36 match value {
37 1 => Ok(MessageType::Get),
38 2 => Ok(MessageType::GetResponse),
39 3 => Ok(MessageType::Put),
40 4 => Ok(MessageType::PutResponse),
41 5 => Ok(MessageType::Invalidate),
42 6 => Ok(MessageType::Ping),
43 7 => Ok(MessageType::Pong),
44 _ => Err(()),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
51pub struct PeerId(pub u64);
52
53impl PeerId {
54 pub fn new(addr: &SocketAddr) -> Self {
55 use std::hash::{Hash, Hasher};
56 use std::collections::hash_map::DefaultHasher;
57
58 let mut hasher = DefaultHasher::new();
59 addr.hash(&mut hasher);
60 Self(hasher.finish())
61 }
62
63 pub fn local() -> Self {
64 Self(0)
65 }
66}
67
68struct HashRing {
70 ring: BTreeMap<u64, PeerId>,
72 virtual_nodes: usize,
74}
75
76impl HashRing {
77 fn new(virtual_nodes: usize) -> Self {
78 Self {
79 ring: BTreeMap::new(),
80 virtual_nodes,
81 }
82 }
83
84 fn add_peer(&mut self, peer: PeerId) {
85 for i in 0..self.virtual_nodes {
86 let hash = Self::hash_peer(peer, i);
87 self.ring.insert(hash, peer);
88 }
89 }
90
91 fn remove_peer(&mut self, peer: PeerId) {
92 self.ring.retain(|_, p| *p != peer);
93 }
94
95 fn get_nodes(&self, key: &[u8], count: u32) -> Vec<PeerId> {
96 if self.ring.is_empty() {
97 return Vec::new();
98 }
99
100 let key_hash = Self::hash_key(key);
101 let mut nodes = Vec::new();
102 let mut seen = std::collections::HashSet::new();
103
104 let iter = self.ring.range(key_hash..).chain(self.ring.range(..key_hash));
106
107 for (_, peer) in iter {
108 if !seen.contains(peer) {
109 seen.insert(*peer);
110 nodes.push(*peer);
111 if nodes.len() >= count as usize {
112 break;
113 }
114 }
115 }
116
117 nodes
118 }
119
120 fn hash_peer(peer: PeerId, vnode: usize) -> u64 {
121 use std::hash::{Hash, Hasher};
122 use std::collections::hash_map::DefaultHasher;
123
124 let mut hasher = DefaultHasher::new();
125 peer.0.hash(&mut hasher);
126 vnode.hash(&mut hasher);
127 hasher.finish()
128 }
129
130 fn hash_key(key: &[u8]) -> u64 {
131 use std::hash::{Hash, Hasher};
132 use std::collections::hash_map::DefaultHasher;
133
134 let mut hasher = DefaultHasher::new();
135 key.hash(&mut hasher);
136 hasher.finish()
137 }
138}
139
140#[derive(Debug)]
142pub struct PeerConnection {
143 pub addr: SocketAddr,
145 pub healthy: bool,
147 pub last_seen: u64,
149 pub rtt_us: u64,
151 timeout_ms: u64,
153}
154
155impl Clone for PeerConnection {
156 fn clone(&self) -> Self {
157 Self {
158 addr: self.addr,
159 healthy: self.healthy,
160 last_seen: self.last_seen,
161 rtt_us: self.rtt_us,
162 timeout_ms: self.timeout_ms,
163 }
164 }
165}
166
167impl PeerConnection {
168 fn new(addr: SocketAddr) -> Self {
169 Self {
170 addr,
171 healthy: true,
172 last_seen: 0,
173 rtt_us: 0,
174 timeout_ms: 5000, }
176 }
177
178 pub async fn get(&self, fingerprint: &QueryFingerprint) -> Result<CacheEntry, &'static str> {
180 let _start = std::time::Instant::now();
181
182 let stream = match tokio::time::timeout(
184 std::time::Duration::from_millis(self.timeout_ms),
185 TcpStream::connect(self.addr),
186 )
187 .await
188 {
189 Ok(Ok(s)) => s,
190 Ok(Err(_)) => return Err("Connection failed"),
191 Err(_) => return Err("Connection timeout"),
192 };
193
194 let fp_bytes = match bincode::serialize(fingerprint) {
196 Ok(b) => b,
197 Err(_) => return Err("Serialization failed"),
198 };
199
200 let (mut reader, mut writer) = stream.into_split();
202
203 let mut header = vec![MessageType::Get as u8];
205 header.extend_from_slice(&(fp_bytes.len() as u32).to_le_bytes());
206
207 if writer.write_all(&header).await.is_err() {
208 return Err("Failed to write header");
209 }
210 if writer.write_all(&fp_bytes).await.is_err() {
211 return Err("Failed to write data");
212 }
213
214 let mut resp_header = [0u8; 5];
216 if reader.read_exact(&mut resp_header).await.is_err() {
217 return Err("Failed to read response header");
218 }
219
220 let _msg_type = MessageType::try_from(resp_header[0]).map_err(|_| "Invalid message type")?;
221 let length = u32::from_le_bytes([resp_header[1], resp_header[2], resp_header[3], resp_header[4]]) as usize;
222
223 if length == 0 {
224 return Err("Entry not found");
225 }
226
227 let mut data = vec![0u8; length];
228 if reader.read_exact(&mut data).await.is_err() {
229 return Err("Failed to read response data");
230 }
231
232 let entry: CacheEntry = bincode::deserialize(&data).map_err(|_| "Deserialization failed")?;
234
235 Ok(entry)
236 }
237
238 pub async fn insert(&self, fingerprint: QueryFingerprint, entry: CacheEntry) -> Result<(), &'static str> {
240 let stream = match tokio::time::timeout(
242 std::time::Duration::from_millis(self.timeout_ms),
243 TcpStream::connect(self.addr),
244 )
245 .await
246 {
247 Ok(Ok(s)) => s,
248 Ok(Err(_)) => return Err("Connection failed"),
249 Err(_) => return Err("Connection timeout"),
250 };
251
252 let fp_bytes = bincode::serialize(&fingerprint).map_err(|_| "FP serialization failed")?;
254 let entry_bytes = bincode::serialize(&entry).map_err(|_| "Entry serialization failed")?;
255
256 let mut message = Vec::with_capacity(1 + 4 + 4 + fp_bytes.len() + entry_bytes.len());
258 message.push(MessageType::Put as u8);
259 message.extend_from_slice(&(fp_bytes.len() as u32).to_le_bytes());
260 message.extend_from_slice(&(entry_bytes.len() as u32).to_le_bytes());
261 message.extend_from_slice(&fp_bytes);
262 message.extend_from_slice(&entry_bytes);
263
264 let (mut reader, mut writer) = stream.into_split();
265
266 if writer.write_all(&message).await.is_err() {
267 return Err("Failed to write");
268 }
269
270 let mut resp_header = [0u8; 5];
272 if reader.read_exact(&mut resp_header).await.is_err() {
273 return Err("Failed to read ack");
274 }
275
276 Ok(())
277 }
278
279 pub async fn ping(&self) -> bool {
281 let _start = std::time::Instant::now();
282
283 let stream = match tokio::time::timeout(
284 std::time::Duration::from_millis(1000),
285 TcpStream::connect(self.addr),
286 )
287 .await
288 {
289 Ok(Ok(s)) => s,
290 _ => return false,
291 };
292
293 let (mut reader, mut writer) = stream.into_split();
294
295 let ping_msg = [MessageType::Ping as u8, 0, 0, 0, 0];
297 if writer.write_all(&ping_msg).await.is_err() {
298 return false;
299 }
300
301 let mut resp = [0u8; 5];
303 match tokio::time::timeout(
304 std::time::Duration::from_millis(1000),
305 reader.read_exact(&mut resp),
306 )
307 .await
308 {
309 Ok(Ok(_)) => resp[0] == MessageType::Pong as u8,
310 _ => false,
311 }
312 }
313
314 pub async fn invalidate(&self, fingerprint: &QueryFingerprint) -> Result<(), &'static str> {
316 let stream = match tokio::time::timeout(
317 std::time::Duration::from_millis(self.timeout_ms),
318 TcpStream::connect(self.addr),
319 )
320 .await
321 {
322 Ok(Ok(s)) => s,
323 Ok(Err(_)) => return Err("Connection failed"),
324 Err(_) => return Err("Connection timeout"),
325 };
326
327 let fp_bytes = bincode::serialize(fingerprint).map_err(|_| "Serialization failed")?;
328
329 let mut message = vec![MessageType::Invalidate as u8];
330 message.extend_from_slice(&(fp_bytes.len() as u32).to_le_bytes());
331 message.extend_from_slice(&fp_bytes);
332
333 let (_, mut writer) = stream.into_split();
334 writer.write_all(&message).await.map_err(|_| "Write failed")?;
335
336 Ok(())
337 }
338}
339
340pub struct DistributedCache {
342 local_peer_id: PeerId,
344
345 hash_ring: std::sync::RwLock<HashRing>,
347
348 peers: DashMap<PeerId, PeerConnection>,
350
351 local: DashMap<u64, CacheEntry>,
353
354 replication_factor: u32,
356
357 hits: AtomicU64,
359 misses: AtomicU64,
360 remote_hits: AtomicU64,
361 replication_lag_ms: AtomicU64,
362 healthy_peers: AtomicU32,
363}
364
365impl DistributedCache {
366 pub fn new(replication_factor: u32, peer_addrs: Vec<SocketAddr>) -> Self {
368 let local_peer_id = PeerId::local();
369
370 let mut hash_ring = HashRing::new(100); hash_ring.add_peer(local_peer_id);
372
373 let peers = DashMap::new();
374 for addr in &peer_addrs {
375 let peer_id = PeerId::new(addr);
376 hash_ring.add_peer(peer_id);
377 peers.insert(peer_id, PeerConnection::new(*addr));
378 }
379
380 Self {
381 local_peer_id,
382 hash_ring: std::sync::RwLock::new(hash_ring),
383 peers,
384 local: DashMap::new(),
385 replication_factor,
386 hits: AtomicU64::new(0),
387 misses: AtomicU64::new(0),
388 remote_hits: AtomicU64::new(0),
389 replication_lag_ms: AtomicU64::new(0),
390 healthy_peers: AtomicU32::new(peer_addrs.len() as u32),
391 }
392 }
393
394 pub async fn get(&self, fingerprint: &QueryFingerprint) -> Option<CacheEntry> {
396 let key = self.fingerprint_to_hash(fingerprint);
397 let key_bytes = key.to_le_bytes();
398
399 let owners = {
401 let ring = self.hash_ring.read().ok()?;
402 ring.get_nodes(&key_bytes, self.replication_factor)
403 };
404
405 if owners.contains(&self.local_peer_id) {
407 if let Some(entry) = self.local.get(&key) {
408 if !entry.is_expired() {
409 self.hits.fetch_add(1, Ordering::Relaxed);
410 return Some(entry.clone());
411 } else {
412 drop(entry);
413 self.local.remove(&key);
414 }
415 }
416 }
417
418 for owner in owners {
420 if owner == self.local_peer_id {
421 continue;
422 }
423
424 if let Some(peer) = self.peers.get(&owner) {
425 if peer.healthy {
426 if let Ok(entry) = peer.get(fingerprint).await {
427 self.local.insert(key, entry.clone());
429 self.remote_hits.fetch_add(1, Ordering::Relaxed);
430 self.hits.fetch_add(1, Ordering::Relaxed);
431 return Some(entry);
432 }
433 }
434 }
435 }
436
437 self.misses.fetch_add(1, Ordering::Relaxed);
438 None
439 }
440
441 pub async fn insert(&self, fingerprint: QueryFingerprint, entry: CacheEntry) {
443 let key = self.fingerprint_to_hash(&fingerprint);
444 let key_bytes = key.to_le_bytes();
445
446 let owners = {
448 let ring = self.hash_ring.read().unwrap();
449 ring.get_nodes(&key_bytes, self.replication_factor)
450 };
451
452 if owners.contains(&self.local_peer_id) {
454 self.local.insert(key, entry.clone());
455 }
456
457 for owner in owners {
459 if owner == self.local_peer_id {
460 continue;
461 }
462
463 if let Some(peer) = self.peers.get(&owner) {
464 if peer.healthy {
465 let fp = fingerprint.clone();
466 let e = entry.clone();
467 let _ = peer.insert(fp, e).await;
468 }
469 }
470 }
471 }
472
473 pub fn add_peer(&self, addr: SocketAddr) {
475 let peer_id = PeerId::new(&addr);
476
477 if let Ok(mut ring) = self.hash_ring.write() {
478 ring.add_peer(peer_id);
479 }
480
481 self.peers.insert(peer_id, PeerConnection::new(addr));
482 self.healthy_peers.fetch_add(1, Ordering::Relaxed);
483 }
484
485 pub fn remove_peer(&self, addr: &SocketAddr) {
487 let peer_id = PeerId::new(addr);
488
489 if let Ok(mut ring) = self.hash_ring.write() {
490 ring.remove_peer(peer_id);
491 }
492
493 if self.peers.remove(&peer_id).is_some() {
494 self.healthy_peers.fetch_sub(1, Ordering::Relaxed);
495 }
496 }
497
498 pub fn mark_unhealthy(&self, addr: &SocketAddr) {
500 let peer_id = PeerId::new(addr);
501
502 if let Some(mut peer) = self.peers.get_mut(&peer_id) {
503 if peer.healthy {
504 peer.healthy = false;
505 self.healthy_peers.fetch_sub(1, Ordering::Relaxed);
506 }
507 }
508 }
509
510 pub fn mark_healthy(&self, addr: &SocketAddr) {
512 let peer_id = PeerId::new(addr);
513
514 if let Some(mut peer) = self.peers.get_mut(&peer_id) {
515 if !peer.healthy {
516 peer.healthy = true;
517 self.healthy_peers.fetch_add(1, Ordering::Relaxed);
518 }
519 }
520 }
521
522 pub async fn invalidate(&self, fingerprint: &QueryFingerprint) {
524 let key = self.fingerprint_to_hash(fingerprint);
525
526 self.local.remove(&key);
528
529 for peer_ref in self.peers.iter() {
531 let peer = peer_ref.value();
532 if peer.healthy {
533 let fp = fingerprint.clone();
535 let peer_clone = peer.clone();
536 tokio::spawn(async move {
537 let _ = peer_clone.invalidate(&fp).await;
538 });
539 }
540 }
541 }
542
543 fn fingerprint_to_hash(&self, fingerprint: &QueryFingerprint) -> u64 {
545 use std::hash::{Hash, Hasher};
546 use std::collections::hash_map::DefaultHasher;
547
548 let mut hasher = DefaultHasher::new();
549 fingerprint.template.hash(&mut hasher);
550 if let Some(param) = fingerprint.param_hash {
551 param.hash(&mut hasher);
552 }
553 hasher.finish()
554 }
555
556 pub fn stats(&self) -> TierStats {
558 let local_size: usize = self.local.iter()
559 .map(|e| e.value().size())
560 .sum();
561
562 TierStats {
563 size_bytes: local_size as u64,
564 max_size_bytes: 0, entry_count: self.local.len() as u64,
566 hits: self.hits.load(Ordering::Relaxed),
567 misses: self.misses.load(Ordering::Relaxed),
568 evictions: 0,
569 compression_ratio: None,
570 peer_count: Some(self.peers.len() as u32 + 1), healthy_peers: Some(self.healthy_peers.load(Ordering::Relaxed) + 1),
572 }
573 }
574
575 pub fn peer_addrs(&self) -> Vec<SocketAddr> {
577 self.peers.iter()
578 .map(|p| p.value().addr)
579 .collect()
580 }
581
582 pub fn copy_valid_entries_to(&self, target: &DistributedCache) {
584 for entry in self.local.iter() {
585 if !entry.value().is_expired() {
586 target.local.insert(*entry.key(), entry.value().clone());
587 }
588 }
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595 use std::time::Duration;
596
597 #[test]
598 fn test_hash_ring_distribution() {
599 let mut ring = HashRing::new(10);
600
601 let peer1 = PeerId(1);
602 let peer2 = PeerId(2);
603 let peer3 = PeerId(3);
604
605 ring.add_peer(peer1);
606 ring.add_peer(peer2);
607 ring.add_peer(peer3);
608
609 let key1 = b"test-key-1";
611 let key2 = b"test-key-2";
612 let key3 = b"test-key-3";
613
614 let nodes1 = ring.get_nodes(key1, 2);
615 let nodes2 = ring.get_nodes(key2, 2);
616 let nodes3 = ring.get_nodes(key3, 2);
617
618 assert_eq!(nodes1.len(), 2);
620 assert_eq!(nodes2.len(), 2);
621 assert_eq!(nodes3.len(), 2);
622 }
623
624 #[test]
625 fn test_hash_ring_replication() {
626 let mut ring = HashRing::new(10);
627
628 let peer1 = PeerId(1);
629 let peer2 = PeerId(2);
630
631 ring.add_peer(peer1);
632 ring.add_peer(peer2);
633
634 let key = b"replicated-key";
635 let nodes = ring.get_nodes(key, 2);
636
637 assert_eq!(nodes.len(), 2);
639 assert!(nodes.contains(&peer1));
640 assert!(nodes.contains(&peer2));
641 }
642
643 #[tokio::test]
644 async fn test_distributed_cache_local_insert_get() {
645 let cache = DistributedCache::new(1, Vec::new());
646
647 let fp = QueryFingerprint::from_query("SELECT * FROM users");
648 let entry = CacheEntry::new(vec![1, 2, 3], vec!["users".to_string()], 1)
649 .with_ttl(Duration::from_secs(300));
650
651 cache.insert(fp.clone(), entry).await;
652
653 let result = cache.get(&fp).await;
654 assert!(result.is_some());
655 assert_eq!(result.unwrap().data, vec![1, 2, 3]);
656 }
657
658 #[test]
659 fn test_distributed_cache_peer_management() {
660 let cache = DistributedCache::new(2, Vec::new());
661
662 let addr1: SocketAddr = "127.0.0.1:9100".parse().unwrap();
663 let addr2: SocketAddr = "127.0.0.1:9101".parse().unwrap();
664
665 cache.add_peer(addr1);
666 cache.add_peer(addr2);
667
668 assert_eq!(cache.stats().peer_count, Some(3)); cache.mark_unhealthy(&addr1);
671 assert_eq!(cache.stats().healthy_peers, Some(2)); cache.remove_peer(&addr1);
674 assert_eq!(cache.stats().peer_count, Some(2)); }
676
677 #[tokio::test]
678 async fn test_distributed_cache_stats() {
679 let cache = DistributedCache::new(1, Vec::new());
680
681 let fp1 = QueryFingerprint::from_query("SELECT * FROM users");
682 let fp2 = QueryFingerprint::from_query("SELECT * FROM orders");
683
684 cache.insert(
685 fp1.clone(),
686 CacheEntry::new(vec![1], vec![], 1).with_ttl(Duration::from_secs(300)),
687 ).await;
688
689 cache.get(&fp1).await; cache.get(&fp2).await; let stats = cache.stats();
693 assert_eq!(stats.hits, 1);
694 assert_eq!(stats.misses, 1);
695 assert_eq!(stats.entry_count, 1);
696 }
697}