use crate::{CrdtEntry, CrdtKv, CrdtStore};
use bytes::Bytes;
use pollen_membership::Membership;
use pollen_transport::{Envelope, MessageType, Transport};
use pollen_types::{NodeId, Result};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use tracing::{debug, info, warn};
#[derive(Clone, Debug)]
pub struct CrdtSyncConfig {
pub sync_interval: Duration,
pub broadcast_interval: Duration,
pub max_entries_per_msg: usize,
pub sync_timeout: Duration,
}
impl Default for CrdtSyncConfig {
fn default() -> Self {
Self {
sync_interval: Duration::from_secs(10),
broadcast_interval: Duration::from_millis(100),
max_entries_per_msg: 100,
sync_timeout: Duration::from_secs(5),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MerkleRequest {
pub level: usize,
pub hashes: Vec<(String, Bytes)>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MerkleResponse {
pub level: usize,
pub hashes: Vec<(String, Bytes)>,
pub differing_buckets: Vec<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DataRangeRequest {
pub start: String,
pub end: String,
}
pub struct CrdtSyncService {
node_id: NodeId,
crdt_store: Arc<CrdtStore>,
transport: Arc<dyn Transport>,
membership: Arc<dyn Membership>,
config: CrdtSyncConfig,
shutdown: watch::Sender<bool>,
}
impl CrdtSyncService {
pub fn new(
node_id: NodeId,
crdt_store: Arc<CrdtStore>,
transport: Arc<dyn Transport>,
membership: Arc<dyn Membership>,
config: CrdtSyncConfig,
) -> Self {
let (shutdown, _) = watch::channel(false);
Self {
node_id,
crdt_store,
transport,
membership,
config,
shutdown,
}
}
pub fn start(self: Arc<Self>) {
let service = Arc::clone(&self);
tokio::spawn(async move {
service.run_anti_entropy().await;
});
let service = Arc::clone(&self);
tokio::spawn(async move {
service.run_delta_broadcast().await;
});
info!("CRDT sync service started");
}
async fn run_anti_entropy(&self) {
let mut interval = tokio::time::interval(self.config.sync_interval);
let mut shutdown_rx = self.shutdown.subscribe();
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
break;
}
}
_ = interval.tick() => {
if let Err(e) = self.anti_entropy_round().await {
warn!("Anti-entropy round failed: {}", e);
}
}
}
}
}
async fn anti_entropy_round(&self) -> Result<()> {
let peers = self.membership.alive_members();
if peers.len() <= 1 {
return Ok(());
}
let other_peers: Vec<_> = peers.iter().filter(|p| p.id != self.node_id).collect();
if other_peers.is_empty() {
return Ok(());
}
let peer_idx = rand::random::<usize>() % other_peers.len();
let peer = other_peers[peer_idx];
debug!("Starting anti-entropy sync with {:?}", peer.id);
self.sync_with_peer(peer.id, peer.addr).await
}
async fn sync_with_peer(&self, peer_id: NodeId, peer_addr: std::net::SocketAddr) -> Result<()> {
let my_root = self.crdt_store.merkle_root();
let request = MerkleRequest {
level: 0,
hashes: vec![("root".to_string(), my_root.clone())],
};
let payload = bincode::serialize(&request)?;
let envelope = Envelope::new(
self.node_id,
peer_id,
MessageType::MerkleTreeRequest,
Bytes::from(payload),
pollen_clock::Timestamp::zero(),
);
let response = match tokio::time::timeout(
self.config.sync_timeout,
self.transport.send_recv(peer_addr, envelope),
)
.await
{
Ok(Ok(resp)) => resp,
Ok(Err(e)) => {
debug!("Failed to get Merkle response from {:?}: {}", peer_id, e);
return Ok(());
}
Err(_) => {
debug!("Merkle request to {:?} timed out", peer_id);
return Ok(());
}
};
if response.msg_type != MessageType::MerkleTreeResponse {
return Ok(());
}
let merkle_response: MerkleResponse = bincode::deserialize(&response.payload)?;
if merkle_response.differing_buckets.is_empty() {
debug!("In sync with {:?}", peer_id);
return Ok(());
}
self.sync_differing_ranges(peer_id, peer_addr, &merkle_response.differing_buckets)
.await
}
async fn sync_differing_ranges(
&self,
peer_id: NodeId,
peer_addr: std::net::SocketAddr,
ranges: &[String],
) -> Result<()> {
for range in ranges {
let (start, end) = self.range_from_bucket(range);
let request = DataRangeRequest {
start: start.clone(),
end: end.clone(),
};
let payload = bincode::serialize(&request)?;
let envelope = Envelope::new(
self.node_id,
peer_id,
MessageType::DataRangeRequest,
Bytes::from(payload),
pollen_clock::Timestamp::zero(),
);
let response = match tokio::time::timeout(
self.config.sync_timeout,
self.transport.send_recv(peer_addr, envelope),
)
.await
{
Ok(Ok(resp)) => resp,
Ok(Err(e)) => {
warn!("Failed to get data range from {:?}: {}", peer_id, e);
continue;
}
Err(_) => {
warn!("Data range request to {:?} timed out", peer_id);
continue;
}
};
if response.msg_type != MessageType::DataRangeResponse {
continue;
}
let entries: Vec<CrdtEntry> = bincode::deserialize(&response.payload)?;
for entry in entries {
if let Err(e) = self.crdt_store.apply_delta(entry).await {
warn!("Failed to apply synced entry: {}", e);
}
}
let our_entries = self.crdt_store.entries_in_range(&start, &end);
if !our_entries.is_empty() {
for chunk in our_entries.chunks(self.config.max_entries_per_msg) {
let payload = bincode::serialize(&chunk.to_vec())?;
let envelope = Envelope::new(
self.node_id,
peer_id,
MessageType::CrdtFullSync,
Bytes::from(payload),
pollen_clock::Timestamp::zero(),
);
let _ = self.transport.send(peer_addr, envelope).await;
}
}
}
Ok(())
}
fn range_from_bucket(&self, bucket: &str) -> (String, String) {
let start = bucket.to_string();
let end = format!("{}~", bucket); (start, end)
}
async fn run_delta_broadcast(&self) {
let mut rx = self.crdt_store.subscribe("");
let mut shutdown_rx = self.shutdown.subscribe();
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
break;
}
}
event = rx.recv() => {
match event {
Ok(crate::CrdtEvent::Updated { key }) => {
if let Err(e) = self.broadcast_key(&key).await {
warn!("Failed to broadcast delta for {}: {}", key, e);
}
}
Ok(crate::CrdtEvent::Deleted { key }) => {
if let Err(e) = self.broadcast_key(&key).await {
warn!("Failed to broadcast deletion for {}: {}", key, e);
}
}
Err(_) => {
}
}
}
}
}
}
async fn broadcast_key(&self, key: &str) -> Result<()> {
let entries = self.crdt_store.entries_in_range(key, &format!("{}~", key));
if entries.is_empty() {
return Ok(());
}
let entry = &entries[0];
let peers = self.membership.alive_members();
for peer in peers {
if peer.id == self.node_id {
continue;
}
let payload = bincode::serialize(&vec![entry.clone()])?;
let envelope = Envelope::new(
self.node_id,
peer.id,
MessageType::CrdtDelta,
Bytes::from(payload),
pollen_clock::Timestamp::zero(),
);
let transport = Arc::clone(&self.transport);
let addr = peer.addr;
tokio::spawn(async move {
let _ = transport.send(addr, envelope).await;
});
}
Ok(())
}
pub async fn handle_message(&self, envelope: Envelope) -> Result<Option<Envelope>> {
match envelope.msg_type {
MessageType::CrdtDelta | MessageType::CrdtFullSync => {
let entries: Vec<CrdtEntry> = bincode::deserialize(&envelope.payload)?;
for entry in entries {
if let Err(e) = self.crdt_store.apply_delta(entry).await {
warn!("Failed to apply delta: {}", e);
}
}
Ok(None)
}
MessageType::MerkleTreeRequest => {
let request: MerkleRequest = bincode::deserialize(&envelope.payload)?;
let response = self.handle_merkle_request(request);
let payload = bincode::serialize(&response)?;
Ok(Some(Envelope::new(
self.node_id,
envelope.from,
MessageType::MerkleTreeResponse,
Bytes::from(payload),
pollen_clock::Timestamp::zero(),
)))
}
MessageType::DataRangeRequest => {
let request: DataRangeRequest = bincode::deserialize(&envelope.payload)?;
let entries = self.crdt_store.entries_in_range(&request.start, &request.end);
let payload = bincode::serialize(&entries)?;
Ok(Some(Envelope::new(
self.node_id,
envelope.from,
MessageType::DataRangeResponse,
Bytes::from(payload),
pollen_clock::Timestamp::zero(),
)))
}
_ => Ok(None),
}
}
fn handle_merkle_request(&self, request: MerkleRequest) -> MerkleResponse {
let my_hashes = self.crdt_store.merkle_level(request.level);
let mut differing = Vec::new();
for (bucket, their_hash) in &request.hashes {
let my_hash = my_hashes
.iter()
.find(|(b, _)| b == bucket)
.map(|(_, h)| h.clone());
match my_hash {
Some(h) if h != *their_hash => {
differing.push(bucket.clone());
}
None => {
differing.push(bucket.clone());
}
_ => {}
}
}
for (bucket, _) in &my_hashes {
if !request.hashes.iter().any(|(b, _)| b == bucket) {
differing.push(bucket.clone());
}
}
MerkleResponse {
level: request.level,
hashes: my_hashes,
differing_buckets: differing,
}
}
pub fn shutdown(&self) {
let _ = self.shutdown.send(true);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = CrdtSyncConfig::default();
assert_eq!(config.sync_interval, Duration::from_secs(10));
assert_eq!(config.sync_timeout, Duration::from_secs(5));
}
#[test]
fn test_merkle_request_serialization() {
let request = MerkleRequest {
level: 1,
hashes: vec![
("bucket1".to_string(), Bytes::from("hash1")),
("bucket2".to_string(), Bytes::from("hash2")),
],
};
let serialized = bincode::serialize(&request).unwrap();
let deserialized: MerkleRequest = bincode::deserialize(&serialized).unwrap();
assert_eq!(deserialized.level, 1);
assert_eq!(deserialized.hashes.len(), 2);
}
#[test]
fn test_data_range_request_serialization() {
let request = DataRangeRequest {
start: "task:".to_string(),
end: "task:~".to_string(),
};
let serialized = bincode::serialize(&request).unwrap();
let deserialized: DataRangeRequest = bincode::deserialize(&serialized).unwrap();
assert_eq!(deserialized.start, "task:");
assert_eq!(deserialized.end, "task:~");
}
}