merkle_root/
builder.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::cmp::min;
6
7use crate::tree::MerkleTree;
8use crate::util::{hash_block, hash_hashes, hashes_per_block};
9use crate::{Hash, HashAlgorithm, BLOCK_SIZE};
10
11/// A `MerkleTreeBuilder` generates a [`MerkleTree`] from one or more write calls.
12///
13/// # Examples
14/// ```
15/// # use merkle_root::*;
16/// let data = vec![0xff; 8192];
17/// let mut builder = MerkleTreeBuilder::new();
18/// for i in 0..8 {
19///     builder.write(&data[..]);
20/// }
21/// assert_eq!(
22///     builder.finish().root(),
23///     &hex::decode("f75f59a944d2433bc6830ec243bfefa457704d2aed12f30539cd4f18bf1d62cf")
24///         .unwrap()
25/// );
26/// ```
27#[derive(Clone, Debug)]
28pub struct MerkleTreeBuilder {
29    /// Buffer to hold a partial block of data between [`MerkleTreeBuilder::write`] calls.
30    /// `block.len()` will never exceed [`BLOCK_SIZE`].
31    block: Vec<u8>,
32    levels: Vec<Vec<Hash>>,
33    algorithm: HashAlgorithm,
34}
35
36impl Default for MerkleTreeBuilder {
37    fn default() -> Self {
38        Self {
39            levels: vec![Vec::new()],
40            block: Vec::with_capacity(BLOCK_SIZE),
41            algorithm: HashAlgorithm::SHA256,
42        }
43    }
44}
45
46impl MerkleTreeBuilder {
47    /// Creates a new, empty `MerkleTreeBuilder`.
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Creates a new, empty `MerkleTreeBuilder` with a specific algorithm.
53    pub fn with_algorithm(algorithm: HashAlgorithm) -> Self {
54        Self {
55            levels: vec![Vec::new()],
56            block: Vec::with_capacity(BLOCK_SIZE),
57            algorithm,
58        }
59    }
60
61    /// Append a buffer of bytes to the merkle tree.
62    ///
63    /// No internal buffering is required if all writes are [`BLOCK_SIZE`] aligned.
64    pub fn write(&mut self, buf: &[u8]) {
65        // Fill the current partial block, if it exists.
66        let buf = if self.block.is_empty() {
67            buf
68        } else {
69            let left = BLOCK_SIZE - self.block.len();
70            let prefix = min(buf.len(), left);
71            let (buf, rest) = buf.split_at(prefix);
72            self.block.extend_from_slice(buf);
73            if self.block.len() == BLOCK_SIZE {
74                self.push_data_hash(self.hash_block(&self.block[..]));
75            }
76            rest
77        };
78
79        // Write full blocks, saving any final partial block for later writes.
80        for block in buf.chunks(BLOCK_SIZE) {
81            if block.len() == BLOCK_SIZE {
82                self.push_data_hash(self.hash_block(block));
83            } else {
84                self.block.extend_from_slice(block);
85            }
86        }
87    }
88
89    /// Hash a block of data (level 0), using an offset based on the current number of level 0
90    /// hashes.
91    fn hash_block(&self, block: &[u8]) -> Hash {
92        hash_block(block, self.levels[0].len() * BLOCK_SIZE, self.algorithm)
93    }
94
95    /// Save a data block hash, propagating full blocks of hashes to higher layers. Also clear a
96    /// stored data block.
97    pub fn push_data_hash(&mut self, hash: Hash) {
98        self.block.clear();
99        self.levels[0].push(hash);
100        if self.levels[0].len() % hashes_per_block(self.algorithm) == 0 {
101            self.commit_tail_block(0);
102        }
103    }
104
105    /// Hash a complete (or final partial) block of hashes, chaining to higher levels as needed.
106    fn commit_tail_block(&mut self, level: usize) {
107        let len = self.levels[level].len();
108        let next_level = level + 1;
109
110        if next_level >= self.levels.len() {
111            self.levels.push(Vec::new());
112        }
113
114        let per_block = hashes_per_block(self.algorithm);
115        let first_hash = if len % per_block == 0 {
116            len - per_block
117        } else {
118            len - (len % per_block)
119        };
120
121        let hash = hash_hashes(
122            &self.levels[level][first_hash..],
123            next_level,
124            self.levels[next_level].len() * BLOCK_SIZE,
125            self.algorithm,
126        );
127
128        self.levels[next_level].push(hash);
129        if self.levels[next_level].len() % per_block == 0 {
130            self.commit_tail_block(next_level);
131        }
132    }
133
134    /// Finalize all levels of the merkle tree, converting this `MerkleTreeBuilder` instance to a
135    /// [`MerkleTree`].
136    pub fn finish(mut self) -> MerkleTree {
137        // The data protected by the tree may not be BLOCK_SIZE aligned. Commit a partial data
138        // block before finalizing the hash levels.
139        // Also, an empty tree consists of a single, empty block. Handle that case now as well.
140        if !self.block.is_empty() || self.levels[0].is_empty() {
141            self.push_data_hash(self.hash_block(&self.block[..]));
142        }
143
144        // Enumerate the hash levels, finalizing any that have a partial block of hashes.
145        // `commit_tail_block` may add new levels to the tree, so don't assume a length up front.
146        for level in 0.. {
147            if level >= self.levels.len() {
148                break;
149            }
150
151            let len = self.levels[level].len();
152            if len > 1 && len % hashes_per_block(self.algorithm) != 0 {
153                self.commit_tail_block(level);
154            }
155        }
156
157        MerkleTree::from_levels(self.levels)
158    }
159}
160
161impl From<MerkleTreeBuilder> for MerkleTree {
162    fn from(builder: MerkleTreeBuilder) -> Self {
163        builder.finish()
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use std::cmp::min;
171    use test_case::test_case;
172
173    #[allow(clippy::unused_unit)]
174    #[test_case(vec![], "15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b" ; "test_empty")]
175    #[test_case(vec![0xFF; 8192], "68d131bc271f9c192d4f6dcd8fe61bef90004856da19d0f2f514a7f4098b0737"; "test_oneblock")]
176    #[test_case(vec![0xFF; 65536], "f75f59a944d2433bc6830ec243bfefa457704d2aed12f30539cd4f18bf1d62cf"; "test_small")]
177    #[test_case(vec![0xFF; 2105344], "7d75dfb18bfd48e03b5be4e8e9aeea2f89880cb81c1551df855e0d0a0cc59a67"; "test_large")]
178    #[test_case(vec![0xFF; 2109440], "7577266aa98ce587922fdc668c186e27f3c742fb1b732737153b70ae46973e43"; "test_unaligned")]
179    fn tests(input: Vec<u8>, output: &str) {
180        let mut builder = MerkleTreeBuilder::new();
181        builder.write(input.as_slice());
182
183        let tree = builder.finish();
184        let actual = tree.root();
185        let expected: Hash = hex::decode(output).unwrap();
186        assert_eq!(&expected, actual);
187    }
188
189    #[test]
190    fn test_unaligned_single_block() {
191        let data = vec![0xFF; 8192];
192        let mut builder = MerkleTreeBuilder::new();
193        let (first, second) = &data[..].split_at(1024);
194        builder.write(first);
195        builder.write(second);
196
197        let tree = builder.finish();
198        let root = tree.root();
199
200        let expected =
201            hex::decode("68d131bc271f9c192d4f6dcd8fe61bef90004856da19d0f2f514a7f4098b0737")
202                .unwrap();
203        assert_eq!(root, &expected);
204    }
205
206    #[test]
207    fn test_unaligned_n_block() {
208        let data = vec![0xFF; 65536];
209        let expected =
210            hex::decode("f75f59a944d2433bc6830ec243bfefa457704d2aed12f30539cd4f18bf1d62cf")
211                .unwrap();
212
213        for chunk_size in &[1, 100, 1024, 8193] {
214            let mut builder = MerkleTreeBuilder::new();
215            for block in data.as_slice().chunks(*chunk_size) {
216                builder.write(block);
217            }
218            let tree = builder.finish();
219            let root = tree.root();
220
221            assert_eq!(root, &expected);
222        }
223    }
224
225    #[test]
226    fn test_fuchsia() {
227        let fuchsia: Vec<_> = vec![0xff, 0x00, 0x80]
228            .into_iter()
229            .cycle()
230            .take(3 * BLOCK_SIZE)
231            .collect();
232
233        let mut t = MerkleTreeBuilder::new();
234
235        let mut remaining = 0xff0080;
236        while remaining > 0 {
237            let n = min(remaining, fuchsia.len());
238            t.write(&fuchsia[..n]);
239            remaining -= n;
240        }
241
242        let tree = t.finish();
243        let root = tree.root();
244
245        let expected: Hash =
246            hex::decode("2feb488cffc976061998ac90ce7292241dfa86883c0edc279433b5c4370d0f30")
247                .unwrap();
248        assert_eq!(&expected, root);
249    }
250}