Skip to main content

peat_mesh/storage/
sync_forwarding.rs

1//! Sync message forwarding for multi-hop mesh networks
2//!
3//! This module provides the `SyncForwarder` which enables sync messages to
4//! propagate through intermediate nodes in a mesh network, allowing full
5//! data coverage with O(k) connections per node instead of O(n) connections.
6//!
7//! # Architecture
8//!
9//! When a node receives a sync batch from a peer, the forwarder determines
10//! which other connected peers should also receive the batch based on:
11//! - Sync direction (Upward, Downward, Lateral, Broadcast)
12//! - TTL (time-to-live) hop count
13//! - Deduplication to prevent infinite forwarding loops
14//!
15//! # Example
16//!
17//! ```ignore
18//! use peat_protocol::storage::sync_forwarding::SyncForwarder;
19//!
20//! let forwarder = SyncForwarder::new(local_node_id);
21//!
22//! // When receiving a batch
23//! if let Some(targets) = forwarder.forward_targets(&batch, source_peer, &connected_peers) {
24//!     for target in targets {
25//!         send_batch_to_peer(&batch, target).await;
26//!     }
27//!     forwarder.mark_forwarded(batch.batch_id);
28//! }
29//! ```
30
31use super::automerge_sync::{SyncBatch, SyncDirection};
32use iroh::EndpointId;
33use lru::LruCache;
34use std::collections::HashSet;
35use std::num::NonZeroUsize;
36use std::sync::{Arc, RwLock};
37
38/// Default capacity for the forwarded batch deduplication cache
39const DEFAULT_DEDUP_CACHE_SIZE: usize = 1000;
40
41/// Sync message forwarder for multi-hop mesh networks
42///
43/// The forwarder tracks which batches have been forwarded to prevent loops
44/// and determines the appropriate forwarding targets based on sync direction.
45pub struct SyncForwarder {
46    /// Local node ID for filtering
47    local_node_id: EndpointId,
48
49    /// Parent node ID for upward forwarding (if known)
50    parent_id: RwLock<Option<EndpointId>>,
51
52    /// Child node IDs for downward forwarding
53    children: RwLock<HashSet<EndpointId>>,
54
55    /// Cache of forwarded batch IDs for deduplication
56    forwarded_batches: Arc<RwLock<LruCache<u64, ()>>>,
57}
58
59impl SyncForwarder {
60    /// Create a new forwarder
61    pub fn new(local_node_id: EndpointId) -> Self {
62        Self {
63            local_node_id,
64            parent_id: RwLock::new(None),
65            children: RwLock::new(HashSet::new()),
66            forwarded_batches: Arc::new(RwLock::new(LruCache::new(
67                NonZeroUsize::new(DEFAULT_DEDUP_CACHE_SIZE).unwrap(),
68            ))),
69        }
70    }
71
72    /// Set the parent node for upward forwarding
73    pub fn set_parent(&self, parent_id: Option<EndpointId>) {
74        *self.parent_id.write().unwrap_or_else(|e| e.into_inner()) = parent_id;
75    }
76
77    /// Add a child node for downward forwarding
78    pub fn add_child(&self, child_id: EndpointId) {
79        self.children
80            .write()
81            .unwrap_or_else(|e| e.into_inner())
82            .insert(child_id);
83    }
84
85    /// Remove a child node
86    pub fn remove_child(&self, child_id: &EndpointId) {
87        self.children
88            .write()
89            .unwrap_or_else(|e| e.into_inner())
90            .remove(child_id);
91    }
92
93    /// Get the parent node ID
94    pub fn parent_id(&self) -> Option<EndpointId> {
95        *self.parent_id.read().unwrap_or_else(|e| e.into_inner())
96    }
97
98    /// Get child node IDs
99    pub fn children(&self) -> Vec<EndpointId> {
100        self.children
101            .read()
102            .unwrap_or_else(|e| e.into_inner())
103            .iter()
104            .copied()
105            .collect()
106    }
107
108    /// Check if a batch has already been forwarded
109    pub fn was_forwarded(&self, batch_id: u64) -> bool {
110        self.forwarded_batches
111            .read()
112            .unwrap_or_else(|e| e.into_inner())
113            .contains(&batch_id)
114    }
115
116    /// Mark a batch as forwarded (for deduplication)
117    pub fn mark_forwarded(&self, batch_id: u64) {
118        self.forwarded_batches
119            .write()
120            .unwrap_or_else(|e| e.into_inner())
121            .put(batch_id, ());
122    }
123
124    /// Determine forwarding targets for a received batch
125    ///
126    /// Returns None if:
127    /// - Batch was already forwarded (duplicate)
128    /// - Batch TTL is 0 (expired)
129    ///
130    /// Returns Some(empty vec) if no forwarding needed.
131    /// Returns Some(targets) with the peers to forward to.
132    pub fn forward_targets(
133        &self,
134        batch: &SyncBatch,
135        source_peer: EndpointId,
136        connected_peers: &[EndpointId],
137    ) -> Option<Vec<EndpointId>> {
138        // Check if already forwarded (dedup)
139        if self.was_forwarded(batch.batch_id) {
140            tracing::trace!(
141                batch_id = batch.batch_id,
142                "Batch already forwarded, skipping"
143            );
144            return None;
145        }
146
147        // Check TTL
148        if batch.ttl == 0 {
149            tracing::trace!(
150                batch_id = batch.batch_id,
151                "Batch TTL expired, not forwarding"
152            );
153            return None;
154        }
155
156        // Determine sync direction from batch entries
157        // Use the most permissive direction if multiple entries have different
158        // directions
159        let direction = self.determine_batch_direction(batch);
160
161        let mut targets = HashSet::new();
162
163        match direction {
164            SyncDirection::Upward => {
165                // Forward to parent only
166                if let Some(parent) = self.parent_id() {
167                    if parent != source_peer && connected_peers.contains(&parent) {
168                        targets.insert(parent);
169                    }
170                }
171            }
172            SyncDirection::Downward => {
173                // Forward to children only
174                for child in self.children() {
175                    if child != source_peer && connected_peers.contains(&child) {
176                        targets.insert(child);
177                    }
178                }
179            }
180            SyncDirection::Lateral => {
181                // Forward to peers at same level (excluding source and parent/children)
182                let parent = self.parent_id();
183                let children = self.children();
184                for peer in connected_peers {
185                    if *peer != source_peer
186                        && *peer != self.local_node_id
187                        && Some(*peer) != parent
188                        && !children.contains(peer)
189                    {
190                        targets.insert(*peer);
191                    }
192                }
193            }
194            SyncDirection::Broadcast => {
195                // Forward to all connected peers except source
196                for peer in connected_peers {
197                    if *peer != source_peer && *peer != self.local_node_id {
198                        targets.insert(*peer);
199                    }
200                }
201            }
202        }
203
204        tracing::debug!(
205            batch_id = batch.batch_id,
206            direction = ?direction,
207            ttl = batch.ttl,
208            source = %hex::encode(source_peer.as_bytes()),
209            target_count = targets.len(),
210            "Determined forward targets"
211        );
212
213        Some(targets.into_iter().collect())
214    }
215
216    /// Determine the sync direction for a batch based on its entries
217    fn determine_batch_direction(&self, batch: &SyncBatch) -> SyncDirection {
218        let mut most_permissive = SyncDirection::Upward;
219
220        for entry in &batch.entries {
221            let dir = SyncDirection::from_doc_key(&entry.doc_key);
222            // Broadcast is most permissive, then Lateral, then Downward, then Upward
223            most_permissive = match (&most_permissive, &dir) {
224                (_, SyncDirection::Broadcast) => SyncDirection::Broadcast,
225                (SyncDirection::Broadcast, _) => SyncDirection::Broadcast,
226                (_, SyncDirection::Lateral) => SyncDirection::Lateral,
227                (SyncDirection::Lateral, _) => SyncDirection::Lateral,
228                (_, SyncDirection::Downward) => SyncDirection::Downward,
229                (SyncDirection::Downward, _) => SyncDirection::Downward,
230                _ => SyncDirection::Upward,
231            };
232
233            // Short-circuit if we hit Broadcast
234            if matches!(most_permissive, SyncDirection::Broadcast) {
235                break;
236            }
237        }
238
239        most_permissive
240    }
241
242    /// Prepare a batch for forwarding by decrementing TTL
243    ///
244    /// Returns a cloned batch with decremented TTL, or None if TTL would be 0.
245    pub fn prepare_for_forward(&self, batch: &SyncBatch) -> Option<SyncBatch> {
246        if batch.ttl == 0 {
247            return None;
248        }
249
250        let mut forwarded = batch.clone();
251        forwarded.ttl = batch.ttl.saturating_sub(1);
252        Some(forwarded)
253    }
254}
255
256/// Statistics for sync forwarding
257#[derive(Debug, Clone, Default)]
258pub struct ForwardingStats {
259    /// Total batches received
260    pub batches_received: u64,
261    /// Batches forwarded to other peers
262    pub batches_forwarded: u64,
263    /// Batches dropped due to deduplication
264    pub batches_deduplicated: u64,
265    /// Batches dropped due to TTL expiry
266    pub batches_ttl_expired: u64,
267}
268
269#[cfg(all(test, feature = "automerge-backend"))]
270mod tests {
271    use super::*;
272    use crate::storage::automerge_sync::{SyncEntry, SyncMessageType};
273
274    fn create_test_peer_id() -> EndpointId {
275        use iroh::SecretKey;
276        let mut rng = rand::rng();
277        SecretKey::generate(&mut rng).public()
278    }
279
280    fn test_endpoint_id(_n: u8) -> EndpointId {
281        // Generate a valid ed25519 public key
282        create_test_peer_id()
283    }
284
285    #[test]
286    fn test_forwarder_new() {
287        let local_id = test_endpoint_id(1);
288        let forwarder = SyncForwarder::new(local_id);
289
290        assert!(forwarder.parent_id().is_none());
291        assert!(forwarder.children().is_empty());
292    }
293
294    #[test]
295    fn test_set_parent_and_children() {
296        let local_id = test_endpoint_id(1);
297        let parent_id = test_endpoint_id(2);
298        let child_id = test_endpoint_id(3);
299
300        let forwarder = SyncForwarder::new(local_id);
301        forwarder.set_parent(Some(parent_id));
302        forwarder.add_child(child_id);
303
304        assert_eq!(forwarder.parent_id(), Some(parent_id));
305        assert!(forwarder.children().contains(&child_id));
306
307        forwarder.remove_child(&child_id);
308        assert!(!forwarder.children().contains(&child_id));
309    }
310
311    #[test]
312    fn test_deduplication() {
313        let local_id = test_endpoint_id(1);
314        let forwarder = SyncForwarder::new(local_id);
315
316        let batch_id = 12345;
317        assert!(!forwarder.was_forwarded(batch_id));
318
319        forwarder.mark_forwarded(batch_id);
320        assert!(forwarder.was_forwarded(batch_id));
321    }
322
323    #[test]
324    fn test_forward_targets_broadcast() {
325        let local_id = test_endpoint_id(1);
326        let source_id = test_endpoint_id(2);
327        let peer_a = test_endpoint_id(3);
328        let peer_b = test_endpoint_id(4);
329
330        let forwarder = SyncForwarder::new(local_id);
331        let connected = vec![source_id, peer_a, peer_b];
332
333        // Create a broadcast batch (alerts)
334        let mut batch = SyncBatch::with_id(1);
335        batch.entries.push(SyncEntry::new(
336            "alerts:alert-1".to_string(),
337            SyncMessageType::DeltaSync,
338            vec![1, 2, 3],
339        ));
340
341        let targets = forwarder
342            .forward_targets(&batch, source_id, &connected)
343            .unwrap();
344
345        // Should forward to peer_a and peer_b, but not source
346        assert_eq!(targets.len(), 2);
347        assert!(targets.contains(&peer_a));
348        assert!(targets.contains(&peer_b));
349        assert!(!targets.contains(&source_id));
350    }
351
352    #[test]
353    fn test_forward_targets_upward() {
354        let local_id = test_endpoint_id(1);
355        let parent_id = test_endpoint_id(2);
356        let child_id = test_endpoint_id(3);
357        let peer_id = test_endpoint_id(4);
358
359        let forwarder = SyncForwarder::new(local_id);
360        forwarder.set_parent(Some(parent_id));
361        forwarder.add_child(child_id);
362
363        let connected = vec![parent_id, child_id, peer_id];
364
365        // Create an upward batch (nodes)
366        let mut batch = SyncBatch::with_id(2);
367        batch.entries.push(SyncEntry::new(
368            "nodes:node-1".to_string(),
369            SyncMessageType::DeltaSync,
370            vec![1, 2, 3],
371        ));
372
373        let targets = forwarder
374            .forward_targets(&batch, child_id, &connected)
375            .unwrap();
376
377        // Should forward to parent only
378        assert_eq!(targets.len(), 1);
379        assert!(targets.contains(&parent_id));
380    }
381
382    #[test]
383    fn test_forward_targets_ttl_expired() {
384        let local_id = test_endpoint_id(1);
385        let source_id = test_endpoint_id(2);
386        let peer_id = test_endpoint_id(3);
387
388        let forwarder = SyncForwarder::new(local_id);
389        let connected = vec![source_id, peer_id];
390
391        // Create a batch with TTL = 0
392        let mut batch = SyncBatch::with_id(3);
393        batch.ttl = 0;
394        batch.entries.push(SyncEntry::new(
395            "alerts:alert-1".to_string(),
396            SyncMessageType::DeltaSync,
397            vec![1, 2, 3],
398        ));
399
400        let targets = forwarder.forward_targets(&batch, source_id, &connected);
401        assert!(targets.is_none());
402    }
403
404    #[test]
405    fn test_prepare_for_forward() {
406        let local_id = test_endpoint_id(1);
407        let forwarder = SyncForwarder::new(local_id);
408
409        let mut batch = SyncBatch::with_id(4);
410        batch.ttl = 3;
411
412        let forwarded = forwarder.prepare_for_forward(&batch).unwrap();
413        assert_eq!(forwarded.ttl, 2);
414
415        // Original unchanged
416        assert_eq!(batch.ttl, 3);
417    }
418
419    #[test]
420    fn test_prepare_for_forward_ttl_zero() {
421        let local_id = test_endpoint_id(1);
422        let forwarder = SyncForwarder::new(local_id);
423
424        let mut batch = SyncBatch::with_id(5);
425        batch.ttl = 0;
426
427        let forwarded = forwarder.prepare_for_forward(&batch);
428        assert!(forwarded.is_none());
429    }
430}