use crate::{hash::FieldHasher, CircuitConfig, FieldHasherVariable};
use rayon::{
iter::{IndexedParallelIterator, ParallelIterator},
slice::ParallelSlice,
};
use serde::{Deserialize, Serialize};
use slop_challenger::IopCtx;
use sp1_core_machine::utils::{log2_strict_usize, reverse_slice_index_bits};
use sp1_hypercube::MerkleProof;
use sp1_primitives::utils::reverse_bits_len;
use sp1_recursion_compiler::ir::Builder;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MerkleTree<GC: IopCtx> {
pub height: usize,
pub digest_layers: Vec<GC::Digest>,
}
impl<GC: FieldHasher<Digest: Default>> MerkleTree<GC> {
pub fn commit(leaves: Vec<GC::Digest>) -> (GC::Digest, Self) {
assert!(!leaves.is_empty());
let new_len = leaves.len().next_power_of_two();
let height = log2_strict_usize(new_len);
let mut digest_layers = Vec::with_capacity(2 * new_len - 2);
let mut last_layer = leaves;
let old_len = last_layer.len();
for _ in old_len..new_len {
last_layer.push(GC::Digest::default());
}
reverse_slice_index_bits(&mut last_layer);
digest_layers.extend(last_layer.iter());
for _ in 0..height - 1 {
let mut next_layer = Vec::with_capacity(last_layer.len() / 2);
last_layer
.par_chunks_exact(2)
.map(|chunk| {
let [left, right] = chunk.try_into().unwrap();
GC::constant_compress([left, right])
})
.collect_into_vec(&mut next_layer);
digest_layers.extend(next_layer.iter());
last_layer = next_layer;
}
debug_assert_eq!(digest_layers.len(), 2 * new_len - 2);
let root = GC::constant_compress([last_layer[0], last_layer[1]]);
(root, Self { height, digest_layers })
}
pub fn open(&self, index: usize) -> (GC::Digest, MerkleProof<GC>) {
let mut path = Vec::with_capacity(self.height);
let mut bit_rev_index = reverse_bits_len(index, self.height);
let value = self.digest_layers[bit_rev_index];
let mut offset = 0;
for i in 0..self.height {
let sibling = if bit_rev_index.is_multiple_of(2) {
self.digest_layers[offset + bit_rev_index + 1]
} else {
self.digest_layers[offset + bit_rev_index - 1]
};
path.push(sibling);
bit_rev_index >>= 1;
offset += 1 << (self.height - i);
}
assert_eq!(path.len(), self.height);
(value, MerkleProof { index, path })
}
}
pub fn verify<C: CircuitConfig, HV: FieldHasherVariable<C>>(
builder: &mut Builder<C>,
path: Vec<HV::DigestVariable>,
index: Vec<C::Bit>,
value: HV::DigestVariable,
merkle_root: HV::DigestVariable,
) {
let mut value = value;
for (sibling, bit) in path.iter().zip(index.iter()) {
let sibling = *sibling;
let new_pair = HV::select_chain_digest(builder, *bit, [value, sibling]);
value = HV::compress(builder, new_pair);
}
HV::assert_digest_eq(builder, value, merkle_root);
}