Skip to main content

luct_core/
tree.rs

1use crate::store::{AppendableStore, Hashable, Store};
2pub use crate::tree::{
3    consistency::ConsistencyProof,
4    inclusion::AuditProof,
5    node::{Node, NodeKey},
6};
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9use thiserror::Error;
10
11mod consistency;
12mod inclusion;
13mod node;
14
15pub(crate) type HashOutput = [u8; 32];
16
17#[derive(Clone, Debug, PartialEq, Eq, Error)]
18pub enum ProofGenerationError {
19    #[error("Index {index} not found in tree of size {tree_size}")]
20    InvalidIndex { tree_size: u64, index: u64 },
21
22    #[error("Invalid tree size {received} smaller than {expected}")]
23    InvalidTreeSize { expected: u64, received: u64 },
24
25    #[error("Failed to fetch key {0:?} from the store")]
26    KeyNotFound(NodeKey),
27}
28
29#[derive(Clone, Debug, PartialEq, Eq, Error)]
30pub enum ProofValidationError {
31    #[error("Found an unxexpected hash length (expected: {expected}, received: {received})")]
32    InvalidHashLength { expected: usize, received: usize },
33
34    #[error("Index {index} not found in tree of size {tree_size}")]
35    InvalidIndex { tree_size: u64, index: u64 },
36
37    #[error("Invalid tree size {received} smaller than {expected}")]
38    InvalidTreeSize { expected: u64, received: u64 },
39
40    #[error("Hash mismatch")]
41    HashMismatch,
42
43    #[error("Merkle path was too short")]
44    PathTooShort,
45
46    #[error("Merkle path was too long")]
47    PathTooLong,
48}
49
50#[derive(Debug, Clone)]
51pub struct Tree<N, L, V> {
52    nodes: N,
53    leafs: L,
54    values: PhantomData<V>,
55}
56
57impl<N, L, V> Tree<N, L, V> {
58    pub fn new(node_store: N, leaf_store: L) -> Self {
59        Self {
60            nodes: node_store,
61            leafs: leaf_store,
62            values: PhantomData,
63        }
64    }
65
66    pub fn nodes(&self) -> &N {
67        &self.nodes
68    }
69}
70
71impl<N, L, V> Tree<N, L, V>
72where
73    N: Store<NodeKey, HashOutput>,
74    L: AppendableStore<u64, V>,
75    V: Hashable,
76{
77    pub fn insert_entry(&self, entry: V) {
78        let entry_hash = entry.hash();
79        let idx = self.leafs.append(entry);
80        let entry_key = NodeKey::leaf(idx);
81        self.nodes.insert(entry_key, entry_hash);
82
83        // Already update intermediate nodes, if they are power of twos
84        let end = idx + 1;
85        let mut diff = 2;
86
87        while end.is_multiple_of(diff) {
88            let start = end - diff;
89
90            let key = NodeKey { start, end };
91            let (left, right) = key.split();
92
93            let node = Node {
94                left: self.nodes.get(&left).unwrap(),
95                right: self.nodes.get(&right).unwrap(),
96            };
97
98            self.nodes.insert(key, node.hash());
99
100            diff <<= 1;
101        }
102    }
103
104    pub fn recompute_tree_head(&self) -> TreeHead {
105        let tree_size = self.leafs.len() as u64;
106        let mut current_key = NodeKey::full_range(tree_size);
107        let mut balanced_nodes = vec![];
108
109        while !current_key.is_balanced() {
110            let (left, right) = current_key.split();
111            assert!(left.is_balanced());
112            balanced_nodes.push(left);
113            current_key = right;
114        }
115
116        let mut current_node_hash = self.nodes.get(&current_key).unwrap();
117        while let Some(left_key) = balanced_nodes.pop() {
118            let current_node = Node {
119                left: self.nodes.get(&left_key).unwrap(),
120                right: self.nodes.get(&current_key).unwrap(),
121            };
122
123            current_key = left_key.merge(&current_key).unwrap();
124            current_node_hash = current_node.hash();
125            self.nodes.insert(current_key.clone(), current_node_hash);
126        }
127
128        TreeHead {
129            tree_size,
130            head: current_node_hash,
131        }
132    }
133
134    pub fn get_latest_tree_head(&self) -> Option<TreeHead> {
135        let idx = self.leafs.len() as u64;
136        self.nodes
137            .get(&NodeKey::full_range(idx))
138            .map(|head| TreeHead {
139                tree_size: idx,
140                head,
141            })
142    }
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
146pub struct TreeHead {
147    pub(crate) tree_size: u64,
148    pub(crate) head: HashOutput,
149}
150
151impl TreeHead {
152    pub fn tree_size(&self) -> u64 {
153        self.tree_size
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use sha2::{Digest, Sha256};
161
162    impl Hashable for String {
163        fn hash(&self) -> HashOutput {
164            Sha256::digest(self.as_bytes()).into()
165        }
166    }
167
168    impl Hashable for HashOutput {
169        fn hash(&self) -> HashOutput {
170            Sha256::digest(self).into()
171        }
172    }
173}