1use borsh::{BorshDeserialize, BorshSerialize};
6
7use super::handshake::{SyncCapabilities, SyncHandshake};
8use super::levelwise::should_use_levelwise;
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, BorshSerialize, BorshDeserialize)]
24pub enum SyncProtocolKind {
25 None,
27 DeltaSync,
29 HashComparison,
31 Snapshot,
33 BloomFilter,
35 SubtreePrefetch,
37 LevelWise,
39}
40
41#[derive(Clone, Debug, PartialEq, BorshSerialize, BorshDeserialize)]
52pub enum SyncProtocol {
53 None,
55
56 DeltaSync {
60 missing_delta_ids: Vec<[u8; 32]>,
62 },
63
64 HashComparison {
68 root_hash: [u8; 32],
70 divergent_subtrees: Vec<[u8; 32]>,
72 },
73
74 Snapshot {
79 compressed: bool,
81 verified: bool,
83 },
84
85 BloomFilter {
89 filter_size: u64,
91 false_positive_rate: f64,
97 },
98
99 SubtreePrefetch {
103 subtree_roots: Vec<[u8; 32]>,
105 },
106
107 LevelWise {
111 max_depth: u32,
113 },
114}
115
116impl Default for SyncProtocol {
117 fn default() -> Self {
118 Self::None
119 }
120}
121
122impl SyncProtocol {
123 #[must_use]
127 pub fn kind(&self) -> SyncProtocolKind {
128 SyncProtocolKind::from(self)
129 }
130}
131
132impl From<&SyncProtocol> for SyncProtocolKind {
133 fn from(protocol: &SyncProtocol) -> Self {
134 match protocol {
135 SyncProtocol::None => Self::None,
136 SyncProtocol::DeltaSync { .. } => Self::DeltaSync,
137 SyncProtocol::HashComparison { .. } => Self::HashComparison,
138 SyncProtocol::Snapshot { .. } => Self::Snapshot,
139 SyncProtocol::BloomFilter { .. } => Self::BloomFilter,
140 SyncProtocol::SubtreePrefetch { .. } => Self::SubtreePrefetch,
141 SyncProtocol::LevelWise { .. } => Self::LevelWise,
142 }
143 }
144}
145
146#[derive(Clone, Debug)]
152pub struct ProtocolSelection {
153 pub protocol: SyncProtocol,
155 pub reason: &'static str,
157}
158
159#[must_use]
168pub fn calculate_divergence(local: &SyncHandshake, remote: &SyncHandshake) -> f64 {
169 let diff = local.entity_count.abs_diff(remote.entity_count);
171 let denominator = remote.entity_count.max(1);
172 diff as f64 / denominator as f64
173}
174
175#[must_use]
192pub fn select_protocol(local: &SyncHandshake, remote: &SyncHandshake) -> ProtocolSelection {
193 if local.root_hash == remote.root_hash {
195 return ProtocolSelection {
196 protocol: SyncProtocol::None,
197 reason: "root hashes match, already in sync",
198 };
199 }
200
201 if !local.is_version_compatible(remote) {
203 return ProtocolSelection {
205 protocol: SyncProtocol::HashComparison {
206 root_hash: remote.root_hash,
207 divergent_subtrees: vec![],
208 },
209 reason: "version mismatch, using safe fallback",
210 };
211 }
212
213 if !local.has_state {
216 return ProtocolSelection {
217 protocol: SyncProtocol::Snapshot {
218 compressed: remote.entity_count > 100,
219 verified: true,
220 },
221 reason: "fresh node bootstrap via snapshot",
222 };
223 }
224
225 let divergence = calculate_divergence(local, remote);
227
228 if divergence > 0.5 {
230 return ProtocolSelection {
231 protocol: SyncProtocol::HashComparison {
232 root_hash: remote.root_hash,
233 divergent_subtrees: vec![],
234 },
235 reason: "high divergence (>50%), using hash comparison with CRDT merge",
236 };
237 }
238
239 if remote.max_depth > 3 && divergence < 0.2 {
241 return ProtocolSelection {
242 protocol: SyncProtocol::SubtreePrefetch {
243 subtree_roots: vec![], },
245 reason: "deep tree with low divergence, using subtree prefetch",
246 };
247 }
248
249 if remote.entity_count > 50 && divergence < 0.1 {
251 return ProtocolSelection {
252 protocol: SyncProtocol::BloomFilter {
253 filter_size: remote.entity_count.saturating_mul(10).min(10_000),
255 false_positive_rate: 0.01,
256 },
257 reason: "large tree with small divergence, using bloom filter",
258 };
259 }
260
261 let max_depth_usize = remote.max_depth as usize;
264 let avg_children_per_level = if remote.max_depth > 0 {
265 (remote.entity_count / u64::from(remote.max_depth)) as usize
266 } else {
267 0
268 };
269 if should_use_levelwise(max_depth_usize, avg_children_per_level) {
270 return ProtocolSelection {
271 protocol: SyncProtocol::LevelWise {
272 max_depth: remote.max_depth,
273 },
274 reason: "wide shallow tree, using level-wise sync",
275 };
276 }
277
278 ProtocolSelection {
280 protocol: SyncProtocol::HashComparison {
281 root_hash: remote.root_hash,
282 divergent_subtrees: vec![],
283 },
284 reason: "default: using hash comparison",
285 }
286}
287
288#[must_use]
290pub fn is_protocol_supported(protocol: &SyncProtocol, capabilities: &SyncCapabilities) -> bool {
291 capabilities.supported_protocols.contains(&protocol.kind())
292}
293
294#[must_use]
299pub fn select_protocol_with_fallback(
300 local: &SyncHandshake,
301 remote: &SyncHandshake,
302 remote_capabilities: &SyncCapabilities,
303) -> ProtocolSelection {
304 let preferred = select_protocol(local, remote);
305
306 if is_protocol_supported(&preferred.protocol, remote_capabilities) {
308 return preferred;
309 }
310
311 if local.has_state {
313 let fallback = SyncProtocol::HashComparison {
314 root_hash: remote.root_hash,
315 divergent_subtrees: vec![],
316 };
317 if is_protocol_supported(&fallback, remote_capabilities) {
318 return ProtocolSelection {
319 protocol: fallback,
320 reason: "fallback to hash comparison (preferred not supported)",
321 };
322 }
323 }
324
325 ProtocolSelection {
327 protocol: SyncProtocol::None,
328 reason: "no mutually supported protocol found",
329 }
330}
331
332#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::sync::handshake::SYNC_PROTOCOL_VERSION;
340
341 #[test]
342 fn test_sync_protocol_roundtrip() {
343 let protocols = vec![
344 SyncProtocol::None,
345 SyncProtocol::DeltaSync {
346 missing_delta_ids: vec![[1; 32], [2; 32]],
347 },
348 SyncProtocol::HashComparison {
349 root_hash: [3; 32],
350 divergent_subtrees: vec![[4; 32]],
351 },
352 SyncProtocol::Snapshot {
353 compressed: true,
354 verified: false,
355 },
356 SyncProtocol::BloomFilter {
357 filter_size: 1024,
358 false_positive_rate: 0.01,
359 },
360 SyncProtocol::SubtreePrefetch {
361 subtree_roots: vec![[5; 32], [6; 32]],
362 },
363 SyncProtocol::LevelWise { max_depth: 3 },
364 ];
365
366 for protocol in protocols {
367 let encoded = borsh::to_vec(&protocol).expect("serialize");
368 let decoded: SyncProtocol = borsh::from_slice(&encoded).expect("deserialize");
369 assert_eq!(protocol, decoded);
370 }
371 }
372
373 #[test]
374 fn test_sync_protocol_kind_roundtrip() {
375 let kinds = vec![
376 SyncProtocolKind::None,
377 SyncProtocolKind::DeltaSync,
378 SyncProtocolKind::HashComparison,
379 SyncProtocolKind::Snapshot,
380 SyncProtocolKind::BloomFilter,
381 SyncProtocolKind::SubtreePrefetch,
382 SyncProtocolKind::LevelWise,
383 ];
384
385 for kind in kinds {
386 let encoded = borsh::to_vec(&kind).expect("serialize");
387 let decoded: SyncProtocolKind = borsh::from_slice(&encoded).expect("deserialize");
388 assert_eq!(kind, decoded);
389 }
390 }
391
392 #[test]
393 fn test_sync_protocol_kind_conversion() {
394 assert_eq!(SyncProtocol::None.kind(), SyncProtocolKind::None);
396 assert_eq!(
397 SyncProtocol::DeltaSync {
398 missing_delta_ids: vec![[1; 32]]
399 }
400 .kind(),
401 SyncProtocolKind::DeltaSync
402 );
403 assert_eq!(
404 SyncProtocol::HashComparison {
405 root_hash: [2; 32],
406 divergent_subtrees: vec![]
407 }
408 .kind(),
409 SyncProtocolKind::HashComparison
410 );
411 assert_eq!(
412 SyncProtocol::Snapshot {
413 compressed: true,
414 verified: true
415 }
416 .kind(),
417 SyncProtocolKind::Snapshot
418 );
419 assert_eq!(
420 SyncProtocol::BloomFilter {
421 filter_size: 1024,
422 false_positive_rate: 0.01
423 }
424 .kind(),
425 SyncProtocolKind::BloomFilter
426 );
427 assert_eq!(
428 SyncProtocol::SubtreePrefetch {
429 subtree_roots: vec![]
430 }
431 .kind(),
432 SyncProtocolKind::SubtreePrefetch
433 );
434 assert_eq!(
435 SyncProtocol::LevelWise { max_depth: 5 }.kind(),
436 SyncProtocolKind::LevelWise
437 );
438
439 let protocol = SyncProtocol::HashComparison {
441 root_hash: [3; 32],
442 divergent_subtrees: vec![],
443 };
444 let kind: SyncProtocolKind = (&protocol).into();
445 assert_eq!(kind, SyncProtocolKind::HashComparison);
446 }
447
448 #[test]
449 fn test_calculate_divergence() {
450 let local = SyncHandshake::new([1; 32], 100, 5, vec![]);
452 let remote = SyncHandshake::new([2; 32], 100, 5, vec![]);
453 assert!((calculate_divergence(&local, &remote) - 0.0).abs() < f64::EPSILON);
454
455 let local = SyncHandshake::new([1; 32], 50, 5, vec![]);
457 let remote = SyncHandshake::new([2; 32], 100, 5, vec![]);
458 assert!((calculate_divergence(&local, &remote) - 0.5).abs() < f64::EPSILON);
459
460 let local = SyncHandshake::new([1; 32], 0, 0, vec![]);
462 let remote = SyncHandshake::new([2; 32], 100, 5, vec![]);
463 assert!((calculate_divergence(&local, &remote) - 1.0).abs() < f64::EPSILON);
464
465 let local = SyncHandshake::new([1; 32], 100, 5, vec![]);
467 let remote = SyncHandshake::new([2; 32], 0, 0, vec![]);
468 assert!((calculate_divergence(&local, &remote) - 100.0).abs() < f64::EPSILON);
469 }
470
471 #[test]
472 fn test_select_protocol_rule1_already_synced() {
473 let local = SyncHandshake::new([42; 32], 100, 5, vec![]);
474 let remote = SyncHandshake::new([42; 32], 200, 3, vec![]); let selection = select_protocol(&local, &remote);
477 assert!(matches!(selection.protocol, SyncProtocol::None));
478 assert!(selection.reason.contains("already in sync"));
479 }
480
481 #[test]
482 fn test_select_protocol_rule2_fresh_node_gets_snapshot() {
483 let local = SyncHandshake::new([0; 32], 0, 0, vec![]); let remote = SyncHandshake::new([42; 32], 200, 5, vec![]);
485
486 let selection = select_protocol(&local, &remote);
487 assert!(matches!(selection.protocol, SyncProtocol::Snapshot { .. }));
488 assert!(selection.reason.contains("fresh node"));
489 }
490
491 #[test]
492 fn test_select_protocol_rule3_initialized_node_never_gets_snapshot() {
493 let local = SyncHandshake::new([1; 32], 1, 1, vec![]); let remote = SyncHandshake::new([42; 32], 200, 5, vec![]);
496
497 let selection = select_protocol(&local, &remote);
498 assert!(!matches!(selection.protocol, SyncProtocol::Snapshot { .. }));
500 }
501
502 #[test]
503 fn test_select_protocol_rule3_high_divergence_uses_hash_comparison() {
504 let local = SyncHandshake::new([1; 32], 10, 2, vec![]); let remote = SyncHandshake::new([2; 32], 100, 5, vec![]); let selection = select_protocol(&local, &remote);
508 assert!(matches!(
509 selection.protocol,
510 SyncProtocol::HashComparison { .. }
511 ));
512 assert!(selection.reason.contains("divergence"));
513 }
514
515 #[test]
516 fn test_select_protocol_rule4_deep_tree_uses_subtree_prefetch() {
517 let local = SyncHandshake::new([1; 32], 90, 5, vec![]); let remote = SyncHandshake::new([2; 32], 100, 5, vec![]); let selection = select_protocol(&local, &remote);
521 assert!(matches!(
522 selection.protocol,
523 SyncProtocol::SubtreePrefetch { .. }
524 ));
525 assert!(selection.reason.contains("subtree"));
526 }
527
528 #[test]
529 fn test_select_protocol_rule5_large_tree_small_diff_uses_bloom() {
530 let local = SyncHandshake::new([1; 32], 95, 2, vec![]); let remote = SyncHandshake::new([2; 32], 100, 2, vec![]); let selection = select_protocol(&local, &remote);
534 assert!(matches!(
535 selection.protocol,
536 SyncProtocol::BloomFilter { .. }
537 ));
538 assert!(selection.reason.contains("bloom"));
539 }
540
541 #[test]
542 fn test_select_protocol_rule6_wide_shallow_uses_levelwise() {
543 let local = SyncHandshake::new([1; 32], 40, 2, vec![]);
546 let remote = SyncHandshake::new([2; 32], 40, 2, vec![]);
547
548 let selection = select_protocol(&local, &remote);
549 assert!(matches!(selection.protocol, SyncProtocol::LevelWise { .. }));
550 assert!(selection.reason.contains("level"));
551 }
552
553 #[test]
554 fn test_select_protocol_rule7_default_uses_hash_comparison() {
555 let local = SyncHandshake::new([1; 32], 30, 2, vec![]); let remote = SyncHandshake::new([2; 32], 40, 3, vec![]); let selection = select_protocol(&local, &remote);
560 assert!(matches!(
561 selection.protocol,
562 SyncProtocol::HashComparison { .. }
563 ));
564 assert!(selection.reason.contains("default"));
565 }
566
567 #[test]
568 fn test_select_protocol_version_mismatch_uses_safe_fallback() {
569 let local = SyncHandshake::new([1; 32], 100, 5, vec![]);
570 let mut remote = SyncHandshake::new([2; 32], 100, 5, vec![]);
571 remote.version = SYNC_PROTOCOL_VERSION + 1; let selection = select_protocol(&local, &remote);
574 assert!(matches!(
575 selection.protocol,
576 SyncProtocol::HashComparison { .. }
577 ));
578 assert!(selection.reason.contains("version mismatch"));
579 }
580
581 #[test]
582 fn test_is_protocol_supported() {
583 let caps = SyncCapabilities::default();
584
585 assert!(is_protocol_supported(&SyncProtocol::None, &caps));
587 assert!(is_protocol_supported(
588 &SyncProtocol::HashComparison {
589 root_hash: [0; 32],
590 divergent_subtrees: vec![]
591 },
592 &caps
593 ));
594
595 assert!(!is_protocol_supported(
597 &SyncProtocol::SubtreePrefetch {
598 subtree_roots: vec![]
599 },
600 &caps
601 ));
602
603 assert!(is_protocol_supported(
605 &SyncProtocol::LevelWise { max_depth: 2 },
606 &caps
607 ));
608 }
609
610 #[test]
611 fn test_select_protocol_with_fallback() {
612 let local = SyncHandshake::new([1; 32], 90, 5, vec![]); let remote = SyncHandshake::new([2; 32], 100, 5, vec![]);
614 let caps = SyncCapabilities::default(); let selection = select_protocol_with_fallback(&local, &remote, &caps);
617
618 assert!(matches!(
620 selection.protocol,
621 SyncProtocol::HashComparison { .. }
622 ));
623 assert!(selection.reason.contains("fallback"));
624 }
625}