use dashmap::DashMap;
use std::collections::BTreeMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use super::{CacheEntry, TierStats};
use crate::distribcache::QueryFingerprint;
#[derive(Debug, Clone, Copy)]
#[repr(u8)]
enum MessageType {
Get = 1,
GetResponse = 2,
Put = 3,
PutResponse = 4,
Invalidate = 5,
Ping = 6,
Pong = 7,
}
impl TryFrom<u8> for MessageType {
type Error = ();
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
1 => Ok(MessageType::Get),
2 => Ok(MessageType::GetResponse),
3 => Ok(MessageType::Put),
4 => Ok(MessageType::PutResponse),
5 => Ok(MessageType::Invalidate),
6 => Ok(MessageType::Ping),
7 => Ok(MessageType::Pong),
_ => Err(()),
}
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct PeerId(pub u64);
impl PeerId {
pub fn new(addr: &SocketAddr) -> Self {
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
addr.hash(&mut hasher);
Self(hasher.finish())
}
pub fn local() -> Self {
Self(0)
}
}
struct HashRing {
ring: BTreeMap<u64, PeerId>,
virtual_nodes: usize,
}
impl HashRing {
fn new(virtual_nodes: usize) -> Self {
Self {
ring: BTreeMap::new(),
virtual_nodes,
}
}
fn add_peer(&mut self, peer: PeerId) {
for i in 0..self.virtual_nodes {
let hash = Self::hash_peer(peer, i);
self.ring.insert(hash, peer);
}
}
fn remove_peer(&mut self, peer: PeerId) {
self.ring.retain(|_, p| *p != peer);
}
fn get_nodes(&self, key: &[u8], count: u32) -> Vec<PeerId> {
if self.ring.is_empty() {
return Vec::new();
}
let key_hash = Self::hash_key(key);
let mut nodes = Vec::new();
let mut seen = std::collections::HashSet::new();
let iter = self.ring.range(key_hash..).chain(self.ring.range(..key_hash));
for (_, peer) in iter {
if !seen.contains(peer) {
seen.insert(*peer);
nodes.push(*peer);
if nodes.len() >= count as usize {
break;
}
}
}
nodes
}
fn hash_peer(peer: PeerId, vnode: usize) -> u64 {
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
peer.0.hash(&mut hasher);
vnode.hash(&mut hasher);
hasher.finish()
}
fn hash_key(key: &[u8]) -> u64 {
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
}
#[derive(Debug)]
pub struct PeerConnection {
pub addr: SocketAddr,
pub healthy: bool,
pub last_seen: u64,
pub rtt_us: u64,
timeout_ms: u64,
}
impl Clone for PeerConnection {
fn clone(&self) -> Self {
Self {
addr: self.addr,
healthy: self.healthy,
last_seen: self.last_seen,
rtt_us: self.rtt_us,
timeout_ms: self.timeout_ms,
}
}
}
impl PeerConnection {
fn new(addr: SocketAddr) -> Self {
Self {
addr,
healthy: true,
last_seen: 0,
rtt_us: 0,
timeout_ms: 5000, }
}
pub async fn get(&self, fingerprint: &QueryFingerprint) -> Result<CacheEntry, &'static str> {
let _start = std::time::Instant::now();
let stream = match tokio::time::timeout(
std::time::Duration::from_millis(self.timeout_ms),
TcpStream::connect(self.addr),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(_)) => return Err("Connection failed"),
Err(_) => return Err("Connection timeout"),
};
let fp_bytes = match bincode::serialize(fingerprint) {
Ok(b) => b,
Err(_) => return Err("Serialization failed"),
};
let (mut reader, mut writer) = stream.into_split();
let mut header = vec![MessageType::Get as u8];
header.extend_from_slice(&(fp_bytes.len() as u32).to_le_bytes());
if writer.write_all(&header).await.is_err() {
return Err("Failed to write header");
}
if writer.write_all(&fp_bytes).await.is_err() {
return Err("Failed to write data");
}
let mut resp_header = [0u8; 5];
if reader.read_exact(&mut resp_header).await.is_err() {
return Err("Failed to read response header");
}
let _msg_type = MessageType::try_from(resp_header[0]).map_err(|_| "Invalid message type")?;
let length = u32::from_le_bytes([resp_header[1], resp_header[2], resp_header[3], resp_header[4]]) as usize;
if length == 0 {
return Err("Entry not found");
}
let mut data = vec![0u8; length];
if reader.read_exact(&mut data).await.is_err() {
return Err("Failed to read response data");
}
let entry: CacheEntry = bincode::deserialize(&data).map_err(|_| "Deserialization failed")?;
Ok(entry)
}
pub async fn insert(&self, fingerprint: QueryFingerprint, entry: CacheEntry) -> Result<(), &'static str> {
let stream = match tokio::time::timeout(
std::time::Duration::from_millis(self.timeout_ms),
TcpStream::connect(self.addr),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(_)) => return Err("Connection failed"),
Err(_) => return Err("Connection timeout"),
};
let fp_bytes = bincode::serialize(&fingerprint).map_err(|_| "FP serialization failed")?;
let entry_bytes = bincode::serialize(&entry).map_err(|_| "Entry serialization failed")?;
let mut message = Vec::with_capacity(1 + 4 + 4 + fp_bytes.len() + entry_bytes.len());
message.push(MessageType::Put as u8);
message.extend_from_slice(&(fp_bytes.len() as u32).to_le_bytes());
message.extend_from_slice(&(entry_bytes.len() as u32).to_le_bytes());
message.extend_from_slice(&fp_bytes);
message.extend_from_slice(&entry_bytes);
let (mut reader, mut writer) = stream.into_split();
if writer.write_all(&message).await.is_err() {
return Err("Failed to write");
}
let mut resp_header = [0u8; 5];
if reader.read_exact(&mut resp_header).await.is_err() {
return Err("Failed to read ack");
}
Ok(())
}
pub async fn ping(&self) -> bool {
let _start = std::time::Instant::now();
let stream = match tokio::time::timeout(
std::time::Duration::from_millis(1000),
TcpStream::connect(self.addr),
)
.await
{
Ok(Ok(s)) => s,
_ => return false,
};
let (mut reader, mut writer) = stream.into_split();
let ping_msg = [MessageType::Ping as u8, 0, 0, 0, 0];
if writer.write_all(&ping_msg).await.is_err() {
return false;
}
let mut resp = [0u8; 5];
match tokio::time::timeout(
std::time::Duration::from_millis(1000),
reader.read_exact(&mut resp),
)
.await
{
Ok(Ok(_)) => resp[0] == MessageType::Pong as u8,
_ => false,
}
}
pub async fn invalidate(&self, fingerprint: &QueryFingerprint) -> Result<(), &'static str> {
let stream = match tokio::time::timeout(
std::time::Duration::from_millis(self.timeout_ms),
TcpStream::connect(self.addr),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(_)) => return Err("Connection failed"),
Err(_) => return Err("Connection timeout"),
};
let fp_bytes = bincode::serialize(fingerprint).map_err(|_| "Serialization failed")?;
let mut message = vec![MessageType::Invalidate as u8];
message.extend_from_slice(&(fp_bytes.len() as u32).to_le_bytes());
message.extend_from_slice(&fp_bytes);
let (_, mut writer) = stream.into_split();
writer.write_all(&message).await.map_err(|_| "Write failed")?;
Ok(())
}
}
pub struct DistributedCache {
local_peer_id: PeerId,
hash_ring: std::sync::RwLock<HashRing>,
peers: DashMap<PeerId, PeerConnection>,
local: DashMap<u64, CacheEntry>,
replication_factor: u32,
hits: AtomicU64,
misses: AtomicU64,
remote_hits: AtomicU64,
replication_lag_ms: AtomicU64,
healthy_peers: AtomicU32,
}
impl DistributedCache {
pub fn new(replication_factor: u32, peer_addrs: Vec<SocketAddr>) -> Self {
let local_peer_id = PeerId::local();
let mut hash_ring = HashRing::new(100); hash_ring.add_peer(local_peer_id);
let peers = DashMap::new();
for addr in &peer_addrs {
let peer_id = PeerId::new(addr);
hash_ring.add_peer(peer_id);
peers.insert(peer_id, PeerConnection::new(*addr));
}
Self {
local_peer_id,
hash_ring: std::sync::RwLock::new(hash_ring),
peers,
local: DashMap::new(),
replication_factor,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
remote_hits: AtomicU64::new(0),
replication_lag_ms: AtomicU64::new(0),
healthy_peers: AtomicU32::new(peer_addrs.len() as u32),
}
}
pub async fn get(&self, fingerprint: &QueryFingerprint) -> Option<CacheEntry> {
let key = self.fingerprint_to_hash(fingerprint);
let key_bytes = key.to_le_bytes();
let owners = {
let ring = self.hash_ring.read().ok()?;
ring.get_nodes(&key_bytes, self.replication_factor)
};
if owners.contains(&self.local_peer_id) {
if let Some(entry) = self.local.get(&key) {
if !entry.is_expired() {
self.hits.fetch_add(1, Ordering::Relaxed);
return Some(entry.clone());
} else {
drop(entry);
self.local.remove(&key);
}
}
}
for owner in owners {
if owner == self.local_peer_id {
continue;
}
if let Some(peer) = self.peers.get(&owner) {
if peer.healthy {
if let Ok(entry) = peer.get(fingerprint).await {
self.local.insert(key, entry.clone());
self.remote_hits.fetch_add(1, Ordering::Relaxed);
self.hits.fetch_add(1, Ordering::Relaxed);
return Some(entry);
}
}
}
}
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
pub async fn insert(&self, fingerprint: QueryFingerprint, entry: CacheEntry) {
let key = self.fingerprint_to_hash(&fingerprint);
let key_bytes = key.to_le_bytes();
let owners = {
let ring = self.hash_ring.read().unwrap();
ring.get_nodes(&key_bytes, self.replication_factor)
};
if owners.contains(&self.local_peer_id) {
self.local.insert(key, entry.clone());
}
for owner in owners {
if owner == self.local_peer_id {
continue;
}
if let Some(peer) = self.peers.get(&owner) {
if peer.healthy {
let fp = fingerprint.clone();
let e = entry.clone();
let _ = peer.insert(fp, e).await;
}
}
}
}
pub fn add_peer(&self, addr: SocketAddr) {
let peer_id = PeerId::new(&addr);
if let Ok(mut ring) = self.hash_ring.write() {
ring.add_peer(peer_id);
}
self.peers.insert(peer_id, PeerConnection::new(addr));
self.healthy_peers.fetch_add(1, Ordering::Relaxed);
}
pub fn remove_peer(&self, addr: &SocketAddr) {
let peer_id = PeerId::new(addr);
if let Ok(mut ring) = self.hash_ring.write() {
ring.remove_peer(peer_id);
}
if self.peers.remove(&peer_id).is_some() {
self.healthy_peers.fetch_sub(1, Ordering::Relaxed);
}
}
pub fn mark_unhealthy(&self, addr: &SocketAddr) {
let peer_id = PeerId::new(addr);
if let Some(mut peer) = self.peers.get_mut(&peer_id) {
if peer.healthy {
peer.healthy = false;
self.healthy_peers.fetch_sub(1, Ordering::Relaxed);
}
}
}
pub fn mark_healthy(&self, addr: &SocketAddr) {
let peer_id = PeerId::new(addr);
if let Some(mut peer) = self.peers.get_mut(&peer_id) {
if !peer.healthy {
peer.healthy = true;
self.healthy_peers.fetch_add(1, Ordering::Relaxed);
}
}
}
pub async fn invalidate(&self, fingerprint: &QueryFingerprint) {
let key = self.fingerprint_to_hash(fingerprint);
self.local.remove(&key);
for peer_ref in self.peers.iter() {
let peer = peer_ref.value();
if peer.healthy {
let fp = fingerprint.clone();
let peer_clone = peer.clone();
tokio::spawn(async move {
let _ = peer_clone.invalidate(&fp).await;
});
}
}
}
fn fingerprint_to_hash(&self, fingerprint: &QueryFingerprint) -> u64 {
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
fingerprint.template.hash(&mut hasher);
if let Some(param) = fingerprint.param_hash {
param.hash(&mut hasher);
}
hasher.finish()
}
pub fn stats(&self) -> TierStats {
let local_size: usize = self.local.iter()
.map(|e| e.value().size())
.sum();
TierStats {
size_bytes: local_size as u64,
max_size_bytes: 0, entry_count: self.local.len() as u64,
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
evictions: 0,
compression_ratio: None,
peer_count: Some(self.peers.len() as u32 + 1), healthy_peers: Some(self.healthy_peers.load(Ordering::Relaxed) + 1),
}
}
pub fn peer_addrs(&self) -> Vec<SocketAddr> {
self.peers.iter()
.map(|p| p.value().addr)
.collect()
}
pub fn copy_valid_entries_to(&self, target: &DistributedCache) {
for entry in self.local.iter() {
if !entry.value().is_expired() {
target.local.insert(*entry.key(), entry.value().clone());
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_hash_ring_distribution() {
let mut ring = HashRing::new(10);
let peer1 = PeerId(1);
let peer2 = PeerId(2);
let peer3 = PeerId(3);
ring.add_peer(peer1);
ring.add_peer(peer2);
ring.add_peer(peer3);
let key1 = b"test-key-1";
let key2 = b"test-key-2";
let key3 = b"test-key-3";
let nodes1 = ring.get_nodes(key1, 2);
let nodes2 = ring.get_nodes(key2, 2);
let nodes3 = ring.get_nodes(key3, 2);
assert_eq!(nodes1.len(), 2);
assert_eq!(nodes2.len(), 2);
assert_eq!(nodes3.len(), 2);
}
#[test]
fn test_hash_ring_replication() {
let mut ring = HashRing::new(10);
let peer1 = PeerId(1);
let peer2 = PeerId(2);
ring.add_peer(peer1);
ring.add_peer(peer2);
let key = b"replicated-key";
let nodes = ring.get_nodes(key, 2);
assert_eq!(nodes.len(), 2);
assert!(nodes.contains(&peer1));
assert!(nodes.contains(&peer2));
}
#[tokio::test]
async fn test_distributed_cache_local_insert_get() {
let cache = DistributedCache::new(1, Vec::new());
let fp = QueryFingerprint::from_query("SELECT * FROM users");
let entry = CacheEntry::new(vec![1, 2, 3], vec!["users".to_string()], 1)
.with_ttl(Duration::from_secs(300));
cache.insert(fp.clone(), entry).await;
let result = cache.get(&fp).await;
assert!(result.is_some());
assert_eq!(result.unwrap().data, vec![1, 2, 3]);
}
#[test]
fn test_distributed_cache_peer_management() {
let cache = DistributedCache::new(2, Vec::new());
let addr1: SocketAddr = "127.0.0.1:9100".parse().unwrap();
let addr2: SocketAddr = "127.0.0.1:9101".parse().unwrap();
cache.add_peer(addr1);
cache.add_peer(addr2);
assert_eq!(cache.stats().peer_count, Some(3));
cache.mark_unhealthy(&addr1);
assert_eq!(cache.stats().healthy_peers, Some(2));
cache.remove_peer(&addr1);
assert_eq!(cache.stats().peer_count, Some(2)); }
#[tokio::test]
async fn test_distributed_cache_stats() {
let cache = DistributedCache::new(1, Vec::new());
let fp1 = QueryFingerprint::from_query("SELECT * FROM users");
let fp2 = QueryFingerprint::from_query("SELECT * FROM orders");
cache.insert(
fp1.clone(),
CacheEntry::new(vec![1], vec![], 1).with_ttl(Duration::from_secs(300)),
).await;
cache.get(&fp1).await; cache.get(&fp2).await;
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.entry_count, 1);
}
}