use calimero_node_primitives::sync::handshake::SyncHandshake;
use calimero_node_primitives::sync::protocol::{select_protocol, SyncProtocol};
use calimero_node_primitives::sync::state_machine::{
build_handshake, build_handshake_from_raw, estimate_entity_count, estimate_max_depth,
LocalSyncState,
};
use crate::sync_sim::prelude::*;
use crate::sync_sim::scenarios::deterministic::Scenario;
#[test]
fn test_simnode_implements_local_sync_state() {
let sim_node = SimNode::new("test");
let root_hash = LocalSyncState::root_hash(&sim_node);
let entity_count = LocalSyncState::entity_count(&sim_node);
let max_depth = LocalSyncState::max_depth(&sim_node);
let dag_heads = LocalSyncState::dag_heads(&sim_node);
let has_state = LocalSyncState::has_state(&sim_node);
assert_eq!(root_hash, [0; 32], "Fresh node should have zero root hash");
assert_eq!(entity_count, 0, "Fresh node should have zero entities");
assert_eq!(max_depth, 0, "Fresh node should have zero depth");
assert!(!has_state, "Fresh node should have has_state=false");
assert!(
!dag_heads.is_empty(),
"SimNode initializes with at least one DAG head"
);
let handshake = build_handshake(&sim_node);
assert_eq!(handshake.root_hash, root_hash);
assert_eq!(handshake.entity_count, entity_count);
assert_eq!(handshake.max_depth, max_depth);
assert!(!handshake.has_state);
}
#[test]
fn test_simnode_build_handshake_uses_trait() {
let mut sim_node = SimNode::new("test");
let via_method = sim_node.build_handshake();
let via_trait = build_handshake(&sim_node);
assert_eq!(via_method.root_hash, via_trait.root_hash);
assert_eq!(via_method.entity_count, via_trait.entity_count);
assert_eq!(via_method.max_depth, via_trait.max_depth);
assert_eq!(via_method.dag_heads, via_trait.dag_heads);
assert_eq!(via_method.has_state, via_trait.has_state);
}
#[test]
fn test_handshake_algorithm_consistency_fresh_node() {
let sim_node = SimNode::new("fresh");
let sim_hs = build_handshake(&sim_node);
let root_hash = [0u8; 32];
let dag_heads: Vec<[u8; 32]> = vec![];
let entity_count = estimate_entity_count(root_hash, dag_heads.len());
let max_depth = estimate_max_depth(entity_count);
let manager_hs = build_handshake_from_raw(root_hash, entity_count, max_depth, dag_heads);
assert_eq!(
sim_hs.has_state, manager_hs.has_state,
"has_state mismatch for fresh node"
);
assert_eq!(
sim_hs.entity_count, manager_hs.entity_count,
"entity_count mismatch for fresh node"
);
assert_eq!(
sim_hs.max_depth, manager_hs.max_depth,
"max_depth mismatch for fresh node"
);
}
#[test]
fn test_handshake_algorithm_consistency_initialized() {
let (mut a, _) = Scenario::both_initialized();
let sim_hs = a.build_handshake();
assert!(sim_hs.has_state, "SimNode should have state");
let root_hash = sim_hs.root_hash;
let dag_heads = sim_hs.dag_heads.clone();
let manager_hs = build_manager_style_handshake(root_hash, &dag_heads);
assert_eq!(
sim_hs.has_state, manager_hs.has_state,
"has_state mismatch for initialized node"
);
assert_eq!(sim_hs.root_hash, manager_hs.root_hash, "root_hash mismatch");
assert_eq!(sim_hs.dag_heads, manager_hs.dag_heads, "dag_heads mismatch");
}
#[test]
fn test_protocol_selection_critical_invariants_with_manager_handshakes() {
let scenarios: Vec<(&str, (SimNode, SimNode))> = vec![
("force_none", Scenario::force_none()),
("force_snapshot", Scenario::force_snapshot()),
("both_initialized", Scenario::both_initialized()),
("partial_overlap", Scenario::partial_overlap()),
];
for (name, (mut a, mut b)) in scenarios {
let sim_hs_a = a.build_handshake();
let sim_hs_b = b.build_handshake();
let mgr_hs_a = build_manager_style_handshake(sim_hs_a.root_hash, &sim_hs_a.dag_heads);
let mgr_hs_b = build_manager_style_handshake(sim_hs_b.root_hash, &sim_hs_b.dag_heads);
assert_eq!(
sim_hs_a.has_state, mgr_hs_a.has_state,
"has_state mismatch for {} (local)",
name
);
assert_eq!(
sim_hs_b.has_state, mgr_hs_b.has_state,
"has_state mismatch for {} (remote)",
name
);
let sim_selection = select_protocol(&sim_hs_a, &sim_hs_b);
let mgr_selection = select_protocol(&mgr_hs_a, &mgr_hs_b);
if matches!(sim_selection.protocol, SyncProtocol::None) {
assert!(
matches!(mgr_selection.protocol, SyncProtocol::None),
"None mismatch in scenario '{}': SimNode=None, Manager={:?}",
name,
mgr_selection.protocol
);
}
if mgr_hs_a.has_state {
assert!(
!matches!(mgr_selection.protocol, SyncProtocol::Snapshot { .. }),
"I5 VIOLATION: Snapshot selected for initialized node in '{}'",
name
);
}
}
}
#[test]
fn test_dispatch_fresh_to_initialized_selects_snapshot() {
let (mut fresh, mut source) = Scenario::force_snapshot();
let local_hs = fresh.build_handshake();
let remote_hs = source.build_handshake();
assert!(!local_hs.has_state, "Precondition: fresh has no state");
assert!(remote_hs.has_state, "Precondition: source has state");
let selection = select_protocol(&local_hs, &remote_hs);
assert!(
matches!(selection.protocol, SyncProtocol::Snapshot { .. }),
"Expected Snapshot dispatch, got {:?}",
selection.protocol
);
assert!(
selection.reason.contains("fresh"),
"Reason should mention fresh node: {}",
selection.reason
);
}
#[test]
fn test_dispatch_same_hash_selects_none() {
let (mut a, mut b) = Scenario::force_none();
let local_hs = a.build_handshake();
let remote_hs = b.build_handshake();
assert_eq!(
local_hs.root_hash, remote_hs.root_hash,
"Precondition: same root hash"
);
let selection = select_protocol(&local_hs, &remote_hs);
assert!(
matches!(selection.protocol, SyncProtocol::None),
"Expected None dispatch (already synced), got {:?}",
selection.protocol
);
assert!(
selection.reason.contains("already in sync") || selection.reason.contains("match"),
"Reason should mention already synced: {}",
selection.reason
);
}
#[test]
fn test_dispatch_diverged_initialized_avoids_snapshot() {
let (mut a, mut b) = Scenario::both_initialized();
let local_hs = a.build_handshake();
let remote_hs = b.build_handshake();
assert!(local_hs.has_state);
assert!(remote_hs.has_state);
assert_ne!(local_hs.root_hash, remote_hs.root_hash);
let selection = select_protocol(&local_hs, &remote_hs);
assert!(
!matches!(selection.protocol, SyncProtocol::Snapshot { .. }),
"VIOLATION: Snapshot selected for initialized nodes!\n\
Should use HashComparison/DeltaSync/etc., got {:?}",
selection.protocol
);
}
#[test]
fn test_dispatch_identifies_unimplemented_protocols() {
let (mut a, mut b) = Scenario::force_subtree_prefetch();
let local_hs = a.build_handshake();
let remote_hs = b.build_handshake();
let selection = select_protocol(&local_hs, &remote_hs);
if matches!(selection.protocol, SyncProtocol::SubtreePrefetch { .. }) {
assert!(
selection.reason.contains("subtree") || selection.reason.contains("deep"),
"Reason should explain subtree selection: {}",
selection.reason
);
}
let (mut c, mut d) = Scenario::force_levelwise();
let local_hs = c.build_handshake();
let remote_hs = d.build_handshake();
let selection = select_protocol(&local_hs, &remote_hs);
if matches!(selection.protocol, SyncProtocol::LevelWise { .. }) {
assert!(
selection.reason.contains("level") || selection.reason.contains("wide"),
"Reason should explain levelwise selection: {}",
selection.reason
);
}
}
#[test]
fn test_all_selections_have_reasons() {
let test_cases: Vec<(&str, SyncHandshake, SyncHandshake)> = vec![
(
"same_hash",
SyncHandshake::new([42; 32], 100, 5, vec![]),
SyncHandshake::new([42; 32], 100, 5, vec![]),
),
(
"fresh_to_init",
SyncHandshake::new([0; 32], 0, 0, vec![]),
SyncHandshake::new([42; 32], 100, 5, vec![]),
),
(
"high_divergence",
SyncHandshake::new([1; 32], 10, 2, vec![]),
SyncHandshake::new([2; 32], 100, 5, vec![]),
),
(
"low_divergence_deep",
SyncHandshake::new([1; 32], 90, 5, vec![]),
SyncHandshake::new([2; 32], 100, 5, vec![]),
),
];
for (name, local, remote) in test_cases {
let selection = select_protocol(&local, &remote);
assert!(
!selection.reason.is_empty(),
"Selection for '{}' has empty reason",
name
);
assert!(
selection.reason.len() > 5,
"Selection reason for '{}' is too short: '{}'",
name,
selection.reason
);
}
}
fn build_manager_style_handshake(root_hash: [u8; 32], dag_heads: &[[u8; 32]]) -> SyncHandshake {
let entity_count = estimate_entity_count(root_hash, dag_heads.len());
let max_depth = estimate_max_depth(entity_count);
build_handshake_from_raw(root_hash, entity_count, max_depth, dag_heads.to_vec())
}