1use crate::{
2 EMPTY_TRIE_HASH, Nibbles, Node, TrieDB, TrieError,
3 node::{BranchNode, ExtensionNode, LeafNode},
4 threadpool::ThreadPool,
5};
6use crossbeam::channel::{Receiver, Sender, bounded};
7use ethereum_types::H256;
8use ethrex_crypto::NativeCrypto;
9use std::{sync::Arc, thread::scope};
10
11#[derive(Debug, Default, Clone)]
15struct StackElement {
16 path: Nibbles,
17 element: BranchNode,
18}
19
20#[allow(clippy::large_enum_variant)]
23#[derive(Debug, Clone)]
28enum CenterSideElement {
29 Branch { node: BranchNode },
30 Leaf { value: Vec<u8> },
31}
32
33#[derive(Debug, Clone)]
35struct CenterSide {
36 path: Nibbles,
38 element: CenterSideElement,
40}
41
42#[derive(Debug, thiserror::Error)]
45pub enum TrieGenerationError {
46 #[error("When creating a child node, the nibbles diff was empty. Child Node {0:x?}")]
47 IndexNotFound(Nibbles),
48 #[error("When popping from the trie stack it was empty. Current position: {0:x?}")]
49 TrieStackEmpty(Nibbles),
50 #[error(transparent)]
51 FlushToDbError(TrieError),
52 #[error("When joining the write threads, error")]
53 ThreadJoinError(),
54}
55
56pub const SIZE_TO_WRITE_DB: u64 = 20_000;
58pub const BUFFER_COUNT: u64 = 32;
61
62impl CenterSide {
63 fn from_value(tuple: (H256, Vec<u8>)) -> CenterSide {
64 CenterSide {
65 path: Nibbles::from_raw(&tuple.0.0, true),
66 element: CenterSideElement::Leaf { value: tuple.1 },
67 }
68 }
69 fn from_stack_element(element: StackElement) -> CenterSide {
70 CenterSide {
71 path: element.path,
72 element: CenterSideElement::Branch {
73 node: element.element,
74 },
75 }
76 }
77}
78
79fn is_child(this: &Nibbles, other: &StackElement) -> bool {
81 this.count_prefix(&other.path) == other.path.len()
82}
83
84fn create_parent(current_node: &CenterSide, closest_nibbles: &Nibbles) -> StackElement {
87 let new_parent_nibbles = current_node
88 .path
89 .slice(0, current_node.path.count_prefix(closest_nibbles));
90 StackElement {
91 path: new_parent_nibbles,
92 element: BranchNode {
93 choices: BranchNode::EMPTY_CHOICES,
94 value: vec![],
95 },
96 }
97}
98
99fn add_current_to_parent_and_write_queue(
103 nodes_to_write: &mut Vec<(Nibbles, Node)>,
104 current_node: &CenterSide,
105 parent_element: &mut StackElement,
106) -> Result<(), TrieGenerationError> {
107 let mut nodehash_buffer = Vec::with_capacity(512);
108 let mut path = current_node.path.clone();
109 path.skip_prefix(&parent_element.path);
110 let index = path
111 .next()
112 .ok_or_else(|| TrieGenerationError::IndexNotFound(current_node.path.clone()))?;
113 let top_path = parent_element.path.append_new(index);
114 let (target_path, node): (Nibbles, Node) = match ¤t_node.element {
115 CenterSideElement::Branch { node } => {
116 if path.is_empty() {
117 (top_path, node.clone().into())
118 } else {
119 let hash = node.compute_hash_no_alloc(&mut nodehash_buffer, &NativeCrypto);
120 nodes_to_write.push((current_node.path.clone(), node.clone().into()));
121 (
122 top_path,
123 ExtensionNode {
124 prefix: path,
125 child: hash.into(),
126 }
127 .into(),
128 )
129 }
130 }
131 CenterSideElement::Leaf { value } => (
132 top_path,
133 LeafNode {
134 partial: path,
135 value: value.clone(),
136 }
137 .into(),
138 ),
139 };
140 parent_element.element.choices[index as usize] = node
141 .compute_hash_no_alloc(&mut nodehash_buffer, &NativeCrypto)
142 .into();
143 nodes_to_write.push((target_path, node));
144 Ok(())
145}
146
147fn flush_nodes_to_write(
150 mut nodes_to_write: Vec<(Nibbles, Node)>,
151 db: &dyn TrieDB,
152 sender: Sender<Vec<(Nibbles, Node)>>,
153) -> Result<(), TrieGenerationError> {
154 db.put_batch_no_alloc(&nodes_to_write)
155 .map_err(TrieGenerationError::FlushToDbError)?;
156 nodes_to_write.clear();
157 let _ = sender.send(nodes_to_write);
158 Ok(())
159}
160
161pub fn trie_from_sorted_accounts<'scope, T>(
168 db: &'scope dyn TrieDB,
169 data_iter: &mut T,
170 scope: Arc<ThreadPool<'scope>>,
171 buffer_sender: Sender<Vec<(Nibbles, Node)>>,
172 buffer_receiver: Receiver<Vec<(Nibbles, Node)>>,
173) -> Result<H256, TrieGenerationError>
174where
175 T: Iterator<Item = (H256, Vec<u8>)> + Send,
176{
177 let Some(initial_value) = data_iter.next() else {
178 return Ok(*EMPTY_TRIE_HASH);
179 };
180 let mut nodes_to_write: Vec<(Nibbles, Node)> = buffer_receiver
181 .recv()
182 .expect("This channel shouldn't close");
183 let mut trie_stack: Vec<StackElement> = Vec::with_capacity(64); let mut nodehash_buffer = Vec::with_capacity(512);
190 let mut current_parent = StackElement::default();
191
192 let mut current_node: CenterSide = CenterSide::from_value(initial_value);
195 let mut next_value_opt: Option<(H256, Vec<u8>)> = data_iter.next();
196
197 while let Some(next_value) = next_value_opt {
198 if nodes_to_write.len() as u64 > SIZE_TO_WRITE_DB {
199 let buffer_sender = buffer_sender.clone();
200 scope.execute_priority(Box::new(move || {
201 let _ = flush_nodes_to_write(nodes_to_write, db, buffer_sender);
202 }));
203 nodes_to_write = buffer_receiver
205 .recv()
206 .expect("This channel shouldn't close");
207 }
208
209 let next_value_path = Nibbles::from_bytes(next_value.0.as_bytes());
210
211 while !is_child(&next_value_path, ¤t_parent) {
215 add_current_to_parent_and_write_queue(
216 &mut nodes_to_write,
217 ¤t_node,
218 &mut current_parent,
219 )?;
220 let temp = CenterSide::from_stack_element(current_parent);
221 current_parent = trie_stack
222 .pop()
223 .ok_or_else(|| TrieGenerationError::TrieStackEmpty(current_node.path.clone()))?;
224 current_node = temp;
225 }
226
227 if current_node.path.count_prefix(¤t_parent.path)
234 == current_node.path.count_prefix(&next_value_path)
235 {
236 add_current_to_parent_and_write_queue(
237 &mut nodes_to_write,
238 ¤t_node,
239 &mut current_parent,
240 )?;
241
242 } else {
251 let mut element = create_parent(¤t_node, &next_value_path);
252 add_current_to_parent_and_write_queue(
253 &mut nodes_to_write,
254 ¤t_node,
255 &mut element,
256 )?;
257 trie_stack.push(current_parent);
258 current_parent = element;
259 }
260 current_node = CenterSide::from_value(next_value);
261 next_value_opt = data_iter.next();
262 }
263
264 add_current_to_parent_and_write_queue(&mut nodes_to_write, ¤t_node, &mut current_parent)?;
267 while let Some(mut parent_node) = trie_stack.pop() {
268 add_current_to_parent_and_write_queue(
269 &mut nodes_to_write,
270 &CenterSide::from_stack_element(current_parent),
271 &mut parent_node,
272 )?;
273 current_parent = parent_node;
274 }
275
276 let hash = if current_parent
277 .element
278 .choices
279 .iter()
280 .filter(|choice| choice.is_valid())
281 .count()
282 == 1
283 {
284 let (index, child) = current_parent
285 .element
286 .choices
287 .into_iter()
288 .enumerate()
289 .find(|(_, child)| child.is_valid())
290 .unwrap();
291
292 let (target_path, node_hash_ref) = nodes_to_write.iter_mut().last().unwrap();
293 match node_hash_ref {
294 Node::Branch(_) => {
295 let node: Node = ExtensionNode {
296 prefix: Nibbles::from_hex(vec![index as u8]),
297 child,
298 }
299 .into();
300 nodes_to_write.push((Nibbles::default(), node));
301 nodes_to_write
302 .last()
303 .expect("we just inserted")
304 .1
305 .compute_hash_no_alloc(&mut nodehash_buffer, &NativeCrypto)
306 .finalize(&NativeCrypto)
307 }
308 Node::Extension(extension_node) => {
309 extension_node.prefix.prepend(index as u8);
310 target_path.next();
313 extension_node
314 .compute_hash_no_alloc(&mut nodehash_buffer, &NativeCrypto)
315 .finalize(&NativeCrypto)
316 }
317 Node::Leaf(leaf_node) => {
318 leaf_node.partial.prepend(index as u8);
319 target_path.next();
322 leaf_node
323 .compute_hash_no_alloc(&mut nodehash_buffer, &NativeCrypto)
324 .finalize(&NativeCrypto)
325 }
326 }
327 } else {
328 let node: Node = current_parent.element.into();
329 nodes_to_write.push((Nibbles::default(), node));
330 nodes_to_write
331 .last()
332 .expect("we just inserted")
333 .1
334 .compute_hash_no_alloc(&mut nodehash_buffer, &NativeCrypto)
335 .finalize(&NativeCrypto)
336 };
337
338 let _ = flush_nodes_to_write(nodes_to_write, db, buffer_sender);
339 Ok(hash)
340}
341
342pub fn trie_from_sorted_accounts_wrap<T>(
345 db: &dyn TrieDB,
346 accounts_iter: &mut T,
347) -> Result<H256, TrieGenerationError>
348where
349 T: Iterator<Item = (H256, Vec<u8>)> + Send,
350{
351 let (buffer_sender, buffer_receiver) = bounded::<Vec<(Nibbles, Node)>>(BUFFER_COUNT as usize);
352 for _ in 0..BUFFER_COUNT {
353 let _ = buffer_sender.send(Vec::with_capacity(SIZE_TO_WRITE_DB as usize));
354 }
355 scope(|s| {
356 let pool = ThreadPool::new(12, s);
357 trie_from_sorted_accounts(
358 db,
359 accounts_iter,
360 Arc::new(pool),
361 buffer_sender,
362 buffer_receiver,
363 )
364 })
365}
366
367#[cfg(test)]
368mod test {
369 use ethereum_types::U256;
370 use ethrex_rlp::encode::RLPEncode;
371
372 use crate::{InMemoryTrieDB, Trie};
373
374 use super::*;
375 use std::{collections::BTreeMap, str::FromStr, sync::Mutex};
376
377 fn generate_input_1() -> BTreeMap<H256, Vec<u8>> {
378 let mut accounts: BTreeMap<H256, Vec<u8>> = BTreeMap::new();
379 for string in [
380 "68521f7430502aef983fd7568ea179ed0f8d12d5b68883c90573781ae0778ec2",
381 "68db10f720d5972738df0d841d64c7117439a1a2ca9ba247e7239b19eb187414",
382 "6b7c1458952b903dbe3717bc7579f18e5cb1136be1b11b113cdac0f0791c07d3",
383 ] {
384 accounts.insert(H256::from_str(string).unwrap(), vec![0, 1]);
385 }
386 accounts
387 }
388
389 fn generate_input_2() -> BTreeMap<H256, Vec<u8>> {
390 let mut accounts: BTreeMap<H256, Vec<u8>> = BTreeMap::new();
391 for string in [
392 "0532f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9",
393 "14d5df819167b77851220ee266178aee165daada67ca865e9d50faed6b4fdbe3",
394 "6908aa86b715fcf221f208a28bb84bf6359ba9c41da04b7e17a925cdb22bf704",
395 "90bbe47533cd80b5d9cef6c283415edd90296bf4ac4ede6d2a6b42bb3d5e7d0e",
396 "90c2fdad333366cf0f18f0dded9b478590c0563e4c847c79aee0b733b5a9104f",
397 "af9e3efce873619102dfdb0504abd44179191bccfb624608961e71492a1ba5b7",
398 "b723d5841dc4d6d3fe7de03ad74dd83798c3b68f752bba29c906ec7f5a469452",
399 "c2c6fd64de59489f0c27e75443c24327cef6415f1d3ee1659646abefab212113",
400 "ca0d791e7a3e0f25d775034acecbaaf9219939288e6282d8291e181b9c3c24b0",
401 "f0dcaaa40dfc67925d6e172e48b8f83954ba46cfb1bb522c809f3b93b49205ee",
402 ] {
403 accounts.insert(H256::from_str(string).unwrap(), vec![0, 1]);
404 }
405 accounts
406 }
407
408 fn generate_input_3() -> BTreeMap<H256, Vec<u8>> {
409 let mut accounts: BTreeMap<H256, Vec<u8>> = BTreeMap::new();
410 for string in [
411 "0532f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9",
412 "0542f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9",
413 "0552f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9",
414 ] {
415 accounts.insert(H256::from_str(string).unwrap(), vec![0, 1]);
416 }
417 accounts
418 }
419
420 fn generate_input_4() -> BTreeMap<H256, Vec<u8>> {
421 let mut accounts: BTreeMap<H256, Vec<u8>> = BTreeMap::new();
422 let string = "0532f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9";
423 accounts.insert(H256::from_str(string).unwrap(), vec![0, 1]);
424 accounts
425 }
426
427 fn generate_input_5() -> BTreeMap<H256, Vec<u8>> {
428 let mut accounts: BTreeMap<H256, Vec<u8>> = BTreeMap::new();
429 for (string, value) in [
430 (
431 "290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563",
432 U256::from_str("1191240792495687806002885977912460542139236513636").unwrap(),
433 ),
434 (
435 "295841a49a1089f4b560f91cfbb0133326654dcbb1041861fc5dde96c724a22f",
436 U256::from(480),
437 ),
438 ] {
439 accounts.insert(H256::from_str(string).unwrap(), value.encode_to_vec());
440 }
441 accounts
442 }
443
444 fn generate_input_slots_1() -> BTreeMap<H256, U256> {
445 let mut slots: BTreeMap<H256, U256> = BTreeMap::new();
446 for string in [
447 "0532f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e8",
448 "0532f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9",
449 "0552f23d3bd5277790ece5a6cb6fc684bc473a91ffe3a0334049527c4f6987e9",
450 ] {
451 slots.insert(H256::from_str(string).unwrap(), U256::zero());
452 }
453 slots
454 }
455
456 pub fn run_test_account_state(accounts: BTreeMap<H256, Vec<u8>>) {
457 let computed_data = Arc::new(Mutex::new(BTreeMap::new()));
458 let trie = Trie::new(Box::new(InMemoryTrieDB::new(computed_data.clone())));
459 let db = trie.db();
460 let tested_trie_hash: H256 = trie_from_sorted_accounts_wrap(
461 db,
462 &mut accounts
463 .clone()
464 .into_iter()
465 .map(|(hash, state)| (hash, state.encode_to_vec())),
466 )
467 .expect("Shouldn't have errors");
468
469 let expected_data = Arc::new(Mutex::new(BTreeMap::new()));
470 let mut trie = Trie::new(Box::new(InMemoryTrieDB::new(expected_data.clone())));
471 for account in accounts.iter() {
472 trie.insert(account.0.as_bytes().to_vec(), account.1.encode_to_vec())
473 .unwrap();
474 }
475
476 assert_eq!(tested_trie_hash, trie.hash(&NativeCrypto).unwrap());
477
478 let computed_data = computed_data.lock().unwrap();
479 let expected_data = expected_data.lock().unwrap();
480 for (k, v) in expected_data.iter() {
481 if k.last().cloned() == Some(16) {
483 continue;
484 }
485 assert!(computed_data.contains_key(k));
486 assert_eq!(*v, computed_data[k]);
487 }
488 }
489
490 pub fn run_test_storage_slots(slots: BTreeMap<H256, U256>) {
491 let trie = Trie::stateless();
492 let db = trie.db();
493 let tested_trie_hash: H256 = trie_from_sorted_accounts_wrap(
494 db,
495 &mut slots
496 .clone()
497 .into_iter()
498 .map(|(hash, state)| (hash, state.encode_to_vec())),
499 )
500 .expect("Shouldn't have errors");
501
502 let mut trie: Trie = Trie::empty_in_memory();
503 for account in slots.iter() {
504 trie.insert(account.0.as_bytes().to_vec(), account.1.encode_to_vec())
505 .unwrap();
506 }
507
508 let trie_hash = trie.hash_no_commit(&NativeCrypto);
509
510 assert!(tested_trie_hash == trie_hash)
511 }
512
513 #[test]
514 fn test_1() {
515 run_test_account_state(generate_input_1());
516 }
517
518 #[test]
519 fn test_2() {
520 run_test_account_state(generate_input_2());
521 }
522
523 #[test]
524 fn test_3() {
525 run_test_account_state(generate_input_3());
526 }
527
528 #[test]
529 fn test_4() {
530 run_test_account_state(generate_input_4());
531 }
532
533 #[test]
534 fn test_5() {
535 run_test_account_state(generate_input_5());
536 }
537
538 #[test]
539 fn test_slots_1() {
540 run_test_storage_slots(generate_input_slots_1());
541 }
542}