#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use core::mem::MaybeUninit;
use core::slice;
use plonky2_maybe_rayon::*;
pub use qp_plonky2_core::MerkleCap;
use crate::hash::hash_types::RichField;
use crate::hash::merkle_proofs::MerkleProof;
use crate::plonk::config::Hasher;
use crate::util::log2_strict;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MerkleTree<F: RichField, H: Hasher<F>> {
pub leaves: Vec<Vec<F>>,
pub digests: Vec<H::Hash>,
pub cap: MerkleCap<F, H>,
}
impl<F: RichField, H: Hasher<F>> Default for MerkleTree<F, H> {
fn default() -> Self {
Self {
leaves: Vec::new(),
digests: Vec::new(),
cap: MerkleCap::default(),
}
}
}
pub(crate) fn capacity_up_to_mut<T>(v: &mut Vec<T>, len: usize) -> &mut [MaybeUninit<T>] {
assert!(v.capacity() >= len);
let v_ptr = v.as_mut_ptr().cast::<MaybeUninit<T>>();
unsafe {
slice::from_raw_parts_mut(v_ptr, len)
}
}
pub(crate) fn fill_subtree<F: RichField, H: Hasher<F>>(
digests_buf: &mut [MaybeUninit<H::Hash>],
leaves: &[Vec<F>],
) -> H::Hash {
assert_eq!(leaves.len(), digests_buf.len() / 2 + 1);
if digests_buf.is_empty() {
H::hash_or_noop(&leaves[0])
} else {
let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2);
let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap();
let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap();
let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2);
let (left_digest, right_digest) = plonky2_maybe_rayon::join(
|| fill_subtree::<F, H>(left_digests_buf, left_leaves),
|| fill_subtree::<F, H>(right_digests_buf, right_leaves),
);
left_digest_mem.write(left_digest);
right_digest_mem.write(right_digest);
H::two_to_one(left_digest, right_digest)
}
}
pub(crate) fn fill_digests_buf<F: RichField, H: Hasher<F>>(
digests_buf: &mut [MaybeUninit<H::Hash>],
cap_buf: &mut [MaybeUninit<H::Hash>],
leaves: &[Vec<F>],
cap_height: usize,
) {
if digests_buf.is_empty() {
debug_assert_eq!(cap_buf.len(), leaves.len());
cap_buf
.par_iter_mut()
.zip(leaves)
.for_each(|(cap_buf, leaf)| {
cap_buf.write(H::hash_or_noop(leaf));
});
return;
}
let subtree_digests_len = digests_buf.len() >> cap_height;
let subtree_leaves_len = leaves.len() >> cap_height;
let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len);
let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len);
assert_eq!(digests_chunks.len(), cap_buf.len());
assert_eq!(digests_chunks.len(), leaves_chunks.len());
digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each(
|((subtree_digests, subtree_cap), subtree_leaves)| {
subtree_cap.write(fill_subtree::<F, H>(subtree_digests, subtree_leaves));
},
);
}
pub(crate) fn merkle_tree_prove<F: RichField, H: Hasher<F>>(
leaf_index: usize,
leaves_len: usize,
cap_height: usize,
digests: &[H::Hash],
) -> Vec<H::Hash> {
let num_layers = log2_strict(leaves_len) - cap_height;
debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0);
let digest_len = 2 * (leaves_len - (1 << cap_height));
assert_eq!(digest_len, digests.len());
let digest_tree: &[H::Hash] = {
let tree_index = leaf_index >> num_layers;
let tree_len = digest_len >> cap_height;
&digests[tree_len * tree_index..tree_len * (tree_index + 1)]
};
let mut pair_index = leaf_index & ((1 << num_layers) - 1);
(0..num_layers)
.map(|i| {
let parity = pair_index & 1;
pair_index >>= 1;
let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1;
let sibling_index = 2 * siblings_index + (1 - parity);
digest_tree[sibling_index]
})
.collect()
}
impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
pub fn new(leaves: Vec<Vec<F>>, cap_height: usize) -> Self {
let log2_leaves_len = log2_strict(leaves.len());
assert!(
cap_height <= log2_leaves_len,
"cap_height={} should be at most log2(leaves.len())={}",
cap_height,
log2_leaves_len
);
let num_digests = 2 * (leaves.len() - (1 << cap_height));
let mut digests = Vec::with_capacity(num_digests);
let len_cap = 1 << cap_height;
let mut cap = Vec::with_capacity(len_cap);
let digests_buf = capacity_up_to_mut(&mut digests, num_digests);
let cap_buf = capacity_up_to_mut(&mut cap, len_cap);
fill_digests_buf::<F, H>(digests_buf, cap_buf, &leaves[..], cap_height);
unsafe {
digests.set_len(num_digests);
cap.set_len(len_cap);
}
Self {
leaves,
digests,
cap: MerkleCap(cap),
}
}
pub fn get(&self, i: usize) -> &[F] {
&self.leaves[i]
}
pub fn prove(&self, leaf_index: usize) -> MerkleProof<F, H> {
let cap_height = log2_strict(self.cap.len());
let siblings =
merkle_tree_prove::<F, H>(leaf_index, self.leaves.len(), cap_height, &self.digests);
MerkleProof { siblings }
}
}
#[cfg(test)]
#[cfg(feature = "rand")]
pub(crate) mod tests {
use anyhow::Result;
use super::*;
use crate::field::extension::Extendable;
use crate::hash::merkle_proofs::verify_merkle_proof_to_cap;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
pub(crate) fn random_data<F: RichField>(n: usize, k: usize) -> Vec<Vec<F>> {
(0..n).map(|_| F::rand_vec(k)).collect()
}
fn verify_all_leaves<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
>(
leaves: Vec<Vec<F>>,
cap_height: usize,
) -> Result<()> {
let tree = MerkleTree::<F, C::Hasher>::new(leaves.clone(), cap_height);
for (i, leaf) in leaves.into_iter().enumerate() {
let proof = tree.prove(i);
verify_merkle_proof_to_cap(leaf, i, &tree.cap, &proof)?;
}
Ok(())
}
#[test]
#[should_panic]
fn test_cap_height_too_big() {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let log_n = 8;
let cap_height = log_n + 1;
let leaves = random_data::<F>(1 << log_n, 7);
let _ = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new(leaves, cap_height);
}
#[test]
fn test_cap_height_eq_log2_len() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let log_n = 8;
let n = 1 << log_n;
let leaves = random_data::<F>(n, 7);
verify_all_leaves::<F, C, D>(leaves, log_n)?;
Ok(())
}
#[test]
fn test_merkle_trees() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let log_n = 8;
let n = 1 << log_n;
let leaves = random_data::<F>(n, 7);
verify_all_leaves::<F, C, D>(leaves, 1)?;
Ok(())
}
}