1use std::sync::Arc;
9
10use ethereum_types::H256;
11use thiserror::Error;
12
13use crate::{
14 nibbles::Nibbles,
15 partial_trie::{Node, PartialTrie, WrappedNode},
16 utils::TrieNodeType,
17};
18
19pub type SubsetTrieResult<T> = Result<T, SubsetTrieError>;
20
21#[derive(Debug, Error)]
23pub enum SubsetTrieError {
24 #[error("Tried to mark nodes in a tracked trie for a key that does not exist! (Key: {0}, trie: {1})")]
25 UnexpectedKey(Nibbles, String),
26}
27
28#[derive(Debug)]
29enum TrackedNodeIntern<N: PartialTrie> {
30 Empty,
31 Hash,
32 Branch(Box<[TrackedNode<N>; 16]>),
33 Extension(Box<TrackedNode<N>>),
34 Leaf,
35}
36
37#[derive(Debug)]
38struct TrackedNode<N: PartialTrie> {
39 node: TrackedNodeIntern<N>,
40 info: TrackedNodeInfo<N>,
41}
42
43impl<N: Clone + PartialTrie> TrackedNode<N> {
44 fn new(underlying_node: &N) -> Self {
45 Self {
46 node: match &**underlying_node {
47 Node::Empty => TrackedNodeIntern::Empty,
48 Node::Hash(_) => TrackedNodeIntern::Hash,
49 Node::Branch { ref children, .. } => {
50 TrackedNodeIntern::Branch(Box::new(tracked_branch(children)))
51 }
52 Node::Extension { child, .. } => {
53 TrackedNodeIntern::Extension(Box::new(TrackedNode::new(child)))
54 }
55 Node::Leaf { .. } => TrackedNodeIntern::Leaf,
56 },
57 info: TrackedNodeInfo::new(underlying_node.clone()),
58 }
59 }
60}
61
62fn tracked_branch<N: PartialTrie>(
63 underlying_children: &[WrappedNode<N>; 16],
64) -> [TrackedNode<N>; 16] {
65 [
66 TrackedNode::new(&underlying_children[0]),
67 TrackedNode::new(&underlying_children[1]),
68 TrackedNode::new(&underlying_children[2]),
69 TrackedNode::new(&underlying_children[3]),
70 TrackedNode::new(&underlying_children[4]),
71 TrackedNode::new(&underlying_children[5]),
72 TrackedNode::new(&underlying_children[6]),
73 TrackedNode::new(&underlying_children[7]),
74 TrackedNode::new(&underlying_children[8]),
75 TrackedNode::new(&underlying_children[9]),
76 TrackedNode::new(&underlying_children[10]),
77 TrackedNode::new(&underlying_children[11]),
78 TrackedNode::new(&underlying_children[12]),
79 TrackedNode::new(&underlying_children[13]),
80 TrackedNode::new(&underlying_children[14]),
81 TrackedNode::new(&underlying_children[15]),
82 ]
83}
84
85fn partial_trie_extension<N: PartialTrie>(nibbles: Nibbles, child: &TrackedNode<N>) -> N {
86 N::new(Node::Extension {
87 nibbles,
88 child: Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
89 child,
90 ))),
91 })
92}
93
94fn partial_trie_branch<N: PartialTrie>(
95 underlying_children: &[TrackedNode<N>; 16],
96 value: &[u8],
97) -> N {
98 let children = [
99 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
100 &underlying_children[0],
101 ))),
102 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
103 &underlying_children[1],
104 ))),
105 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
106 &underlying_children[2],
107 ))),
108 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
109 &underlying_children[3],
110 ))),
111 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
112 &underlying_children[4],
113 ))),
114 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
115 &underlying_children[5],
116 ))),
117 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
118 &underlying_children[6],
119 ))),
120 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
121 &underlying_children[7],
122 ))),
123 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
124 &underlying_children[8],
125 ))),
126 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
127 &underlying_children[9],
128 ))),
129 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
130 &underlying_children[10],
131 ))),
132 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
133 &underlying_children[11],
134 ))),
135 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
136 &underlying_children[12],
137 ))),
138 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
139 &underlying_children[13],
140 ))),
141 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
142 &underlying_children[14],
143 ))),
144 Arc::new(Box::new(create_partial_trie_subset_from_tracked_trie(
145 &underlying_children[15],
146 ))),
147 ];
148
149 N::new(Node::Branch {
150 children,
151 value: value.to_owned(),
152 })
153}
154
155#[derive(Debug)]
156struct TrackedNodeInfo<N: PartialTrie> {
157 underlying_node: N,
158 touched: bool,
159}
160
161impl<N: PartialTrie> TrackedNodeInfo<N> {
162 fn new(underlying_node: N) -> Self {
163 Self {
164 underlying_node,
165 touched: false,
166 }
167 }
168
169 fn reset(&mut self) {
170 self.touched = false;
171 }
172
173 fn get_nibbles_expected(&self) -> &Nibbles {
174 match &*self.underlying_node {
175 Node::Extension { nibbles, .. } => nibbles,
176 Node::Leaf { nibbles, .. } => nibbles,
177 _ => unreachable!(
178 "Tried getting the nibbles field from a {} node!",
179 TrieNodeType::from(&*self.underlying_node)
180 ),
181 }
182 }
183
184 fn get_hash_node_hash_expected(&self) -> H256 {
185 match *self.underlying_node {
186 Node::Hash(h) => h,
187 _ => unreachable!("Expected an underlying hash node!"),
188 }
189 }
190
191 fn get_branch_value_expected(&self) -> &Vec<u8> {
192 match &*self.underlying_node {
193 Node::Branch { value, .. } => value,
194 _ => unreachable!("Expected an underlying branch node!"),
195 }
196 }
197
198 fn get_leaf_nibbles_and_value_expected(&self) -> (&Nibbles, &Vec<u8>) {
199 match &*self.underlying_node {
200 Node::Leaf { nibbles, value } => (nibbles, value),
201 _ => unreachable!("Expected an underlying leaf node!"),
202 }
203 }
204}
205
206pub fn create_trie_subset<N, K, I>(trie: &N, keys_involved: I) -> SubsetTrieResult<N>
212where
213 N: PartialTrie,
214 K: Into<Nibbles>,
215 I: IntoIterator<Item = K>,
216{
217 let mut tracked_trie = TrackedNode::new(trie);
218 create_trie_subset_intern(&mut tracked_trie, keys_involved.into_iter())
219}
220
221pub fn create_trie_subsets<N, K, I, O>(base_trie: &N, keys_involved: O) -> SubsetTrieResult<Vec<N>>
225where
226 N: PartialTrie,
227 K: Into<Nibbles>,
228 I: IntoIterator<Item = K>,
229 O: IntoIterator<Item = I>,
230{
231 let mut tracked_trie = TrackedNode::new(base_trie);
232
233 keys_involved
234 .into_iter()
235 .map(|ks| {
236 let res = create_trie_subset_intern(&mut tracked_trie, ks.into_iter())?;
237 reset_tracked_trie_state(&mut tracked_trie);
238
239 Ok(res)
240 })
241 .collect::<SubsetTrieResult<_>>()
242}
243
244fn create_trie_subset_intern<N, K>(
245 tracked_trie: &mut TrackedNode<N>,
246 keys_involved: impl Iterator<Item = K>,
247) -> SubsetTrieResult<N>
248where
249 N: PartialTrie,
250 K: Into<Nibbles>,
251{
252 for k in keys_involved {
253 mark_nodes_that_are_needed(tracked_trie, &mut k.into())?;
254 }
255
256 Ok(create_partial_trie_subset_from_tracked_trie(tracked_trie))
257}
258
259fn mark_nodes_that_are_needed<N: PartialTrie>(
260 trie: &mut TrackedNode<N>,
261 curr_nibbles: &mut Nibbles,
262) -> SubsetTrieResult<()> {
263 trie.info.touched = true;
264
265 match &mut trie.node {
266 TrackedNodeIntern::Empty => Ok(()),
267 TrackedNodeIntern::Hash => match curr_nibbles.is_empty() {
268 false => Err(SubsetTrieError::UnexpectedKey(
269 *curr_nibbles,
270 format!("{:?}", trie),
271 )),
272 true => Ok(()),
273 },
274 TrackedNodeIntern::Branch(children) => {
276 if curr_nibbles.is_empty() {
278 return Ok(());
279 }
280
281 let nib = curr_nibbles.pop_next_nibble_front();
282 mark_nodes_that_are_needed(&mut children[nib as usize], curr_nibbles)
283 }
284 TrackedNodeIntern::Extension(child) => {
285 let nibbles = trie.info.get_nibbles_expected();
286 let r = curr_nibbles.pop_nibbles_front(nibbles.count);
287
288 match r.nibbles_are_identical_up_to_smallest_count(nibbles) {
289 false => Ok(()),
290 true => mark_nodes_that_are_needed(child, curr_nibbles),
291 }
292 }
293 TrackedNodeIntern::Leaf => Ok(()),
294 }
295}
296
297fn create_partial_trie_subset_from_tracked_trie<N: PartialTrie>(
298 tracked_node: &TrackedNode<N>,
299) -> N {
300 match tracked_node.info.touched {
301 false => N::new(Node::Hash(tracked_node.info.underlying_node.hash())),
302 true => match &tracked_node.node {
303 TrackedNodeIntern::Empty => N::new(Node::Empty),
304 TrackedNodeIntern::Hash => {
305 N::new(Node::Hash(tracked_node.info.get_hash_node_hash_expected()))
306 }
307 TrackedNodeIntern::Branch(children) => {
308 partial_trie_branch(children, tracked_node.info.get_branch_value_expected())
309 }
310 TrackedNodeIntern::Extension(child) => {
311 partial_trie_extension(*tracked_node.info.get_nibbles_expected(), child)
312 }
313 TrackedNodeIntern::Leaf => {
314 let (nibbles, value) = tracked_node.info.get_leaf_nibbles_and_value_expected();
315 N::new(Node::Leaf {
316 nibbles: *nibbles,
317 value: value.clone(),
318 })
319 }
320 },
321 }
322}
323
324fn reset_tracked_trie_state<N: PartialTrie>(tracked_node: &mut TrackedNode<N>) {
325 match tracked_node.node {
326 TrackedNodeIntern::Branch(ref mut children) => {
327 children.iter_mut().for_each(|c| c.info.reset())
328 }
329 TrackedNodeIntern::Extension(ref mut child) => child.info.reset(),
330 TrackedNodeIntern::Empty | TrackedNodeIntern::Hash | TrackedNodeIntern::Leaf => {
331 tracked_node.info.reset()
332 }
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use std::{collections::HashSet, iter::once};
339
340 use ethereum_types::H256;
341
342 use super::{create_trie_subset, create_trie_subsets};
343 use crate::{
344 nibbles::Nibbles,
345 partial_trie::{HashedPartialTrie, Node, PartialTrie},
346 testing_utils::generate_n_random_fixed_trie_entries,
347 trie_ops::ValOrHash,
348 utils::TrieNodeType,
349 };
350
351 type TrieType = HashedPartialTrie;
352
353 const MASSIVE_TEST_NUM_SUB_TRIES: usize = 10;
354 const MASSIVE_TEST_NUM_SUB_TRIE_SIZE: usize = 5000;
355
356 #[derive(Debug, Eq, PartialEq)]
357 struct NodeFullNibbles {
358 n_type: TrieNodeType,
359 nibbles: Nibbles,
360 }
361
362 impl NodeFullNibbles {
363 fn new_from_node<N: PartialTrie>(node: &Node<N>, nibbles: Nibbles) -> Self {
364 Self {
365 n_type: node.into(),
366 nibbles,
367 }
368 }
369
370 fn new_from_node_type<K: Into<Nibbles>>(n_type: TrieNodeType, nibbles: K) -> Self {
371 Self {
372 n_type,
373 nibbles: nibbles.into(),
374 }
375 }
376 }
377
378 fn get_all_non_empty_and_hash_nodes_in_trie(trie: &TrieType) -> Vec<NodeFullNibbles> {
379 let mut nodes = Vec::new();
380 get_all_non_empty_and_hash_nodes_in_trie_intern(trie, Nibbles::default(), &mut nodes);
381
382 nodes
383 }
384
385 fn get_all_non_empty_and_hash_nodes_in_trie_intern(
386 trie: &TrieType,
387 mut curr_nibbles: Nibbles,
388 nodes: &mut Vec<NodeFullNibbles>,
389 ) {
390 match &trie.node {
391 Node::Empty | Node::Hash(_) => return,
392 Node::Branch { children, .. } => {
393 for (i, c) in children.iter().enumerate() {
394 get_all_non_empty_and_hash_nodes_in_trie_intern(
395 c,
396 curr_nibbles.merge_nibble(i as u8),
397 nodes,
398 )
399 }
400 }
401 Node::Extension { nibbles, child } => get_all_non_empty_and_hash_nodes_in_trie_intern(
402 child,
403 curr_nibbles.merge_nibbles(nibbles),
404 nodes,
405 ),
406 Node::Leaf { nibbles, .. } => curr_nibbles = curr_nibbles.merge_nibbles(nibbles),
407 };
408
409 nodes.push(NodeFullNibbles::new_from_node(trie, curr_nibbles.reverse()));
410 }
411
412 fn get_all_nibbles_of_leaf_nodes_in_trie(trie: &TrieType) -> HashSet<Nibbles> {
413 trie.items()
414 .filter_map(|(n, v_or_h)| matches!(v_or_h, ValOrHash::Val(_)).then(|| n))
415 .collect()
416 }
417
418 #[test]
419 fn empty_trie_does_not_return_err_on_query() {
420 let trie = TrieType::default();
421 let nibbles: Nibbles = 0x1234.into();
422 let res = create_trie_subset(&trie, once(nibbles));
423
424 assert!(res.is_ok());
425 }
426
427 #[test]
428 fn non_existent_key_does_not_return_err() {
429 let mut trie = TrieType::default();
430 trie.insert(0x1234, vec![0, 1, 2]);
431 let res = create_trie_subset(&trie, once(0x5678));
432
433 assert!(res.is_ok());
434 }
435
436 #[test]
437 fn encountering_a_hash_node_returns_err() {
438 let trie = HashedPartialTrie::new(Node::Hash(H256::zero()));
439 let res = create_trie_subset(&trie, once(0x1234));
440
441 assert!(res.is_err())
442 }
443
444 #[test]
445 fn single_node_trie_is_queryable() {
446 let mut trie = TrieType::default();
447 trie.insert(0x1234, vec![0, 1, 2]);
448 let trie_subset = create_trie_subset(&trie, once(0x1234)).unwrap();
449
450 assert_eq!(trie, trie_subset);
451 }
452
453 #[test]
454 fn multi_node_trie_returns_proper_subset() {
455 let mut trie = TrieType::default();
456 trie.insert(0x1234, vec![0]);
457 trie.insert(0x56, vec![1]);
458 trie.insert(0x12345, vec![2]);
459
460 let trie_subset = create_trie_subset(&trie, vec![0x1234, 0x56].into_iter()).unwrap();
461 let leaf_keys = get_all_nibbles_of_leaf_nodes_in_trie(&trie_subset);
462
463 assert!(leaf_keys.contains(&(Nibbles::from(0x1234))));
464 assert!(leaf_keys.contains(&(Nibbles::from(0x56))));
465 assert!(!leaf_keys.contains(&Nibbles::from(0x12345)));
466 }
467
468 #[test]
469 fn intermediate_nodes_are_included_in_subset() {
470 let mut trie = TrieType::default();
471 let inserts = vec![
472 (0x1234_u64.into(), vec![0]),
473 (0x1324_u64.into(), vec![1]),
474 (0x132400005_u64.into(), vec![2]),
475 (0x2001_u64.into(), vec![3]),
476 (0x2002_u64.into(), vec![4]),
477 ];
478
479 for (k, v) in inserts.iter() {
492 trie.insert(*k, v.clone());
493 }
494
495 let ks: Vec<_> = inserts.iter().map(|(k, _)| k).cloned().collect();
496 let trie_subset_all = create_trie_subset(&trie, ks.iter().cloned()).unwrap();
497
498 let subset_keys = get_all_nibbles_of_leaf_nodes_in_trie(&trie_subset_all);
499 assert!(subset_keys.iter().all(|k| ks.contains(k)));
500 assert!(ks.iter().all(|k| subset_keys.contains(k)));
501
502 let all_non_empty_and_hash_nodes =
503 get_all_non_empty_and_hash_nodes_in_trie(&trie_subset_all);
504
505 assert_node_exists(
506 &all_non_empty_and_hash_nodes,
507 TrieNodeType::Branch,
508 Nibbles::default(),
509 );
510 assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Branch, 0x1);
511 assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Leaf, 0x1234);
512
513 assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Extension, 0x13);
514 assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Branch, 0x1324);
515 assert_node_exists(
516 &all_non_empty_and_hash_nodes,
517 TrieNodeType::Leaf,
518 0x132400005_u64,
519 );
520
521 assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Extension, 0x2);
522 assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Branch, 0x200);
523 assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Leaf, 0x2001);
524 assert_node_exists(&all_non_empty_and_hash_nodes, TrieNodeType::Leaf, 0x2002);
525
526 assert_eq!(all_non_empty_and_hash_nodes.len(), 10);
527
528 let all_non_empty_and_hash_nodes_partial = get_all_non_empty_and_hash_nodes_in_trie(
530 &create_trie_subset(&trie, once(0x2001)).unwrap(),
531 );
532 assert_node_exists(
533 &all_non_empty_and_hash_nodes_partial,
534 TrieNodeType::Branch,
535 Nibbles::default(),
536 );
537 assert_node_exists(
538 &all_non_empty_and_hash_nodes_partial,
539 TrieNodeType::Extension,
540 0x2,
541 );
542 assert_node_exists(
543 &all_non_empty_and_hash_nodes_partial,
544 TrieNodeType::Branch,
545 0x200,
546 );
547 assert_node_exists(
548 &all_non_empty_and_hash_nodes_partial,
549 TrieNodeType::Leaf,
550 0x2001,
551 );
552 assert_eq!(all_non_empty_and_hash_nodes_partial.len(), 4);
553
554 let all_non_empty_and_hash_nodes_partial = get_all_non_empty_and_hash_nodes_in_trie(
555 &create_trie_subset(&trie, once(0x1324)).unwrap(),
556 );
557 assert_node_exists(
558 &all_non_empty_and_hash_nodes_partial,
559 TrieNodeType::Branch,
560 Nibbles::default(),
561 );
562 assert_node_exists(
563 &all_non_empty_and_hash_nodes_partial,
564 TrieNodeType::Branch,
565 0x1,
566 );
567 assert_node_exists(
568 &all_non_empty_and_hash_nodes_partial,
569 TrieNodeType::Extension,
570 0x13,
571 );
572 assert_node_exists(
573 &all_non_empty_and_hash_nodes_partial,
574 TrieNodeType::Branch,
575 0x1324,
576 );
577 assert_eq!(all_non_empty_and_hash_nodes_partial.len(), 4);
578 }
579
580 fn assert_node_exists<K: Into<Nibbles>>(
581 nodes: &[NodeFullNibbles],
582 n_type: TrieNodeType,
583 nibbles: K,
584 ) {
585 assert!(nodes.contains(&NodeFullNibbles::new_from_node_type(
586 n_type,
587 nibbles.into().reverse()
588 )));
589 }
590
591 #[test]
592 fn all_leafs_of_keys_to_create_subset_are_included_in_subset_for_giant_trie() {
593 let trie_size = MASSIVE_TEST_NUM_SUB_TRIES * MASSIVE_TEST_NUM_SUB_TRIE_SIZE;
594
595 let random_entries: Vec<_> =
596 generate_n_random_fixed_trie_entries(trie_size, 9009).collect();
597 let entry_keys: Vec<_> = random_entries.iter().map(|(k, _)| k).cloned().collect();
598 let trie = TrieType::from_iter(random_entries);
599
600 let keys_of_subsets: Vec<Vec<_>> = (0..MASSIVE_TEST_NUM_SUB_TRIES)
601 .map(|i| {
602 let entry_range_start = i * MASSIVE_TEST_NUM_SUB_TRIE_SIZE;
603 let entry_range_end = entry_range_start + MASSIVE_TEST_NUM_SUB_TRIE_SIZE;
604 entry_keys[entry_range_start..entry_range_end].to_vec()
605 })
606 .collect();
607
608 let trie_subsets =
609 create_trie_subsets(&trie, keys_of_subsets.iter().map(|v| v.iter().cloned())).unwrap();
610
611 for (sub_trie, ks_used) in trie_subsets.into_iter().zip(keys_of_subsets.into_iter()) {
612 let leaf_nibbles = get_all_nibbles_of_leaf_nodes_in_trie(&sub_trie);
613 assert!(ks_used.into_iter().all(|k| leaf_nibbles.contains(&k)));
614 }
615 }
616}