1use crate::utils::{Length, opt_hash, opt_packing_depth, opt_packing_factor};
2use crate::{Arc, Error, Leaf, PackedLeaf, UpdateMap, Value};
3use educe::Educe;
4use ethereum_hashing::{ZERO_HASHES, hash32_concat};
5use parking_lot::RwLock;
6use std::collections::BTreeMap;
7use std::collections::HashMap;
8use std::ops::ControlFlow;
9use tree_hash::Hash256;
10
11#[derive(Debug, Educe)]
12#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
13#[educe(PartialEq(bound(T: Value)), Hash)]
14pub enum Tree<T: Value> {
15 Leaf(Leaf<T>),
16 PackedLeaf(PackedLeaf<T>),
17 Node {
18 #[educe(PartialEq(ignore), Hash(ignore))]
19 #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::utils::arb_rwlock))]
20 hash: RwLock<Hash256>,
21 #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::utils::arb_arc))]
22 left: Arc<Self>,
23 #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::utils::arb_arc))]
24 right: Arc<Self>,
25 },
26 Zero(usize),
27}
28
29impl<T: Value> Clone for Tree<T> {
30 fn clone(&self) -> Self {
31 match self {
32 Self::Node { hash, left, right } => Self::Node {
33 hash: RwLock::new(*hash.read()),
34 left: left.clone(),
35 right: right.clone(),
36 },
37 Self::Leaf(leaf) => Self::Leaf(leaf.clone()),
38 Self::PackedLeaf(leaf) => Self::PackedLeaf(leaf.clone()),
39 Self::Zero(depth) => Self::Zero(*depth),
40 }
41 }
42}
43
44impl<T: Value> Tree<T> {
45 pub fn empty(depth: usize) -> Arc<Self> {
46 Self::zero(depth)
47 }
48
49 pub fn node(left: Arc<Self>, right: Arc<Self>, hash: Hash256) -> Arc<Self> {
50 Arc::new(Self::Node {
51 hash: RwLock::new(hash),
52 left,
53 right,
54 })
55 }
56
57 pub fn zero(depth: usize) -> Arc<Self> {
58 Arc::new(Self::Zero(depth))
59 }
60
61 pub fn leaf(value: T) -> Arc<Self> {
62 Arc::new(Self::Leaf(Leaf::new(value)))
63 }
64
65 pub fn leaf_with_hash(value: T, hash: Hash256) -> Arc<Self> {
66 Arc::new(Self::Leaf(Leaf::with_hash(value, hash)))
67 }
68
69 pub fn node_unboxed(left: Arc<Self>, right: Arc<Self>) -> Self {
70 Self::Node {
71 hash: RwLock::new(Hash256::ZERO),
72 left,
73 right,
74 }
75 }
76
77 pub fn zero_unboxed(depth: usize) -> Self {
78 Self::Zero(depth)
79 }
80
81 pub fn leaf_unboxed(value: T) -> Self {
82 Self::Leaf(Leaf::new(value))
83 }
84
85 pub fn get_recursive(&self, index: usize, depth: usize, packing_depth: usize) -> Option<&T> {
86 match self {
87 Self::Leaf(Leaf { value, .. }) if depth == 0 => Some(value),
88 Self::PackedLeaf(PackedLeaf { values, .. }) if depth == 0 => {
89 values.get(index % T::tree_hash_packing_factor())
90 }
91 Self::Node { left, right, .. } if depth > 0 => {
92 let new_depth = depth - 1;
93 if (index >> (new_depth + packing_depth)) & 1 == 0 {
95 left.get_recursive(index, new_depth, packing_depth)
96 }
97 else {
99 right.get_recursive(index, new_depth, packing_depth)
100 }
101 }
102 _ => None,
103 }
104 }
105
106 pub fn with_updated_leaf(
110 &self,
111 index: usize,
112 new_value: T,
113 depth: usize,
114 ) -> Result<Arc<Self>, Error> {
115 match self {
116 Self::Leaf(_) if depth == 0 => Ok(Self::leaf(new_value)),
117 Self::PackedLeaf(leaf) if depth == 0 => Ok(Arc::new(Self::PackedLeaf(
118 leaf.insert_at_index(index, new_value)?,
119 ))),
120 Self::Node { left, right, .. } if depth > 0 => {
121 let packing_depth = opt_packing_depth::<T>().unwrap_or(0);
122 let new_depth = depth - 1;
123 if (index >> (new_depth + packing_depth)) & 1 == 0 {
124 Ok(Self::node(
126 left.with_updated_leaf(index, new_value, new_depth)?,
127 right.clone(),
128 Hash256::ZERO,
129 ))
130 } else {
131 Ok(Self::node(
133 left.clone(),
134 right.with_updated_leaf(index, new_value, new_depth)?,
135 Hash256::ZERO,
136 ))
137 }
138 }
139 Self::Zero(zero_depth) if *zero_depth == depth => {
140 if depth == 0 {
141 if opt_packing_factor::<T>().is_some() {
142 Ok(Arc::new(Self::PackedLeaf(PackedLeaf::single(new_value))))
143 } else {
144 Ok(Self::leaf(new_value))
145 }
146 } else {
147 let new_zero = Self::zero(depth - 1);
150 Self::node(new_zero.clone(), new_zero, Hash256::ZERO)
151 .with_updated_leaf(index, new_value, depth)
152 }
153 }
154 _ => Err(Error::UpdateLeafError),
155 }
156 }
157
158 pub fn with_updated_leaves<U: UpdateMap<T>>(
159 &self,
160 updates: &U,
161 prefix: usize,
162 depth: usize,
163 hashes: Option<&BTreeMap<(usize, usize), Hash256>>,
164 ) -> Result<Arc<Self>, Error> {
165 let hash = opt_hash(hashes, depth, prefix).unwrap_or_default();
166
167 match self {
168 Self::Leaf(_) if depth == 0 => {
169 let index = prefix;
170 let value = updates
171 .get(index)
172 .cloned()
173 .ok_or(Error::LeafUpdateMissing { index })?;
174 Ok(Self::leaf_with_hash(value, hash))
175 }
176 Self::PackedLeaf(packed_leaf) if depth == 0 => Ok(Arc::new(Self::PackedLeaf(
177 packed_leaf.update(prefix, hash, updates)?,
178 ))),
179 Self::Node { left, right, .. } if depth > 0 => {
180 let packing_depth = opt_packing_depth::<T>().unwrap_or(0);
181 let new_depth = depth - 1;
182 let left_prefix = prefix;
183 let right_prefix = prefix | (1 << (new_depth + packing_depth));
184 let right_subtree_end = prefix + (1 << (depth + packing_depth));
185
186 let mut has_left_updates = false;
187 updates.for_each_range(left_prefix, right_prefix, |_, _| {
188 has_left_updates = true;
189 ControlFlow::Break(())
190 })?;
191 let mut has_right_updates = false;
192 updates.for_each_range(right_prefix, right_subtree_end, |_, _| {
193 has_right_updates = true;
194 ControlFlow::Break(())
195 })?;
196
197 if !has_left_updates && !has_right_updates {
199 return Err(Error::NodeUpdatesMissing { prefix });
200 }
201
202 let new_left = if has_left_updates {
203 left.with_updated_leaves(updates, left_prefix, new_depth, hashes)?
204 } else {
205 left.clone()
206 };
207 let new_right = if has_right_updates {
208 right.with_updated_leaves(updates, right_prefix, new_depth, hashes)?
209 } else {
210 right.clone()
211 };
212
213 Ok(Self::node(new_left, new_right, hash))
214 }
215 Self::Zero(zero_depth) if *zero_depth == depth => {
216 if depth == 0 {
217 if opt_packing_factor::<T>().is_some() {
218 let packed_leaf = PackedLeaf::empty().update(prefix, hash, updates)?;
219 Ok(Arc::new(Self::PackedLeaf(packed_leaf)))
220 } else {
221 let index = prefix;
222 let value = updates
223 .get(index)
224 .cloned()
225 .ok_or(Error::LeafUpdateMissing { index })?;
226 Ok(Self::leaf_with_hash(value, hash))
227 }
228 } else {
229 let new_zero = Self::zero(depth - 1);
231 Self::node(new_zero.clone(), new_zero, hash)
232 .with_updated_leaves(updates, prefix, depth, hashes)
233 }
234 }
235 _ => Err(Error::UpdateLeavesError),
236 }
237 }
238
239 pub fn compute_len(&self) -> usize {
244 match self {
245 Self::Leaf(_) => 1,
246 Self::PackedLeaf(leaf) => leaf.values.len(),
247 Self::Node { left, right, .. } => left.compute_len() + right.compute_len(),
248 Self::Zero(_) => 0,
249 }
250 }
251}
252
253pub enum RebaseAction<'a, T> {
254 NotEqualNoop,
256 NotEqualReplace(Arc<T>),
258 EqualNoop,
260 EqualReplace(&'a Arc<T>),
262}
263
264pub enum IntraRebaseAction<T> {
265 Noop,
266 Replace(Arc<T>),
267}
268
269impl<T: Value> Tree<T> {
270 pub fn rebase_on<'a>(
271 orig: &'a Arc<Self>,
272 base: &'a Arc<Self>,
273 lengths: Option<(Length, Length)>,
274 full_depth: usize,
275 ) -> Result<RebaseAction<'a, Self>, Error> {
276 if Arc::ptr_eq(orig, base) {
277 return Ok(RebaseAction::EqualNoop);
278 }
279 match (&**orig, &**base) {
280 (Self::Leaf(l1), Self::Leaf(l2)) => {
281 if l1.value == l2.value {
282 Ok(RebaseAction::EqualReplace(base))
283 } else {
284 Ok(RebaseAction::NotEqualNoop)
285 }
286 }
287 (Self::PackedLeaf(l1), Self::PackedLeaf(l2)) => {
288 if l1.values == l2.values {
289 Ok(RebaseAction::EqualReplace(base))
290 } else {
291 Ok(RebaseAction::NotEqualNoop)
292 }
293 }
294 (Self::Zero(z1), Self::Zero(z2)) if z1 == z2 => Ok(RebaseAction::EqualReplace(base)),
295 (
296 Self::Node {
297 hash: orig_hash_lock,
298 left: l1,
299 right: r1,
300 },
301 Self::Node {
302 hash: base_hash_lock,
303 left: l2,
304 right: r2,
305 },
306 ) if full_depth > 0 => {
307 use RebaseAction::*;
308
309 let orig_hash = *orig_hash_lock.read();
310 let base_hash = *base_hash_lock.read();
311
312 if !orig_hash.is_zero()
316 && orig_hash == base_hash
317 && lengths.is_none_or(|(orig_length, base_length)| orig_length == base_length)
318 {
319 return Ok(EqualReplace(base));
320 }
321
322 let new_full_depth = full_depth - 1;
323 let (left_lengths, right_lengths) = lengths
324 .map(|(orig_length, base_length)| {
325 let max_left_length = Length(1 << new_full_depth);
326 let orig_left_length = std::cmp::min(orig_length, max_left_length);
327 let orig_right_length =
328 Length(orig_length.as_usize() - orig_left_length.as_usize());
329
330 let base_left_length = std::cmp::min(base_length, max_left_length);
331 let base_right_length =
332 Length(base_length.as_usize() - base_left_length.as_usize());
333 (
334 (orig_left_length, base_left_length),
335 (orig_right_length, base_right_length),
336 )
337 })
338 .unzip();
339
340 let left_action = Tree::rebase_on(l1, l2, left_lengths, new_full_depth)?;
341 let right_action = Tree::rebase_on(r1, r2, right_lengths, new_full_depth)?;
342
343 match (left_action, right_action) {
344 (NotEqualNoop, NotEqualNoop | EqualNoop) | (EqualNoop, NotEqualNoop) => {
345 Ok(NotEqualNoop)
346 }
347 (EqualNoop, EqualNoop) => Ok(EqualNoop),
348 (NotEqualNoop | EqualNoop, NotEqualReplace(new_right)) => {
349 Ok(NotEqualReplace(Arc::new(Self::Node {
350 hash: RwLock::new(orig_hash),
351 left: l1.clone(),
352 right: new_right,
353 })))
354 }
355 (NotEqualNoop | EqualNoop, EqualReplace(new_right)) => {
356 Ok(NotEqualReplace(Arc::new(Self::Node {
357 hash: RwLock::new(orig_hash),
358 left: l1.clone(),
359 right: new_right.clone(),
360 })))
361 }
362 (NotEqualReplace(new_left), NotEqualNoop | EqualNoop) => {
363 Ok(NotEqualReplace(Arc::new(Self::Node {
364 hash: RwLock::new(orig_hash),
365 left: new_left,
366 right: r1.clone(),
367 })))
368 }
369 (NotEqualReplace(new_left), NotEqualReplace(new_right)) => {
370 Ok(NotEqualReplace(Arc::new(Self::Node {
371 hash: RwLock::new(orig_hash),
372 left: new_left,
373 right: new_right,
374 })))
375 }
376 (NotEqualReplace(new_left), EqualReplace(new_right)) => {
377 Ok(NotEqualReplace(Arc::new(Self::Node {
378 hash: RwLock::new(orig_hash),
379 left: new_left,
380 right: new_right.clone(),
381 })))
382 }
383 (EqualReplace(new_left), NotEqualNoop) => {
384 Ok(NotEqualReplace(Arc::new(Self::Node {
385 hash: RwLock::new(orig_hash),
386 left: new_left.clone(),
387 right: r1.clone(),
388 })))
389 }
390 (EqualReplace(new_left), NotEqualReplace(new_right)) => {
391 Ok(NotEqualReplace(Arc::new(Self::Node {
392 hash: RwLock::new(orig_hash),
393 left: new_left.clone(),
394 right: new_right,
395 })))
396 }
397 (EqualReplace(_), EqualReplace(_)) | (EqualReplace(_), EqualNoop) => {
398 Ok(EqualReplace(base))
399 }
400 }
401 }
402 (Self::Zero(_), _) | (_, Self::Zero(_)) => Ok(RebaseAction::NotEqualNoop),
403 (Self::Node { .. }, Self::Node { .. }) => Err(Error::InvalidRebaseNode),
404 (Self::Leaf(_) | Self::PackedLeaf(_), _) | (_, Self::Leaf(_) | Self::PackedLeaf(_)) => {
405 Err(Error::InvalidRebaseLeaf)
406 }
407 }
408 }
409
410 pub fn intra_rebase(
434 orig: &Arc<Self>,
435 known_subtrees: &mut HashMap<(usize, Hash256), Arc<Self>>,
436 current_depth: usize,
437 ) -> Result<IntraRebaseAction<Self>, Error> {
438 match &**orig {
439 Self::Leaf(_) | Self::PackedLeaf(_) | Self::Zero(_) => Ok(IntraRebaseAction::Noop),
440 Self::Node { hash, left, right } if current_depth > 0 => {
441 let hash = *hash.read();
442
443 if hash.is_zero() {
445 return Err(Error::IntraRebaseZeroHash);
446 }
447
448 if let Some(known_subtree) = known_subtrees.get(&(current_depth, hash)) {
449 return Ok(IntraRebaseAction::Replace(known_subtree.clone()));
452 }
453
454 let left_action = Self::intra_rebase(left, known_subtrees, current_depth - 1)?;
455 let right_action = Self::intra_rebase(right, known_subtrees, current_depth - 1)?;
456
457 let action = match (left_action, right_action) {
458 (IntraRebaseAction::Noop, IntraRebaseAction::Noop) => IntraRebaseAction::Noop,
459 (IntraRebaseAction::Noop, IntraRebaseAction::Replace(new_right)) => {
460 IntraRebaseAction::Replace(Self::node(left.clone(), new_right, hash))
461 }
462 (IntraRebaseAction::Replace(new_left), IntraRebaseAction::Noop) => {
463 IntraRebaseAction::Replace(Self::node(new_left, right.clone(), hash))
464 }
465 (
466 IntraRebaseAction::Replace(new_left),
467 IntraRebaseAction::Replace(new_right),
468 ) => IntraRebaseAction::Replace(Self::node(new_left, new_right, hash)),
469 };
470
471 let new_subtree = match &action {
473 IntraRebaseAction::Noop => orig.clone(),
476 IntraRebaseAction::Replace(new) => new.clone(),
477 };
478 let existing_entry = known_subtrees.insert((current_depth, hash), new_subtree);
479
480 if existing_entry.is_some() {
484 return Err(Error::IntraRebaseRepeatVisit);
485 }
486
487 Ok(action)
488 }
489 Self::Node { .. } => Err(Error::IntraRebaseZeroDepth),
490 }
491 }
492}
493
494impl<T: Value + Send + Sync> Tree<T> {
495 pub fn tree_hash(&self) -> Hash256 {
496 match self {
497 Self::Leaf(Leaf { hash, value }) => {
498 let read_lock = hash.read();
500 let existing_hash = *read_lock;
501 drop(read_lock);
502
503 if !existing_hash.is_zero() {
511 existing_hash
512 } else {
513 let tree_hash = value.tree_hash_root();
514 *hash.write() = tree_hash;
515 tree_hash
516 }
517 }
518 Self::PackedLeaf(leaf) => leaf.tree_hash(),
519 Self::Zero(depth) => Hash256::from(ZERO_HASHES[*depth]),
520 Self::Node { hash, left, right } => {
521 let read_lock = hash.read();
522 let existing_hash = *read_lock;
523 drop(read_lock);
524
525 if !existing_hash.is_zero() {
526 existing_hash
527 } else {
528 let (left_hash, right_hash) =
530 rayon::join(|| left.tree_hash(), || right.tree_hash());
531 let tree_hash =
532 Hash256::from(hash32_concat(left_hash.as_slice(), right_hash.as_slice()));
533 *hash.write() = tree_hash;
534 tree_hash
535 }
536 }
537 }
538 }
539}