1use crate::hasher::NodeHasher;
4use crate::trie::{self, KeyPath, LeafData, Node, ValueHash};
5
6use bitvec::prelude::*;
7
8#[cfg(not(feature = "std"))]
9use alloc::vec::Vec;
10
11pub(crate) fn shared_bits(a: &BitSlice<u8, Msb0>, b: &BitSlice<u8, Msb0>) -> usize {
13 a.iter().zip(b.iter()).take_while(|(a, b)| a == b).count()
14}
15
16pub fn leaf_ops_spliced(
19 leaf: Option<LeafData>,
20 ops: &[(KeyPath, Option<ValueHash>)],
21) -> impl Iterator<Item = (KeyPath, ValueHash)> + Clone + '_ {
22 let splice_index = leaf
23 .as_ref()
24 .and_then(|leaf| ops.binary_search_by_key(&leaf.key_path, |x| x.0).err());
25 let preserve_value = splice_index
26 .zip(leaf)
27 .map(|(_, leaf)| (leaf.key_path, Some(leaf.value_hash)));
28 let splice_index = splice_index.unwrap_or(0);
29
30 ops[..splice_index]
33 .into_iter()
34 .cloned()
35 .chain(preserve_value)
36 .chain(ops[splice_index..].into_iter().cloned())
37 .filter_map(|(k, o)| o.map(move |value| (k, value)))
38}
39
40pub enum WriteNode<'a> {
41 Leaf {
42 up: bool,
43 down: &'a BitSlice<u8, Msb0>,
44 leaf_data: LeafData,
45 node: Node,
46 },
47 Internal {
48 internal_data: trie::InternalData,
49 node: Node,
50 },
51 Terminator,
52}
53
54impl<'a> WriteNode<'a> {
55 pub fn up(&self) -> bool {
57 match self {
58 WriteNode::Leaf { up, .. } => *up,
59 WriteNode::Internal { .. } => true,
60 WriteNode::Terminator => false,
61 }
62 }
63
64 pub fn down(&self) -> &BitSlice<u8, Msb0> {
66 match self {
67 WriteNode::Leaf { down, .. } => down,
68 _ => BitSlice::empty(),
69 }
70 }
71
72 pub fn node(&self) -> Node {
74 match self {
75 WriteNode::Leaf { node, .. } => *node,
76 WriteNode::Internal { node, .. } => *node,
77 WriteNode::Terminator => trie::TERMINATOR,
78 }
79 }
80}
81
82pub fn build_trie<H: NodeHasher>(
95 skip: usize,
96 ops: impl IntoIterator<Item = (KeyPath, ValueHash)>,
97 mut visit: impl FnMut(WriteNode),
98) -> Node {
99 let mut pending_siblings: Vec<(Node, usize)> = Vec::new();
130
131 let mut leaf_ops = ops.into_iter();
132
133 let mut a = None;
134 let mut b = leaf_ops.next();
135 let mut c = leaf_ops.next();
136
137 match (b, c) {
138 (None, _) => {
139 visit(WriteNode::Terminator);
141 return trie::TERMINATOR;
142 }
143 (Some((ref k, ref v)), None) => {
144 let leaf_data = trie::LeafData {
146 key_path: *k,
147 value_hash: *v,
148 };
149 let leaf = H::hash_leaf(&leaf_data);
150 visit(WriteNode::Leaf {
151 up: false,
152 down: BitSlice::empty(),
153 leaf_data,
154 node: leaf,
155 });
156
157 return leaf;
158 }
159 _ => {}
160 }
161
162 let common_after_prefix = |k1: &KeyPath, k2: &KeyPath| {
163 let x = &k1.view_bits::<Msb0>()[skip..];
164 let y = &k2.view_bits::<Msb0>()[skip..];
165 shared_bits(x, y)
166 };
167
168 while let Some((this_key, this_val)) = b {
169 let n1 = a.as_ref().map(|(k, _)| common_after_prefix(k, &this_key));
170 let n2 = c.as_ref().map(|(k, _)| common_after_prefix(k, &this_key));
171
172 let leaf_data = trie::LeafData {
173 key_path: this_key,
174 value_hash: this_val,
175 };
176 let leaf = H::hash_leaf(&leaf_data);
177 let (leaf_depth, hash_up_layers) = match (n1, n2) {
178 (None, None) => {
179 (0, 0)
181 }
182 (None, Some(n2)) => {
183 (n2 + 1, 0)
185 }
186 (Some(n1), None) => {
187 (n1 + 1, n1 + 1)
189 }
190 (Some(n1), Some(n2)) => {
191 (core::cmp::max(n1, n2) + 1, n1.saturating_sub(n2))
193 }
194 };
195
196 let mut layer = leaf_depth;
197 let mut last_node = leaf;
198 let down_start = skip + n1.unwrap_or(0);
199 let leaf_end_bit = skip + leaf_depth;
200
201 visit(WriteNode::Leaf {
202 up: n1.is_some(), down: &this_key.view_bits::<Msb0>()[down_start..leaf_end_bit],
204 node: leaf,
205 leaf_data,
206 });
207
208 for bit in this_key.view_bits::<Msb0>()[skip..leaf_end_bit]
209 .iter()
210 .by_vals()
211 .rev()
212 .take(hash_up_layers)
213 {
214 layer -= 1;
215
216 let sibling = if pending_siblings.last().map_or(false, |l| l.1 == layer + 1) {
217 pending_siblings.pop().unwrap().0
219 } else {
220 trie::TERMINATOR
221 };
222
223 let internal_data = if bit {
224 trie::InternalData {
225 left: sibling,
226 right: last_node,
227 }
228 } else {
229 trie::InternalData {
230 left: last_node,
231 right: sibling,
232 }
233 };
234
235 last_node = H::hash_internal(&internal_data);
236 visit(WriteNode::Internal {
237 internal_data,
238 node: last_node,
239 });
240 }
241 pending_siblings.push((last_node, layer));
242
243 a = Some((this_key, this_val));
244 b = c;
245 c = leaf_ops.next();
246 }
247
248 let new_root = pending_siblings
249 .pop()
250 .map(|n| n.0)
251 .unwrap_or(trie::TERMINATOR);
252 new_root
253}
254
255#[cfg(test)]
256mod tests {
257 use crate::trie::{NodeKind, TERMINATOR};
258
259 use super::{bitvec, build_trie, trie, BitVec, LeafData, Msb0, Node, NodeHasher, WriteNode};
260
261 struct DummyNodeHasher;
262
263 impl NodeHasher for DummyNodeHasher {
264 fn hash_leaf(data: &trie::LeafData) -> [u8; 32] {
265 let mut hasher = blake3::Hasher::new();
266 hasher.update(&data.key_path);
267 hasher.update(&data.value_hash);
268 let mut hash: [u8; 32] = hasher.finalize().into();
269
270 hash[0] |= 0b10000000;
272 hash
273 }
274
275 fn hash_internal(data: &trie::InternalData) -> [u8; 32] {
276 let mut hasher = blake3::Hasher::new();
277 hasher.update(&data.left);
278 hasher.update(&data.right);
279 let mut hash: [u8; 32] = hasher.finalize().into();
280
281 hash[0] &= 0b01111111;
283 hash
284 }
285
286 fn node_kind(node: &Node) -> NodeKind {
287 if node[0] >> 7 == 1 {
288 NodeKind::Leaf
289 } else if node == &TERMINATOR {
290 NodeKind::Terminator
291 } else {
292 NodeKind::Internal
293 }
294 }
295 }
296
297 fn leaf(key: u8) -> (LeafData, [u8; 32]) {
298 let key = [key; 32];
299 let leaf = trie::LeafData {
300 key_path: key.clone(),
301 value_hash: key.clone(),
302 };
303
304 let hash = DummyNodeHasher::hash_leaf(&leaf);
305 (leaf, hash)
306 }
307
308 fn branch_hash(left: [u8; 32], right: [u8; 32]) -> [u8; 32] {
309 let data = trie::InternalData { left, right };
310
311 let hash = DummyNodeHasher::hash_internal(&data);
312 hash
313 }
314
315 #[derive(Default)]
316 struct Visited {
317 key: BitVec<u8, Msb0>,
318 visited: Vec<(BitVec<u8, Msb0>, Node)>,
319 }
320
321 impl Visited {
322 fn at(key: BitVec<u8, Msb0>) -> Self {
323 Visited {
324 key,
325 visited: Vec::new(),
326 }
327 }
328
329 fn visit(&mut self, control: WriteNode) {
330 let n = self.key.len() - control.up() as usize;
331 self.key.truncate(n);
332 self.key.extend_from_bitslice(control.down());
333 self.visited.push((self.key.clone(), control.node()));
334 }
335 }
336
337 #[test]
338 fn build_empty_trie() {
339 let mut visited = Visited::default();
340 let root = build_trie::<DummyNodeHasher>(0, vec![], |control| visited.visit(control));
341
342 let visited = visited.visited;
343
344 assert_eq!(visited, vec![(bitvec![u8, Msb0;], [0u8; 32]),],);
345
346 assert_eq!(root, [0u8; 32]);
347 }
348
349 #[test]
350 fn build_single_value_trie() {
351 let mut visited = Visited::default();
352
353 let (leaf, leaf_hash) = leaf(0xff);
354 let root =
355 build_trie::<DummyNodeHasher>(0, vec![(leaf.key_path, leaf.value_hash)], |control| {
356 visited.visit(control)
357 });
358
359 let visited = visited.visited;
360
361 assert_eq!(visited, vec![(bitvec![u8, Msb0;], leaf_hash),],);
362
363 assert_eq!(root, leaf_hash);
364 }
365
366 #[test]
367 fn sub_trie() {
368 let (leaf_a, leaf_hash_a) = leaf(0b0001_0001);
369 let (leaf_b, leaf_hash_b) = leaf(0b0001_0010);
370 let (leaf_c, leaf_hash_c) = leaf(0b0001_0100);
371
372 let mut visited = Visited::at(bitvec![u8, Msb0; 0, 0, 0, 1]);
373
374 let ops = [leaf_a, leaf_b, leaf_c]
375 .iter()
376 .map(|l| (l.key_path, l.value_hash))
377 .collect::<Vec<_>>();
378
379 let root = build_trie::<DummyNodeHasher>(4, ops, |control| visited.visit(control));
380
381 let visited = visited.visited;
382
383 let branch_ab_hash = branch_hash(leaf_hash_a, leaf_hash_b);
384 let branch_abc_hash = branch_hash(branch_ab_hash, leaf_hash_c);
385 let root_branch_hash = branch_hash(branch_abc_hash, [0u8; 32]);
386
387 assert_eq!(
388 visited,
389 vec![
390 (bitvec![u8, Msb0; 0, 0, 0, 1, 0, 0, 0], leaf_hash_a),
391 (bitvec![u8, Msb0; 0, 0, 0, 1, 0, 0, 1], leaf_hash_b),
392 (bitvec![u8, Msb0; 0, 0, 0, 1, 0, 0], branch_ab_hash),
393 (bitvec![u8, Msb0; 0, 0, 0, 1, 0, 1], leaf_hash_c),
394 (bitvec![u8, Msb0; 0, 0, 0, 1, 0], branch_abc_hash),
395 (bitvec![u8, Msb0; 0, 0, 0, 1], root_branch_hash),
396 ],
397 );
398
399 assert_eq!(root, root_branch_hash);
400 }
401
402 #[test]
403 fn multi_value() {
404 let (leaf_a, leaf_hash_a) = leaf(0b0001_0000);
405 let (leaf_b, leaf_hash_b) = leaf(0b0010_0000);
406 let (leaf_c, leaf_hash_c) = leaf(0b0100_0000);
407 let (leaf_d, leaf_hash_d) = leaf(0b1010_0000);
408 let (leaf_e, leaf_hash_e) = leaf(0b1011_0000);
409
410 let mut visited = Visited::default();
411
412 let ops = [leaf_a, leaf_b, leaf_c, leaf_d, leaf_e]
413 .iter()
414 .map(|l| (l.key_path, l.value_hash))
415 .collect::<Vec<_>>();
416
417 let root = build_trie::<DummyNodeHasher>(0, ops, |control| visited.visit(control));
418
419 let visited = visited.visited;
420
421 let branch_ab_hash = branch_hash(leaf_hash_a, leaf_hash_b);
422 let branch_abc_hash = branch_hash(branch_ab_hash, leaf_hash_c);
423
424 let branch_de_hash_1 = branch_hash(leaf_hash_d, leaf_hash_e);
425 let branch_de_hash_2 = branch_hash([0u8; 32], branch_de_hash_1);
426 let branch_de_hash_3 = branch_hash(branch_de_hash_2, [0u8; 32]);
427
428 let branch_abc_de_hash = branch_hash(branch_abc_hash, branch_de_hash_3);
429
430 assert_eq!(
431 visited,
432 vec![
433 (bitvec![u8, Msb0; 0, 0, 0], leaf_hash_a),
434 (bitvec![u8, Msb0; 0, 0, 1], leaf_hash_b),
435 (bitvec![u8, Msb0; 0, 0], branch_ab_hash),
436 (bitvec![u8, Msb0; 0, 1], leaf_hash_c),
437 (bitvec![u8, Msb0; 0], branch_abc_hash),
438 (bitvec![u8, Msb0; 1, 0, 1, 0], leaf_hash_d),
439 (bitvec![u8, Msb0; 1, 0, 1, 1], leaf_hash_e),
440 (bitvec![u8, Msb0; 1, 0, 1], branch_de_hash_1),
441 (bitvec![u8, Msb0; 1, 0], branch_de_hash_2),
442 (bitvec![u8, Msb0; 1], branch_de_hash_3),
443 (bitvec![u8, Msb0;], branch_abc_de_hash),
444 ],
445 );
446
447 assert_eq!(root, branch_abc_de_hash);
448 }
449}