Skip to main content

ethrex_trie/
trie_sorted.rs

1use crate::{
2    EMPTY_TRIE_HASH, Nibbles, Node, TrieDB, TrieError,
3    node::{BranchNode, ExtensionNode, LeafNode},
4    threadpool::ThreadPool,
5};
6use crossbeam::channel::{Receiver, Sender, bounded};
7use ethereum_types::H256;
8use ethrex_crypto::NativeCrypto;
9use std::{sync::Arc, thread::scope};
10
11/// The elements of the stack represent the branch node that is the parent of the current
12/// parent element. When the current parent is no longer valid (is not the parent of
13/// the current elements), the stack gets popped and this element becomes the parent
14#[derive(Debug, Default, Clone)]
15struct StackElement {
16    path: Nibbles,
17    element: BranchNode,
18}
19
20// The large size isn't a performance problem because we use a single instance of this
21// struct
22#[allow(clippy::large_enum_variant)]
23/// This struct handles the current element that the algorithm is processing. The
24/// current parent is the parent of this element and the next one in the queue.
25/// If that isn't true, we pop the stack and the old parent becomes the new current element
26/// This is an enum because the current element can be a leaf or a branch
27#[derive(Debug, Clone)]
28enum CenterSideElement {
29    Branch { node: BranchNode },
30    Leaf { value: Vec<u8> },
31}
32
33/// The current element and its full path.
34#[derive(Debug, Clone)]
35struct CenterSide {
36    // Full path to the element
37    path: Nibbles,
38    // Element, can be branch or leaf
39    element: CenterSideElement,
40}
41
42/// These errors should never happen on a correctly ordered list, but they can happen if
43/// the iterator used as input has repeated or out of order values
44#[derive(Debug, thiserror::Error)]
45pub enum TrieGenerationError {
46    #[error("When creating a child node, the nibbles diff was empty. Child Node {0:x?}")]
47    IndexNotFound(Nibbles),
48    #[error("When popping from the trie stack it was empty. Current position: {0:x?}")]
49    TrieStackEmpty(Nibbles),
50    #[error(transparent)]
51    FlushToDbError(TrieError),
52    #[error("When joining the write threads, error")]
53    ThreadJoinError(),
54}
55
56/// How many nodes we group before sending to write
57pub const SIZE_TO_WRITE_DB: u64 = 20_000;
58/// How many write buffers we can use at the same time.
59/// This number and SIZE_TO_WRITE_DB limits how much memory we use
60pub const BUFFER_COUNT: u64 = 32;
61
62impl CenterSide {
63    fn from_value(tuple: (H256, Vec<u8>)) -> CenterSide {
64        CenterSide {
65            path: Nibbles::from_raw(&tuple.0.0, true),
66            element: CenterSideElement::Leaf { value: tuple.1 },
67        }
68    }
69    fn from_stack_element(element: StackElement) -> CenterSide {
70        CenterSide {
71            path: element.path,
72            element: CenterSideElement::Branch {
73                node: element.element,
74            },
75        }
76    }
77}
78
79/// Checks if the stack element is a child node of the element at path `this`
80fn is_child(this: &Nibbles, other: &StackElement) -> bool {
81    this.count_prefix(&other.path) == other.path.len()
82}
83
84/// Creates a parent element that can have as children both the parent and the closest nibbles
85/// That parent is created with no children
86fn create_parent(current_node: &CenterSide, closest_nibbles: &Nibbles) -> StackElement {
87    let new_parent_nibbles = current_node
88        .path
89        .slice(0, current_node.path.count_prefix(closest_nibbles));
90    StackElement {
91        path: new_parent_nibbles,
92        element: BranchNode {
93            choices: BranchNode::EMPTY_CHOICES,
94            value: vec![],
95        },
96    }
97}
98
99/// This function modifies a parent element to include the `current_node` element, and
100/// then adds the `current_node` to the write queue.
101/// When adding the current_node to the write queue we create an extension if needed
102fn add_current_to_parent_and_write_queue(
103    nodes_to_write: &mut Vec<(Nibbles, Node)>,
104    current_node: &CenterSide,
105    parent_element: &mut StackElement,
106) -> Result<(), TrieGenerationError> {
107    let mut nodehash_buffer = Vec::with_capacity(512);
108    let mut path = current_node.path.clone();
109    path.skip_prefix(&parent_element.path);
110    let index = path
111        .next()
112        .ok_or_else(|| TrieGenerationError::IndexNotFound(current_node.path.clone()))?;
113    let top_path = parent_element.path.append_new(index);
114    let (target_path, node): (Nibbles, Node) = match &current_node.element {
115        CenterSideElement::Branch { node } => {
116            if path.is_empty() {
117                (top_path, node.clone().into())
118            } else {
119                let hash = node.compute_hash_no_alloc(&mut nodehash_buffer, &NativeCrypto);
120                nodes_to_write.push((current_node.path.clone(), node.clone().into()));
121                (
122                    top_path,
123                    ExtensionNode {
124                        prefix: path,
125                        child: hash.into(),
126                    }
127                    .into(),
128                )
129            }
130        }
131        CenterSideElement::Leaf { value } => (
132            top_path,
133            LeafNode {
134                partial: path,
135                value: value.clone(),
136            }
137            .into(),
138        ),
139    };
140    parent_element.element.choices[index as usize] = node
141        .compute_hash_no_alloc(&mut nodehash_buffer, &NativeCrypto)
142        .into();
143    nodes_to_write.push((target_path, node));
144    Ok(())
145}
146
147/// flush_nodes_to_write writes the nodes into the database, and when it's done it
148/// returns the vector used to write nodes into the channel for future use
149fn flush_nodes_to_write(
150    mut nodes_to_write: Vec<(Nibbles, Node)>,
151    db: &dyn TrieDB,
152    sender: Sender<Vec<(Nibbles, Node)>>,
153) -> Result<(), TrieGenerationError> {
154    db.put_batch_no_alloc(&nodes_to_write)
155        .map_err(TrieGenerationError::FlushToDbError)?;
156    nodes_to_write.clear();
157    let _ = sender.send(nodes_to_write);
158    Ok(())
159}
160
161/// trie_from_sorted_accounts computes and stores into a db a trie from a sorted
162/// iterator of H256 paths and values. This function takes a ThreadPool Arc to send
163/// the writing task to be done concurrently.
164/// To limit the amount of memory this function can use, we use a crossbeam multiproducer
165/// multiconsumer queue, which gives the function a buffer to write nodes into before
166/// flushing to the db.
167pub fn trie_from_sorted_accounts<'scope, T>(
168    db: &'scope dyn TrieDB,
169    data_iter: &mut T,
170    scope: Arc<ThreadPool<'scope>>,
171    buffer_sender: Sender<Vec<(Nibbles, Node)>>,
172    buffer_receiver: Receiver<Vec<(Nibbles, Node)>>,
173) -> Result<H256, TrieGenerationError>
174where
175    T: Iterator<Item = (H256, Vec<u8>)> + Send,
176{
177    let Some(initial_value) = data_iter.next() else {
178        return Ok(*EMPTY_TRIE_HASH);
179    };
180    let mut nodes_to_write: Vec<(Nibbles, Node)> = buffer_receiver
181        .recv()
182        .expect("This channel shouldn't close");
183    // We have a stack of the parents of the current parent
184    let mut trie_stack: Vec<StackElement> = Vec::with_capacity(64); // Optimized for H256
185
186    // This is the current parent of the first element. We assume that the root node
187    // is always a parent, and we fix it afterwards if it's not true
188    // The root is a parent of all nodes
189    let mut nodehash_buffer = Vec::with_capacity(512);
190    let mut current_parent = StackElement::default();
191
192    // The current node that is being used for computing. We compare it with the current
193    // parent and the next value to see where it should be written
194    let mut current_node: CenterSide = CenterSide::from_value(initial_value);
195    let mut next_value_opt: Option<(H256, Vec<u8>)> = data_iter.next();
196
197    while let Some(next_value) = next_value_opt {
198        if nodes_to_write.len() as u64 > SIZE_TO_WRITE_DB {
199            let buffer_sender = buffer_sender.clone();
200            scope.execute_priority(Box::new(move || {
201                let _ = flush_nodes_to_write(nodes_to_write, db, buffer_sender);
202            }));
203            // We wait to get a new buffer to avoid writing too much
204            nodes_to_write = buffer_receiver
205                .recv()
206                .expect("This channel shouldn't close");
207        }
208
209        let next_value_path = Nibbles::from_bytes(next_value.0.as_bytes());
210
211        // If the current parent isn't a parent of the next value, that means
212        // that the current value doesn't have a sibling to the right
213        // As such we write this node and change the current node to the current parent
214        while !is_child(&next_value_path, &current_parent) {
215            add_current_to_parent_and_write_queue(
216                &mut nodes_to_write,
217                &current_node,
218                &mut current_parent,
219            )?;
220            let temp = CenterSide::from_stack_element(current_parent);
221            current_parent = trie_stack
222                .pop()
223                .ok_or_else(|| TrieGenerationError::TrieStackEmpty(current_node.path.clone()))?;
224            current_node = temp;
225        }
226
227        // If the "distance" (same prefix count) between the current and next value is equal to the
228        // parent node, that means that they're both "siblings" of the current parent
229        // Ex: parent=[05] current=[0567] next=[0589]
230        // there is not a branch between the parent and current, so we just write the
231        // current element and change the current with the next value while
232        // advancing the iterator for our next value
233        if current_node.path.count_prefix(&current_parent.path)
234            == current_node.path.count_prefix(&next_value_path)
235        {
236            add_current_to_parent_and_write_queue(
237                &mut nodes_to_write,
238                &current_node,
239                &mut current_parent,
240            )?;
241
242        // If the "distance" between the current and next value is larger than that to
243        // the parent node, that means that there is a closer parent for both of them
244        // Ex: parent=[05] current=[0567] next=[0569]
245        // This means that there is a branch in [056] and current is a child
246        // of that parent
247        // So we create a parent, mark it as current, write the current node to that parent.
248        // The old parent goes into the stack
249        // Then we advance the iterator for our next value
250        } else {
251            let mut element = create_parent(&current_node, &next_value_path);
252            add_current_to_parent_and_write_queue(
253                &mut nodes_to_write,
254                &current_node,
255                &mut element,
256            )?;
257            trie_stack.push(current_parent);
258            current_parent = element;
259        }
260        current_node = CenterSide::from_value(next_value);
261        next_value_opt = data_iter.next();
262    }
263
264    // We empty the stack, where each node is a child of the one in the stack, so we just keep
265    // popping and adding to parent
266    add_current_to_parent_and_write_queue(&mut nodes_to_write, &current_node, &mut current_parent)?;
267    while let Some(mut parent_node) = trie_stack.pop() {
268        add_current_to_parent_and_write_queue(
269            &mut nodes_to_write,
270            &CenterSide::from_stack_element(current_parent),
271            &mut parent_node,
272        )?;
273        current_parent = parent_node;
274    }
275
276    let hash = if current_parent
277        .element
278        .choices
279        .iter()
280        .filter(|choice| choice.is_valid())
281        .count()
282        == 1
283    {
284        let (index, child) = current_parent
285            .element
286            .choices
287            .into_iter()
288            .enumerate()
289            .find(|(_, child)| child.is_valid())
290            .unwrap();
291
292        let (target_path, node_hash_ref) = nodes_to_write.iter_mut().last().unwrap();
293        match node_hash_ref {
294            Node::Branch(_) => {
295                let node: Node = ExtensionNode {
296                    prefix: Nibbles::from_hex(vec![index as u8]),
297                    child,
298                }
299                .into();
300                nodes_to_write.push((Nibbles::default(), node));
301                nodes_to_write
302                    .last()
303                    .expect("we just inserted")
304                    .1
305                    .compute_hash_no_alloc(&mut nodehash_buffer, &NativeCrypto)
306                    .finalize(&NativeCrypto)
307            }
308            Node::Extension(extension_node) => {
309                extension_node.prefix.prepend(index as u8);
310                // This next works because this target path is always length of 1 element,
311                // and we're just removing that one element
312                target_path.next();
313                extension_node
314                    .compute_hash_no_alloc(&mut nodehash_buffer, &NativeCrypto)
315                    .finalize(&NativeCrypto)
316            }
317            Node::Leaf(leaf_node) => {
318                leaf_node.partial.prepend(index as u8);
319                // This next works because this target path is always length of 1 element,
320                // and we're just removing that one element
321                target_path.next();
322                leaf_node
323                    .compute_hash_no_alloc(&mut nodehash_buffer, &NativeCrypto)
324                    .finalize(&NativeCrypto)
325            }
326        }
327    } else {
328        let node: Node = current_parent.element.into();
329        nodes_to_write.push((Nibbles::default(), node));
330        nodes_to_write
331            .last()
332            .expect("we just inserted")
333            .1
334            .compute_hash_no_alloc(&mut nodehash_buffer, &NativeCrypto)
335            .finalize(&NativeCrypto)
336    };
337
338    let _ = flush_nodes_to_write(nodes_to_write, db, buffer_sender);
339    Ok(hash)
340}
341
342/// Wrapper function for `trie_from_sorted_accounts` that handles concurrency
343/// and memory limits
344pub fn trie_from_sorted_accounts_wrap<T>(
345    db: &dyn TrieDB,
346    accounts_iter: &mut T,
347) -> Result<H256, TrieGenerationError>
348where
349    T: Iterator<Item = (H256, Vec<u8>)> + Send,
350{
351    let (buffer_sender, buffer_receiver) = bounded::<Vec<(Nibbles, Node)>>(BUFFER_COUNT as usize);
352    for _ in 0..BUFFER_COUNT {
353        let _ = buffer_sender.send(Vec::with_capacity(SIZE_TO_WRITE_DB as usize));
354    }
355    scope(|s| {
356        let pool = ThreadPool::new(12, s);
357        trie_from_sorted_accounts(
358            db,
359            accounts_iter,
360            Arc::new(pool),
361            buffer_sender,
362            buffer_receiver,
363        )
364    })
365}
366
367#[cfg(test)]
368mod test {
369    use ethereum_types::U256;
370    use ethrex_rlp::encode::RLPEncode;
371
372    use crate::{InMemoryTrieDB, Trie};
373
374    use super::*;
375    use std::{collections::BTreeMap, str::FromStr, sync::Mutex};
376
377    fn generate_input_1() -> BTreeMap<H256, Vec<u8>> {
378        let mut accounts: BTreeMap<H256, Vec<u8>> = BTreeMap::new();
379        for string in [
380            "68521f7430502aef983fd7568ea179ed0f8d12d5b68883c90573781ae0778ec2",
381            "68db10f720d5972738df0d841d64c7117439a1a2ca9ba247e7239b19eb187414",
382            "6b7c1458952b903dbe3717bc7579f18e5cb1136be1b11b113cdac0f0791c07d3",
383        ] {
384            accounts.insert(H256::from_str(string).unwrap(), vec![0, 1]);
385        }
386        accounts
387    }
388
389    fn generate_input_2() -> BTreeMap<H256, Vec<u8>> {
390        let mut accounts: BTreeMap<H256, Vec<u8>> = BTreeMap::new();
391        for string in [
392            "0532f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9",
393            "14d5df819167b77851220ee266178aee165daada67ca865e9d50faed6b4fdbe3",
394            "6908aa86b715fcf221f208a28bb84bf6359ba9c41da04b7e17a925cdb22bf704",
395            "90bbe47533cd80b5d9cef6c283415edd90296bf4ac4ede6d2a6b42bb3d5e7d0e",
396            "90c2fdad333366cf0f18f0dded9b478590c0563e4c847c79aee0b733b5a9104f",
397            "af9e3efce873619102dfdb0504abd44179191bccfb624608961e71492a1ba5b7",
398            "b723d5841dc4d6d3fe7de03ad74dd83798c3b68f752bba29c906ec7f5a469452",
399            "c2c6fd64de59489f0c27e75443c24327cef6415f1d3ee1659646abefab212113",
400            "ca0d791e7a3e0f25d775034acecbaaf9219939288e6282d8291e181b9c3c24b0",
401            "f0dcaaa40dfc67925d6e172e48b8f83954ba46cfb1bb522c809f3b93b49205ee",
402        ] {
403            accounts.insert(H256::from_str(string).unwrap(), vec![0, 1]);
404        }
405        accounts
406    }
407
408    fn generate_input_3() -> BTreeMap<H256, Vec<u8>> {
409        let mut accounts: BTreeMap<H256, Vec<u8>> = BTreeMap::new();
410        for string in [
411            "0532f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9",
412            "0542f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9",
413            "0552f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9",
414        ] {
415            accounts.insert(H256::from_str(string).unwrap(), vec![0, 1]);
416        }
417        accounts
418    }
419
420    fn generate_input_4() -> BTreeMap<H256, Vec<u8>> {
421        let mut accounts: BTreeMap<H256, Vec<u8>> = BTreeMap::new();
422        let string = "0532f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9";
423        accounts.insert(H256::from_str(string).unwrap(), vec![0, 1]);
424        accounts
425    }
426
427    fn generate_input_5() -> BTreeMap<H256, Vec<u8>> {
428        let mut accounts: BTreeMap<H256, Vec<u8>> = BTreeMap::new();
429        for (string, value) in [
430            (
431                "290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563",
432                U256::from_str("1191240792495687806002885977912460542139236513636").unwrap(),
433            ),
434            (
435                "295841a49a1089f4b560f91cfbb0133326654dcbb1041861fc5dde96c724a22f",
436                U256::from(480),
437            ),
438        ] {
439            accounts.insert(H256::from_str(string).unwrap(), value.encode_to_vec());
440        }
441        accounts
442    }
443
444    fn generate_input_slots_1() -> BTreeMap<H256, U256> {
445        let mut slots: BTreeMap<H256, U256> = BTreeMap::new();
446        for string in [
447            "0532f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e8",
448            "0532f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9",
449            "0552f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9",
450        ] {
451            slots.insert(H256::from_str(string).unwrap(), U256::zero());
452        }
453        slots
454    }
455
456    pub fn run_test_account_state(accounts: BTreeMap<H256, Vec<u8>>) {
457        let computed_data = Arc::new(Mutex::new(BTreeMap::new()));
458        let trie = Trie::new(Box::new(InMemoryTrieDB::new(computed_data.clone())));
459        let db = trie.db();
460        let tested_trie_hash: H256 = trie_from_sorted_accounts_wrap(
461            db,
462            &mut accounts
463                .clone()
464                .into_iter()
465                .map(|(hash, state)| (hash, state.encode_to_vec())),
466        )
467        .expect("Shouldn't have errors");
468
469        let expected_data = Arc::new(Mutex::new(BTreeMap::new()));
470        let mut trie = Trie::new(Box::new(InMemoryTrieDB::new(expected_data.clone())));
471        for account in accounts.iter() {
472            trie.insert(account.0.as_bytes().to_vec(), account.1.encode_to_vec())
473                .unwrap();
474        }
475
476        assert_eq!(tested_trie_hash, trie.hash(&NativeCrypto).unwrap());
477
478        let computed_data = computed_data.lock().unwrap();
479        let expected_data = expected_data.lock().unwrap();
480        for (k, v) in expected_data.iter() {
481            // skip flatkeyvalues, we don't want them
482            if k.last().cloned() == Some(16) {
483                continue;
484            }
485            assert!(computed_data.contains_key(k));
486            assert_eq!(*v, computed_data[k]);
487        }
488    }
489
490    pub fn run_test_storage_slots(slots: BTreeMap<H256, U256>) {
491        let trie = Trie::stateless();
492        let db = trie.db();
493        let tested_trie_hash: H256 = trie_from_sorted_accounts_wrap(
494            db,
495            &mut slots
496                .clone()
497                .into_iter()
498                .map(|(hash, state)| (hash, state.encode_to_vec())),
499        )
500        .expect("Shouldn't have errors");
501
502        let mut trie: Trie = Trie::empty_in_memory();
503        for account in slots.iter() {
504            trie.insert(account.0.as_bytes().to_vec(), account.1.encode_to_vec())
505                .unwrap();
506        }
507
508        let trie_hash = trie.hash_no_commit(&NativeCrypto);
509
510        assert!(tested_trie_hash == trie_hash)
511    }
512
513    #[test]
514    fn test_1() {
515        run_test_account_state(generate_input_1());
516    }
517
518    #[test]
519    fn test_2() {
520        run_test_account_state(generate_input_2());
521    }
522
523    #[test]
524    fn test_3() {
525        run_test_account_state(generate_input_3());
526    }
527
528    #[test]
529    fn test_4() {
530        run_test_account_state(generate_input_4());
531    }
532
533    #[test]
534    fn test_5() {
535        run_test_account_state(generate_input_5());
536    }
537
538    #[test]
539    fn test_slots_1() {
540        run_test_storage_slots(generate_input_slots_1());
541    }
542}