1use super::automerge_sync::{SyncBatch, SyncDirection};
32use iroh::EndpointId;
33use lru::LruCache;
34use std::collections::HashSet;
35use std::num::NonZeroUsize;
36use std::sync::{Arc, RwLock};
37
38const DEFAULT_DEDUP_CACHE_SIZE: usize = 1000;
40
41pub struct SyncForwarder {
46 local_node_id: EndpointId,
48
49 parent_id: RwLock<Option<EndpointId>>,
51
52 children: RwLock<HashSet<EndpointId>>,
54
55 forwarded_batches: Arc<RwLock<LruCache<u64, ()>>>,
57}
58
59impl SyncForwarder {
60 pub fn new(local_node_id: EndpointId) -> Self {
62 Self {
63 local_node_id,
64 parent_id: RwLock::new(None),
65 children: RwLock::new(HashSet::new()),
66 forwarded_batches: Arc::new(RwLock::new(LruCache::new(
67 NonZeroUsize::new(DEFAULT_DEDUP_CACHE_SIZE).unwrap(),
68 ))),
69 }
70 }
71
72 pub fn set_parent(&self, parent_id: Option<EndpointId>) {
74 *self.parent_id.write().unwrap_or_else(|e| e.into_inner()) = parent_id;
75 }
76
77 pub fn add_child(&self, child_id: EndpointId) {
79 self.children
80 .write()
81 .unwrap_or_else(|e| e.into_inner())
82 .insert(child_id);
83 }
84
85 pub fn remove_child(&self, child_id: &EndpointId) {
87 self.children
88 .write()
89 .unwrap_or_else(|e| e.into_inner())
90 .remove(child_id);
91 }
92
93 pub fn parent_id(&self) -> Option<EndpointId> {
95 *self.parent_id.read().unwrap_or_else(|e| e.into_inner())
96 }
97
98 pub fn children(&self) -> Vec<EndpointId> {
100 self.children
101 .read()
102 .unwrap_or_else(|e| e.into_inner())
103 .iter()
104 .copied()
105 .collect()
106 }
107
108 pub fn was_forwarded(&self, batch_id: u64) -> bool {
110 self.forwarded_batches
111 .read()
112 .unwrap_or_else(|e| e.into_inner())
113 .contains(&batch_id)
114 }
115
116 pub fn mark_forwarded(&self, batch_id: u64) {
118 self.forwarded_batches
119 .write()
120 .unwrap_or_else(|e| e.into_inner())
121 .put(batch_id, ());
122 }
123
124 pub fn forward_targets(
133 &self,
134 batch: &SyncBatch,
135 source_peer: EndpointId,
136 connected_peers: &[EndpointId],
137 ) -> Option<Vec<EndpointId>> {
138 if self.was_forwarded(batch.batch_id) {
140 tracing::trace!(
141 batch_id = batch.batch_id,
142 "Batch already forwarded, skipping"
143 );
144 return None;
145 }
146
147 if batch.ttl == 0 {
149 tracing::trace!(
150 batch_id = batch.batch_id,
151 "Batch TTL expired, not forwarding"
152 );
153 return None;
154 }
155
156 let direction = self.determine_batch_direction(batch);
160
161 let mut targets = HashSet::new();
162
163 match direction {
164 SyncDirection::Upward => {
165 if let Some(parent) = self.parent_id() {
167 if parent != source_peer && connected_peers.contains(&parent) {
168 targets.insert(parent);
169 }
170 }
171 }
172 SyncDirection::Downward => {
173 for child in self.children() {
175 if child != source_peer && connected_peers.contains(&child) {
176 targets.insert(child);
177 }
178 }
179 }
180 SyncDirection::Lateral => {
181 let parent = self.parent_id();
183 let children = self.children();
184 for peer in connected_peers {
185 if *peer != source_peer
186 && *peer != self.local_node_id
187 && Some(*peer) != parent
188 && !children.contains(peer)
189 {
190 targets.insert(*peer);
191 }
192 }
193 }
194 SyncDirection::Broadcast => {
195 for peer in connected_peers {
197 if *peer != source_peer && *peer != self.local_node_id {
198 targets.insert(*peer);
199 }
200 }
201 }
202 }
203
204 tracing::debug!(
205 batch_id = batch.batch_id,
206 direction = ?direction,
207 ttl = batch.ttl,
208 source = %hex::encode(source_peer.as_bytes()),
209 target_count = targets.len(),
210 "Determined forward targets"
211 );
212
213 Some(targets.into_iter().collect())
214 }
215
216 fn determine_batch_direction(&self, batch: &SyncBatch) -> SyncDirection {
218 let mut most_permissive = SyncDirection::Upward;
219
220 for entry in &batch.entries {
221 let dir = SyncDirection::from_doc_key(&entry.doc_key);
222 most_permissive = match (&most_permissive, &dir) {
224 (_, SyncDirection::Broadcast) => SyncDirection::Broadcast,
225 (SyncDirection::Broadcast, _) => SyncDirection::Broadcast,
226 (_, SyncDirection::Lateral) => SyncDirection::Lateral,
227 (SyncDirection::Lateral, _) => SyncDirection::Lateral,
228 (_, SyncDirection::Downward) => SyncDirection::Downward,
229 (SyncDirection::Downward, _) => SyncDirection::Downward,
230 _ => SyncDirection::Upward,
231 };
232
233 if matches!(most_permissive, SyncDirection::Broadcast) {
235 break;
236 }
237 }
238
239 most_permissive
240 }
241
242 pub fn prepare_for_forward(&self, batch: &SyncBatch) -> Option<SyncBatch> {
246 if batch.ttl == 0 {
247 return None;
248 }
249
250 let mut forwarded = batch.clone();
251 forwarded.ttl = batch.ttl.saturating_sub(1);
252 Some(forwarded)
253 }
254}
255
256#[derive(Debug, Clone, Default)]
258pub struct ForwardingStats {
259 pub batches_received: u64,
261 pub batches_forwarded: u64,
263 pub batches_deduplicated: u64,
265 pub batches_ttl_expired: u64,
267}
268
269#[cfg(all(test, feature = "automerge-backend"))]
270mod tests {
271 use super::*;
272 use crate::storage::automerge_sync::{SyncEntry, SyncMessageType};
273
274 fn create_test_peer_id() -> EndpointId {
275 use iroh::SecretKey;
276 let mut rng = rand::rng();
277 SecretKey::generate(&mut rng).public()
278 }
279
280 fn test_endpoint_id(_n: u8) -> EndpointId {
281 create_test_peer_id()
283 }
284
285 #[test]
286 fn test_forwarder_new() {
287 let local_id = test_endpoint_id(1);
288 let forwarder = SyncForwarder::new(local_id);
289
290 assert!(forwarder.parent_id().is_none());
291 assert!(forwarder.children().is_empty());
292 }
293
294 #[test]
295 fn test_set_parent_and_children() {
296 let local_id = test_endpoint_id(1);
297 let parent_id = test_endpoint_id(2);
298 let child_id = test_endpoint_id(3);
299
300 let forwarder = SyncForwarder::new(local_id);
301 forwarder.set_parent(Some(parent_id));
302 forwarder.add_child(child_id);
303
304 assert_eq!(forwarder.parent_id(), Some(parent_id));
305 assert!(forwarder.children().contains(&child_id));
306
307 forwarder.remove_child(&child_id);
308 assert!(!forwarder.children().contains(&child_id));
309 }
310
311 #[test]
312 fn test_deduplication() {
313 let local_id = test_endpoint_id(1);
314 let forwarder = SyncForwarder::new(local_id);
315
316 let batch_id = 12345;
317 assert!(!forwarder.was_forwarded(batch_id));
318
319 forwarder.mark_forwarded(batch_id);
320 assert!(forwarder.was_forwarded(batch_id));
321 }
322
323 #[test]
324 fn test_forward_targets_broadcast() {
325 let local_id = test_endpoint_id(1);
326 let source_id = test_endpoint_id(2);
327 let peer_a = test_endpoint_id(3);
328 let peer_b = test_endpoint_id(4);
329
330 let forwarder = SyncForwarder::new(local_id);
331 let connected = vec![source_id, peer_a, peer_b];
332
333 let mut batch = SyncBatch::with_id(1);
335 batch.entries.push(SyncEntry::new(
336 "alerts:alert-1".to_string(),
337 SyncMessageType::DeltaSync,
338 vec![1, 2, 3],
339 ));
340
341 let targets = forwarder
342 .forward_targets(&batch, source_id, &connected)
343 .unwrap();
344
345 assert_eq!(targets.len(), 2);
347 assert!(targets.contains(&peer_a));
348 assert!(targets.contains(&peer_b));
349 assert!(!targets.contains(&source_id));
350 }
351
352 #[test]
353 fn test_forward_targets_upward() {
354 let local_id = test_endpoint_id(1);
355 let parent_id = test_endpoint_id(2);
356 let child_id = test_endpoint_id(3);
357 let peer_id = test_endpoint_id(4);
358
359 let forwarder = SyncForwarder::new(local_id);
360 forwarder.set_parent(Some(parent_id));
361 forwarder.add_child(child_id);
362
363 let connected = vec![parent_id, child_id, peer_id];
364
365 let mut batch = SyncBatch::with_id(2);
367 batch.entries.push(SyncEntry::new(
368 "nodes:node-1".to_string(),
369 SyncMessageType::DeltaSync,
370 vec![1, 2, 3],
371 ));
372
373 let targets = forwarder
374 .forward_targets(&batch, child_id, &connected)
375 .unwrap();
376
377 assert_eq!(targets.len(), 1);
379 assert!(targets.contains(&parent_id));
380 }
381
382 #[test]
383 fn test_forward_targets_ttl_expired() {
384 let local_id = test_endpoint_id(1);
385 let source_id = test_endpoint_id(2);
386 let peer_id = test_endpoint_id(3);
387
388 let forwarder = SyncForwarder::new(local_id);
389 let connected = vec![source_id, peer_id];
390
391 let mut batch = SyncBatch::with_id(3);
393 batch.ttl = 0;
394 batch.entries.push(SyncEntry::new(
395 "alerts:alert-1".to_string(),
396 SyncMessageType::DeltaSync,
397 vec![1, 2, 3],
398 ));
399
400 let targets = forwarder.forward_targets(&batch, source_id, &connected);
401 assert!(targets.is_none());
402 }
403
404 #[test]
405 fn test_prepare_for_forward() {
406 let local_id = test_endpoint_id(1);
407 let forwarder = SyncForwarder::new(local_id);
408
409 let mut batch = SyncBatch::with_id(4);
410 batch.ttl = 3;
411
412 let forwarded = forwarder.prepare_for_forward(&batch).unwrap();
413 assert_eq!(forwarded.ttl, 2);
414
415 assert_eq!(batch.ttl, 3);
417 }
418
419 #[test]
420 fn test_prepare_for_forward_ttl_zero() {
421 let local_id = test_endpoint_id(1);
422 let forwarder = SyncForwarder::new(local_id);
423
424 let mut batch = SyncBatch::with_id(5);
425 batch.ttl = 0;
426
427 let forwarded = forwarder.prepare_for_forward(&batch);
428 assert!(forwarded.is_none());
429 }
430}