1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use bitvec::prelude::*;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::io::Read;

use crate::{
    hash::{hash_data, hash_node},
    Hashed,
};

pub(crate) fn key_to_path(
    key: &'_ Hashed,
) -> impl Iterator<Item = bool> + ExactSizeIterator + DoubleEndedIterator + '_ {
    // let mut toret = [false; 256];
    // // enumerate each byte
    // for (i, k_i) in key.iter().enumerate() {
    //     // walk through the bits
    //     for j in 0..8 {
    //         toret[i * 8 + j] = k_i & (MSB_SET >> j) != 0;
    //     }
    // }
    // toret
    let bslice = BitSlice::<Msb0, _>::from_slice(key).unwrap();
    bslice.iter().by_val()
}

/// Returns the root hash of a one-element SMT with the given key and value.
pub(crate) fn singleton_smt_root(height: usize, key: Hashed, val: &[u8]) -> Hashed {
    let mut rpath = key_to_path(&key).collect::<Vec<_>>();
    rpath.reverse();
    let mut toret = hash_data(val);
    for bit in rpath.into_iter().take(height) {
        if bit {
            toret = hash_node([0; 32], toret)
        } else {
            toret = hash_node(toret, [0; 32])
        }
    }
    toret
}

#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
/// A full proof with 256 levels.
pub struct FullProof(pub Vec<Hashed>);

impl FullProof {
    /// Compresses the proof to a serializable form.
    pub fn compress(&self) -> CompressedProof {
        let FullProof(proof_nodes) = self;
        assert_eq!(proof_nodes.len(), 256);
        // build bitmap
        let mut bitmap = bitvec![Msb0, u8; 0; 256];
        for (i, pn) in proof_nodes.iter().enumerate() {
            if *pn == [0u8; 32] {
                bitmap.set(i, true);
            }
        }
        let mut bitmap_slice = bitmap.into_vec();
        for pn in proof_nodes.iter() {
            if *pn != [0u8; 32] {
                bitmap_slice.extend_from_slice(pn);
            }
        }
        CompressedProof(bitmap_slice)
    }

    /// Verifies that this merkle branch is a valid proof of inclusion or non-inclusion. To check proofs of non-inclusion, set val to the empty vector.
    pub fn verify(&self, root: Hashed, key: Hashed, val: &[u8]) -> bool {
        assert_eq!(self.0.len(), 256);
        self.verify_pure(root, key, val)
    }

    fn verify_pure(&self, root: Hashed, key: Hashed, val: &[u8]) -> bool {
        let path = key_to_path(&key).collect::<Vec<_>>();
        let mut my_root = hash_data(val);
        for (&level, &direction) in self.0.iter().zip(path.iter()).rev() {
            if direction {
                // log::trace!(
                //     "verify: my_root <- hash_node({}, {})",
                //     hex::encode(&level),
                //     hex::encode(&my_root)
                // );
                my_root = hash_node(level, my_root)
            } else {
                // log::trace!(
                //     "verify: my_root <- hash_node({}, {})",
                //     hex::encode(&my_root),
                //     hex::encode(&level)
                // );
                my_root = hash_node(my_root, level)
            }
        }
        root == my_root
    }
}

#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Serialize, Deserialize)]
/// A compressed proof.
pub struct CompressedProof(pub Vec<u8>);

impl CompressedProof {
    /// Decompresses a compressed proof. Returns None if the format is invalid.
    pub fn decompress(&self) -> Option<FullProof> {
        let b = &self.0;
        if b.len() < 32 || b.len() % 32 != 0 {
            return None;
        }
        let bitmap = BitVec::<Msb0, u8>::from_vec(b[..32].to_vec());
        let mut b = &b[32..];
        let mut out = Vec::new();
        // go through the bitmap. if b is set, insert a zero. otherwise, take 32 bytes from b. if b runs out, we are dead.
        for is_zero in bitmap {
            if is_zero {
                out.push([0u8; 32])
            } else {
                let mut buf = [0; 32];
                b.read_exact(&mut buf).ok()?;
                out.push(buf);
            }
        }
        Some(FullProof(out))
    }
}