curv/cryptographic_primitives/hashing/
merkle_tree.rs1use std::marker::PhantomData;
9
10use digest::{Digest, Output};
11use merkle_cbt::merkle_tree::{Merge, MerkleProof, MerkleTree, CBMT};
12use serde::{Deserialize, Serialize};
13
14use crate::cryptographic_primitives::hashing::DigestExt;
15use crate::cryptographic_primitives::proofs::ProofError;
16use crate::elliptic::curves::{Curve, Point};
17
18pub struct MT256<E: Curve, H: Digest> {
19 tree: MerkleTree<Output<H>, MergeDigest<H>>,
20 leaves: Vec<Point<E>>,
21}
22
23impl<E: Curve, H: Digest + Clone> MT256<E, H> {
24 pub fn create_tree(leaves: Vec<Point<E>>) -> Self {
25 let hashes = leaves
26 .iter()
27 .map(|leaf| H::new().chain_point(leaf).finalize())
28 .collect::<Vec<_>>();
29
30 MT256 {
31 tree: CBMT::<Output<H>, MergeDigest<H>>::build_merkle_tree(&hashes),
32 leaves,
33 }
34 }
35
36 pub fn build_proof(&self, point: Point<E>) -> Option<Proof<E, H>> {
37 let index = (0u32..)
38 .zip(&self.leaves)
39 .find(|(_, leaf)| **leaf == point)
40 .map(|(i, _)| i)?;
41 let proof = self.tree.build_proof(&[index])?;
42 Some(Proof {
43 index: proof.indices()[0],
44 lemmas: proof.lemmas().to_vec(),
45 point,
46 })
47 }
48
49 pub fn get_root(&self) -> Output<H> {
50 self.tree.root()
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(bound(
56 serialize = "Output<H>: Serialize",
57 deserialize = "Output<H>: Deserialize<'de>"
58))]
59pub struct Proof<E: Curve, H: Digest> {
60 pub index: u32,
61 pub lemmas: Vec<Output<H>>,
62 pub point: Point<E>,
63}
64
65impl<E: Curve, H: Digest + Clone> Proof<E, H> {
66 pub fn verify(&self, root: &Output<H>) -> Result<(), ProofError> {
67 let leaf = H::new().chain_point(&self.point).finalize();
68 let valid =
69 MerkleProof::<Output<H>, MergeDigest<H>>::new(vec![self.index], self.lemmas.clone())
70 .verify(root, &[leaf]);
71 if valid {
72 Ok(())
73 } else {
74 Err(ProofError)
75 }
76 }
77}
78
79struct MergeDigest<D>(PhantomData<D>);
80
81impl<D> Merge for MergeDigest<D>
82where
83 D: Digest,
84{
85 type Item = Output<D>;
86
87 fn merge(left: &Self::Item, right: &Self::Item) -> Self::Item {
88 D::new().chain(left).chain(right).finalize()
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::MT256;
95 use crate::elliptic::curves::{Curve, Point};
96
97 use crate::test_for_all_curves;
98
99 test_for_all_curves!(test_mt_functionality_four_leaves);
100
101 fn test_mt_functionality_four_leaves<E: Curve>() {
102 let ge1: Point<E> = Point::generator().to_point();
103 let ge2: Point<E> = ge1.clone();
104 let ge3: Point<E> = &ge1 + &ge2;
105 let ge4: Point<E> = &ge1 + &ge3;
106 let ge_vec = vec![ge1.clone(), ge2, ge3, ge4];
107 let mt256 = MT256::<_, sha3::Keccak256>::create_tree(ge_vec);
108 let proof1 = mt256.build_proof(ge1).unwrap();
109 let root = mt256.get_root();
110 proof1.verify(&root).expect("proof is invalid");
111 }
112
113 test_for_all_curves!(test_mt_functionality_three_leaves);
114
115 fn test_mt_functionality_three_leaves<E: Curve>() {
116 let ge1: Point<E> = Point::generator().to_point();
117 let ge2: Point<E> = ge1.clone();
118 let ge3: Point<E> = &ge1 + &ge2;
119
120 let ge_vec = vec![ge1.clone(), ge2, ge3];
121 let mt256 = MT256::<_, sha3::Keccak256>::create_tree(ge_vec);
122 let proof1 = mt256.build_proof(ge1).unwrap();
123 let root = mt256.get_root();
124 proof1.verify(&root).expect("proof is invalid");
125 }
126}