use std::fmt;
use std::sync::Arc;
use calimero_node::sync::metrics::SyncMetricsCollector;
use super::metrics::SimMetrics;
use super::metrics_adapter::SimMetricsCollector;
use super::node::SimNode;
use super::scenarios::Scenario;
use super::sim_runtime::SimRuntime;
#[derive(Debug, Clone)]
pub struct BenchmarkResult {
pub scenario: String,
pub protocol: String,
pub round_trips: u64,
pub entities_transferred: u64,
pub bytes_transferred: u64,
pub merges: u64,
pub time_to_converge_ms: u64,
pub converged: bool,
}
impl fmt::Display for BenchmarkResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let status = if self.converged { "OK" } else { "FAIL" };
write!(
f,
"{:<25} {:<20} RT:{:>4} Ent:{:>6} Bytes:{:>8} Merges:{:>6} Time:{:>6}ms [{}]",
self.scenario,
self.protocol,
self.round_trips,
self.entities_transferred,
self.bytes_transferred,
self.merges,
self.time_to_converge_ms,
status,
)
}
}
impl BenchmarkResult {
pub fn from_metrics(
scenario: impl Into<String>,
protocol: String,
metrics: &SimMetrics,
converged: bool,
) -> Self {
Self {
scenario: scenario.into(),
protocol,
round_trips: metrics.protocol.round_trips,
entities_transferred: metrics.protocol.entities_transferred,
bytes_transferred: metrics.protocol.payload_bytes,
merges: metrics.protocol.merges_performed,
time_to_converge_ms: metrics
.convergence
.time_to_converge
.map(|t| t.as_millis())
.unwrap_or(0),
converged,
}
}
}
pub fn run_two_node_benchmark(
scenario_name: impl Into<String>,
mut node_a: SimNode,
mut node_b: SimNode,
) -> BenchmarkResult {
use calimero_node_primitives::sync::select_protocol;
let handshake_a = node_a.build_handshake();
let handshake_b = node_b.build_handshake();
let selection = select_protocol(&handshake_a, &handshake_b);
let protocol_name = format!("{:?}", selection.protocol.kind());
let collector = Arc::new(SimMetricsCollector::new());
let mut rt = SimRuntime::new(42);
rt.add_existing_node(node_a);
rt.add_existing_node(node_b);
let converged = rt.run_until_converged();
record_simulation_metrics(&*collector, rt.metrics(), &protocol_name);
let metrics = collector.snapshot();
BenchmarkResult::from_metrics(scenario_name, protocol_name, &metrics, converged)
}
fn record_simulation_metrics(
collector: &dyn SyncMetricsCollector,
sim_metrics: &SimMetrics,
protocol: &str,
) {
let messages_sent = sim_metrics.protocol.messages_sent;
if messages_sent > 0 {
let total_bytes = sim_metrics.protocol.payload_bytes;
let base_bytes = (total_bytes / messages_sent) as usize;
let remainder = (total_bytes % messages_sent) as usize;
for i in 0..messages_sent {
let bytes = if i == messages_sent - 1 {
base_bytes + remainder
} else {
base_bytes
};
collector.record_message_sent(protocol, bytes);
}
}
for _ in 0..sim_metrics.protocol.round_trips {
collector.record_round_trip(protocol);
}
collector.record_entities_transferred(sim_metrics.protocol.entities_transferred as usize);
for _ in 0..sim_metrics.protocol.merges_performed {
collector.record_merge("unknown"); }
for _ in 0..sim_metrics.protocol.entities_compared {
collector.record_comparison();
}
for _ in 0..sim_metrics.effects.buffer_drops {
collector.record_buffer_drop();
}
if sim_metrics.convergence.converged {
let duration = sim_metrics
.convergence
.time_to_converge
.map(|t| std::time::Duration::from_micros(t.as_micros()))
.unwrap_or_default();
collector.record_sync_complete(
"benchmark",
protocol,
duration,
sim_metrics.protocol.entities_transferred as usize,
);
}
}
#[derive(Debug, Default)]
pub struct BenchmarkSummary {
pub total: usize,
pub converged: usize,
pub lowest_round_trips: Option<BenchmarkResult>,
pub highest_bandwidth: Option<BenchmarkResult>,
pub fastest_convergence: Option<BenchmarkResult>,
}
impl BenchmarkSummary {
pub fn add(&mut self, result: BenchmarkResult) {
self.total += 1;
if result.converged {
self.converged += 1;
}
if result.converged {
match &self.lowest_round_trips {
None => self.lowest_round_trips = Some(result.clone()),
Some(best) if result.round_trips < best.round_trips => {
self.lowest_round_trips = Some(result.clone());
}
_ => {}
}
}
match &self.highest_bandwidth {
None => self.highest_bandwidth = Some(result.clone()),
Some(best) if result.bytes_transferred > best.bytes_transferred => {
self.highest_bandwidth = Some(result.clone());
}
_ => {}
}
if result.converged && result.time_to_converge_ms > 0 {
match &self.fastest_convergence {
None => self.fastest_convergence = Some(result.clone()),
Some(best) if result.time_to_converge_ms < best.time_to_converge_ms => {
self.fastest_convergence = Some(result);
}
_ => {}
}
}
}
}
impl fmt::Display for BenchmarkSummary {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "=== Benchmark Summary ===")?;
writeln!(f, "Total: {} / Converged: {}", self.total, self.converged)?;
if let Some(ref best) = self.lowest_round_trips {
writeln!(
f,
"Most efficient (lowest RT): {} ({} round trips)",
best.scenario, best.round_trips
)?;
}
if let Some(ref best) = self.highest_bandwidth {
writeln!(
f,
"Highest bandwidth: {} ({} bytes)",
best.scenario, best.bytes_transferred
)?;
}
if let Some(ref best) = self.fastest_convergence {
writeln!(
f,
"Fastest convergence: {} ({}ms)",
best.scenario, best.time_to_converge_ms
)?;
}
Ok(())
}
}
pub fn run_all_benchmarks() -> (Vec<BenchmarkResult>, BenchmarkSummary) {
let mut results = Vec::new();
let mut summary = BenchmarkSummary::default();
let scenarios: Vec<(&'static str, (SimNode, SimNode))> = vec![
("same_state", Scenario::force_none()),
("fresh_bootstrap", Scenario::force_snapshot()),
("high_divergence", Scenario::force_hash_high_divergence()),
("partial_overlap", Scenario::partial_overlap()),
("deep_tree_localized", Scenario::force_subtree_prefetch()),
("wide_shallow", Scenario::force_levelwise()),
("delta_sync", Scenario::force_delta_sync()),
("bloom_filter", Scenario::force_bloom_filter()),
];
for (name, (node_a, node_b)) in scenarios {
let result = run_two_node_benchmark(name, node_a, node_b);
summary.add(result.clone());
results.push(result);
}
(results, summary)
}
pub fn run_scaling_benchmarks(entity_counts: &[usize]) -> (Vec<BenchmarkResult>, BenchmarkSummary) {
use super::scenarios::deterministic::generate_entities;
let mut results = Vec::new();
let mut summary = BenchmarkSummary::default();
for &count in entity_counts {
let mut node_a = SimNode::new("a");
let mut node_b = SimNode::new("b");
for (id, data, metadata) in generate_entities(count / 2, 1) {
node_a.insert_entity_with_metadata(id, data, metadata);
}
for (id, data, metadata) in generate_entities(count, 2) {
node_b.insert_entity_with_metadata(id, data, metadata);
}
let scenario_name = format!("diverged_{count}");
let result = run_two_node_benchmark(scenario_name, node_a, node_b);
summary.add(result.clone());
results.push(result);
}
(results, summary)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn benchmark_all_scenarios() {
println!("\n=== Sync Protocol Benchmarks ===\n");
let (results, summary) = run_all_benchmarks();
for result in &results {
println!("{result}");
}
println!();
println!("{summary}");
assert!(
summary.converged > 0,
"Expected at least some benchmarks to converge"
);
}
#[test]
fn benchmark_scaling() {
println!("\n=== Scaling Benchmark ===\n");
let entity_counts = vec![10, 50, 100, 200, 500];
let (results, summary) = run_scaling_benchmarks(&entity_counts);
for result in &results {
println!("{result}");
}
println!();
println!("{summary}");
}
#[test]
fn test_same_state_uses_none_protocol() {
let (node_a, node_b) = Scenario::force_none();
let result = run_two_node_benchmark("same_state", node_a, node_b);
assert!(
result.protocol.contains("None"),
"Expected None protocol for same_state, got {}",
result.protocol
);
assert!(result.converged);
}
#[test]
fn test_fresh_bootstrap_uses_snapshot() {
let (fresh, source) = Scenario::force_snapshot();
let result = run_two_node_benchmark("fresh_bootstrap", fresh, source);
assert!(
result.protocol.contains("Snapshot"),
"Expected Snapshot protocol for fresh_bootstrap, got {}",
result.protocol
);
}
#[test]
fn test_high_divergence_uses_hash_comparison() {
let (node_a, node_b) = Scenario::force_hash_high_divergence();
let result = run_two_node_benchmark("high_divergence", node_a, node_b);
assert!(
result.protocol.contains("HashComparison"),
"Expected HashComparison protocol for high_divergence, got {}",
result.protocol
);
}
#[test]
fn test_benchmark_result_display() {
let result = BenchmarkResult {
scenario: "test_scenario".to_string(),
protocol: "HashComparison".to_string(),
round_trips: 5,
entities_transferred: 100,
bytes_transferred: 10240,
merges: 50,
time_to_converge_ms: 150,
converged: true,
};
let display = result.to_string();
assert!(display.contains("test_scenario"));
assert!(display.contains("HashComparison"));
assert!(display.contains("OK"));
}
#[test]
fn test_sync_metrics_collector_integration() {
use std::sync::Arc;
let (mut node_a, mut node_b) = Scenario::force_hash_high_divergence();
use calimero_node_primitives::sync::select_protocol;
let handshake_a = node_a.build_handshake();
let handshake_b = node_b.build_handshake();
let selection = select_protocol(&handshake_a, &handshake_b);
let protocol_name = format!("{:?}", selection.protocol.kind());
let collector = Arc::new(SimMetricsCollector::new());
let mut rt = SimRuntime::new(42);
rt.add_existing_node(node_a);
rt.add_existing_node(node_b);
let converged = rt.run_until_converged();
let sim_metrics = rt.metrics().clone();
record_simulation_metrics(&*collector, &sim_metrics, &protocol_name);
let collected = collector.snapshot();
assert_eq!(
collected.protocol.messages_sent, sim_metrics.protocol.messages_sent,
"Message count mismatch"
);
assert_eq!(
collected.protocol.round_trips, sim_metrics.protocol.round_trips,
"Round trip count mismatch"
);
assert_eq!(
collected.protocol.entities_transferred, sim_metrics.protocol.entities_transferred,
"Entities transferred mismatch"
);
assert_eq!(
collected.protocol.merges_performed, sim_metrics.protocol.merges_performed,
"Merges performed mismatch"
);
assert_eq!(
collected.protocol.entities_compared, sim_metrics.protocol.entities_compared,
"Entities compared mismatch"
);
assert_eq!(
collected.effects.buffer_drops, sim_metrics.effects.buffer_drops,
"Buffer drops mismatch"
);
println!("SyncMetricsCollector integration test passed!");
println!(" Protocol: {protocol_name}");
println!(" Converged: {converged}");
println!(" Messages: {}", collected.protocol.messages_sent);
println!(" Round trips: {}", collected.protocol.round_trips);
println!(" Entities: {}", collected.protocol.entities_transferred);
println!(" Merges: {}", collected.protocol.merges_performed);
}
#[test]
fn test_trait_object_usage() {
let collector = SimMetricsCollector::new();
fn record_via_trait(metrics: &dyn SyncMetricsCollector) {
metrics.record_message_sent("TestProtocol", 1024);
metrics.record_message_sent("TestProtocol", 2048);
metrics.record_round_trip("TestProtocol");
metrics.record_entities_transferred(5);
metrics.record_merge("GCounter");
metrics.record_comparison();
metrics.record_buffer_drop();
let timer = metrics.start_phase("test_phase");
std::thread::sleep(std::time::Duration::from_millis(1));
metrics.record_phase_complete(timer);
metrics.record_sync_start("ctx-123", "TestProtocol", "manual");
metrics.record_sync_complete(
"ctx-123",
"TestProtocol",
std::time::Duration::from_millis(100),
5,
);
metrics.record_protocol_selected("TestProtocol", "test", 0.5);
}
record_via_trait(&collector);
let metrics = collector.snapshot();
assert_eq!(metrics.protocol.messages_sent, 2);
assert_eq!(metrics.protocol.payload_bytes, 3072);
assert_eq!(metrics.protocol.round_trips, 1);
assert_eq!(metrics.protocol.entities_transferred, 5);
assert_eq!(metrics.protocol.merges_performed, 1);
assert_eq!(metrics.protocol.entities_compared, 1);
assert_eq!(metrics.effects.buffer_drops, 1);
println!("Trait object usage test passed!");
println!(" All metrics recorded correctly through &dyn SyncMetricsCollector");
}
}