1use std::cmp::min;
6
7use crate::tree::MerkleTree;
8use crate::util::{hash_block, hash_hashes, hashes_per_block};
9use crate::{Hash, HashAlgorithm, BLOCK_SIZE};
10
11#[derive(Clone, Debug)]
28pub struct MerkleTreeBuilder {
29 block: Vec<u8>,
32 levels: Vec<Vec<Hash>>,
33 algorithm: HashAlgorithm,
34}
35
36impl Default for MerkleTreeBuilder {
37 fn default() -> Self {
38 Self {
39 levels: vec![Vec::new()],
40 block: Vec::with_capacity(BLOCK_SIZE),
41 algorithm: HashAlgorithm::SHA256,
42 }
43 }
44}
45
46impl MerkleTreeBuilder {
47 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn with_algorithm(algorithm: HashAlgorithm) -> Self {
54 Self {
55 levels: vec![Vec::new()],
56 block: Vec::with_capacity(BLOCK_SIZE),
57 algorithm,
58 }
59 }
60
61 pub fn write(&mut self, buf: &[u8]) {
65 let buf = if self.block.is_empty() {
67 buf
68 } else {
69 let left = BLOCK_SIZE - self.block.len();
70 let prefix = min(buf.len(), left);
71 let (buf, rest) = buf.split_at(prefix);
72 self.block.extend_from_slice(buf);
73 if self.block.len() == BLOCK_SIZE {
74 self.push_data_hash(self.hash_block(&self.block[..]));
75 }
76 rest
77 };
78
79 for block in buf.chunks(BLOCK_SIZE) {
81 if block.len() == BLOCK_SIZE {
82 self.push_data_hash(self.hash_block(block));
83 } else {
84 self.block.extend_from_slice(block);
85 }
86 }
87 }
88
89 fn hash_block(&self, block: &[u8]) -> Hash {
92 hash_block(block, self.levels[0].len() * BLOCK_SIZE, self.algorithm)
93 }
94
95 pub fn push_data_hash(&mut self, hash: Hash) {
98 self.block.clear();
99 self.levels[0].push(hash);
100 if self.levels[0].len() % hashes_per_block(self.algorithm) == 0 {
101 self.commit_tail_block(0);
102 }
103 }
104
105 fn commit_tail_block(&mut self, level: usize) {
107 let len = self.levels[level].len();
108 let next_level = level + 1;
109
110 if next_level >= self.levels.len() {
111 self.levels.push(Vec::new());
112 }
113
114 let per_block = hashes_per_block(self.algorithm);
115 let first_hash = if len % per_block == 0 {
116 len - per_block
117 } else {
118 len - (len % per_block)
119 };
120
121 let hash = hash_hashes(
122 &self.levels[level][first_hash..],
123 next_level,
124 self.levels[next_level].len() * BLOCK_SIZE,
125 self.algorithm,
126 );
127
128 self.levels[next_level].push(hash);
129 if self.levels[next_level].len() % per_block == 0 {
130 self.commit_tail_block(next_level);
131 }
132 }
133
134 pub fn finish(mut self) -> MerkleTree {
137 if !self.block.is_empty() || self.levels[0].is_empty() {
141 self.push_data_hash(self.hash_block(&self.block[..]));
142 }
143
144 for level in 0.. {
147 if level >= self.levels.len() {
148 break;
149 }
150
151 let len = self.levels[level].len();
152 if len > 1 && len % hashes_per_block(self.algorithm) != 0 {
153 self.commit_tail_block(level);
154 }
155 }
156
157 MerkleTree::from_levels(self.levels)
158 }
159}
160
161impl From<MerkleTreeBuilder> for MerkleTree {
162 fn from(builder: MerkleTreeBuilder) -> Self {
163 builder.finish()
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use std::cmp::min;
171 use test_case::test_case;
172
173 #[allow(clippy::unused_unit)]
174 #[test_case(vec![], "15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b" ; "test_empty")]
175 #[test_case(vec![0xFF; 8192], "68d131bc271f9c192d4f6dcd8fe61bef90004856da19d0f2f514a7f4098b0737"; "test_oneblock")]
176 #[test_case(vec![0xFF; 65536], "f75f59a944d2433bc6830ec243bfefa457704d2aed12f30539cd4f18bf1d62cf"; "test_small")]
177 #[test_case(vec![0xFF; 2105344], "7d75dfb18bfd48e03b5be4e8e9aeea2f89880cb81c1551df855e0d0a0cc59a67"; "test_large")]
178 #[test_case(vec![0xFF; 2109440], "7577266aa98ce587922fdc668c186e27f3c742fb1b732737153b70ae46973e43"; "test_unaligned")]
179 fn tests(input: Vec<u8>, output: &str) {
180 let mut builder = MerkleTreeBuilder::new();
181 builder.write(input.as_slice());
182
183 let tree = builder.finish();
184 let actual = tree.root();
185 let expected: Hash = hex::decode(output).unwrap();
186 assert_eq!(&expected, actual);
187 }
188
189 #[test]
190 fn test_unaligned_single_block() {
191 let data = vec![0xFF; 8192];
192 let mut builder = MerkleTreeBuilder::new();
193 let (first, second) = &data[..].split_at(1024);
194 builder.write(first);
195 builder.write(second);
196
197 let tree = builder.finish();
198 let root = tree.root();
199
200 let expected =
201 hex::decode("68d131bc271f9c192d4f6dcd8fe61bef90004856da19d0f2f514a7f4098b0737")
202 .unwrap();
203 assert_eq!(root, &expected);
204 }
205
206 #[test]
207 fn test_unaligned_n_block() {
208 let data = vec![0xFF; 65536];
209 let expected =
210 hex::decode("f75f59a944d2433bc6830ec243bfefa457704d2aed12f30539cd4f18bf1d62cf")
211 .unwrap();
212
213 for chunk_size in &[1, 100, 1024, 8193] {
214 let mut builder = MerkleTreeBuilder::new();
215 for block in data.as_slice().chunks(*chunk_size) {
216 builder.write(block);
217 }
218 let tree = builder.finish();
219 let root = tree.root();
220
221 assert_eq!(root, &expected);
222 }
223 }
224
225 #[test]
226 fn test_fuchsia() {
227 let fuchsia: Vec<_> = vec![0xff, 0x00, 0x80]
228 .into_iter()
229 .cycle()
230 .take(3 * BLOCK_SIZE)
231 .collect();
232
233 let mut t = MerkleTreeBuilder::new();
234
235 let mut remaining = 0xff0080;
236 while remaining > 0 {
237 let n = min(remaining, fuchsia.len());
238 t.write(&fuchsia[..n]);
239 remaining -= n;
240 }
241
242 let tree = t.finish();
243 let root = tree.root();
244
245 let expected: Hash =
246 hex::decode("2feb488cffc976061998ac90ce7292241dfa86883c0edc279433b5c4370d0f30")
247 .unwrap();
248 assert_eq!(&expected, root);
249 }
250}