use crate::sync::helpers::{apply_leaf_with_crdt_merge, generate_nonce};
use async_trait::async_trait;
use calimero_node_primitives::sync::{
compare_tree_nodes, create_runtime_env, InitPayload, LeafMetadata, MessagePayload,
StreamMessage, SyncProtocolExecutor, SyncTransport, TreeCompareResult, TreeLeafData, TreeNode,
TreeNodeResponse, MAX_NODES_PER_RESPONSE,
};
use calimero_primitives::context::ContextId;
use calimero_primitives::identity::PublicKey;
use calimero_storage::address::Id;
use calimero_storage::env::with_runtime_env;
use calimero_storage::index::Index;
use calimero_storage::interface::Interface;
use calimero_storage::store::MainStorage;
use calimero_store::Store;
use eyre::{bail, Result};
use tracing::{debug, info, trace, warn};
const MAX_PENDING_NODES: usize = 10_000;
pub const MAX_REQUEST_DEPTH: u8 = 16;
const MAX_HASH_COMPARISON_REQUESTS: u64 = 10_000;
#[derive(Debug, Clone)]
pub struct HashComparisonConfig {
pub remote_root_hash: [u8; 32],
}
#[derive(Debug, Clone)]
pub struct HashComparisonFirstRequest {
pub node_id: [u8; 32],
pub max_depth: Option<u8>,
}
#[derive(Debug, Default, Clone)]
pub struct HashComparisonStats {
pub nodes_compared: u64,
pub entities_merged: u64,
pub nodes_skipped: u64,
pub requests_sent: u64,
}
pub struct HashComparisonProtocol;
#[async_trait(?Send)]
impl SyncProtocolExecutor for HashComparisonProtocol {
type Config = HashComparisonConfig;
type ResponderInit = HashComparisonFirstRequest;
type Stats = HashComparisonStats;
async fn run_initiator<T: SyncTransport>(
transport: &mut T,
store: &Store,
context_id: ContextId,
identity: PublicKey,
config: Self::Config,
) -> Result<Self::Stats> {
run_initiator_impl(
transport,
store,
context_id,
identity,
config.remote_root_hash,
)
.await
}
async fn run_responder<T: SyncTransport>(
transport: &mut T,
store: &Store,
context_id: ContextId,
identity: PublicKey,
first_request: Self::ResponderInit,
) -> Result<()> {
run_responder_impl(
transport,
store,
context_id,
identity,
first_request.node_id,
first_request.max_depth,
)
.await
}
}
async fn run_initiator_impl<T: SyncTransport>(
transport: &mut T,
store: &Store,
context_id: ContextId,
identity: PublicKey,
remote_root_hash: [u8; 32],
) -> Result<HashComparisonStats> {
info!(%context_id, "Starting HashComparison sync (initiator)");
let mut stats = HashComparisonStats::default();
let runtime_env = create_runtime_env(store, context_id, identity);
let mut to_compare: Vec<([u8; 32], bool)> = vec![(remote_root_hash, true)];
while let Some((node_id, is_root_request)) = to_compare.pop() {
if to_compare.len() > MAX_PENDING_NODES {
bail!(
"HashComparison sync aborted: pending nodes ({}) exceeds limit ({})",
to_compare.len(),
MAX_PENDING_NODES
);
}
let request_msg = StreamMessage::Init {
context_id,
party_id: identity,
payload: InitPayload::TreeNodeRequest {
context_id,
node_id,
max_depth: Some(1),
},
next_nonce: generate_nonce(),
};
transport.send(&request_msg).await?;
stats.requests_sent += 1;
let response = transport
.recv()
.await?
.ok_or_else(|| eyre::eyre!("stream closed unexpectedly"))?;
let StreamMessage::Message { payload, .. } = response else {
bail!("Unexpected response type during HashComparison sync");
};
let (nodes, not_found) = match payload {
MessagePayload::TreeNodeResponse { nodes, not_found } => (nodes, not_found),
MessagePayload::SnapshotError { error } => {
warn!(%context_id, ?error, "Peer returned error");
bail!("Peer error: {:?}", error);
}
_ => bail!("Unexpected payload type"),
};
if nodes.len() > MAX_NODES_PER_RESPONSE {
warn!(%context_id, count = nodes.len(), "Response too large, skipping");
continue;
}
if not_found {
debug!(%context_id, node_id = %hex::encode(node_id), "Node not found on peer");
continue;
}
for remote_node in nodes {
if !remote_node.is_valid() {
warn!(%context_id, "Invalid TreeNode, skipping");
continue;
}
stats.nodes_compared += 1;
if remote_node.is_leaf() {
if let Some(ref leaf_data) = remote_node.leaf_data {
trace!(
%context_id,
key = %hex::encode(leaf_data.key),
"Merging leaf entity"
);
with_runtime_env(runtime_env.clone(), || {
apply_leaf_with_crdt_merge(context_id, leaf_data)
})?;
stats.entities_merged += 1;
}
} else {
let is_this_node_root = is_root_request && remote_node.id == node_id;
let local_version = with_runtime_env(runtime_env.clone(), || {
get_local_tree_node(context_id, &remote_node.id, is_this_node_root)
})?;
match compare_tree_nodes(local_version.as_ref(), Some(&remote_node)) {
TreeCompareResult::Equal => {
stats.nodes_skipped += 1;
trace!(%context_id, "Subtree matches, skipping");
}
TreeCompareResult::LocalMissing => {
for child_id in &remote_node.children {
to_compare.push((*child_id, false));
}
}
TreeCompareResult::Different {
remote_only_children,
common_children,
..
} => {
for child_id in remote_only_children {
to_compare.push((child_id, false));
}
for child_id in common_children {
to_compare.push((child_id, false));
}
}
TreeCompareResult::RemoteMissing => {
}
}
}
}
}
transport.close().await?;
info!(
%context_id,
nodes_compared = stats.nodes_compared,
entities_merged = stats.entities_merged,
nodes_skipped = stats.nodes_skipped,
"HashComparison sync complete"
);
Ok(stats)
}
async fn run_responder_impl<T: SyncTransport>(
transport: &mut T,
store: &Store,
context_id: ContextId,
identity: PublicKey,
first_node_id: [u8; 32],
first_max_depth: Option<u8>,
) -> Result<()> {
info!(%context_id, "Starting HashComparison sync (responder)");
if let Some(depth) = first_max_depth {
if depth > MAX_REQUEST_DEPTH {
bail!(
"First request max_depth {} exceeds maximum {}",
depth,
MAX_REQUEST_DEPTH
);
}
}
let runtime_env = create_runtime_env(store, context_id, identity);
let local_root_hash = with_runtime_env(runtime_env.clone(), || {
Index::<MainStorage>::get_hashes_for(Id::new(*context_id.as_ref()))
.ok()
.flatten()
.map(|(full, _)| full)
.unwrap_or([0; 32])
});
let mut sequence_id = 0u64;
let mut requests_handled = 0u64;
{
let clamped_depth = first_max_depth.map(|d| d.min(MAX_REQUEST_DEPTH));
let is_root_request = first_node_id == local_root_hash;
let local_node = with_runtime_env(runtime_env.clone(), || {
get_local_tree_node(context_id, &first_node_id, is_root_request)
})?;
let response =
build_tree_node_response_internal(context_id, local_node, clamped_depth, &runtime_env)?;
let msg = StreamMessage::Message {
sequence_id,
payload: MessagePayload::TreeNodeResponse {
nodes: response.nodes,
not_found: response.not_found,
},
next_nonce: generate_nonce(),
};
transport.send(&msg).await?;
sequence_id += 1;
requests_handled += 1;
}
loop {
if requests_handled >= MAX_HASH_COMPARISON_REQUESTS {
warn!(
%context_id,
requests_handled,
max = MAX_HASH_COMPARISON_REQUESTS,
"Request limit reached, closing responder"
);
break;
}
let Some(request) = transport.recv().await? else {
debug!(%context_id, requests_handled, "Stream closed, responder done");
break;
};
let StreamMessage::Init { payload, .. } = request else {
debug!(%context_id, "Received non-Init message, ending responder");
break;
};
let InitPayload::TreeNodeRequest {
node_id, max_depth, ..
} = payload
else {
debug!(%context_id, "Received non-TreeNodeRequest, ending responder");
break;
};
trace!(
%context_id,
node_id = %hex::encode(node_id),
?max_depth,
"Handling TreeNodeRequest"
);
let clamped_depth = max_depth.map(|d| d.min(MAX_REQUEST_DEPTH));
let is_root_request = node_id == local_root_hash;
let local_node = with_runtime_env(runtime_env.clone(), || {
get_local_tree_node(context_id, &node_id, is_root_request)
})?;
let response =
build_tree_node_response_internal(context_id, local_node, clamped_depth, &runtime_env)?;
let msg = StreamMessage::Message {
sequence_id,
payload: MessagePayload::TreeNodeResponse {
nodes: response.nodes,
not_found: response.not_found,
},
next_nonce: generate_nonce(),
};
transport.send(&msg).await?;
sequence_id += 1;
requests_handled += 1;
}
info!(%context_id, requests_handled, "HashComparison responder complete");
Ok(())
}
fn build_tree_node_response_internal(
context_id: ContextId,
local_node: Option<TreeNode>,
clamped_depth: Option<u8>,
runtime_env: &calimero_storage::env::RuntimeEnv,
) -> Result<TreeNodeResponse> {
let response = if let Some(node) = local_node {
let mut nodes = vec![node.clone()];
let depth = clamped_depth.unwrap_or(0);
if depth > 0 && node.is_internal() {
for child_id in &node.children {
if let Some(child) = with_runtime_env(runtime_env.clone(), || {
get_local_tree_node(context_id, child_id, false)
})? {
nodes.push(child);
if nodes.len() >= MAX_NODES_PER_RESPONSE {
break;
}
}
}
}
TreeNodeResponse::new(nodes)
} else {
TreeNodeResponse::not_found()
};
Ok(response)
}
fn get_local_tree_node(
context_id: ContextId,
node_id: &[u8; 32],
is_root_request: bool,
) -> Result<Option<TreeNode>> {
let entity_id = if is_root_request {
Id::new(*context_id.as_ref())
} else {
Id::new(*node_id)
};
let index = match Index::<MainStorage>::get_index(entity_id) {
Ok(Some(idx)) => idx,
Ok(None) => return Ok(None),
Err(e) => {
warn!(%context_id, %entity_id, error = %e, "Failed to get index");
return Ok(None);
}
};
let full_hash = index.full_hash();
let children_ids: Vec<[u8; 32]> = index
.children()
.map(|children| children.iter().map(|c| *c.id().as_bytes()).collect())
.unwrap_or_default();
if children_ids.is_empty() {
if let Some(entry_data) = Interface::<MainStorage>::find_by_id_raw(entity_id) {
let crdt_type = index.metadata.crdt_type.clone().ok_or_else(|| {
eyre::eyre!(
"Missing CRDT type metadata for leaf entity {}: data integrity issue",
entity_id
)
})?;
let metadata = LeafMetadata::new(crdt_type, index.metadata.updated_at(), [0u8; 32]);
let leaf_data = TreeLeafData::new(*entity_id.as_bytes(), entry_data, metadata);
Ok(Some(TreeNode::leaf(
*entity_id.as_bytes(),
full_hash,
leaf_data,
)))
} else {
Ok(Some(TreeNode::internal(
*entity_id.as_bytes(),
full_hash,
vec![],
)))
}
} else {
Ok(Some(TreeNode::internal(
*entity_id.as_bytes(),
full_hash,
children_ids,
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_creation() {
let config = HashComparisonConfig {
remote_root_hash: [1u8; 32],
};
assert_eq!(config.remote_root_hash, [1u8; 32]);
}
#[test]
fn test_stats_default() {
let stats = HashComparisonStats::default();
assert_eq!(stats.nodes_compared, 0);
assert_eq!(stats.entities_merged, 0);
}
}