use super::{Digest, Hashable, NounDecode, NounEncode};
use alloc::vec::Vec;
#[cfg(feature = "wasm")]
use alloc::{boxed::Box, format, string::ToString};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, NounEncode, NounDecode, Hashable, Serialize, Deserialize)]
#[iris_ztd_derive::wasm_noun_codec]
pub struct MerkleProof {
pub root: Digest,
pub path: Vec<Digest>,
}
#[derive(Debug, Clone, NounEncode, NounDecode, Hashable, Serialize, Deserialize)]
#[iris_ztd_derive::wasm_noun_codec]
pub struct MerkleProvenAxis {
pub proof: MerkleProof,
pub axis: u64,
}
impl MerkleProof {
pub fn prove_hashable<T: Hashable>(item: &T, index: usize) -> MerkleProvenAxis {
let Some((left, right)) = item.hashable_pair() else {
return MerkleProvenAxis {
proof: Self {
root: item.hash(),
path: Vec::new(),
},
axis: 1,
};
};
let lc = left.leaf_count();
if index < lc {
let mut rec = Self::prove_hashable(&left, index);
let sib = right.hash();
rec.proof.root = (rec.proof.root, sib).hash();
rec.proof.path.push(sib);
let alz = rec.axis.leading_zeros();
rec.axis ^= 0b11 << (63 - alz);
rec
} else {
let mut rec = Self::prove_hashable(&right, index - lc);
let sib = left.hash();
rec.proof.root = (sib, rec.proof.root).hash();
rec.proof.path.push(sib);
let alz = rec.axis.leading_zeros();
rec.axis ^= 0b10 << (63 - alz);
rec
}
}
pub fn verify(&self, mut axis: u64, hashable: &impl Hashable) -> bool {
let mut leaf = hashable.hash();
let mut path = &self.path[..];
while axis > 1 {
let Some((sib, rest)) = path.split_first() else {
return false;
};
path = rest;
if axis.is_multiple_of(2) {
leaf = (leaf, sib).hash();
} else {
leaf = (sib, leaf).hash();
}
axis /= 2;
}
axis == 1 && self.root == leaf && path.is_empty()
}
pub fn visible_hashes(
&self,
mut axis: u64,
hashable: &impl Hashable,
) -> Option<Vec<(u64, Digest)>> {
let mut hashes = Vec::new();
let mut leaf = hashable.hash();
let mut path = &self.path[..];
while axis > 1 {
let (sib, rest) = path.split_first()?;
path = rest;
if axis.is_multiple_of(2) {
hashes.push((axis ^ 1, *sib));
hashes.push((axis, leaf));
leaf = (leaf, sib).hash();
} else {
hashes.push((axis, leaf));
hashes.push((axis ^ 1, *sib));
leaf = (sib, leaf).hash();
}
axis /= 2;
}
if axis == 1 && self.root == leaf && path.is_empty() {
hashes.push((1, leaf));
Some(hashes)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::HashableList;
use alloc::string::ToString;
#[test]
fn test_empty_proof() {
let MerkleProvenAxis { proof, axis } = MerkleProof::prove_hashable(&(), 0);
assert_eq!(axis, 1);
assert_eq!(proof.root.to_string(), ().hash().to_string());
assert_eq!(proof.path.len(), 0);
assert!(proof.verify(axis, &()));
let all_hashes = proof.visible_hashes(axis, &());
let all_hashes = all_hashes
.unwrap()
.iter()
.map(|(a, h)| (*a, h.to_string()))
.collect::<Vec<_>>();
assert_eq!(all_hashes, &[(1, ().hash().to_string())]);
}
#[test]
fn test_left_proof() {
let MerkleProvenAxis { proof, axis } = MerkleProof::prove_hashable(&((), ()), 0);
assert_eq!(axis, 2);
assert_eq!(
proof.root.to_string(),
"3LPSS51pUxLaxMD8VjyBSW6S9sotLpfx65zibBvm5k1xu18qt5ZGp3S"
);
assert_eq!(
proof.path.iter().map(|v| v.to_string()).collect::<Vec<_>>(),
["3Ssr4tiWsbX5CE3AG6p5qPHP51fiyvtt1XEEHmSbGgDjp3qjUew6DFB"]
);
assert!(proof.verify(axis, &()));
let all_hashes = proof.visible_hashes(axis, &());
let all_hashes = all_hashes
.unwrap()
.iter()
.map(|(a, h)| (*a, h.to_string()))
.collect::<Vec<_>>();
assert_eq!(
all_hashes,
&[
(3, ().hash().to_string()),
(2, ().hash().to_string()),
(1, proof.root.to_string()),
]
);
}
#[test]
fn test_right_proof() {
let MerkleProvenAxis { proof, axis } = MerkleProof::prove_hashable(&((), ()), 1);
assert_eq!(
proof.root.to_string(),
"3LPSS51pUxLaxMD8VjyBSW6S9sotLpfx65zibBvm5k1xu18qt5ZGp3S"
);
assert_eq!(
proof.path.iter().map(|v| v.to_string()).collect::<Vec<_>>(),
["3Ssr4tiWsbX5CE3AG6p5qPHP51fiyvtt1XEEHmSbGgDjp3qjUew6DFB"]
);
assert_eq!(axis, 3);
assert!(proof.verify(axis, &()));
let all_hashes = proof.visible_hashes(axis, &());
let all_hashes = all_hashes
.unwrap()
.iter()
.map(|(a, h)| (*a, h.to_string()))
.collect::<Vec<_>>();
assert_eq!(
all_hashes,
&[
(3, ().hash().to_string()),
(2, ().hash().to_string()),
(1, proof.root.to_string()),
]
);
}
#[test]
fn test_complex_proof() {
let MerkleProvenAxis { proof, axis } =
MerkleProof::prove_hashable(&((1u64, 2u64), (3u64, 4u64)), 2);
assert_eq!(
proof.root.to_string(),
"9BC9gRQaJ7Ub4SivF6NmPBQrmqfwdKeDSkbkRjmnKf9yYscct3AcohH"
);
assert_eq!(
proof.path.iter().map(|v| v.to_string()).collect::<Vec<_>>(),
[
"CdEJceqNNH5iCGYEsWhRf2gHE37zbJXVkVPLpfWW7uYrJjt8magUvgi",
"BqxDmSrtFP6QsDuoYxjaFxedEzGpy7gfwhmtZnD25FxeedB1ssNPH4t"
]
);
assert_eq!(axis, 6);
assert!(proof.verify(axis, &3u64));
let all_hashes = proof.visible_hashes(axis, &3u64);
let all_hashes = all_hashes
.unwrap()
.iter()
.map(|(a, h)| (*a, h.to_string()))
.collect::<Vec<_>>();
assert_eq!(
all_hashes,
&[
(7, 4.hash().to_string()),
(6, 3.hash().to_string()),
(3, (3, 4).hash().to_string()),
(2, (1, 2).hash().to_string()),
(1, proof.root.to_string()),
]
);
}
#[test]
fn test_list_proof() {
let lst = [0u64, 1u64];
let lst = HashableList(&lst[..]);
let MerkleProvenAxis { proof, axis } = MerkleProof::prove_hashable(&(&lst, ()), 0);
assert_eq!(axis, 2);
assert_eq!(
proof.root.to_string(),
"8cSTFCsmaL4KTMVq6RQSQaMMMQfb3YpT6xR1YmnRtG7P4WurnhRRDbM"
);
assert_eq!(
proof.path.iter().map(|v| v.to_string()).collect::<Vec<_>>(),
["3Ssr4tiWsbX5CE3AG6p5qPHP51fiyvtt1XEEHmSbGgDjp3qjUew6DFB"]
);
assert!(proof.verify(axis, &lst));
let all_hashes = proof.visible_hashes(axis, &lst);
let all_hashes = all_hashes
.unwrap()
.iter()
.map(|(a, h)| (*a, h.to_string()))
.collect::<Vec<_>>();
assert_eq!(
all_hashes,
&[
(3, ().hash().to_string()),
(2, lst.hash().to_string()),
(1, proof.root.to_string()),
]
);
}
}