smoldot/trie/
calculate_root.rs

1// Smoldot
2// Copyright (C) 2023  Pierre Krieger
3// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
4
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13// GNU General Public License for more details.
14
15// You should have received a copy of the GNU General Public License
16// along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18//! Freestanding function that calculates the root of a radix-16 Merkle-Patricia trie.
19//!
20//! See the parent module documentation for an explanation of what the trie is.
21//!
22//! This module is meant to be used in situations where all the nodes of the trie that have a
23//! storage value associated to them are known and easily accessible, and that no cache is
24//! available.
25//!
26//! # Usage
27//!
28//! Calling the [`root_merkle_value`] function creates a [`RootMerkleValueCalculation`] object
29//! which you have to drive to completion.
30//!
31//! Example:
32//!
33//! ```
34//! use std::{collections::BTreeMap, ops::Bound};
35//! use smoldot::trie::{HashFunction, TrieEntryVersion, calculate_root};
36//!
37//! // In this example, the storage consists in a binary tree map.
38//! let mut storage = BTreeMap::<Vec<u8>, (Vec<u8>, TrieEntryVersion)>::new();
39//! storage.insert(b"foo".to_vec(), (b"bar".to_vec(), TrieEntryVersion::V1));
40//!
41//! let trie_root = {
42//!     let mut calculation = calculate_root::root_merkle_value(HashFunction::Blake2);
43//!     loop {
44//!         match calculation {
45//!             calculate_root::RootMerkleValueCalculation::Finished { hash, .. } => break hash,
46//!             calculate_root::RootMerkleValueCalculation::NextKey(next_key) => {
47//!                 let key_before = next_key.key_before().collect::<Vec<_>>();
48//!                 let lower_bound = if next_key.or_equal() {
49//!                     Bound::Included(key_before)
50//!                 } else {
51//!                     Bound::Excluded(key_before)
52//!                 };
53//!                 let outcome = storage
54//!                     .range((lower_bound, Bound::Unbounded))
55//!                     .next()
56//!                     .filter(|(k, _)| {
57//!                         k.iter()
58//!                             .copied()
59//!                             .zip(next_key.prefix())
60//!                             .all(|(a, b)| a == b)
61//!                     })
62//!                     .map(|(k, _)| k);
63//!                 calculation = next_key.inject_key(outcome.map(|k| k.iter().copied()));
64//!             }
65//!             calculate_root::RootMerkleValueCalculation::StorageValue(value_request) => {
66//!                 let key = value_request.key().collect::<Vec<u8>>();
67//!                 calculation = value_request.inject(storage.get(&key).map(|(val, v)| (val, *v)));
68//!             }
69//!         }
70//!     }
71//! };
72//!
73//! assert_eq!(
74//!     trie_root,
75//!     [204, 86, 28, 213, 155, 206, 247, 145, 28, 169, 212, 146, 182, 159, 224, 82,
76//!      116, 162, 143, 156, 19, 43, 183, 8, 41, 178, 204, 69, 41, 37, 224, 91]
77//! );
78//! ```
79//!
80
81use super::{
82    EMPTY_BLAKE2_TRIE_MERKLE_VALUE, EMPTY_KECCAK256_TRIE_MERKLE_VALUE, HashFunction,
83    TrieEntryVersion, branch_search,
84    nibble::{Nibble, nibbles_to_bytes_suffix_extend},
85    trie_node,
86};
87
88use alloc::vec::Vec;
89use core::array;
90
91/// Start calculating the Merkle value of the root node.
92pub fn root_merkle_value(hash_function: HashFunction) -> RootMerkleValueCalculation {
93    CalcInner {
94        hash_function,
95        stack: Vec::with_capacity(8),
96    }
97    .next()
98}
99
100/// Current state of the [`RootMerkleValueCalculation`] and how to continue.
101#[must_use]
102pub enum RootMerkleValueCalculation {
103    /// The calculation is finished.
104    Finished {
105        /// Root hash that has been calculated.
106        hash: [u8; 32],
107    },
108
109    /// Request to return the key that follows (in lexicographic order) a given one in the storage.
110    /// Call [`NextKey::inject_key`] to indicate this list.
111    NextKey(NextKey),
112
113    /// Request the value of the node with a specific key. Call [`StorageValue::inject`] to
114    /// indicate the value.
115    StorageValue(StorageValue),
116}
117
118/// Calculation of the Merkle value is ready to continue.
119/// Shared by all the public-facing structs.
120struct CalcInner {
121    /// Hash function used by the trie.
122    hash_function: HashFunction,
123    /// Stack of nodes whose value is currently being calculated.
124    stack: Vec<Node>,
125}
126
127#[derive(Debug)]
128struct Node {
129    /// Partial key of the node currently being calculated.
130    partial_key: Vec<Nibble>,
131    /// Merkle values of the children of the node. Filled up to 16 elements, then popped. Each
132    /// element is `Some` or `None` depending on whether a child exists.
133    children: arrayvec::ArrayVec<Option<trie_node::MerkleValueOutput>, 16>,
134}
135
136impl CalcInner {
137    /// Returns the full key of the node currently being iterated.
138    fn current_iter_node_full_key(&self) -> impl Iterator<Item = Nibble> {
139        self.stack.iter().flat_map(|node| {
140            let child_nibble = if node.children.len() == 16 {
141                None
142            } else {
143                Some(Nibble::try_from(u8::try_from(node.children.len()).unwrap()).unwrap())
144            };
145
146            node.partial_key.iter().copied().chain(child_nibble)
147        })
148    }
149
150    /// Advances the calculation to the next step.
151    fn next(mut self) -> RootMerkleValueCalculation {
152        loop {
153            // If all the children of the node at the end of the stack are known, calculate the Merkle
154            // value of that node. To do so, we need to ask the user for the storage value.
155            if self
156                .stack
157                .last()
158                .map_or(false, |node| node.children.len() == 16)
159            {
160                // If the key has an even number of nibbles, we need to ask the user for the
161                // storage value.
162                if self.current_iter_node_full_key().count() % 2 == 0 {
163                    break RootMerkleValueCalculation::StorageValue(StorageValue {
164                        calculation: self,
165                    });
166                }
167
168                // Otherwise we can calculate immediately.
169                let calculated_elem = self.stack.pop().unwrap();
170
171                // Calculate the Merkle value of the node.
172                let merkle_value = trie_node::calculate_merkle_value(
173                    trie_node::Decoded {
174                        children: array::from_fn(|n| calculated_elem.children[n].as_ref()),
175                        partial_key: calculated_elem.partial_key.iter().copied(),
176                        storage_value: trie_node::StorageValue::None,
177                    },
178                    self.hash_function,
179                    self.stack.is_empty(),
180                )
181                .unwrap_or_else(|_| unreachable!());
182
183                // Insert Merkle value into the stack, or, if no parent, we have our result!
184                if let Some(parent) = self.stack.last_mut() {
185                    parent.children.push(Some(merkle_value));
186                } else {
187                    // Because we pass `is_root_node: true` in the calculation above, it is
188                    // guaranteed that the Merkle value is always 32 bytes.
189                    let hash = *<&[u8; 32]>::try_from(merkle_value.as_ref()).unwrap();
190                    break RootMerkleValueCalculation::Finished { hash };
191                }
192            } else {
193                // Need to find the closest descendant to the first unknown child at the top of the
194                // stack.
195                break RootMerkleValueCalculation::NextKey(NextKey {
196                    branch_search: branch_search::start_branch_search(branch_search::Config {
197                        key_before: self.current_iter_node_full_key(),
198                        or_equal: true,
199                        prefix: self.current_iter_node_full_key(),
200                        no_branch_search: false,
201                    }),
202                    calculation: self,
203                });
204            }
205        }
206    }
207}
208
209/// Request to return the key that follows (in lexicographic order) a given one in the storage.
210/// Call [`NextKey::inject_key`] to indicate this list.
211#[must_use]
212pub struct NextKey {
213    calculation: CalcInner,
214
215    /// Current branch search running to find the closest descendant to the node at the top of
216    /// the trie.
217    branch_search: branch_search::NextKey,
218}
219
220impl NextKey {
221    /// Returns the key whose next key must be passed back.
222    pub fn key_before(&self) -> impl Iterator<Item = u8> {
223        self.branch_search.key_before()
224    }
225
226    /// If `true`, then the provided value must the one superior or equal to the requested key.
227    /// If `false`, then the provided value must be strictly superior to the requested key.
228    pub fn or_equal(&self) -> bool {
229        self.branch_search.or_equal()
230    }
231
232    /// Returns the prefix the next key must start with. If the next key doesn't start with the
233    /// given prefix, then `None` should be provided.
234    pub fn prefix(&self) -> impl Iterator<Item = u8> {
235        self.branch_search.prefix()
236    }
237
238    /// Injects the key.
239    ///
240    /// # Panic
241    ///
242    /// Panics if the key passed as parameter isn't strictly superior to the requested key.
243    ///
244    pub fn inject_key(
245        mut self,
246        key: Option<impl Iterator<Item = u8>>,
247    ) -> RootMerkleValueCalculation {
248        match self.branch_search.inject(key) {
249            branch_search::BranchSearch::NextKey(next_key) => {
250                RootMerkleValueCalculation::NextKey(NextKey {
251                    calculation: self.calculation,
252                    branch_search: next_key,
253                })
254            }
255            branch_search::BranchSearch::Found {
256                branch_trie_node_key,
257            } => {
258                // Add the closest descendant to the stack.
259                if let Some(branch_trie_node_key) = branch_trie_node_key {
260                    let partial_key = branch_trie_node_key
261                        .skip(self.calculation.current_iter_node_full_key().count())
262                        .collect();
263                    self.calculation.stack.push(Node {
264                        partial_key,
265                        children: arrayvec::ArrayVec::new(),
266                    });
267                    self.calculation.next()
268                } else if let Some(stack_top) = self.calculation.stack.last_mut() {
269                    stack_top.children.push(None);
270                    self.calculation.next()
271                } else {
272                    // Trie is completely empty.
273                    RootMerkleValueCalculation::Finished {
274                        hash: match self.calculation.hash_function {
275                            HashFunction::Blake2 => EMPTY_BLAKE2_TRIE_MERKLE_VALUE,
276                            HashFunction::Keccak256 => EMPTY_KECCAK256_TRIE_MERKLE_VALUE,
277                        },
278                    }
279                }
280            }
281        }
282    }
283}
284
285/// Request the value of the node with a specific key. Call [`StorageValue::inject`] to indicate
286/// the value.
287#[must_use]
288pub struct StorageValue {
289    calculation: CalcInner,
290}
291
292impl StorageValue {
293    /// Returns the key whose value is being requested.
294    pub fn key(&self) -> impl Iterator<Item = u8> {
295        // This function can never be reached if the number of nibbles is uneven.
296        debug_assert_eq!(self.calculation.current_iter_node_full_key().count() % 2, 0);
297        nibbles_to_bytes_suffix_extend(self.calculation.current_iter_node_full_key())
298    }
299
300    /// Indicates the storage value and advances the calculation.
301    pub fn inject(
302        mut self,
303        storage_value: Option<(impl AsRef<[u8]>, TrieEntryVersion)>,
304    ) -> RootMerkleValueCalculation {
305        let calculated_elem = self.calculation.stack.pop().unwrap();
306
307        // Due to some borrow checker troubles, we need to calculate the storage value
308        // hash ahead of time if relevant.
309        let storage_value_hash = if let Some((value, TrieEntryVersion::V1)) = storage_value.as_ref()
310        {
311            if value.as_ref().len() >= 33 {
312                Some(blake2_rfc::blake2b::blake2b(32, &[], value.as_ref()))
313            } else {
314                None
315            }
316        } else {
317            None
318        };
319
320        // Calculate the Merkle value of the node.
321        let merkle_value = trie_node::calculate_merkle_value(
322            trie_node::Decoded {
323                children: array::from_fn(|n| calculated_elem.children[n].as_ref()),
324                partial_key: calculated_elem.partial_key.iter().copied(),
325                storage_value: match (storage_value.as_ref(), storage_value_hash.as_ref()) {
326                    (_, Some(storage_value_hash)) => trie_node::StorageValue::Hashed(
327                        <&[u8; 32]>::try_from(storage_value_hash.as_bytes())
328                            .unwrap_or_else(|_| unreachable!()),
329                    ),
330                    (Some((value, _)), _) => trie_node::StorageValue::Unhashed(value.as_ref()),
331                    (None, _) => trie_node::StorageValue::None,
332                },
333            },
334            self.calculation.hash_function,
335            self.calculation.stack.is_empty(),
336        )
337        .unwrap_or_else(|_| unreachable!());
338
339        // Insert Merkle value into the stack, or, if no parent, we have our result!
340        if let Some(parent) = self.calculation.stack.last_mut() {
341            parent.children.push(Some(merkle_value));
342            self.calculation.next()
343        } else {
344            // Because we pass `is_root_node: true` in the calculation above, it is guaranteed
345            // that the Merkle value is always 32 bytes.
346            let hash = *<&[u8; 32]>::try_from(merkle_value.as_ref()).unwrap();
347            RootMerkleValueCalculation::Finished { hash }
348        }
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use crate::trie::{HashFunction, TrieEntryVersion};
355    use alloc::collections::BTreeMap;
356    use core::ops::Bound;
357
358    fn calculate_root(version: TrieEntryVersion, trie: &BTreeMap<Vec<u8>, Vec<u8>>) -> [u8; 32] {
359        let mut calculation = super::root_merkle_value(HashFunction::Blake2);
360
361        loop {
362            match calculation {
363                super::RootMerkleValueCalculation::Finished { hash } => {
364                    return hash;
365                }
366                super::RootMerkleValueCalculation::NextKey(next_key) => {
367                    let lower_bound = if next_key.or_equal() {
368                        Bound::Included(next_key.key_before().collect::<Vec<_>>())
369                    } else {
370                        Bound::Excluded(next_key.key_before().collect::<Vec<_>>())
371                    };
372
373                    let k = trie
374                        .range((lower_bound, Bound::Unbounded))
375                        .next()
376                        .filter(|(k, _)| {
377                            k.iter()
378                                .copied()
379                                .zip(next_key.prefix())
380                                .all(|(a, b)| a == b)
381                        })
382                        .map(|(k, _)| k);
383
384                    calculation = next_key.inject_key(k.map(|k| k.iter().copied()));
385                }
386                super::RootMerkleValueCalculation::StorageValue(value) => {
387                    let key = value.key().collect::<Vec<u8>>();
388                    calculation = value.inject(trie.get(&key).map(|v| (v, version)));
389                }
390            }
391        }
392    }
393
394    #[test]
395    fn trie_root_one_node() {
396        let mut trie = BTreeMap::new();
397        trie.insert(b"abcd".to_vec(), b"hello world".to_vec());
398
399        let expected = [
400            122, 177, 134, 89, 211, 178, 120, 158, 242, 64, 13, 16, 113, 4, 199, 212, 251, 147,
401            208, 109, 154, 182, 168, 182, 65, 165, 222, 124, 63, 236, 200, 81,
402        ];
403
404        assert_eq!(calculate_root(TrieEntryVersion::V0, &trie), &expected[..]);
405        assert_eq!(calculate_root(TrieEntryVersion::V1, &trie), &expected[..]);
406    }
407
408    #[test]
409    fn trie_root_empty() {
410        let trie = BTreeMap::new();
411        let expected = blake2_rfc::blake2b::blake2b(32, &[], &[0x0]);
412        assert_eq!(
413            calculate_root(TrieEntryVersion::V0, &trie),
414            expected.as_bytes()
415        );
416        assert_eq!(
417            calculate_root(TrieEntryVersion::V1, &trie),
418            expected.as_bytes()
419        );
420    }
421
422    #[test]
423    fn trie_root_single_tuple() {
424        let mut trie = BTreeMap::new();
425        trie.insert([0xaa].to_vec(), [0xbb].to_vec());
426
427        let expected = blake2_rfc::blake2b::blake2b(
428            32,
429            &[],
430            &[
431                0x42,   // leaf 0x40 (2^6) with (+) key of 2 nibbles (0x02)
432                0xaa,   // key data
433                1 << 2, // length of value in bytes as Compact
434                0xbb,   // value data
435            ],
436        );
437
438        assert_eq!(
439            calculate_root(TrieEntryVersion::V0, &trie),
440            expected.as_bytes()
441        );
442        assert_eq!(
443            calculate_root(TrieEntryVersion::V1, &trie),
444            expected.as_bytes()
445        );
446    }
447
448    #[test]
449    fn trie_root_example() {
450        let mut trie = BTreeMap::new();
451        trie.insert([0x48, 0x19].to_vec(), [0xfe].to_vec());
452        trie.insert([0x13, 0x14].to_vec(), [0xff].to_vec());
453
454        let ex = vec![
455            0x80,      // branch, no value (0b_10..) no nibble
456            0x12,      // slots 1 & 4 are taken from 0-7
457            0x00,      // no slots from 8-15
458            0x05 << 2, // first slot: LEAF, 5 bytes long.
459            0x43,      // leaf 0x40 with 3 nibbles
460            0x03,      // first nibble
461            0x14,      // second & third nibble
462            0x01 << 2, // 1 byte data
463            0xff,      // value data
464            0x05 << 2, // second slot: LEAF, 5 bytes long.
465            0x43,      // leaf with 3 nibbles
466            0x08,      // first nibble
467            0x19,      // second & third nibble
468            0x01 << 2, // 1 byte data
469            0xfe,      // value data
470        ];
471
472        let expected = blake2_rfc::blake2b::blake2b(32, &[], &ex);
473        assert_eq!(
474            calculate_root(TrieEntryVersion::V0, &trie),
475            expected.as_bytes()
476        );
477        assert_eq!(
478            calculate_root(TrieEntryVersion::V1, &trie),
479            expected.as_bytes()
480        );
481    }
482}