use std::fmt::Debug;
use std::marker::PhantomData;
use crate::node_template::{HashNodeSmt, MTreeNodeSmt, SumNodeSmt};
use crate::pad_secret::ALL_ZEROS_SECRET;
use crate::{
index::{TreeIndex, MAX_HEIGHT},
node_template,
proof::{MerkleProof, RandomSamplingProof},
traits::{
InclusionProvable, Mergeable, Paddable, PaddingProvable, ProofExtractable, Rand,
RandomSampleable, Serializable, TypeName,
},
tree::SparseMerkleTree,
utils::{generate_sorted_index_value_pairs, print_output},
};
type SMT<P> = SparseMerkleTree<P>;
pub struct Tester<P> {
_phantom: PhantomData<P>,
}
const LEAF_NUM: usize = 100;
const TREE_HEIGHT: usize = 8;
#[test]
#[should_panic]
fn test_index_exceed_max_height() {
let _index = TreeIndex::zero(MAX_HEIGHT + 1);
}
#[test]
#[should_panic]
fn test_tree_exceed_max_height() {
let _tree: SMT<SumNodeSmt> = SMT::new(MAX_HEIGHT + 1);
}
#[test]
fn test_padding_provable() {
let mut idx = TreeIndex::zero(256);
let secret = &ALL_ZEROS_SECRET;
for _i in 0..1000 {
idx.randomize();
let sum = SumNodeSmt::padding(&idx, secret);
assert!(SumNodeSmt::verify_padding_node(
&sum.get_proof_node(),
&sum.prove_padding_node(&idx, secret),
&idx
));
let node = HashNodeSmt::<blake3::Hasher>::padding(&idx, secret);
assert!(
node_template::HashNodeSmt::<blake3::Hasher>::verify_padding_node(
&node.get_proof_node(),
&node.prove_padding_node(&idx, &secret),
&idx,
)
);
}
}
impl<
P: Default
+ Clone
+ Mergeable
+ Paddable
+ ProofExtractable
+ Rand
+ TypeName
+ PaddingProvable,
> Tester<P>
where
<P as ProofExtractable>::ProofNode:
Debug + Clone + Default + Eq + Debug + Mergeable + Serializable,
<P as PaddingProvable>::PaddingProof: Clone + Default + Eq + Serializable,
{
fn test_building_smt(list: &[(TreeIndex, P)]) -> SMT<P> {
let secret = &ALL_ZEROS_SECRET;
let mut build_tree = SMT::new(TREE_HEIGHT);
build_tree.build(&list, secret);
let mut update_tree = SMT::new(TREE_HEIGHT);
for item in list.iter() {
update_tree.update(&item.0, item.1.clone(), secret);
}
assert_eq!(build_tree.get_root(), update_tree.get_root());
assert_eq!(
build_tree.get_leaves().len(),
update_tree.get_leaves().len()
);
assert_eq!(
build_tree.get_paddings().len(),
update_tree.get_paddings().len()
);
assert_eq!(
build_tree.get_internals().len(),
update_tree.get_internals().len()
);
build_tree
}
fn merkle_proof_existing(tree: &SMT<P>, leaves: &[P::ProofNode], list: &[TreeIndex]) -> bool {
let proof = MerkleProof::<P>::generate_inclusion_proof(&tree, list);
match proof {
None => unreachable!(),
Some(proof) => {
let serialized_proof = proof.serialize();
let deserialized_proof = MerkleProof::<P>::deserialize(&serialized_proof).unwrap();
deserialized_proof.verify_inclusion_proof(leaves, &tree.get_root())
}
}
}
fn test_merkle_proof(list: &[(TreeIndex, P)], tree: &SMT<P>) {
for item in list.iter() {
assert!(Tester::<P>::merkle_proof_existing(
tree,
&[item.1.get_proof_node()],
&[item.0]
));
}
}
fn test_merkle_proof_batch(list: &[(TreeIndex, P)], tree: &SMT<P>) {
assert!(Tester::<P>::merkle_proof_existing(tree, &[], &[]));
for batch_size in &[1, 100, list.len()] {
for i in 0..LEAF_NUM / batch_size {
let mut proof_list = Vec::new();
let mut leaves = Vec::new();
for j in 0..*batch_size {
proof_list.push(list[i * batch_size + j].0);
leaves.push(list[i * batch_size + j].1.get_proof_node());
}
assert!(Tester::<P>::merkle_proof_existing(
tree,
&leaves,
&proof_list
));
}
}
}
fn random_sampling(tree: &SMT<P>, idx: &TreeIndex) -> bool {
let secret = &ALL_ZEROS_SECRET;
let proof = RandomSamplingProof::<P>::random_sampling(tree, idx, secret);
let serialized = proof.serialize();
let deserialized = RandomSamplingProof::<P>::deserialize(&serialized).unwrap();
deserialized.verify_random_sampling_proof(&tree.get_root())
}
fn test_random_sampling(list: &[(TreeIndex, P)], tree: &SMT<P>) {
let secret = &ALL_ZEROS_SECRET;
for item in list.iter() {
assert!(Tester::<P>::random_sampling(tree, &item.0));
}
let empty_tree: SMT<P> = SMT::new(TREE_HEIGHT);
assert!(Tester::<P>::random_sampling(&empty_tree, &list[0].0));
let index = list[0].0.get_left_index();
if let Some(index) = index {
assert!(Tester::<P>::random_sampling(tree, &index));
let proof = RandomSamplingProof::<P>::random_sampling(tree, &index, secret);
assert_eq!(proof.get_merkle_proof().get_indexes().len(), 1);
assert_eq!(proof.get_merkle_proof().get_indexes()[0], list[0].0);
}
let index = list[list.len() - 1].0.get_right_index();
if let Some(index) = index {
assert!(Tester::<P>::random_sampling(tree, &index));
let proof = RandomSamplingProof::<P>::random_sampling(tree, &index, secret);
assert_eq!(proof.get_merkle_proof().get_indexes().len(), 1);
assert_eq!(
proof.get_merkle_proof().get_indexes()[0],
list[list.len() - 1].0
);
}
for i in 1..list.len() {
let index = list[i].0.get_left_index().unwrap();
if index > list[i - 1].0 {
assert!(Tester::<P>::random_sampling(tree, &index));
let proof = RandomSamplingProof::<P>::random_sampling(tree, &index, secret);
assert_eq!(proof.get_merkle_proof().get_indexes().len(), 2);
assert_eq!(proof.get_merkle_proof().get_indexes()[0], list[i - 1].0);
assert_eq!(proof.get_merkle_proof().get_indexes()[1], list[i].0);
}
let index = list[i - 1].0.get_right_index().unwrap();
if index < list[i].0 {
assert!(Tester::<P>::random_sampling(tree, &index));
let proof = RandomSamplingProof::<P>::random_sampling(tree, &index, secret);
assert_eq!(proof.get_merkle_proof().get_indexes().len(), 2);
assert_eq!(proof.get_merkle_proof().get_indexes()[0], list[i - 1].0);
assert_eq!(proof.get_merkle_proof().get_indexes()[1], list[i].0);
}
}
}
pub fn test() {
for _iter in 0..10 {
println!(
"Test #{} for SMT({}) with {} leaves of {} starts!",
_iter,
TREE_HEIGHT,
LEAF_NUM,
P::get_name()
);
let list: Vec<(TreeIndex, P)> =
generate_sorted_index_value_pairs(TREE_HEIGHT, LEAF_NUM);
let tree = Tester::<P>::test_building_smt(&list);
Tester::<P>::test_merkle_proof(&list, &tree);
Tester::<P>::test_merkle_proof_batch(&list, &tree);
Tester::<P>::test_random_sampling(&list, &tree);
println!("Succeed!");
}
}
}
#[test]
fn test_smt() {
Tester::<node_template::SumNodeSmt>::test();
Tester::<node_template::HashNodeSmt<blake3::Hasher>>::test();
Tester::<node_template::HashNodeSmt<blake2::Blake2b>>::test();
Tester::<node_template::HashNodeSmt<sha2::Sha256>>::test();
Tester::<node_template::HashNodeSmt<sha3::Sha3_256>>::test();
}
#[test]
fn test_merkle_tree() {
let example_leaf = MTreeNodeSmt::new(vec![0; 32]);
let list: Vec<MTreeNodeSmt<blake3::Hasher>> = vec![example_leaf.clone(); 5];
let tree = SMT::<MTreeNodeSmt<blake3::Hasher>>::new_merkle_tree(&list);
assert_eq!(tree.get_height(), 3); assert_eq!(tree.get_paddings().len(), 2);
let index_list = vec![TreeIndex::from_u64(tree.get_height(), 2)];
let proof =
MerkleProof::<MTreeNodeSmt<blake3::Hasher>>::generate_inclusion_proof(&tree, &index_list)
.unwrap();
assert_eq!(proof.verify(&example_leaf, &tree.get_root()), true);
let serialized_proof = proof.serialize();
let deserialized_proof =
MerkleProof::<MTreeNodeSmt<blake3::Hasher>>::deserialize(&serialized_proof).unwrap();
assert_eq!(serialized_proof, deserialized_proof.serialize());
assert_eq!(
deserialized_proof.verify(&example_leaf, &tree.get_root()),
true
);
print_output(&tree);
}