use std::{fmt, mem::swap};
use ark_std::rand::{CryptoRng, RngCore};
use serde::{Deserialize, Serialize};
#[cfg(feature = "tracing")]
use tracing::{instrument, span, Level};
use zerocopy::IntoBytes;
use crate::{
engines::EngineId,
hash::{self, Hash, HashEngine, ENGINES},
transcript::{
DuplexSpongeInterface, ProverMessage, ProverState, VerificationError, VerificationResult,
VerifierState,
},
utils::zip_strict,
verify,
};
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)]
pub struct Config {
pub num_leaves: usize,
pub layers: Vec<LayerConfig>,
}
#[derive(
Clone, PartialEq, Eq, PartialOrd, Ord, Copy, Debug, Hash, Default, Serialize, Deserialize,
)]
pub struct LayerConfig {
pub hash_id: EngineId,
}
impl fmt::Display for Config {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "MerkleTree(num_leaves: {})", self.num_leaves)
}
}
#[derive(
Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize,
)]
#[must_use]
pub struct Commitment {
hash: Hash,
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)]
#[must_use]
pub struct Witness {
nodes: Vec<Hash>,
}
impl Config {
pub fn new(num_leaves: usize) -> Self {
Self::with_hash(hash::BLAKE3, num_leaves)
}
pub fn with_hash(hash_id: EngineId, num_leaves: usize) -> Self {
Self {
num_leaves,
layers: vec![LayerConfig { hash_id }; layers_for_size(num_leaves)],
}
}
pub const fn num_nodes(&self) -> usize {
(1 << (self.layers.len() + 1)) - 1
}
#[cfg_attr(feature = "tracing", instrument(skip(prover_state, leaves), fields(self = %self)))]
pub fn commit<H, R>(&self, prover_state: &mut ProverState<H, R>, leaves: Vec<Hash>) -> Witness
where
H: DuplexSpongeInterface,
R: RngCore + CryptoRng,
Hash: ProverMessage<[H::U]>,
{
assert_eq!(
leaves.len(),
self.num_leaves,
"Expected {} leaf hashes, got {}",
self.num_leaves,
leaves.len()
);
let mut nodes = leaves;
nodes.resize(self.num_nodes(), Hash::default());
let (mut previous, mut remaining) = nodes.split_at_mut(1 << self.layers.len());
for layer in self.layers.iter().rev() {
let (current, next_remaining) = remaining.split_at_mut(previous.len() / 2);
let engine = ENGINES
.retrieve(layer.hash_id)
.expect("Hash Engine not found");
#[cfg(feature = "tracing")]
let _span = span!(
Level::DEBUG,
"layer",
engine = engine.name().as_ref(),
count = current.len()
)
.entered();
parallel_hash(&*engine, 64, previous.as_bytes(), current);
previous = current;
remaining = next_remaining;
}
prover_state.prover_message(&previous[0]);
Witness { nodes }
}
pub fn receive_commitment<H>(
&self,
verifier_state: &mut VerifierState<H>,
) -> VerificationResult<Commitment>
where
H: DuplexSpongeInterface,
Hash: ProverMessage<[H::U]>,
{
let hash = verifier_state.prover_message()?;
Ok(Commitment { hash })
}
#[cfg_attr(feature = "tracing", instrument(skip_all, fields(num_indices = indices.len())))]
pub fn open<H, R>(
&self,
prover_state: &mut ProverState<H, R>,
witness: &Witness,
indices: &[usize],
) where
H: DuplexSpongeInterface,
R: RngCore + CryptoRng,
Hash: ProverMessage<[H::U]>,
{
assert_eq!(witness.nodes.len(), self.num_nodes());
assert!(indices.iter().all(|&i| i < self.num_leaves));
let mut indices = indices.to_vec();
indices.sort_unstable();
indices.dedup();
let (mut layer, mut remaining) = witness.nodes.split_at(1 << self.layers.len());
while layer.len() > 1 {
let mut next_indices = Vec::with_capacity(indices.len());
let mut iter = indices.iter().copied().peekable();
loop {
match (iter.next(), iter.peek()) {
(Some(a), Some(&b)) if b == a ^ 1 => {
next_indices.push(a >> 1);
iter.next(); }
(Some(a), _) => {
prover_state.prover_hint(&layer[a ^ 1]);
next_indices.push(a >> 1);
}
(None, _) => break,
}
}
indices = next_indices;
let (next_layer, next_remaining) = remaining.split_at(layer.len() / 2);
layer = next_layer;
remaining = next_remaining;
}
}
pub fn verify<H>(
&self,
verifier_state: &mut VerifierState<H>,
commitment: &Commitment,
indices: &[usize],
leaf_hashes: &[Hash],
) -> VerificationResult<()>
where
H: DuplexSpongeInterface,
Hash: ProverMessage<[H::U]>,
{
verify!(indices.len() == leaf_hashes.len());
verify!(indices.iter().all(|&i| i < self.num_leaves));
if indices.is_empty() {
return Ok(());
}
let mut layer = zip_strict(indices.iter().copied(), leaf_hashes.iter().copied())
.collect::<Vec<(usize, Hash)>>();
layer.sort_unstable_by_key(|(i, _)| *i);
for i in 1..layer.len() {
if layer[i - 1].0 == layer[i].0 {
verify!(layer[i - 1].1 == layer[i].1);
}
}
layer.dedup_by_key(|(i, _)| *i);
let mut indices = layer.iter().map(|(i, _)| *i).collect::<Vec<_>>();
let mut hashes = layer.iter().map(|(_, h)| *h).collect::<Vec<_>>();
let mut next_indices = Vec::with_capacity(layer.len());
let mut input_hashes = Vec::with_capacity(layer.len() * 2);
let mut next_hashes = Vec::with_capacity(layer.len());
for layer in self.layers.iter().rev() {
next_indices.clear();
input_hashes.clear();
next_hashes.clear();
let mut indices_iter = indices.iter().copied().peekable();
let mut hashes_iter = hashes.iter().copied();
loop {
match (indices_iter.next(), indices_iter.peek()) {
(Some(a), Some(&b)) if b == a ^ 1 => {
input_hashes.push(hashes_iter.next().unwrap());
input_hashes.push(hashes_iter.next().unwrap());
next_indices.push(a >> 1);
indices_iter.next(); }
(Some(a), _) => {
let hash = verifier_state.prover_hint()?;
if a & 1 == 0 {
input_hashes.push(hashes_iter.next().unwrap());
input_hashes.push(hash);
} else {
input_hashes.push(hash);
input_hashes.push(hashes_iter.next().unwrap());
}
next_indices.push(a >> 1);
}
(None, _) => break,
}
}
next_hashes.resize(input_hashes.len() / 2, Hash::default());
ENGINES
.retrieve(layer.hash_id)
.ok_or(VerificationError)?
.hash_many(64, input_hashes.as_bytes(), &mut next_hashes);
swap(&mut indices, &mut next_indices);
swap(&mut hashes, &mut next_hashes);
}
verify!(indices == [0]);
verify!(hashes == [commitment.hash]);
Ok(())
}
}
impl Witness {
pub const fn num_nodes(&self) -> usize {
self.nodes.len()
}
}
pub const fn layers_for_size(size: usize) -> usize {
size.next_power_of_two().ilog2() as usize
}
#[cfg(not(feature = "parallel"))]
fn parallel_hash(engine: &dyn HashEngine, size: usize, input: &[u8], output: &mut [Hash]) {
engine.hash_many(size, input, output);
}
#[cfg(feature = "parallel")]
fn parallel_hash(engine: &dyn HashEngine, size: usize, input: &[u8], output: &mut [Hash]) {
use crate::utils::workload_size;
assert_eq!(input.len(), size * output.len());
if input.len() > workload_size::<u8>() && input.len() / size >= 2 {
let (input_a, input_b) = input.split_at(input.len() / 2);
let (output_a, output_b) = output.split_at_mut(output.len() / 2);
rayon::join(
|| parallel_hash(engine, size, input_a, output_a),
|| parallel_hash(engine, size, input_b, output_b),
);
} else {
engine.hash_many(size, input, output);
}
}
#[cfg(test)]
pub(crate) mod tests {
use proptest::{collection::vec, prelude::Strategy};
use super::*;
use crate::{
hash::{tests::hash_for_size, BLAKE3},
transcript::{codecs::Empty, DomainSeparator},
};
pub fn config(num_leaves: usize) -> impl Strategy<Value = Config> {
let min_layers = layers_for_size(num_leaves);
let num_layers = min_layers..=min_layers + 3;
let layer = hash_for_size(64).prop_map(|hash_id| LayerConfig { hash_id });
vec(layer, num_layers).prop_map(move |layers| Config { num_leaves, layers })
}
#[test]
fn test_merkle_tree() {
crate::tests::init();
let config = Config {
num_leaves: 256,
layers: vec![LayerConfig { hash_id: BLAKE3 }; 8],
};
let leaves = (0..config.num_leaves)
.map(|i| Hash([i as u8; 32]))
.collect::<Vec<_>>();
let ds = DomainSeparator::protocol(&config)
.session(&format!("Test at {}:{}", file!(), line!()))
.instance(&Empty);
let mut prover_state = ProverState::new_std(&ds);
let tree = config.commit(&mut prover_state, leaves);
config.open(&mut prover_state, &tree, &[13, 42]);
let proof = prover_state.proof();
let mut verifier_state = VerifierState::new_std(&ds, &proof);
let root = config.receive_commitment(&mut verifier_state).unwrap();
config
.verify(
&mut verifier_state,
&root,
&[13, 42],
&[Hash([13; 32]), Hash([42; 32])],
)
.unwrap();
}
#[test]
fn test_layers_for_size() {
assert_eq!(layers_for_size(0), 0);
assert_eq!(layers_for_size(1), 0);
assert_eq!(layers_for_size(2), 1);
assert_eq!(layers_for_size(3), 2);
assert_eq!(layers_for_size(4), 2);
assert_eq!(layers_for_size(5), 3);
assert_eq!(layers_for_size(6), 3);
assert_eq!(layers_for_size(7), 3);
assert_eq!(layers_for_size(8), 3);
}
}