use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::BTreeMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use scirs2_core::profiling::Profiler;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
pub type MerkleHash = [u8; 32];
#[allow(dead_code)]
fn hash_to_hex(hash: &MerkleHash) -> String {
hex::encode(hash)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MerkleNode {
Leaf { hash: MerkleHash, data_key: String },
Internal {
hash: MerkleHash,
left: Box<MerkleNode>,
right: Box<MerkleNode>,
},
}
impl MerkleNode {
pub fn hash(&self) -> &MerkleHash {
match self {
MerkleNode::Leaf { hash, .. } => hash,
MerkleNode::Internal { hash, .. } => hash,
}
}
pub fn is_leaf(&self) -> bool {
matches!(self, MerkleNode::Leaf { .. })
}
pub fn depth(&self) -> usize {
match self {
MerkleNode::Leaf { .. } => 0,
MerkleNode::Internal { left, right, .. } => {
1 + std::cmp::max(left.depth(), right.depth())
}
}
}
}
#[derive(Debug, Clone)]
pub struct MerkleTree {
root: Arc<RwLock<Option<MerkleNode>>>,
leaves: Arc<RwLock<BTreeMap<String, MerkleHash>>>,
stats: Arc<RwLock<MerkleTreeStats>>,
hash_counter: Arc<AtomicU64>,
rebuild_time_ns: Arc<AtomicU64>,
profiler: Arc<Profiler>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MerkleTreeStats {
pub leaf_count: usize,
pub depth: usize,
pub total_verifications: u64,
pub successful_verifications: u64,
pub failed_verifications: u64,
pub total_rebuilds: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MerkleProof {
pub data_key: String,
pub leaf_hash: MerkleHash,
pub path: Vec<(MerkleHash, bool)>,
pub root_hash: MerkleHash,
}
impl MerkleTree {
pub fn new() -> Self {
Self {
root: Arc::new(RwLock::new(None)),
leaves: Arc::new(RwLock::new(BTreeMap::new())),
stats: Arc::new(RwLock::new(MerkleTreeStats::default())),
hash_counter: Arc::new(AtomicU64::new(0)),
rebuild_time_ns: Arc::new(AtomicU64::new(0)),
profiler: Arc::new(Profiler::new()),
}
}
pub fn hash_operations(&self) -> u64 {
self.hash_counter.load(Ordering::Relaxed)
}
pub async fn average_rebuild_time_us(&self) -> f64 {
let stats = self.stats.read().await;
if stats.total_rebuilds == 0 {
return 0.0;
}
let total_ns = self.rebuild_time_ns.load(Ordering::Relaxed);
(total_ns as f64) / (stats.total_rebuilds as f64) / 1000.0
}
fn hash_data(&self, data: &str) -> MerkleHash {
self.hash_counter.fetch_add(1, Ordering::Relaxed);
let mut hasher = Sha256::new();
hasher.update(data.as_bytes());
hasher.finalize().into()
}
fn hash_nodes(&self, left: &MerkleHash, right: &MerkleHash) -> MerkleHash {
self.hash_counter.fetch_add(1, Ordering::Relaxed);
let mut hasher = Sha256::new();
hasher.update(left);
hasher.update(right);
hasher.finalize().into()
}
fn batch_hash_data(&self, items: &[(String, String)]) -> Vec<(String, MerkleHash)> {
use rayon::prelude::*;
let results: Vec<(String, MerkleHash)> = items
.par_iter()
.map(|(key, data)| {
self.hash_counter.fetch_add(1, Ordering::Relaxed);
let mut hasher = Sha256::new();
hasher.update(data.as_bytes());
let hash: MerkleHash = hasher.finalize().into();
(key.clone(), hash)
})
.collect();
results
}
pub async fn insert(&self, key: String, data: &str) {
let hash = self.hash_data(data);
let mut leaves = self.leaves.write().await;
leaves.insert(key, hash);
drop(leaves);
self.rebuild().await;
}
pub async fn insert_batch(&self, items: Vec<(String, String)>) {
if items.is_empty() {
return;
}
let hashed_items = self.batch_hash_data(&items);
let mut leaves = self.leaves.write().await;
for (key, hash) in hashed_items {
leaves.insert(key, hash);
}
drop(leaves);
self.rebuild().await;
}
pub async fn remove(&self, key: &str) {
let mut leaves = self.leaves.write().await;
leaves.remove(key);
drop(leaves);
self.rebuild().await;
}
async fn rebuild(&self) {
let start = Instant::now();
let leaves = self.leaves.read().await;
if leaves.is_empty() {
*self.root.write().await = None;
let mut stats = self.stats.write().await;
stats.leaf_count = 0;
stats.depth = 0;
stats.total_rebuilds += 1;
let elapsed_ns = start.elapsed().as_nanos() as u64;
self.rebuild_time_ns
.fetch_add(elapsed_ns, Ordering::Relaxed);
return;
}
let mut nodes: Vec<MerkleNode> = leaves
.iter()
.map(|(key, hash)| MerkleNode::Leaf {
hash: *hash,
data_key: key.clone(),
})
.collect();
const PARALLEL_THRESHOLD: usize = 256;
while nodes.len() > 1 {
let next_level = if nodes.len() >= PARALLEL_THRESHOLD {
use rayon::prelude::*;
nodes
.par_chunks(2)
.map(|chunk| {
if chunk.len() == 2 {
let hash = self.hash_nodes(chunk[0].hash(), chunk[1].hash());
MerkleNode::Internal {
hash,
left: Box::new(chunk[0].clone()),
right: Box::new(chunk[1].clone()),
}
} else {
chunk[0].clone()
}
})
.collect()
} else {
let mut next_level = Vec::new();
for chunk in nodes.chunks(2) {
if chunk.len() == 2 {
let hash = self.hash_nodes(chunk[0].hash(), chunk[1].hash());
next_level.push(MerkleNode::Internal {
hash,
left: Box::new(chunk[0].clone()),
right: Box::new(chunk[1].clone()),
});
} else {
next_level.push(chunk[0].clone());
}
}
next_level
};
nodes = next_level;
}
let root_node = nodes.into_iter().next();
let depth = root_node.as_ref().map(|n| n.depth()).unwrap_or(0);
*self.root.write().await = root_node;
let mut stats = self.stats.write().await;
stats.leaf_count = leaves.len();
stats.depth = depth;
stats.total_rebuilds += 1;
let elapsed_ns = start.elapsed().as_nanos() as u64;
self.rebuild_time_ns
.fetch_add(elapsed_ns, Ordering::Relaxed);
}
pub async fn root_hash(&self) -> Option<MerkleHash> {
self.root.read().await.as_ref().map(|node| *node.hash())
}
pub async fn verify(&self, key: &str, data: &str) -> bool {
let hash = self.hash_data(data);
let leaves = self.leaves.read().await;
let result = leaves
.get(key)
.map(|stored_hash| *stored_hash == hash)
.unwrap_or(false);
let mut stats = self.stats.write().await;
stats.total_verifications += 1;
if result {
stats.successful_verifications += 1;
} else {
stats.failed_verifications += 1;
}
result
}
pub async fn generate_proof(&self, key: &str) -> Option<MerkleProof> {
let leaves = self.leaves.read().await;
let leaf_hash = *leaves.get(key)?;
let root = self.root.read().await;
let root_node = root.as_ref()?;
let root_hash = *root_node.hash();
let path = self.find_proof_path(root_node, key);
Some(MerkleProof {
data_key: key.to_string(),
leaf_hash,
path,
root_hash,
})
}
fn find_proof_path(&self, node: &MerkleNode, key: &str) -> Vec<(MerkleHash, bool)> {
match node {
MerkleNode::Leaf { data_key, .. } => {
if data_key == key {
Vec::new()
} else {
Vec::new()
}
}
MerkleNode::Internal { left, right, .. } => {
if self.contains_key(left, key) {
let mut path = self.find_proof_path(left, key);
path.push((*right.hash(), false));
path
} else {
let mut path = self.find_proof_path(right, key);
path.push((*left.hash(), true));
path
}
}
}
}
fn contains_key(&self, node: &MerkleNode, key: &str) -> bool {
match node {
MerkleNode::Leaf { data_key, .. } => data_key == key,
MerkleNode::Internal { left, right, .. } => {
self.contains_key(left, key) || self.contains_key(right, key)
}
}
}
pub fn verify_proof(&self, proof: &MerkleProof, data: &str) -> bool {
let computed_hash = self.hash_data(data);
if computed_hash != proof.leaf_hash {
return false;
}
let mut current_hash = proof.leaf_hash;
for (sibling_hash, is_left_sibling) in &proof.path {
current_hash = if *is_left_sibling {
self.hash_nodes(sibling_hash, ¤t_hash)
} else {
self.hash_nodes(¤t_hash, sibling_hash)
};
}
current_hash == proof.root_hash
}
pub async fn compare(&self, other: &MerkleTree) -> MerkleComparison {
let our_root = self.root_hash().await;
let their_root = other.root_hash().await;
if our_root == their_root {
return MerkleComparison::Identical;
}
let our_leaves = self.leaves.read().await;
let their_leaves = other.leaves.read().await;
let mut missing_keys = Vec::new();
let mut extra_keys = Vec::new();
let mut conflicting_keys = Vec::new();
for key in our_leaves.keys() {
if !their_leaves.contains_key(key) {
extra_keys.push(key.clone());
}
}
for (key, their_hash) in their_leaves.iter() {
if let Some(our_hash) = our_leaves.get(key) {
if our_hash != their_hash {
conflicting_keys.push(key.clone());
}
} else {
missing_keys.push(key.clone());
}
}
MerkleComparison::Different {
missing_keys,
extra_keys,
conflicting_keys,
}
}
pub async fn get_stats(&self) -> MerkleTreeStats {
self.stats.read().await.clone()
}
pub fn get_profiling_report(&self) -> String {
self.profiler.get_report()
}
pub fn profiler(&self) -> &Profiler {
&self.profiler
}
pub async fn get_keys(&self) -> Vec<String> {
self.leaves.read().await.keys().cloned().collect()
}
pub async fn len(&self) -> usize {
self.leaves.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.leaves.read().await.is_empty()
}
pub async fn clear(&self) {
self.leaves.write().await.clear();
*self.root.write().await = None;
let mut stats = self.stats.write().await;
stats.leaf_count = 0;
stats.depth = 0;
}
}
impl Default for MerkleTree {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum MerkleComparison {
Identical,
Different {
missing_keys: Vec<String>,
extra_keys: Vec<String>,
conflicting_keys: Vec<String>,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_merkle_tree_creation() {
let tree = MerkleTree::new();
assert!(tree.is_empty().await);
assert_eq!(tree.len().await, 0);
assert!(tree.root_hash().await.is_none());
}
#[tokio::test]
async fn test_insert_and_verify() {
let tree = MerkleTree::new();
tree.insert("key1".to_string(), "value1").await;
tree.insert("key2".to_string(), "value2").await;
assert_eq!(tree.len().await, 2);
assert!(tree.root_hash().await.is_some());
assert!(tree.verify("key1", "value1").await);
assert!(tree.verify("key2", "value2").await);
assert!(!tree.verify("key1", "wrong_value").await);
}
#[tokio::test]
async fn test_remove() {
let tree = MerkleTree::new();
tree.insert("key1".to_string(), "value1").await;
tree.insert("key2".to_string(), "value2").await;
assert_eq!(tree.len().await, 2);
tree.remove("key1").await;
assert_eq!(tree.len().await, 1);
assert!(!tree.verify("key1", "value1").await);
assert!(tree.verify("key2", "value2").await);
}
#[tokio::test]
async fn test_root_hash_changes() {
let tree = MerkleTree::new();
tree.insert("key1".to_string(), "value1").await;
let hash1 = tree.root_hash().await;
tree.insert("key2".to_string(), "value2").await;
let hash2 = tree.root_hash().await;
assert_ne!(hash1, hash2);
}
#[tokio::test]
async fn test_merkle_proof() {
let tree = MerkleTree::new();
tree.insert("key1".to_string(), "value1").await;
tree.insert("key2".to_string(), "value2").await;
tree.insert("key3".to_string(), "value3").await;
let proof = tree.generate_proof("key2").await;
assert!(proof.is_some());
let proof = proof.unwrap();
assert_eq!(proof.data_key, "key2");
assert!(tree.verify_proof(&proof, "value2"));
assert!(!tree.verify_proof(&proof, "wrong_value"));
}
#[tokio::test]
async fn test_compare_identical_trees() {
let tree1 = MerkleTree::new();
let tree2 = MerkleTree::new();
tree1.insert("key1".to_string(), "value1").await;
tree1.insert("key2".to_string(), "value2").await;
tree2.insert("key1".to_string(), "value1").await;
tree2.insert("key2".to_string(), "value2").await;
let comparison = tree1.compare(&tree2).await;
assert_eq!(comparison, MerkleComparison::Identical);
}
#[tokio::test]
async fn test_compare_different_trees() {
let tree1 = MerkleTree::new();
let tree2 = MerkleTree::new();
tree1.insert("key1".to_string(), "value1").await;
tree1.insert("key2".to_string(), "value2").await;
tree2.insert("key2".to_string(), "value2").await;
tree2.insert("key3".to_string(), "value3").await;
let comparison = tree1.compare(&tree2).await;
match comparison {
MerkleComparison::Different {
missing_keys,
extra_keys,
conflicting_keys,
} => {
assert_eq!(missing_keys, vec!["key3"]);
assert_eq!(extra_keys, vec!["key1"]);
assert!(conflicting_keys.is_empty());
}
_ => panic!("Expected different trees"),
}
}
#[tokio::test]
async fn test_compare_conflicting_trees() {
let tree1 = MerkleTree::new();
let tree2 = MerkleTree::new();
tree1.insert("key1".to_string(), "value1").await;
tree2.insert("key1".to_string(), "different_value").await;
let comparison = tree1.compare(&tree2).await;
match comparison {
MerkleComparison::Different {
missing_keys,
extra_keys,
conflicting_keys,
} => {
assert!(missing_keys.is_empty());
assert!(extra_keys.is_empty());
assert_eq!(conflicting_keys, vec!["key1"]);
}
_ => panic!("Expected different trees"),
}
}
#[tokio::test]
async fn test_stats_tracking() {
let tree = MerkleTree::new();
tree.insert("key1".to_string(), "value1").await;
tree.insert("key2".to_string(), "value2").await;
tree.verify("key1", "value1").await;
tree.verify("key2", "wrong_value").await;
let stats = tree.get_stats().await;
assert_eq!(stats.leaf_count, 2);
assert_eq!(stats.total_verifications, 2);
assert_eq!(stats.successful_verifications, 1);
assert_eq!(stats.failed_verifications, 1);
assert!(stats.total_rebuilds > 0);
}
#[tokio::test]
async fn test_clear() {
let tree = MerkleTree::new();
tree.insert("key1".to_string(), "value1").await;
tree.insert("key2".to_string(), "value2").await;
assert_eq!(tree.len().await, 2);
tree.clear().await;
assert_eq!(tree.len().await, 0);
assert!(tree.is_empty().await);
assert!(tree.root_hash().await.is_none());
}
#[tokio::test]
async fn test_large_tree() {
let tree = MerkleTree::new();
for i in 0..100 {
tree.insert(format!("key{}", i), &format!("value{}", i))
.await;
}
assert_eq!(tree.len().await, 100);
let stats = tree.get_stats().await;
assert_eq!(stats.leaf_count, 100);
assert!(stats.depth > 0);
for i in 0..100 {
assert!(
tree.verify(&format!("key{}", i), &format!("value{}", i))
.await
);
}
}
#[tokio::test]
async fn test_batch_insert() {
let tree = MerkleTree::new();
let items: Vec<(String, String)> = (0..50)
.map(|i| (format!("batch_key{}", i), format!("batch_value{}", i)))
.collect();
tree.insert_batch(items).await;
assert_eq!(tree.len().await, 50);
for i in 0..50 {
assert!(
tree.verify(&format!("batch_key{}", i), &format!("batch_value{}", i))
.await
);
}
}
#[tokio::test]
async fn test_hash_operation_metrics() {
let tree = MerkleTree::new();
assert_eq!(tree.hash_operations(), 0);
tree.insert("key1".to_string(), "value1").await;
tree.insert("key2".to_string(), "value2").await;
tree.insert("key3".to_string(), "value3").await;
let hash_ops = tree.hash_operations();
assert!(hash_ops > 0, "Hash operations should be tracked");
tree.verify("key1", "value1").await;
assert!(tree.hash_operations() > hash_ops);
}
#[tokio::test]
async fn test_rebuild_time_metrics() {
let tree = MerkleTree::new();
for i in 0..10 {
tree.insert(format!("key{}", i), &format!("value{}", i))
.await;
}
let avg_rebuild_time = tree.average_rebuild_time_us().await;
assert!(avg_rebuild_time > 0.0, "Rebuild time should be tracked");
let stats = tree.get_stats().await;
assert!(stats.total_rebuilds > 0);
}
#[tokio::test]
async fn test_batch_vs_sequential_performance() {
use std::time::Instant;
let tree_seq = MerkleTree::new();
let start_seq = Instant::now();
for i in 0..100 {
tree_seq
.insert(format!("seq_key{}", i), &format!("seq_value{}", i))
.await;
}
let seq_duration = start_seq.elapsed();
let tree_batch = MerkleTree::new();
let items: Vec<(String, String)> = (0..100)
.map(|i| (format!("batch_key{}", i), format!("batch_value{}", i)))
.collect();
let start_batch = Instant::now();
tree_batch.insert_batch(items).await;
let batch_duration = start_batch.elapsed();
assert_eq!(tree_seq.len().await, 100);
assert_eq!(tree_batch.len().await, 100);
println!(
"Sequential: {:?}, Batch: {:?}, Speedup: {:.2}x",
seq_duration,
batch_duration,
seq_duration.as_secs_f64() / batch_duration.as_secs_f64()
);
}
}