use rayon::prelude::*;
use crate::aligned::AlignedHashArray;
use crate::arena::BatchArena;
use crate::crypto::{hash_pair, Hash};
use crate::event::AuditEvent;
const PARALLEL_THRESHOLD: usize = 64;
const BATCH_THRESHOLD: usize = 8;
pub fn compute_root_optimized(events: &[AuditEvent]) -> Hash {
if events.is_empty() {
return Hash::ZERO;
}
if events.len() == 1 {
return events[0].id().0;
}
if events.len() < PARALLEL_THRESHOLD {
return compute_root_sequential(events);
}
let arena = BatchArena::for_hashes(events.len() * 2);
compute_root_with_arena(events, &arena)
}
pub fn compute_root_with_arena(events: &[AuditEvent], arena: &BatchArena) -> Hash {
if events.is_empty() {
return Hash::ZERO;
}
if events.len() == 1 {
return events[0].id().0;
}
let leaf_hashes: Vec<Hash> = events.par_iter().map(|e| e.id().0).collect();
let padded_len = leaf_hashes.len().next_power_of_two();
let mut current_level = arena.alloc_vec_with_capacity(padded_len);
current_level.extend(leaf_hashes.iter().copied());
if let Some(last) = current_level.last().copied() {
while current_level.len() < padded_len {
current_level.push(last);
}
}
while current_level.len() > 1 {
let pairs = current_level.len() / 2;
if pairs >= PARALLEL_THRESHOLD {
let next_level: Vec<Hash> = current_level
.chunks(2)
.collect::<Vec<_>>()
.par_iter()
.map(|pair| hash_pair(pair[0], pair[1]))
.collect();
current_level.clear();
current_level.extend(next_level);
} else if pairs >= BATCH_THRESHOLD {
let mut next_level = arena.alloc_vec_with_capacity(pairs);
let chunks = current_level.chunks_exact(8);
let remainder = chunks.remainder();
for chunk in chunks {
let h0 = hash_pair(chunk[0], chunk[1]);
let h1 = hash_pair(chunk[2], chunk[3]);
let h2 = hash_pair(chunk[4], chunk[5]);
let h3 = hash_pair(chunk[6], chunk[7]);
next_level.extend([h0, h1, h2, h3]);
}
for pair in remainder.chunks(2) {
if pair.len() == 2 {
next_level.push(hash_pair(pair[0], pair[1]));
}
}
current_level = next_level;
} else {
let mut next_level = arena.alloc_vec_with_capacity(pairs);
for pair in current_level.chunks(2) {
next_level.push(hash_pair(pair[0], pair[1]));
}
current_level = next_level;
}
}
current_level.first().copied().unwrap_or(Hash::ZERO)
}
fn compute_root_sequential(events: &[AuditEvent]) -> Hash {
if events.is_empty() {
return Hash::ZERO;
}
let mut hashes: Vec<Hash> = events.iter().map(|e| e.id().0).collect();
let target_len = hashes.len().next_power_of_two();
if let Some(last) = hashes.last().copied() {
while hashes.len() < target_len {
hashes.push(last);
}
}
while hashes.len() > 1 {
let mut next = Vec::with_capacity(hashes.len() / 2);
for pair in hashes.chunks(2) {
next.push(hash_pair(pair[0], pair[1]));
}
hashes = next;
}
hashes.first().copied().unwrap_or(Hash::ZERO)
}
pub fn compute_roots_batch(event_sets: &[&[AuditEvent]]) -> Vec<Hash> {
if event_sets.is_empty() {
return Vec::new();
}
event_sets
.par_iter()
.map(|events| {
if events.len() < PARALLEL_THRESHOLD {
compute_root_sequential(events)
} else {
let arena = BatchArena::for_hashes(events.len() * 2);
compute_root_with_arena(events, &arena)
}
})
.collect()
}
pub fn compute_proof(events: &[AuditEvent], index: usize) -> Option<Vec<Hash>> {
if index >= events.len() || events.is_empty() {
return None;
}
let mut hashes: Vec<Hash> = events.iter().map(|e| e.id().0).collect();
let target_len = hashes.len().next_power_of_two();
if let Some(last) = hashes.last().copied() {
while hashes.len() < target_len {
hashes.push(last);
}
}
let mut proof = Vec::new();
let mut idx = index;
while hashes.len() > 1 {
let sibling_idx = if idx % 2 == 0 { idx + 1 } else { idx - 1 };
if sibling_idx < hashes.len() {
proof.push(hashes[sibling_idx]);
}
let mut next = Vec::with_capacity(hashes.len() / 2);
for pair in hashes.chunks(2) {
next.push(hash_pair(pair[0], pair[1]));
}
hashes = next;
idx /= 2;
}
Some(proof)
}
pub fn verify_proof(leaf: Hash, proof: &[Hash], index: usize, root: Hash) -> bool {
let mut current = leaf;
let mut idx = index;
for sibling in proof {
if idx % 2 == 0 {
current = hash_pair(current, *sibling);
} else {
current = hash_pair(*sibling, current);
}
idx /= 2;
}
current == root
}
#[allow(dead_code)]
pub struct MerkleTreeBuffer {
level_0: Vec<AlignedHashArray<64>>,
level_1: Vec<AlignedHashArray<32>>,
level_2: Vec<AlignedHashArray<16>>,
level_3: Vec<AlignedHashArray<8>>,
max_leaves: usize,
}
impl MerkleTreeBuffer {
pub fn new(max_leaves: usize) -> Self {
let l0_arrays = max_leaves.div_ceil(64);
let l1_arrays = (max_leaves / 2).div_ceil(32);
let l2_arrays = (max_leaves / 4).div_ceil(16);
let l3_arrays = (max_leaves / 8).div_ceil(8);
Self {
level_0: vec![AlignedHashArray::new(); l0_arrays],
level_1: vec![AlignedHashArray::new(); l1_arrays],
level_2: vec![AlignedHashArray::new(); l2_arrays],
level_3: vec![AlignedHashArray::new(); l3_arrays],
max_leaves,
}
}
pub fn max_leaves(&self) -> usize {
self.max_leaves
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::SecretKey;
use crate::event::{ActorId, ActorKind, EventType, Outcome, ResourceId, ResourceKind};
fn create_test_event(key: &SecretKey, n: u32) -> AuditEvent {
let actor = ActorId::new(key.public_key(), ActorKind::User);
let resource = ResourceId::new(ResourceKind::Repository, format!("repo-{}", n));
AuditEvent::builder()
.now()
.event_type(EventType::Push {
force: false,
commits: n,
})
.actor(actor)
.resource(resource)
.outcome(Outcome::Success)
.sign(key)
.unwrap()
}
#[test]
fn test_merkle_root_empty() {
let events: Vec<AuditEvent> = vec![];
assert_eq!(compute_root_optimized(&events), Hash::ZERO);
}
#[test]
fn test_merkle_root_single() {
let key = SecretKey::generate();
let events = vec![create_test_event(&key, 0)];
let root = compute_root_optimized(&events);
assert_eq!(root, events[0].id().0);
}
#[test]
fn test_merkle_root_multiple() {
let key = SecretKey::generate();
let events: Vec<_> = (0..10).map(|i| create_test_event(&key, i)).collect();
let root = compute_root_optimized(&events);
let root_sequential = compute_root_sequential(&events);
assert_eq!(root, root_sequential);
}
#[test]
fn test_merkle_root_matches_block_impl() {
let key = SecretKey::generate();
let events: Vec<_> = (0..100).map(|i| create_test_event(&key, i)).collect();
let optimized_root = compute_root_optimized(&events);
let block_root = crate::block::compute_events_root(&events);
assert_eq!(optimized_root, block_root);
}
#[test]
fn test_merkle_proof_verify() {
let key = SecretKey::generate();
let events: Vec<_> = (0..16).map(|i| create_test_event(&key, i)).collect();
let root = compute_root_optimized(&events);
for i in 0..events.len() {
let proof = compute_proof(&events, i).expect("proof should exist");
let leaf = events[i].id().0;
assert!(
verify_proof(leaf, &proof, i, root),
"Proof verification failed for index {}",
i
);
}
}
#[test]
fn test_merkle_proof_invalid() {
let key = SecretKey::generate();
let events: Vec<_> = (0..8).map(|i| create_test_event(&key, i)).collect();
let root = compute_root_optimized(&events);
let proof = compute_proof(&events, 0).expect("proof should exist");
let wrong_leaf = events[1].id().0;
assert!(!verify_proof(wrong_leaf, &proof, 0, root));
}
#[test]
fn test_batch_roots() {
let key = SecretKey::generate();
let events1: Vec<_> = (0..10).map(|i| create_test_event(&key, i)).collect();
let events2: Vec<_> = (10..30).map(|i| create_test_event(&key, i)).collect();
let events3: Vec<_> = (30..35).map(|i| create_test_event(&key, i)).collect();
let event_sets: Vec<&[AuditEvent]> = vec![&events1, &events2, &events3];
let roots = compute_roots_batch(&event_sets);
assert_eq!(roots.len(), 3);
assert_eq!(roots[0], compute_root_optimized(&events1));
assert_eq!(roots[1], compute_root_optimized(&events2));
assert_eq!(roots[2], compute_root_optimized(&events3));
}
#[test]
fn test_merkle_tree_buffer() {
let buffer = MerkleTreeBuffer::new(1000);
assert_eq!(buffer.max_leaves(), 1000);
}
}