1use crate::hash_hashes;
6use crate::util::hashes_per_block;
7use crate::Hash;
8use crate::HashAlgorithm;
9use crate::BLOCK_SIZE;
10use std::io;
11
12#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
39pub struct MerkleTree {
40 levels: Vec<Vec<Hash>>,
41}
42
43impl MerkleTree {
44 pub fn from_levels(levels: Vec<Vec<Hash>>) -> MerkleTree {
51 MerkleTree { levels }
52 }
53
54 pub fn root(&self) -> &Hash {
56 &self.levels[self.levels.len() - 1][0]
57 }
58
59 pub fn from_reader(mut reader: impl std::io::Read) -> Result<MerkleTree, io::Error> {
72 let mut builder = crate::builder::MerkleTreeBuilder::new();
73 let mut buf = [0u8; BLOCK_SIZE];
74 loop {
75 let size = reader.read(&mut buf)?;
76 if size == 0 {
77 break;
78 }
79 builder.write(&buf[0..size]);
80 }
81 Ok(builder.finish())
82 }
83
84 pub fn from_walked_directory(walked_files: &[Hash], algorithm: HashAlgorithm) -> MerkleTree {
89 let mut levels: Vec<Vec<Hash>> = Vec::new();
90 levels.push(walked_files.to_vec());
91
92 let mut current_level = 0;
93 while levels[current_level].len() > 1 {
94 let mut next_level = Vec::new();
95 for chunk in levels[current_level].chunks(hashes_per_block(algorithm)) {
96 next_level.push(hash_hashes(
97 chunk,
98 current_level + 1,
99 next_level.len() * BLOCK_SIZE,
100 algorithm,
101 ));
102 }
103
104 levels.push(next_level);
105 current_level += 1;
106 }
107
108 MerkleTree::from_levels(levels)
109 }
110
111 pub fn includes_hash(&self, hash: &Hash, algorithm: HashAlgorithm) -> bool {
112 for level in &self.levels {
113 if !level.contains(hash) {
114 continue;
115 }
116
117 let new_root = MerkleTree::from_walked_directory(level, algorithm);
120 return new_root.root() == self.root();
121 }
122
123 false
124 }
125}
126
127impl AsRef<[Vec<Hash>]> for MerkleTree {
128 fn as_ref(&self) -> &[Vec<Hash>] {
129 &self.levels[..]
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use test_case::test_case;
137
138 use crate::{
139 util::{hash_block, hash_hashes, hashes_per_block},
140 HashAlgorithm,
141 };
142
143 #[cfg(feature = "xxhash")]
144 use rand::Rng;
145
146 pub const HASH: HashAlgorithm = HashAlgorithm::SHA256;
147
148 impl MerkleTree {
149 fn clone_leaf_hash(&self, block: usize) -> Hash {
151 self.levels[0][block].clone()
152 }
153 }
154
155 #[test]
156 fn test_single_full_hash_block() {
157 let mut leafs = Vec::new();
158 {
159 let block = vec![0xFF; BLOCK_SIZE];
160 for i in 0..hashes_per_block(HASH) {
161 leafs.push(hash_block(&block, i * BLOCK_SIZE, HASH));
162 }
163 }
164 let root = hash_hashes(&leafs, 1, 0, HASH);
165 let tree = MerkleTree::from_levels(vec![leafs.clone(), vec![root]]);
166
167 for (i, leaf) in leafs.iter().enumerate().take(hashes_per_block(HASH)) {
168 assert_eq!(&tree.clone_leaf_hash(i), leaf);
169 }
170 }
171
172 #[test]
173 fn test_from_reader_empty() {
174 let data_to_hash = [0x00u8; 0];
175 let tree = MerkleTree::from_reader(&data_to_hash[..]).unwrap();
176 let expected: Hash =
177 hex::decode("15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b")
178 .unwrap();
179 assert_eq!(tree.root(), &expected);
180 }
181
182 #[test]
183 fn test_from_reader_oneblock() {
184 let data_to_hash = [0xffu8; 8192];
185 let tree = MerkleTree::from_reader(&data_to_hash[..]).unwrap();
186 let expected: Hash =
187 hex::decode("68d131bc271f9c192d4f6dcd8fe61bef90004856da19d0f2f514a7f4098b0737")
188 .unwrap();
189 assert_eq!(tree.root(), &expected);
190 }
191
192 #[test]
193 fn test_from_reader_unaligned() {
194 let size = 2_109_440usize;
195 let mut the_bytes = Vec::with_capacity(size);
196 the_bytes.extend(std::iter::repeat(0xff).take(size));
197 let tree = MerkleTree::from_reader(&the_bytes[..]).unwrap();
198 let expected: Hash =
199 hex::decode("7577266aa98ce587922fdc668c186e27f3c742fb1b732737153b70ae46973e43")
200 .unwrap();
201 assert_eq!(tree.root(), &expected);
202 }
203
204 #[test]
205 fn test_from_reader_error_propagation() {
206 const CUSTOM_ERROR_MESSAGE: &str = "merkle tree custom error message";
207 struct ReaderSuccessThenError {
208 been_called: bool,
209 }
210
211 impl std::io::Read for ReaderSuccessThenError {
212 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
213 if !self.been_called {
214 self.been_called = true;
215 buf[0] = 0;
216 Ok(1)
217 } else {
218 Err(io::Error::new(io::ErrorKind::Other, CUSTOM_ERROR_MESSAGE))
219 }
220 }
221 }
222
223 let reader = ReaderSuccessThenError { been_called: false };
224 let result = MerkleTree::from_reader(reader);
225 assert_eq!(result.unwrap_err().to_string(), CUSTOM_ERROR_MESSAGE);
226 }
227
228 #[test_case(0, 1 ; "test_empty")]
229 #[test_case(256, 2 ; "test_two")]
230 #[test_case(257, 3 ; "test_three")]
231 #[test_case(65537, 4 ; "test_four")]
232
233 fn test_from_walked_directory(len: usize, level: usize) {
234 let mut data: Vec<Hash> = Vec::new();
235 for _ in 0..len {
236 data.push(vec![0; 32]);
237 }
238
239 let tree = MerkleTree::from_walked_directory(&data, HASH);
240 assert_eq!(tree.levels.len(), level);
241 }
242
243 #[test]
244 #[cfg(feature = "xxhash")]
245 fn test_includes_hash() {
246 let mut rng = rand::thread_rng();
247
248 let mut data: Vec<Hash> = Vec::new();
249 for _ in 0..255 {
250 let hash: u64 = rng.gen_range(0..u64::MAX);
251 data.push(hash.to_le_bytes().to_vec());
252 }
253
254 let tree = MerkleTree::from_walked_directory(&data, HashAlgorithm::XXHash64);
255 for hash in &data {
256 assert!(tree.includes_hash(hash, HashAlgorithm::XXHash64));
257 }
258 }
259}