1use crate::store::{AppendableStore, Hashable, Store};
2pub use crate::tree::{
3 consistency::ConsistencyProof,
4 inclusion::AuditProof,
5 node::{Node, NodeKey},
6};
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9use thiserror::Error;
10
11mod consistency;
12mod inclusion;
13mod node;
14
15pub(crate) type HashOutput = [u8; 32];
16
17#[derive(Clone, Debug, PartialEq, Eq, Error)]
18pub enum ProofGenerationError {
19 #[error("Index {index} not found in tree of size {tree_size}")]
20 InvalidIndex { tree_size: u64, index: u64 },
21
22 #[error("Invalid tree size {received} smaller than {expected}")]
23 InvalidTreeSize { expected: u64, received: u64 },
24
25 #[error("Failed to fetch key {0:?} from the store")]
26 KeyNotFound(NodeKey),
27}
28
29#[derive(Clone, Debug, PartialEq, Eq, Error)]
30pub enum ProofValidationError {
31 #[error("Found an unxexpected hash length (expected: {expected}, received: {received})")]
32 InvalidHashLength { expected: usize, received: usize },
33
34 #[error("Index {index} not found in tree of size {tree_size}")]
35 InvalidIndex { tree_size: u64, index: u64 },
36
37 #[error("Invalid tree size {received} smaller than {expected}")]
38 InvalidTreeSize { expected: u64, received: u64 },
39
40 #[error("Hash mismatch")]
41 HashMismatch,
42
43 #[error("Merkle path was too short")]
44 PathTooShort,
45
46 #[error("Merkle path was too long")]
47 PathTooLong,
48}
49
50#[derive(Debug, Clone)]
51pub struct Tree<N, L, V> {
52 nodes: N,
53 leafs: L,
54 values: PhantomData<V>,
55}
56
57impl<N, L, V> Tree<N, L, V> {
58 pub fn new(node_store: N, leaf_store: L) -> Self {
59 Self {
60 nodes: node_store,
61 leafs: leaf_store,
62 values: PhantomData,
63 }
64 }
65
66 pub fn nodes(&self) -> &N {
67 &self.nodes
68 }
69}
70
71impl<N, L, V> Tree<N, L, V>
72where
73 N: Store<NodeKey, HashOutput>,
74 L: AppendableStore<u64, V>,
75 V: Hashable,
76{
77 pub fn insert_entry(&self, entry: V) {
78 let entry_hash = entry.hash();
79 let idx = self.leafs.append(entry);
80 let entry_key = NodeKey::leaf(idx);
81 self.nodes.insert(entry_key, entry_hash);
82
83 let end = idx + 1;
85 let mut diff = 2;
86
87 while end.is_multiple_of(diff) {
88 let start = end - diff;
89
90 let key = NodeKey { start, end };
91 let (left, right) = key.split();
92
93 let node = Node {
94 left: self.nodes.get(&left).unwrap(),
95 right: self.nodes.get(&right).unwrap(),
96 };
97
98 self.nodes.insert(key, node.hash());
99
100 diff <<= 1;
101 }
102 }
103
104 pub fn recompute_tree_head(&self) -> TreeHead {
105 let tree_size = self.leafs.len() as u64;
106 let mut current_key = NodeKey::full_range(tree_size);
107 let mut balanced_nodes = vec![];
108
109 while !current_key.is_balanced() {
110 let (left, right) = current_key.split();
111 assert!(left.is_balanced());
112 balanced_nodes.push(left);
113 current_key = right;
114 }
115
116 let mut current_node_hash = self.nodes.get(¤t_key).unwrap();
117 while let Some(left_key) = balanced_nodes.pop() {
118 let current_node = Node {
119 left: self.nodes.get(&left_key).unwrap(),
120 right: self.nodes.get(¤t_key).unwrap(),
121 };
122
123 current_key = left_key.merge(¤t_key).unwrap();
124 current_node_hash = current_node.hash();
125 self.nodes.insert(current_key.clone(), current_node_hash);
126 }
127
128 TreeHead {
129 tree_size,
130 head: current_node_hash,
131 }
132 }
133
134 pub fn get_latest_tree_head(&self) -> Option<TreeHead> {
135 let idx = self.leafs.len() as u64;
136 self.nodes
137 .get(&NodeKey::full_range(idx))
138 .map(|head| TreeHead {
139 tree_size: idx,
140 head,
141 })
142 }
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
146pub struct TreeHead {
147 pub(crate) tree_size: u64,
148 pub(crate) head: HashOutput,
149}
150
151impl TreeHead {
152 pub fn tree_size(&self) -> u64 {
153 self.tree_size
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use sha2::{Digest, Sha256};
161
162 impl Hashable for String {
163 fn hash(&self) -> HashOutput {
164 Sha256::digest(self.as_bytes()).into()
165 }
166 }
167
168 impl Hashable for HashOutput {
169 fn hash(&self) -> HashOutput {
170 Sha256::digest(self).into()
171 }
172 }
173}