novasmt/
merkmath.rs

1use bitvec::prelude::*;
2use lru::LruCache;
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use std::io::Read;
6
7use crate::{
8    hash::{hash_data, hash_node},
9    Hashed,
10};
11
12pub(crate) fn key_to_path(
13    key: &'_ Hashed,
14) -> impl Iterator<Item = bool> + ExactSizeIterator + DoubleEndedIterator + '_ {
15    // let mut toret = [false; 256];
16    // // enumerate each byte
17    // for (i, k_i) in key.iter().enumerate() {
18    //     // walk through the bits
19    //     for j in 0..8 {
20    //         toret[i * 8 + j] = k_i & (MSB_SET >> j) != 0;
21    //     }
22    // }
23    // toret
24    let bslice = BitSlice::<Msb0, _>::from_slice(key).unwrap();
25    bslice.iter().by_val()
26}
27
28fn singleton_smt_roots(key: Hashed, val: &[u8], out: &mut [Hashed]) {
29    let mut rpath = key_to_path(&key).collect::<Vec<_>>();
30    rpath.reverse();
31    out[0] = hash_data(val);
32    for i in 0..rpath.len() {
33        out[i + 1] = if rpath[i] {
34            hash_node([0; 32], out[i])
35        } else {
36            hash_node(out[i], [0; 32])
37        }
38    }
39}
40
41/// Returns the root hash of a one-element SMT with the given key and value.
42pub(crate) fn singleton_smt_root(height: usize, key: Hashed, val: &[u8]) -> Hashed {
43    thread_local! {
44        static CACHE: std::cell::RefCell<lru::LruCache<(Hashed,Vec<u8>), [Hashed; 258], std::hash::BuildHasherDefault<rustc_hash::FxHasher>>>   = std::cell::RefCell::new(LruCache::with_hasher(128, Default::default()));
45    }
46    CACHE.with(|cache| {
47        let mut cache = cache.borrow_mut();
48        // dbg!(cache.len());
49        if let Some(val) = cache.get(&(key, val.to_vec())) {
50            val[height]
51        } else {
52            let mut buf = [Hashed::default(); 258];
53            singleton_smt_roots(key, val, &mut buf);
54            cache.put((key, val.to_vec()), buf);
55            buf[height]
56        }
57    })
58}
59
60#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
61/// A full proof with 256 levels.
62pub struct FullProof(pub Vec<Hashed>);
63
64impl FullProof {
65    /// Compresses the proof to a serializable form.
66    pub fn compress(&self) -> CompressedProof {
67        let FullProof(proof_nodes) = self;
68        assert_eq!(proof_nodes.len(), 256);
69        // build bitmap
70        let mut bitmap = bitvec![Msb0, u8; 0; 256];
71        for (i, pn) in proof_nodes.iter().enumerate() {
72            if *pn == [0u8; 32] {
73                bitmap.set(i, true);
74            }
75        }
76        let mut bitmap_slice = bitmap.into_vec();
77        for pn in proof_nodes.iter() {
78            if *pn != [0u8; 32] {
79                bitmap_slice.extend_from_slice(pn);
80            }
81        }
82        CompressedProof(bitmap_slice)
83    }
84
85    /// Verifies that this merkle branch is a valid proof of inclusion or non-inclusion. To check proofs of non-inclusion, set val to the empty vector.
86    pub fn verify(&self, root: Hashed, key: Hashed, val: &[u8]) -> bool {
87        assert_eq!(self.0.len(), 256);
88        self.verify_pure(root, key, val)
89    }
90
91    fn verify_pure(&self, root: Hashed, key: Hashed, val: &[u8]) -> bool {
92        let path = key_to_path(&key).collect::<Vec<_>>();
93        let mut my_root = hash_data(val);
94        for (&level, &direction) in self.0.iter().zip(path.iter()).rev() {
95            if direction {
96                // log::trace!(
97                //     "verify: my_root <- hash_node({}, {})",
98                //     hex::encode(&level),
99                //     hex::encode(&my_root)
100                // );
101                my_root = hash_node(level, my_root)
102            } else {
103                // log::trace!(
104                //     "verify: my_root <- hash_node({}, {})",
105                //     hex::encode(&my_root),
106                //     hex::encode(&level)
107                // );
108                my_root = hash_node(my_root, level)
109            }
110        }
111        root == my_root
112    }
113}
114
115#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Serialize, Deserialize)]
116/// A compressed proof.
117pub struct CompressedProof(pub Vec<u8>);
118
119impl CompressedProof {
120    /// Decompresses a compressed proof. Returns None if the format is invalid.
121    pub fn decompress(&self) -> Option<FullProof> {
122        let b = &self.0;
123        if b.len() < 32 || b.len() % 32 != 0 {
124            return None;
125        }
126        let bitmap = BitVec::<Msb0, u8>::from_vec(b[..32].to_vec());
127        let mut b = &b[32..];
128        let mut out = Vec::new();
129        // go through the bitmap. if b is set, insert a zero. otherwise, take 32 bytes from b. if b runs out, we are dead.
130        for is_zero in bitmap {
131            if is_zero {
132                out.push([0u8; 32])
133            } else {
134                let mut buf = [0; 32];
135                b.read_exact(&mut buf).ok()?;
136                out.push(buf);
137            }
138        }
139        Some(FullProof(out))
140    }
141}