zeekit/merkle/
mod.rs

1#[cfg(feature = "plonk")]
2pub mod plonk;
3
4#[cfg(feature = "groth16")]
5pub mod groth16;
6
7use super::config::LOG_TREE_SIZE;
8use crate::{mimc, Fr};
9use ff::Field;
10use std::collections::HashMap;
11
12pub struct SparseTree {
13    defaults: Vec<Fr>,
14    levels: Vec<HashMap<u64, Fr>>,
15}
16
17#[derive(Debug, Clone)]
18pub struct Proof(pub [Fr; LOG_TREE_SIZE]);
19impl Default for Proof {
20    fn default() -> Self {
21        Self([Fr::zero(); LOG_TREE_SIZE])
22    }
23}
24
25impl SparseTree {
26    pub fn new(default_leaf: Fr) -> Self {
27        let mut defaults = vec![default_leaf];
28        for i in 0..LOG_TREE_SIZE {
29            defaults.push(mimc::mimc(&[defaults[i], defaults[i]]));
30        }
31        Self {
32            defaults,
33            levels: vec![HashMap::new(); LOG_TREE_SIZE + 1],
34        }
35    }
36    pub fn root(&self) -> Fr {
37        self.get(LOG_TREE_SIZE, 0)
38    }
39    fn get(&self, level: usize, index: u64) -> Fr {
40        self.levels[level]
41            .get(&index)
42            .cloned()
43            .unwrap_or(self.defaults[level])
44    }
45    pub fn prove(&self, mut index: u64) -> Proof {
46        let mut proof = [Fr::zero(); LOG_TREE_SIZE];
47        for level in 0..LOG_TREE_SIZE {
48            let neigh = if index & 1 == 0 { index + 1 } else { index - 1 };
49            proof[level] = self.get(level, neigh);
50            index = index >> 1;
51        }
52        Proof(proof)
53    }
54    pub fn verify(mut index: u64, mut value: Fr, proof: Proof, root: Fr) -> bool {
55        for p in proof.0 {
56            value = if index & 1 == 0 {
57                mimc::double_mimc(value, p)
58            } else {
59                mimc::double_mimc(p, value)
60            };
61            index = index >> 1;
62        }
63        value == root
64    }
65    pub fn set(&mut self, mut index: u64, mut value: Fr) {
66        for level in 0..(LOG_TREE_SIZE + 1) {
67            self.levels[level].insert(index, value);
68            let neigh = if index & 1 == 0 { index + 1 } else { index - 1 };
69            let neigh_val = self.get(level, neigh);
70            value = if index & 1 == 0 {
71                mimc::double_mimc(value, neigh_val)
72            } else {
73                mimc::double_mimc(neigh_val, value)
74            };
75            index = index >> 1;
76        }
77    }
78}