extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use lib_q_core::Result;
use lib_q_poseidon::{
Poseidon,
Poseidon128,
};
use crate::air::merkle_inclusion::{
MAX_TREE_DEPTH,
MerkleHash,
};
use crate::air::{
merkle_root_to_bytes,
next_power_of_two,
};
fn empty_leaf_hash() -> MerkleHash {
MerkleHash::hash_data(&[])
}
#[derive(Clone)]
pub struct PoseidonMerkleTree {
layers: Vec<Vec<MerkleHash>>,
num_real_leaves: usize,
}
impl PoseidonMerkleTree {
pub fn from_leaves(leaves: &[&[u8]]) -> Result<Self> {
let n = leaves.len();
if n == 0 {
return Err(lib_q_core::Error::InvalidState {
operation: "PoseidonMerkleTree::from_leaves".into(),
reason: "at least one leaf required".into(),
});
}
let padded = core::cmp::max(2, next_power_of_two(n));
let depth = (padded as u64).trailing_zeros() as usize;
if depth > MAX_TREE_DEPTH {
return Err(lib_q_core::Error::InvalidState {
operation: "PoseidonMerkleTree::from_leaves".into(),
reason: alloc::format!(
"leaf count (padded {}) exceeds 2^{}",
padded,
MAX_TREE_DEPTH
),
});
}
let mut layer0: Vec<MerkleHash> = leaves
.iter()
.map(|leaf| MerkleHash::hash_data(leaf))
.collect();
let empty = empty_leaf_hash();
while layer0.len() < padded {
layer0.push(empty.clone());
}
let mut layers = vec![layer0];
let poseidon = Poseidon128;
for _ in 0..depth {
let prev = layers.last().unwrap();
let mut next = Vec::with_capacity(prev.len().div_ceil(2));
for pair in prev.chunks(2) {
let left = pair[0].as_field();
let right = pair[1].as_field();
let combined = poseidon.hash(&[*left, *right]);
next.push(MerkleHash::from_field(combined[0]));
}
layers.push(next);
}
Ok(Self {
layers,
num_real_leaves: n,
})
}
#[must_use]
pub fn root(&self) -> MerkleHash {
self.layers
.last()
.and_then(|v| v.first())
.cloned()
.unwrap_or_else(empty_leaf_hash)
}
#[must_use]
pub fn root_bytes(&self) -> [u8; 32] {
merkle_root_to_bytes(self.root().as_field())
}
#[must_use]
pub fn depth(&self) -> usize {
self.layers.len().saturating_sub(1)
}
#[must_use]
pub fn num_leaves(&self) -> usize {
self.num_real_leaves
}
pub fn path(&self, leaf_index: usize) -> Result<(Vec<bool>, Vec<MerkleHash>)> {
if leaf_index >= self.num_real_leaves {
return Err(lib_q_core::Error::InvalidState {
operation: "PoseidonMerkleTree::path".into(),
reason: alloc::format!(
"leaf_index {} >= num_leaves {}",
leaf_index,
self.num_real_leaves
),
});
}
let mut path_bits = Vec::with_capacity(self.depth());
let mut siblings = Vec::with_capacity(self.depth());
let mut idx = leaf_index;
for layer in self.layers.iter().take(self.depth()) {
let sibling_idx = idx ^ 1;
let sibling = layer
.get(sibling_idx)
.cloned()
.unwrap_or_else(empty_leaf_hash);
siblings.push(sibling);
path_bits.push((idx & 1) == 1);
idx >>= 1;
}
Ok((path_bits, siblings))
}
#[must_use]
pub fn verify_path(
root: &MerkleHash,
leaf: &[u8],
path_bits: &[bool],
siblings: &[MerkleHash],
) -> bool {
use crate::air::merkle_inclusion::compute_merkle_root;
if path_bits.len() != siblings.len() {
return false;
}
let computed = compute_merkle_root(leaf, path_bits, siblings);
computed.as_field() == root.as_field()
}
}
impl Drop for PoseidonMerkleTree {
fn drop(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear();
}
self.layers.clear();
}
}
impl core::fmt::Debug for PoseidonMerkleTree {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("PoseidonMerkleTree")
.field("depth", &self.depth())
.field("num_real_leaves", &self.num_real_leaves)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_leaves_rejects_empty_input() {
let result = PoseidonMerkleTree::from_leaves(&[]);
assert!(result.is_err());
}
#[test]
fn test_tree_path_and_verify_path_round_trip() {
let leaves: Vec<&[u8]> = vec![b"a", b"b", b"c"];
let tree = PoseidonMerkleTree::from_leaves(&leaves).expect("tree");
let (path_bits, siblings) = tree.path(1).expect("path for index 1");
assert_eq!(path_bits.len(), tree.depth());
assert_eq!(siblings.len(), tree.depth());
assert_eq!(tree.num_leaves(), 3);
assert!(tree.root_bytes().iter().any(|b| *b != 0u8));
let is_valid =
PoseidonMerkleTree::verify_path(&tree.root(), leaves[1], &path_bits, &siblings);
assert!(is_valid);
}
#[test]
fn test_tree_path_rejects_out_of_bounds_index() {
let leaves: Vec<&[u8]> = vec![b"x", b"y"];
let tree = PoseidonMerkleTree::from_leaves(&leaves).expect("tree");
let result = tree.path(2);
assert!(result.is_err());
}
#[test]
fn test_verify_path_rejects_invalid_inputs() {
let leaves: Vec<&[u8]> = vec![b"left", b"right"];
let tree = PoseidonMerkleTree::from_leaves(&leaves).expect("tree");
let (path_bits, siblings) = tree.path(0).expect("path");
let mut wrong_siblings = siblings.clone();
wrong_siblings[0] = MerkleHash::hash_data(b"wrong");
assert!(!PoseidonMerkleTree::verify_path(
&tree.root(),
leaves[0],
&path_bits,
&wrong_siblings
));
assert!(!PoseidonMerkleTree::verify_path(
&tree.root(),
leaves[0],
&path_bits[..0],
&siblings
));
}
}