nomt_core/
update.rs

1//! Trie update logic helpers.
2
3use crate::hasher::NodeHasher;
4use crate::trie::{self, KeyPath, LeafData, Node, ValueHash};
5
6use bitvec::prelude::*;
7
8#[cfg(not(feature = "std"))]
9use alloc::vec::Vec;
10
11// TODO: feels extremely out of place.
12pub(crate) fn shared_bits(a: &BitSlice<u8, Msb0>, b: &BitSlice<u8, Msb0>) -> usize {
13    a.iter().zip(b.iter()).take_while(|(a, b)| a == b).count()
14}
15
16/// Creates an iterator of all provided operations, with the leaf value spliced in if its key
17/// does not appear in the original ops list. Then filters out all `None`s.
18pub fn leaf_ops_spliced(
19    leaf: Option<LeafData>,
20    ops: &[(KeyPath, Option<ValueHash>)],
21) -> impl Iterator<Item = (KeyPath, ValueHash)> + Clone + '_ {
22    let splice_index = leaf
23        .as_ref()
24        .and_then(|leaf| ops.binary_search_by_key(&leaf.key_path, |x| x.0).err());
25    let preserve_value = splice_index
26        .zip(leaf)
27        .map(|(_, leaf)| (leaf.key_path, Some(leaf.value_hash)));
28    let splice_index = splice_index.unwrap_or(0);
29
30    // splice: before / item / after
31    // skip deleted.
32    ops[..splice_index]
33        .into_iter()
34        .cloned()
35        .chain(preserve_value)
36        .chain(ops[splice_index..].into_iter().cloned())
37        .filter_map(|(k, o)| o.map(move |value| (k, value)))
38}
39
40pub enum WriteNode<'a> {
41    Leaf {
42        up: bool,
43        down: &'a BitSlice<u8, Msb0>,
44        leaf_data: LeafData,
45        node: Node,
46    },
47    Internal {
48        internal_data: trie::InternalData,
49        node: Node,
50    },
51    Terminator,
52}
53
54impl<'a> WriteNode<'a> {
55    /// Whether to move up a step before writing the node.
56    pub fn up(&self) -> bool {
57        match self {
58            WriteNode::Leaf { up, .. } => *up,
59            WriteNode::Internal { .. } => true,
60            WriteNode::Terminator => false,
61        }
62    }
63
64    /// What path to follow down (after going up) before writing the node.
65    pub fn down(&self) -> &BitSlice<u8, Msb0> {
66        match self {
67            WriteNode::Leaf { down, .. } => down,
68            _ => BitSlice::empty(),
69        }
70    }
71
72    /// The node itself.
73    pub fn node(&self) -> Node {
74        match self {
75            WriteNode::Leaf { node, .. } => *node,
76            WriteNode::Internal { node, .. } => *node,
77            WriteNode::Terminator => trie::TERMINATOR,
78        }
79    }
80}
81
82// Build a trie out of the given prior terminal and operations. Operations should all start
83// with the same prefix of len `skip` and be ordered lexicographically. The root node of the
84// generated trie is the one residing at path `prefix[..skip]`. When skip=0, this is the actual
85// root.
86//
87// Provide a visitor which will be called for each computed node of the trie.
88//
89// The visitor is assumed to have a default position at the root of the trie and from
90// there will be controlled with `WriteNode`. The changes to the position before writing the node
91// can be extracted from the command.
92// The root is always visited at the end. If the written node is a leaf, the leaf-data preimage
93// will be provided.
94pub fn build_trie<H: NodeHasher>(
95    skip: usize,
96    ops: impl IntoIterator<Item = (KeyPath, ValueHash)>,
97    mut visit: impl FnMut(WriteNode),
98) -> Node {
99    // we build a compact addressable sub-trie in-place based on the given set of ordered keys,
100    // ignoring deletions as they are implicit in a fresh sub-trie.
101    //
102    // an algorithm for building the compact sub-trie follows:
103    //
104    // consider any three leaves, A, B, C in sorted order by key, with different keys.
105    // A and B have some number of shared bits n1
106    // B and C have some number of shared bits n2
107    //
108    // We can make an accurate statement about the position of B regardless of what other material
109    // appears in the trie, as long as there is no A' s.t. A < A' < B and no C' s.t. B < C' < C.
110    //
111    // A is a leaf somewhere to the left of B, which is in turn somewhere to the left of C
112    // A and B share an internal node at depth n1, while B and C share an internal node at depth n2.
113    // n1 cannot equal n2, as there are only 2 keys with shared prefix n and a != b != c.
114    // If n1 is less than n2, then B is a leaf at depth n2+1 along its path (always left)
115    // If n2 is less than n1, then B is a leaf at depth n1+1 along its path (always right)
116    // QED
117    //
118    // A similar process applies to the first leaf in the list: it is a leaf on the left of an
119    // internal node at depth n, where n is the number of shared bits with the following key.
120    //
121    // Same for the last leaf in the list: it is on the right of an internal node at depth n,
122    // where n is the number of shared bits with the previous key.
123    //
124    // If the list has a single item, the sub-trie is a single leaf.
125    // And if the list is empty, the sub-trie is a terminator.
126
127    // A left-frontier: all modified nodes are to the left of
128    // `b`, so this stores their layers.
129    let mut pending_siblings: Vec<(Node, usize)> = Vec::new();
130
131    let mut leaf_ops = ops.into_iter();
132
133    let mut a = None;
134    let mut b = leaf_ops.next();
135    let mut c = leaf_ops.next();
136
137    match (b, c) {
138        (None, _) => {
139            // fast path: delete single node.
140            visit(WriteNode::Terminator);
141            return trie::TERMINATOR;
142        }
143        (Some((ref k, ref v)), None) => {
144            // fast path: place single leaf.
145            let leaf_data = trie::LeafData {
146                key_path: *k,
147                value_hash: *v,
148            };
149            let leaf = H::hash_leaf(&leaf_data);
150            visit(WriteNode::Leaf {
151                up: false,
152                down: BitSlice::empty(),
153                leaf_data,
154                node: leaf,
155            });
156
157            return leaf;
158        }
159        _ => {}
160    }
161
162    let common_after_prefix = |k1: &KeyPath, k2: &KeyPath| {
163        let x = &k1.view_bits::<Msb0>()[skip..];
164        let y = &k2.view_bits::<Msb0>()[skip..];
165        shared_bits(x, y)
166    };
167
168    while let Some((this_key, this_val)) = b {
169        let n1 = a.as_ref().map(|(k, _)| common_after_prefix(k, &this_key));
170        let n2 = c.as_ref().map(|(k, _)| common_after_prefix(k, &this_key));
171
172        let leaf_data = trie::LeafData {
173            key_path: this_key,
174            value_hash: this_val,
175        };
176        let leaf = H::hash_leaf(&leaf_data);
177        let (leaf_depth, hash_up_layers) = match (n1, n2) {
178            (None, None) => {
179                // single value - no hashing required.
180                (0, 0)
181            }
182            (None, Some(n2)) => {
183                // first value, n2 ancestor will be affected by next.
184                (n2 + 1, 0)
185            }
186            (Some(n1), None) => {
187                // last value, hash up to sub-trie root.
188                (n1 + 1, n1 + 1)
189            }
190            (Some(n1), Some(n2)) => {
191                // middle value, hash up to incoming ancestor + 1.
192                (core::cmp::max(n1, n2) + 1, n1.saturating_sub(n2))
193            }
194        };
195
196        let mut layer = leaf_depth;
197        let mut last_node = leaf;
198        let down_start = skip + n1.unwrap_or(0);
199        let leaf_end_bit = skip + leaf_depth;
200
201        visit(WriteNode::Leaf {
202            up: n1.is_some(), // previous iterations always get to current layer + 1
203            down: &this_key.view_bits::<Msb0>()[down_start..leaf_end_bit],
204            node: leaf,
205            leaf_data,
206        });
207
208        for bit in this_key.view_bits::<Msb0>()[skip..leaf_end_bit]
209            .iter()
210            .by_vals()
211            .rev()
212            .take(hash_up_layers)
213        {
214            layer -= 1;
215
216            let sibling = if pending_siblings.last().map_or(false, |l| l.1 == layer + 1) {
217                // unwrap: just checked
218                pending_siblings.pop().unwrap().0
219            } else {
220                trie::TERMINATOR
221            };
222
223            let internal_data = if bit {
224                trie::InternalData {
225                    left: sibling,
226                    right: last_node,
227                }
228            } else {
229                trie::InternalData {
230                    left: last_node,
231                    right: sibling,
232                }
233            };
234
235            last_node = H::hash_internal(&internal_data);
236            visit(WriteNode::Internal {
237                internal_data,
238                node: last_node,
239            });
240        }
241        pending_siblings.push((last_node, layer));
242
243        a = Some((this_key, this_val));
244        b = c;
245        c = leaf_ops.next();
246    }
247
248    let new_root = pending_siblings
249        .pop()
250        .map(|n| n.0)
251        .unwrap_or(trie::TERMINATOR);
252    new_root
253}
254
255#[cfg(test)]
256mod tests {
257    use crate::trie::{NodeKind, TERMINATOR};
258
259    use super::{bitvec, build_trie, trie, BitVec, LeafData, Msb0, Node, NodeHasher, WriteNode};
260
261    struct DummyNodeHasher;
262
263    impl NodeHasher for DummyNodeHasher {
264        fn hash_leaf(data: &trie::LeafData) -> [u8; 32] {
265            let mut hasher = blake3::Hasher::new();
266            hasher.update(&data.key_path);
267            hasher.update(&data.value_hash);
268            let mut hash: [u8; 32] = hasher.finalize().into();
269
270            // Label with MSB
271            hash[0] |= 0b10000000;
272            hash
273        }
274
275        fn hash_internal(data: &trie::InternalData) -> [u8; 32] {
276            let mut hasher = blake3::Hasher::new();
277            hasher.update(&data.left);
278            hasher.update(&data.right);
279            let mut hash: [u8; 32] = hasher.finalize().into();
280
281            // Label with MSB
282            hash[0] &= 0b01111111;
283            hash
284        }
285
286        fn node_kind(node: &Node) -> NodeKind {
287            if node[0] >> 7 == 1 {
288                NodeKind::Leaf
289            } else if node == &TERMINATOR {
290                NodeKind::Terminator
291            } else {
292                NodeKind::Internal
293            }
294        }
295    }
296
297    fn leaf(key: u8) -> (LeafData, [u8; 32]) {
298        let key = [key; 32];
299        let leaf = trie::LeafData {
300            key_path: key.clone(),
301            value_hash: key.clone(),
302        };
303
304        let hash = DummyNodeHasher::hash_leaf(&leaf);
305        (leaf, hash)
306    }
307
308    fn branch_hash(left: [u8; 32], right: [u8; 32]) -> [u8; 32] {
309        let data = trie::InternalData { left, right };
310
311        let hash = DummyNodeHasher::hash_internal(&data);
312        hash
313    }
314
315    #[derive(Default)]
316    struct Visited {
317        key: BitVec<u8, Msb0>,
318        visited: Vec<(BitVec<u8, Msb0>, Node)>,
319    }
320
321    impl Visited {
322        fn at(key: BitVec<u8, Msb0>) -> Self {
323            Visited {
324                key,
325                visited: Vec::new(),
326            }
327        }
328
329        fn visit(&mut self, control: WriteNode) {
330            let n = self.key.len() - control.up() as usize;
331            self.key.truncate(n);
332            self.key.extend_from_bitslice(control.down());
333            self.visited.push((self.key.clone(), control.node()));
334        }
335    }
336
337    #[test]
338    fn build_empty_trie() {
339        let mut visited = Visited::default();
340        let root = build_trie::<DummyNodeHasher>(0, vec![], |control| visited.visit(control));
341
342        let visited = visited.visited;
343
344        assert_eq!(visited, vec![(bitvec![u8, Msb0;], [0u8; 32]),],);
345
346        assert_eq!(root, [0u8; 32]);
347    }
348
349    #[test]
350    fn build_single_value_trie() {
351        let mut visited = Visited::default();
352
353        let (leaf, leaf_hash) = leaf(0xff);
354        let root =
355            build_trie::<DummyNodeHasher>(0, vec![(leaf.key_path, leaf.value_hash)], |control| {
356                visited.visit(control)
357            });
358
359        let visited = visited.visited;
360
361        assert_eq!(visited, vec![(bitvec![u8, Msb0;], leaf_hash),],);
362
363        assert_eq!(root, leaf_hash);
364    }
365
366    #[test]
367    fn sub_trie() {
368        let (leaf_a, leaf_hash_a) = leaf(0b0001_0001);
369        let (leaf_b, leaf_hash_b) = leaf(0b0001_0010);
370        let (leaf_c, leaf_hash_c) = leaf(0b0001_0100);
371
372        let mut visited = Visited::at(bitvec![u8, Msb0; 0, 0, 0, 1]);
373
374        let ops = [leaf_a, leaf_b, leaf_c]
375            .iter()
376            .map(|l| (l.key_path, l.value_hash))
377            .collect::<Vec<_>>();
378
379        let root = build_trie::<DummyNodeHasher>(4, ops, |control| visited.visit(control));
380
381        let visited = visited.visited;
382
383        let branch_ab_hash = branch_hash(leaf_hash_a, leaf_hash_b);
384        let branch_abc_hash = branch_hash(branch_ab_hash, leaf_hash_c);
385        let root_branch_hash = branch_hash(branch_abc_hash, [0u8; 32]);
386
387        assert_eq!(
388            visited,
389            vec![
390                (bitvec![u8, Msb0; 0, 0, 0, 1, 0, 0, 0], leaf_hash_a),
391                (bitvec![u8, Msb0; 0, 0, 0, 1, 0, 0, 1], leaf_hash_b),
392                (bitvec![u8, Msb0; 0, 0, 0, 1, 0, 0], branch_ab_hash),
393                (bitvec![u8, Msb0; 0, 0, 0, 1, 0, 1], leaf_hash_c),
394                (bitvec![u8, Msb0; 0, 0, 0, 1, 0], branch_abc_hash),
395                (bitvec![u8, Msb0; 0, 0, 0, 1], root_branch_hash),
396            ],
397        );
398
399        assert_eq!(root, root_branch_hash);
400    }
401
402    #[test]
403    fn multi_value() {
404        let (leaf_a, leaf_hash_a) = leaf(0b0001_0000);
405        let (leaf_b, leaf_hash_b) = leaf(0b0010_0000);
406        let (leaf_c, leaf_hash_c) = leaf(0b0100_0000);
407        let (leaf_d, leaf_hash_d) = leaf(0b1010_0000);
408        let (leaf_e, leaf_hash_e) = leaf(0b1011_0000);
409
410        let mut visited = Visited::default();
411
412        let ops = [leaf_a, leaf_b, leaf_c, leaf_d, leaf_e]
413            .iter()
414            .map(|l| (l.key_path, l.value_hash))
415            .collect::<Vec<_>>();
416
417        let root = build_trie::<DummyNodeHasher>(0, ops, |control| visited.visit(control));
418
419        let visited = visited.visited;
420
421        let branch_ab_hash = branch_hash(leaf_hash_a, leaf_hash_b);
422        let branch_abc_hash = branch_hash(branch_ab_hash, leaf_hash_c);
423
424        let branch_de_hash_1 = branch_hash(leaf_hash_d, leaf_hash_e);
425        let branch_de_hash_2 = branch_hash([0u8; 32], branch_de_hash_1);
426        let branch_de_hash_3 = branch_hash(branch_de_hash_2, [0u8; 32]);
427
428        let branch_abc_de_hash = branch_hash(branch_abc_hash, branch_de_hash_3);
429
430        assert_eq!(
431            visited,
432            vec![
433                (bitvec![u8, Msb0; 0, 0, 0], leaf_hash_a),
434                (bitvec![u8, Msb0; 0, 0, 1], leaf_hash_b),
435                (bitvec![u8, Msb0; 0, 0], branch_ab_hash),
436                (bitvec![u8, Msb0; 0, 1], leaf_hash_c),
437                (bitvec![u8, Msb0; 0], branch_abc_hash),
438                (bitvec![u8, Msb0; 1, 0, 1, 0], leaf_hash_d),
439                (bitvec![u8, Msb0; 1, 0, 1, 1], leaf_hash_e),
440                (bitvec![u8, Msb0; 1, 0, 1], branch_de_hash_1),
441                (bitvec![u8, Msb0; 1, 0], branch_de_hash_2),
442                (bitvec![u8, Msb0; 1], branch_de_hash_3),
443                (bitvec![u8, Msb0;], branch_abc_de_hash),
444            ],
445        );
446
447        assert_eq!(root, branch_abc_de_hash);
448    }
449}