use alloy_primitives::{B256, Keccak256};
use crate::bmt::{Hasher, constants::*, error::BmtError};
use crate::error::Result;
#[derive(Clone, Debug)]
pub struct Proof {
pub segment_index: usize,
pub segment: B256,
pub proof_segments: Vec<B256>,
pub span: u64,
pub prefix: Option<Vec<u8>>,
}
impl Proof {
pub const fn new(
segment_index: usize,
segment: B256,
proof_segments: Vec<B256>,
span: u64,
prefix: Option<Vec<u8>>,
) -> Self {
Self {
segment_index,
segment,
proof_segments,
span,
prefix,
}
}
pub fn verify(&self, root_hash: &[u8]) -> Result<bool> {
if self.proof_segments.len() != PROOF_LENGTH {
return Err(
BmtError::invalid_proof_length(PROOF_LENGTH, self.proof_segments.len()).into(),
);
}
let mut current_hash = self.segment;
let mut current_index = self.segment_index;
for proof_segment in &self.proof_segments {
let mut hasher = Keccak256::new();
if current_index.is_multiple_of(2) {
hasher.update(current_hash.as_slice());
hasher.update(proof_segment.as_slice());
} else {
hasher.update(proof_segment.as_slice());
hasher.update(current_hash.as_slice());
}
current_hash = B256::from_slice(hasher.finalize().as_slice());
current_index /= 2;
}
let mut hasher = Keccak256::new();
if let Some(prefix) = &self.prefix {
hasher.update(prefix);
}
hasher.update(self.span.to_le_bytes());
hasher.update(current_hash.as_slice());
let computed_root = B256::from_slice(hasher.finalize().as_slice());
Ok(computed_root.as_slice() == root_hash)
}
}
pub trait Prover {
fn generate_proof(&self, data: &[u8], segment_index: usize) -> Result<Proof>;
fn verify_proof(proof: &Proof, root_hash: &[u8]) -> Result<bool>;
}
impl Prover for Hasher {
fn generate_proof(&self, data: &[u8], segment_index: usize) -> Result<Proof> {
if segment_index >= BRANCHES {
return Err(self::BmtError::invalid_input_size(format!(
"Segment index {segment_index} out of bounds for BRANCHES"
))
.into());
}
let data_len = data.len();
#[cfg(not(target_arch = "wasm32"))]
let segments = {
use rayon::prelude::*;
(0..BRANCHES)
.into_par_iter()
.map(|i| {
let start = i * SEGMENT_SIZE;
let mut segment = [0u8; SEGMENT_SIZE];
if start < data_len {
let end = (start + SEGMENT_SIZE).min(data_len);
let copy_len = end - start;
segment[..copy_len].copy_from_slice(&data[start..end]);
}
B256::from_slice(&segment)
})
.collect::<Vec<_>>()
};
#[cfg(target_arch = "wasm32")]
let segments = {
let mut segs = Vec::with_capacity(BRANCHES);
for i in 0..BRANCHES {
let start = i * SEGMENT_SIZE;
let mut segment = [0u8; SEGMENT_SIZE];
if start < data_len {
let end = (start + SEGMENT_SIZE).min(data_len);
let copy_len = end - start;
segment[..copy_len].copy_from_slice(&data[start..end]);
}
segs.push(B256::from_slice(&segment));
}
segs
};
let segment = segments[segment_index];
let mut proof_segments = Vec::with_capacity(PROOF_LENGTH);
let mut current_level = segments;
let mut current_index = segment_index;
while proof_segments.len() < PROOF_LENGTH {
let sibling_index = if current_index.is_multiple_of(2) {
current_index + 1
} else {
current_index - 1
};
if sibling_index < current_level.len() {
proof_segments.push(current_level[sibling_index]);
} else {
proof_segments.push(B256::ZERO);
}
let mut next_level = Vec::with_capacity(current_level.len().div_ceil(2));
for i in (0..current_level.len()).step_by(2) {
let left = ¤t_level[i];
let right = if i + 1 < current_level.len() {
¤t_level[i + 1]
} else {
&B256::ZERO
};
let mut hasher = Keccak256::new();
hasher.update(left.as_slice());
hasher.update(right.as_slice());
let parent = B256::from_slice(hasher.finalize().as_slice());
next_level.push(parent);
}
current_level = next_level;
current_index /= 2;
if current_level.len() <= 1 {
break;
}
}
while proof_segments.len() < PROOF_LENGTH {
proof_segments.push(B256::ZERO);
}
let prefix = if !self.prefix().is_empty() {
Some(self.prefix().to_vec())
} else {
None
};
Ok(Proof::new(
segment_index,
segment,
proof_segments,
self.span(),
prefix,
))
}
fn verify_proof(proof: &Proof, root_hash: &[u8]) -> Result<bool> {
proof.verify(root_hash)
}
}