use crate::sync::helpers::{
apply_leaf_with_crdt_merge, generate_nonce, handle_entity_push, MAX_ENTITIES_PER_PUSH,
};
use async_trait::async_trait;
use calimero_node_primitives::sync::{
compare_tree_nodes, create_runtime_env, InitPayload, LeafMetadata, MessagePayload,
StreamMessage, SyncProtocolExecutor, SyncTransport, TreeCompareResult, TreeLeafData, TreeNode,
TreeNodeResponse, MAX_NODES_PER_RESPONSE,
};
use calimero_primitives::context::ContextId;
use calimero_primitives::identity::PublicKey;
use calimero_storage::address::Id;
use calimero_storage::env::with_runtime_env;
use calimero_storage::index::Index;
use calimero_storage::interface::Interface;
use calimero_storage::store::MainStorage;
use calimero_store::Store;
use eyre::{bail, Result};
use tracing::{debug, info, trace, warn};
const MAX_PENDING_NODES: usize = 10_000;
pub const MAX_REQUEST_DEPTH: u8 = 16;
pub const MAX_HASH_COMPARISON_REQUESTS: u64 = 10_000;
#[derive(Debug, Clone)]
pub struct HashComparisonConfig {
pub remote_root_hash: [u8; 32],
}
#[derive(Debug, Clone)]
pub struct HashComparisonFirstRequest {
pub node_id: [u8; 32],
pub max_depth: Option<u8>,
}
#[derive(Debug, Default, Clone)]
pub struct HashComparisonStats {
pub nodes_compared: u64,
pub entities_merged: u64,
pub entities_pushed: u64,
pub nodes_skipped: u64,
pub requests_sent: u64,
}
pub struct HashComparisonProtocol;
#[async_trait(?Send)]
impl SyncProtocolExecutor for HashComparisonProtocol {
type Config = HashComparisonConfig;
type ResponderInit = HashComparisonFirstRequest;
type Stats = HashComparisonStats;
async fn run_initiator<T: SyncTransport>(
transport: &mut T,
store: &Store,
context_id: ContextId,
identity: PublicKey,
config: Self::Config,
) -> Result<Self::Stats> {
run_initiator_impl(
transport,
store,
context_id,
identity,
config.remote_root_hash,
)
.await
}
async fn run_responder<T: SyncTransport>(
transport: &mut T,
store: &Store,
context_id: ContextId,
identity: PublicKey,
first_request: Self::ResponderInit,
) -> Result<()> {
run_responder_impl(
transport,
store,
context_id,
identity,
first_request.node_id,
first_request.max_depth,
)
.await
}
}
async fn run_initiator_impl<T: SyncTransport>(
transport: &mut T,
store: &Store,
context_id: ContextId,
identity: PublicKey,
remote_root_hash: [u8; 32],
) -> Result<HashComparisonStats> {
info!(%context_id, "Starting HashComparison sync (initiator)");
let mut stats = HashComparisonStats::default();
let runtime_env = create_runtime_env(store, context_id, identity);
let mut to_compare: Vec<([u8; 32], bool)> = vec![(remote_root_hash, true)];
while let Some((node_id, is_root_request)) = to_compare.pop() {
if to_compare.len() > MAX_PENDING_NODES {
bail!(
"HashComparison sync aborted: pending nodes ({}) exceeds limit ({})",
to_compare.len(),
MAX_PENDING_NODES
);
}
let request_msg = StreamMessage::Init {
context_id,
party_id: identity,
payload: InitPayload::TreeNodeRequest {
context_id,
node_id,
max_depth: Some(1),
},
next_nonce: generate_nonce(),
};
transport.send(&request_msg).await?;
stats.requests_sent += 1;
let response = transport
.recv()
.await?
.ok_or_else(|| eyre::eyre!("stream closed unexpectedly"))?;
let StreamMessage::Message { payload, .. } = response else {
bail!("Unexpected response type during HashComparison sync");
};
let (nodes, not_found) = match payload {
MessagePayload::TreeNodeResponse { nodes, not_found } => (nodes, not_found),
MessagePayload::SnapshotError { error } => {
warn!(%context_id, ?error, "Peer returned error");
bail!("Peer error: {:?}", error);
}
_ => bail!("Unexpected payload type"),
};
if nodes.len() > MAX_NODES_PER_RESPONSE {
warn!(%context_id, count = nodes.len(), "Response too large, skipping");
continue;
}
if not_found {
debug!(%context_id, node_id = %hex::encode(node_id), "Node not found on peer");
continue;
}
for remote_node in nodes {
if !remote_node.is_valid() {
warn!(%context_id, "Invalid TreeNode, skipping");
continue;
}
stats.nodes_compared += 1;
if remote_node.is_leaf() {
if let Some(ref leaf_data) = remote_node.leaf_data {
trace!(
%context_id,
key = %hex::encode(leaf_data.key),
"Merging leaf entity"
);
with_runtime_env(runtime_env.clone(), || {
apply_leaf_with_crdt_merge(context_id, leaf_data)
})?;
stats.entities_merged += 1;
}
} else {
let is_this_node_root = is_root_request && remote_node.id == node_id;
let local_version = with_runtime_env(runtime_env.clone(), || {
get_local_tree_node(context_id, &remote_node.id, is_this_node_root)
})?;
match compare_tree_nodes(local_version.as_ref(), Some(&remote_node)) {
TreeCompareResult::Equal => {
stats.nodes_skipped += 1;
trace!(%context_id, "Subtree matches, skipping");
}
TreeCompareResult::LocalMissing => {
for child_id in &remote_node.children {
to_compare.push((*child_id, false));
}
}
TreeCompareResult::Different {
remote_only_children,
local_only_children,
common_children,
} => {
for child_id in remote_only_children {
to_compare.push((child_id, false));
}
for child_id in common_children {
to_compare.push((child_id, false));
}
if !local_only_children.is_empty() {
let pushed = push_local_subtrees(
transport,
&runtime_env,
context_id,
identity,
&local_only_children,
&mut stats,
)
.await?;
debug!(
%context_id,
local_only = local_only_children.len(),
entities_pushed = pushed,
"Pushed local-only children to peer"
);
}
}
TreeCompareResult::RemoteMissing => {
if let Some(ref local_node) = local_version {
let leaves = with_runtime_env(runtime_env.clone(), || {
collect_local_leaves(context_id, &local_node.id, is_this_node_root)
})?;
if !leaves.is_empty() {
push_entities(transport, context_id, identity, &leaves, &mut stats)
.await?;
}
}
}
}
}
}
}
transport.close().await?;
info!(
%context_id,
nodes_compared = stats.nodes_compared,
entities_merged = stats.entities_merged,
entities_pushed = stats.entities_pushed,
nodes_skipped = stats.nodes_skipped,
"HashComparison sync complete"
);
Ok(stats)
}
async fn run_responder_impl<T: SyncTransport>(
transport: &mut T,
store: &Store,
context_id: ContextId,
identity: PublicKey,
first_node_id: [u8; 32],
first_max_depth: Option<u8>,
) -> Result<()> {
info!(%context_id, "Starting HashComparison sync (responder)");
if let Some(depth) = first_max_depth {
if depth > MAX_REQUEST_DEPTH {
bail!(
"First request max_depth {} exceeds maximum {}",
depth,
MAX_REQUEST_DEPTH
);
}
}
let runtime_env = create_runtime_env(store, context_id, identity);
let local_root_hash = with_runtime_env(runtime_env.clone(), || {
Index::<MainStorage>::get_hashes_for(Id::new(*context_id.as_ref()))
.ok()
.flatten()
.map(|(full, _)| full)
.unwrap_or([0; 32])
});
let mut sequence_id = 0u64;
let mut requests_handled = 0u64;
{
let clamped_depth = first_max_depth.map(|d| d.min(MAX_REQUEST_DEPTH));
let is_root_request = first_node_id == local_root_hash;
let local_node = with_runtime_env(runtime_env.clone(), || {
get_local_tree_node(context_id, &first_node_id, is_root_request)
})?;
let response =
build_tree_node_response_internal(context_id, local_node, clamped_depth, &runtime_env)?;
let msg = StreamMessage::Message {
sequence_id,
payload: MessagePayload::TreeNodeResponse {
nodes: response.nodes,
not_found: response.not_found,
},
next_nonce: generate_nonce(),
};
transport.send(&msg).await?;
sequence_id += 1;
requests_handled += 1;
}
loop {
if requests_handled >= MAX_HASH_COMPARISON_REQUESTS {
warn!(
%context_id,
requests_handled,
max = MAX_HASH_COMPARISON_REQUESTS,
"Request limit reached, closing responder"
);
break;
}
let Some(request) = transport.recv().await? else {
debug!(%context_id, requests_handled, "Stream closed, responder done");
break;
};
let StreamMessage::Init { payload, .. } = request else {
debug!(%context_id, "Received non-Init message, ending responder");
break;
};
match payload {
InitPayload::TreeNodeRequest {
node_id, max_depth, ..
} => {
trace!(
%context_id,
node_id = %hex::encode(node_id),
?max_depth,
"Handling TreeNodeRequest"
);
let clamped_depth = max_depth.map(|d| d.min(MAX_REQUEST_DEPTH));
let is_root_request = node_id == local_root_hash;
let local_node = with_runtime_env(runtime_env.clone(), || {
get_local_tree_node(context_id, &node_id, is_root_request)
})?;
let response = build_tree_node_response_internal(
context_id,
local_node,
clamped_depth,
&runtime_env,
)?;
let msg = StreamMessage::Message {
sequence_id,
payload: MessagePayload::TreeNodeResponse {
nodes: response.nodes,
not_found: response.not_found,
},
next_nonce: generate_nonce(),
};
transport.send(&msg).await?;
sequence_id += 1;
requests_handled += 1;
}
InitPayload::EntityPush { entities, .. } => {
let entity_count = entities.len();
trace!(%context_id, entity_count, "Handling EntityPush from initiator");
let applied = handle_entity_push(&runtime_env, context_id, &entities);
let msg = StreamMessage::Message {
sequence_id,
payload: MessagePayload::EntityPushAck {
applied_count: applied,
},
next_nonce: generate_nonce(),
};
transport.send(&msg).await?;
sequence_id += 1;
requests_handled += 1;
info!(
%context_id,
applied,
total = entity_count,
"Applied pushed entities via CRDT merge"
);
}
_ => {
debug!(%context_id, "Received unknown payload, ending responder");
break;
}
}
}
info!(%context_id, requests_handled, "HashComparison responder complete");
Ok(())
}
fn build_tree_node_response_internal(
context_id: ContextId,
local_node: Option<TreeNode>,
clamped_depth: Option<u8>,
runtime_env: &calimero_storage::env::RuntimeEnv,
) -> Result<TreeNodeResponse> {
let response = if let Some(node) = local_node {
let mut nodes = vec![node.clone()];
let depth = clamped_depth.unwrap_or(0);
if depth > 0 && node.is_internal() {
for child_id in &node.children {
if let Some(child) = with_runtime_env(runtime_env.clone(), || {
get_local_tree_node(context_id, child_id, false)
})? {
nodes.push(child);
if nodes.len() >= MAX_NODES_PER_RESPONSE {
break;
}
}
}
}
TreeNodeResponse::new(nodes)
} else {
TreeNodeResponse::not_found()
};
Ok(response)
}
const MAX_COLLECT_DEPTH: u32 = 64;
const MAX_LEAVES_PER_SUBTREE: usize = 10_000;
fn collect_local_leaves(
context_id: ContextId,
node_id: &[u8; 32],
is_root: bool,
) -> Result<Vec<TreeLeafData>> {
let mut leaves = Vec::new();
collect_leaves_recursive(context_id, node_id, is_root, &mut leaves, 0)?;
Ok(leaves)
}
fn collect_leaves_recursive(
context_id: ContextId,
node_id: &[u8; 32],
is_root: bool,
leaves: &mut Vec<TreeLeafData>,
depth: u32,
) -> Result<()> {
if depth >= MAX_COLLECT_DEPTH {
warn!(
depth,
node_id = %hex::encode(node_id),
"collect_leaves_recursive: max depth reached, truncating"
);
return Ok(());
}
if leaves.len() > MAX_LEAVES_PER_SUBTREE {
return Ok(());
}
let entity_id = if is_root {
Id::new(*context_id.as_ref())
} else {
Id::new(*node_id)
};
let index = match Index::<MainStorage>::get_index(entity_id) {
Ok(Some(idx)) => idx,
Ok(None) => return Ok(()),
Err(e) => {
warn!(
%entity_id,
error = %e,
"collect_leaves_recursive: failed to read index, skipping subtree"
);
return Ok(());
}
};
let children_ids: Vec<[u8; 32]> = index
.children()
.map(|children| children.iter().map(|c| *c.id().as_bytes()).collect())
.unwrap_or_default();
if children_ids.is_empty() {
if let Some(entry_data) = Interface::<MainStorage>::find_by_id_raw(entity_id) {
if let Some(ref crdt_type) = index.metadata.crdt_type {
let metadata =
LeafMetadata::new(crdt_type.clone(), index.metadata.updated_at(), [0u8; 32]);
let leaf_data = TreeLeafData::new(*entity_id.as_bytes(), entry_data, metadata);
leaves.push(leaf_data);
} else {
warn!(
%entity_id,
"collect_leaves_recursive: leaf missing crdt_type, skipping"
);
}
}
} else {
for child_id in &children_ids {
collect_leaves_recursive(context_id, child_id, false, leaves, depth + 1)?;
}
}
Ok(())
}
async fn push_local_subtrees<T: SyncTransport>(
transport: &mut T,
runtime_env: &calimero_storage::env::RuntimeEnv,
context_id: ContextId,
identity: PublicKey,
local_only_children: &[[u8; 32]],
stats: &mut HashComparisonStats,
) -> Result<u64> {
let mut total = 0u64;
for child_id in local_only_children {
let leaves = with_runtime_env(runtime_env.clone(), || {
collect_local_leaves(context_id, child_id, false)
})?;
if !leaves.is_empty() {
total += push_entities(transport, context_id, identity, &leaves, stats).await?;
}
}
Ok(total)
}
async fn push_entities<T: SyncTransport>(
transport: &mut T,
context_id: ContextId,
identity: PublicKey,
leaves: &[TreeLeafData],
stats: &mut HashComparisonStats,
) -> Result<u64> {
let mut total_pushed = 0u64;
for chunk in leaves.chunks(MAX_ENTITIES_PER_PUSH) {
let push_msg = StreamMessage::Init {
context_id,
party_id: identity,
payload: InitPayload::EntityPush {
context_id,
entities: chunk.to_vec(),
},
next_nonce: generate_nonce(),
};
transport.send(&push_msg).await?;
stats.requests_sent += 1;
let ack = transport
.recv()
.await?
.ok_or_else(|| eyre::eyre!("stream closed while waiting for EntityPushAck"))?;
match ack {
StreamMessage::Message {
payload: MessagePayload::EntityPushAck { applied_count },
..
} => {
total_pushed += u64::from(applied_count);
}
_ => {
bail!(
"Unexpected response to EntityPush (peer may not support bidirectional sync)"
);
}
}
}
stats.entities_pushed += total_pushed;
Ok(total_pushed)
}
fn get_local_tree_node(
context_id: ContextId,
node_id: &[u8; 32],
is_root_request: bool,
) -> Result<Option<TreeNode>> {
let entity_id = if is_root_request {
Id::new(*context_id.as_ref())
} else {
Id::new(*node_id)
};
let index = match Index::<MainStorage>::get_index(entity_id) {
Ok(Some(idx)) => idx,
Ok(None) => return Ok(None),
Err(e) => {
warn!(%context_id, %entity_id, error = %e, "Failed to get index");
return Ok(None);
}
};
let full_hash = index.full_hash();
let children_ids: Vec<[u8; 32]> = index
.children()
.map(|children| children.iter().map(|c| *c.id().as_bytes()).collect())
.unwrap_or_default();
if children_ids.is_empty() {
if let Some(entry_data) = Interface::<MainStorage>::find_by_id_raw(entity_id) {
let Some(crdt_type) = index.metadata.crdt_type.clone() else {
warn!(
%entity_id,
"leaf has no CRDT type, treating as opaque node"
);
return Ok(Some(TreeNode::internal(
*entity_id.as_bytes(),
full_hash,
vec![],
)));
};
let metadata = LeafMetadata::new(crdt_type, index.metadata.updated_at(), [0u8; 32]);
let leaf_data = TreeLeafData::new(*entity_id.as_bytes(), entry_data, metadata);
Ok(Some(TreeNode::leaf(
*entity_id.as_bytes(),
full_hash,
leaf_data,
)))
} else {
Ok(Some(TreeNode::internal(
*entity_id.as_bytes(),
full_hash,
vec![],
)))
}
} else {
Ok(Some(TreeNode::internal(
*entity_id.as_bytes(),
full_hash,
children_ids,
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_creation() {
let config = HashComparisonConfig {
remote_root_hash: [1u8; 32],
};
assert_eq!(config.remote_root_hash, [1u8; 32]);
}
#[test]
fn test_stats_default() {
let stats = HashComparisonStats::default();
assert_eq!(stats.nodes_compared, 0);
assert_eq!(stats.entities_merged, 0);
}
}