use crate::merkle::{Bagging, Error, Family, Location, Position};
use alloc::vec::Vec;
use commonware_cryptography::{Digest, Hasher as CHasher};
use core::marker::PhantomData;
pub trait Hasher<F: Family>: Clone + Send + Sync {
type Digest: Digest;
fn hash<'a>(&self, parts: impl IntoIterator<Item = &'a [u8]>) -> Self::Digest;
fn root_bagging(&self) -> Bagging;
fn node_digest(
&self,
pos: Position<F>,
left: &Self::Digest,
right: &Self::Digest,
) -> Self::Digest {
self.hash([
(*pos).to_be_bytes().as_slice(),
left.as_ref(),
right.as_ref(),
])
}
fn leaf_digest(&self, pos: Position<F>, element: &[u8]) -> Self::Digest {
self.hash([(*pos).to_be_bytes().as_slice(), element])
}
fn digest(&self, data: &[u8]) -> Self::Digest {
self.hash(core::iter::once(data))
}
fn fold(&self, acc: &Self::Digest, peak: &Self::Digest) -> Self::Digest {
self.hash([acc.as_ref(), peak.as_ref()])
}
fn root<'a, I>(
&self,
leaves: Location<F>,
inactive_peaks: usize,
peak_digests: I,
) -> Result<Self::Digest, Error<F>>
where
I: IntoIterator<Item = &'a Self::Digest>,
I::IntoIter: ExactSizeIterator,
{
let iter = peak_digests.into_iter();
let peaks = iter.len();
self.root_with_folded_peaks(leaves, inactive_peaks, inactive_peaks, iter)
.ok_or(Error::InvalidInactivePeaks {
requested: inactive_peaks,
peaks,
})
}
fn root_with_folded_peaks<'a>(
&self,
leaves: Location<F>,
inactive_peaks_to_fold: usize,
committed_inactive_peaks: usize,
peak_digests: impl IntoIterator<Item = &'a Self::Digest>,
) -> Option<Self::Digest> {
let mut peak_digests = peak_digests.into_iter();
let Some(first) = peak_digests.next() else {
return (inactive_peaks_to_fold == 0 && committed_inactive_peaks == 0)
.then(|| self.digest(&(*leaves).to_be_bytes()));
};
let mut acc = *first;
for _ in 0..inactive_peaks_to_fold.saturating_sub(1) {
let peak = peak_digests.next()?;
acc = self.fold(&acc, peak);
}
let folded_peaks = match self.root_bagging() {
Bagging::ForwardFold => {
for peak in peak_digests {
acc = self.fold(&acc, peak);
}
acc
}
Bagging::BackwardFold => {
let (lower, upper) = peak_digests.size_hint();
let mut active_peaks = Vec::with_capacity(1 + upper.unwrap_or(lower));
active_peaks.push(acc);
active_peaks.extend(peak_digests.copied());
let mut acc = *active_peaks.last().unwrap();
for peak in active_peaks.iter().rev().skip(1) {
acc = self.fold(peak, &acc);
}
acc
}
};
if committed_inactive_peaks == 0 {
Some(self.hash([(*leaves).to_be_bytes().as_slice(), folded_peaks.as_ref()]))
} else {
Some(self.hash([
(*leaves).to_be_bytes().as_slice(),
(committed_inactive_peaks as u64).to_be_bytes().as_slice(),
folded_peaks.as_ref(),
]))
}
}
}
#[derive(Clone)]
pub struct Standard<H: CHasher> {
_hasher: PhantomData<H>,
bagging: Bagging,
}
impl<H: CHasher> Standard<H> {
pub const fn new(bagging: Bagging) -> Self {
Self {
_hasher: PhantomData,
bagging,
}
}
pub const fn root_bagging(&self) -> Bagging {
self.bagging
}
pub fn hash<'a>(&self, parts: impl IntoIterator<Item = &'a [u8]>) -> H::Digest {
let mut h = H::new();
for part in parts {
h.update(part);
}
h.finalize()
}
pub fn digest(&self, data: &[u8]) -> H::Digest {
self.hash(core::iter::once(data))
}
}
impl<F: Family, H: CHasher> Hasher<F> for Standard<H> {
type Digest = H::Digest;
fn hash<'a>(&self, parts: impl IntoIterator<Item = &'a [u8]>) -> H::Digest {
Self::hash(self, parts)
}
fn root_bagging(&self) -> Bagging {
Self::root_bagging(self)
}
}
impl<F: Family, T: Hasher<F>> Hasher<F> for &T {
type Digest = T::Digest;
fn hash<'a>(&self, parts: impl IntoIterator<Item = &'a [u8]>) -> Self::Digest {
(**self).hash(parts)
}
fn root_bagging(&self) -> Bagging {
(**self).root_bagging()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::merkle::{
mmr::{Location, Position, StandardHasher as Standard},
Bagging::{BackwardFold, ForwardFold},
};
use alloc::vec::Vec;
use commonware_cryptography::{sha256, Hasher as CHasher, Sha256};
#[test]
fn test_leaf_digest_sha256() {
test_leaf_digest::<Sha256>();
}
#[test]
fn test_node_digest_sha256() {
test_node_digest::<Sha256>();
}
#[test]
fn test_root_sha256() {
test_root::<Sha256>();
}
#[test]
fn test_invalid_inactive_prefix_returns_err() {
let mmr_hasher: Standard<Sha256> = Standard::new(BackwardFold);
let d1 = test_digest::<Sha256>(1);
let d2 = test_digest::<Sha256>(2);
let digests = [d1, d2];
assert!(matches!(
<Standard<Sha256> as Hasher<crate::merkle::mmr::Family>>::root(
&mmr_hasher,
Location::new(2),
3,
digests.iter()
),
Err(crate::merkle::Error::InvalidInactivePeaks {
requested: 3,
peaks: 2
})
));
assert!(
<Standard<Sha256> as Hasher<crate::merkle::mmr::Family>>::root_with_folded_peaks(
&mmr_hasher,
Location::new(2),
3,
3,
digests.iter()
)
.is_none()
);
assert!(matches!(
<Standard<Sha256> as Hasher<crate::merkle::mmr::Family>>::root(
&mmr_hasher,
Location::new(0),
1,
Vec::<sha256::Digest>::new().iter()
),
Err(crate::merkle::Error::InvalidInactivePeaks {
requested: 1,
peaks: 0
})
));
}
fn test_digest<H: CHasher>(value: u8) -> H::Digest {
H::hash(&[value])
}
fn test_leaf_digest<H: CHasher>() {
let mmr_hasher: Standard<H> = Standard::new(ForwardFold);
let digest1 = test_digest::<H>(1);
let digest2 = test_digest::<H>(2);
let out = mmr_hasher.leaf_digest(Position::new(0), &digest1);
assert_ne!(out, test_digest::<H>(0), "hash should be non-zero");
let mut out2 = mmr_hasher.leaf_digest(Position::new(0), &digest1);
assert_eq!(out, out2, "hash should be re-computed consistently");
out2 = mmr_hasher.leaf_digest(Position::new(1), &digest1);
assert_ne!(out, out2, "hash should change with different pos");
out2 = mmr_hasher.leaf_digest(Position::new(0), &digest2);
assert_ne!(out, out2, "hash should change with different input digest");
}
fn test_node_digest<H: CHasher>() {
let mmr_hasher: Standard<H> = Standard::new(ForwardFold);
let d1 = test_digest::<H>(1);
let d2 = test_digest::<H>(2);
let d3 = test_digest::<H>(3);
let out = mmr_hasher.node_digest(Position::new(0), &d1, &d2);
assert_ne!(out, test_digest::<H>(0), "hash should be non-zero");
let mut out2 = mmr_hasher.node_digest(Position::new(0), &d1, &d2);
assert_eq!(out, out2, "hash should be re-computed consistently");
out2 = mmr_hasher.node_digest(Position::new(1), &d1, &d2);
assert_ne!(out, out2, "hash should change with different pos");
out2 = mmr_hasher.node_digest(Position::new(0), &d3, &d2);
assert_ne!(
out, out2,
"hash should change with different first input hash"
);
out2 = mmr_hasher.node_digest(Position::new(0), &d1, &d3);
assert_ne!(
out, out2,
"hash should change with different second input hash"
);
out2 = mmr_hasher.node_digest(Position::new(0), &d2, &d1);
assert_ne!(
out, out2,
"hash should change when swapping order of inputs"
);
}
fn test_root<H: CHasher>() {
let mmr_hasher: Standard<H> = Standard::new(ForwardFold);
let d1 = test_digest::<H>(1);
let d2 = test_digest::<H>(2);
let d3 = test_digest::<H>(3);
let d4 = test_digest::<H>(4);
let empty_vec: Vec<H::Digest> = Vec::new();
let empty_out = mmr_hasher
.root(Location::new(0), 0, empty_vec.iter())
.expect("zero inactive peaks is always valid");
assert_ne!(
empty_out,
test_digest::<H>(0),
"root of empty MMR should be non-zero"
);
assert_eq!(
empty_out,
mmr_hasher
.root(Location::new(0), 0, empty_vec.iter())
.expect("zero inactive peaks is always valid")
);
let digests = [d1, d2, d3, d4];
let out = mmr_hasher
.root(Location::new(10), 0, digests.iter())
.expect("zero inactive peaks is always valid");
assert_ne!(out, test_digest::<H>(0), "root should be non-zero");
assert_ne!(out, empty_out, "root should differ from empty MMR");
let mut out2 = mmr_hasher
.root(Location::new(10), 0, digests.iter())
.expect("zero inactive peaks is always valid");
assert_eq!(out, out2, "root should be computed consistently");
out2 = mmr_hasher
.root(Location::new(11), 0, digests.iter())
.expect("zero inactive peaks is always valid");
assert_ne!(out, out2, "root should change with different position");
let digests = [d1, d2, d4, d3];
out2 = mmr_hasher
.root(Location::new(10), 0, digests.iter())
.expect("zero inactive peaks is always valid");
assert_ne!(out, out2, "root should change with different digest order");
let digests = [d1, d2, d3];
out2 = mmr_hasher
.root(Location::new(10), 0, digests.iter())
.expect("zero inactive peaks is always valid");
assert_ne!(
out, out2,
"root should change with different number of hashes"
);
}
}