use super::automerge_sync::{SyncBatch, SyncDirection};
use iroh::EndpointId;
use lru::LruCache;
use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::sync::{Arc, RwLock};
const DEFAULT_DEDUP_CACHE_SIZE: usize = 1000;
pub struct SyncForwarder {
local_node_id: EndpointId,
parent_id: RwLock<Option<EndpointId>>,
children: RwLock<HashSet<EndpointId>>,
forwarded_batches: Arc<RwLock<LruCache<u64, ()>>>,
}
impl SyncForwarder {
pub fn new(local_node_id: EndpointId) -> Self {
Self {
local_node_id,
parent_id: RwLock::new(None),
children: RwLock::new(HashSet::new()),
forwarded_batches: Arc::new(RwLock::new(LruCache::new(
NonZeroUsize::new(DEFAULT_DEDUP_CACHE_SIZE).unwrap(),
))),
}
}
pub fn set_parent(&self, parent_id: Option<EndpointId>) {
*self.parent_id.write().unwrap_or_else(|e| e.into_inner()) = parent_id;
}
pub fn add_child(&self, child_id: EndpointId) {
self.children
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(child_id);
}
pub fn remove_child(&self, child_id: &EndpointId) {
self.children
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(child_id);
}
pub fn parent_id(&self) -> Option<EndpointId> {
*self.parent_id.read().unwrap_or_else(|e| e.into_inner())
}
pub fn children(&self) -> Vec<EndpointId> {
self.children
.read()
.unwrap_or_else(|e| e.into_inner())
.iter()
.copied()
.collect()
}
pub fn was_forwarded(&self, batch_id: u64) -> bool {
self.forwarded_batches
.read()
.unwrap_or_else(|e| e.into_inner())
.contains(&batch_id)
}
pub fn mark_forwarded(&self, batch_id: u64) {
self.forwarded_batches
.write()
.unwrap_or_else(|e| e.into_inner())
.put(batch_id, ());
}
pub fn forward_targets(
&self,
batch: &SyncBatch,
source_peer: EndpointId,
connected_peers: &[EndpointId],
) -> Option<Vec<EndpointId>> {
if self.was_forwarded(batch.batch_id) {
tracing::trace!(
batch_id = batch.batch_id,
"Batch already forwarded, skipping"
);
return None;
}
if batch.ttl == 0 {
tracing::trace!(
batch_id = batch.batch_id,
"Batch TTL expired, not forwarding"
);
return None;
}
let direction = self.determine_batch_direction(batch);
let mut targets = HashSet::new();
match direction {
SyncDirection::Upward => {
if let Some(parent) = self.parent_id() {
if parent != source_peer && connected_peers.contains(&parent) {
targets.insert(parent);
}
}
}
SyncDirection::Downward => {
for child in self.children() {
if child != source_peer && connected_peers.contains(&child) {
targets.insert(child);
}
}
}
SyncDirection::Lateral => {
let parent = self.parent_id();
let children = self.children();
for peer in connected_peers {
if *peer != source_peer
&& *peer != self.local_node_id
&& Some(*peer) != parent
&& !children.contains(peer)
{
targets.insert(*peer);
}
}
}
SyncDirection::Broadcast => {
for peer in connected_peers {
if *peer != source_peer && *peer != self.local_node_id {
targets.insert(*peer);
}
}
}
}
tracing::debug!(
batch_id = batch.batch_id,
direction = ?direction,
ttl = batch.ttl,
source = %hex::encode(source_peer.as_bytes()),
target_count = targets.len(),
"Determined forward targets"
);
Some(targets.into_iter().collect())
}
fn determine_batch_direction(&self, batch: &SyncBatch) -> SyncDirection {
let mut most_permissive = SyncDirection::Upward;
for entry in &batch.entries {
let dir = SyncDirection::from_doc_key(&entry.doc_key);
most_permissive = match (&most_permissive, &dir) {
(_, SyncDirection::Broadcast) => SyncDirection::Broadcast,
(SyncDirection::Broadcast, _) => SyncDirection::Broadcast,
(_, SyncDirection::Lateral) => SyncDirection::Lateral,
(SyncDirection::Lateral, _) => SyncDirection::Lateral,
(_, SyncDirection::Downward) => SyncDirection::Downward,
(SyncDirection::Downward, _) => SyncDirection::Downward,
_ => SyncDirection::Upward,
};
if matches!(most_permissive, SyncDirection::Broadcast) {
break;
}
}
most_permissive
}
pub fn prepare_for_forward(&self, batch: &SyncBatch) -> Option<SyncBatch> {
if batch.ttl == 0 {
return None;
}
let mut forwarded = batch.clone();
forwarded.ttl = batch.ttl.saturating_sub(1);
Some(forwarded)
}
}
#[derive(Debug, Clone, Default)]
pub struct ForwardingStats {
pub batches_received: u64,
pub batches_forwarded: u64,
pub batches_deduplicated: u64,
pub batches_ttl_expired: u64,
}
#[cfg(all(test, feature = "automerge-backend"))]
mod tests {
use super::*;
use crate::storage::automerge_sync::{SyncEntry, SyncMessageType};
fn create_test_peer_id() -> EndpointId {
use iroh::SecretKey;
let mut rng = rand::rng();
SecretKey::generate(&mut rng).public()
}
fn test_endpoint_id(_n: u8) -> EndpointId {
create_test_peer_id()
}
#[test]
fn test_forwarder_new() {
let local_id = test_endpoint_id(1);
let forwarder = SyncForwarder::new(local_id);
assert!(forwarder.parent_id().is_none());
assert!(forwarder.children().is_empty());
}
#[test]
fn test_set_parent_and_children() {
let local_id = test_endpoint_id(1);
let parent_id = test_endpoint_id(2);
let child_id = test_endpoint_id(3);
let forwarder = SyncForwarder::new(local_id);
forwarder.set_parent(Some(parent_id));
forwarder.add_child(child_id);
assert_eq!(forwarder.parent_id(), Some(parent_id));
assert!(forwarder.children().contains(&child_id));
forwarder.remove_child(&child_id);
assert!(!forwarder.children().contains(&child_id));
}
#[test]
fn test_deduplication() {
let local_id = test_endpoint_id(1);
let forwarder = SyncForwarder::new(local_id);
let batch_id = 12345;
assert!(!forwarder.was_forwarded(batch_id));
forwarder.mark_forwarded(batch_id);
assert!(forwarder.was_forwarded(batch_id));
}
#[test]
fn test_forward_targets_broadcast() {
let local_id = test_endpoint_id(1);
let source_id = test_endpoint_id(2);
let peer_a = test_endpoint_id(3);
let peer_b = test_endpoint_id(4);
let forwarder = SyncForwarder::new(local_id);
let connected = vec![source_id, peer_a, peer_b];
let mut batch = SyncBatch::with_id(1);
batch.entries.push(SyncEntry::new(
"alerts:alert-1".to_string(),
SyncMessageType::DeltaSync,
vec![1, 2, 3],
));
let targets = forwarder
.forward_targets(&batch, source_id, &connected)
.unwrap();
assert_eq!(targets.len(), 2);
assert!(targets.contains(&peer_a));
assert!(targets.contains(&peer_b));
assert!(!targets.contains(&source_id));
}
#[test]
fn test_forward_targets_upward() {
let local_id = test_endpoint_id(1);
let parent_id = test_endpoint_id(2);
let child_id = test_endpoint_id(3);
let peer_id = test_endpoint_id(4);
let forwarder = SyncForwarder::new(local_id);
forwarder.set_parent(Some(parent_id));
forwarder.add_child(child_id);
let connected = vec![parent_id, child_id, peer_id];
let mut batch = SyncBatch::with_id(2);
batch.entries.push(SyncEntry::new(
"nodes:node-1".to_string(),
SyncMessageType::DeltaSync,
vec![1, 2, 3],
));
let targets = forwarder
.forward_targets(&batch, child_id, &connected)
.unwrap();
assert_eq!(targets.len(), 1);
assert!(targets.contains(&parent_id));
}
#[test]
fn test_forward_targets_ttl_expired() {
let local_id = test_endpoint_id(1);
let source_id = test_endpoint_id(2);
let peer_id = test_endpoint_id(3);
let forwarder = SyncForwarder::new(local_id);
let connected = vec![source_id, peer_id];
let mut batch = SyncBatch::with_id(3);
batch.ttl = 0;
batch.entries.push(SyncEntry::new(
"alerts:alert-1".to_string(),
SyncMessageType::DeltaSync,
vec![1, 2, 3],
));
let targets = forwarder.forward_targets(&batch, source_id, &connected);
assert!(targets.is_none());
}
#[test]
fn test_prepare_for_forward() {
let local_id = test_endpoint_id(1);
let forwarder = SyncForwarder::new(local_id);
let mut batch = SyncBatch::with_id(4);
batch.ttl = 3;
let forwarded = forwarder.prepare_for_forward(&batch).unwrap();
assert_eq!(forwarded.ttl, 2);
assert_eq!(batch.ttl, 3);
}
#[test]
fn test_prepare_for_forward_ttl_zero() {
let local_id = test_endpoint_id(1);
let forwarder = SyncForwarder::new(local_id);
let mut batch = SyncBatch::with_id(5);
batch.ttl = 0;
let forwarded = forwarder.prepare_for_forward(&batch);
assert!(forwarded.is_none());
}
}