1use std::{
4 convert::{TryFrom, TryInto},
5 fmt::{self, Debug, Display, Formatter},
6 iter::Flatten,
7 mem::MaybeUninit,
8 slice,
9};
10
11use datasize::DataSize;
12use num_derive::{FromPrimitive, ToPrimitive};
13use num_traits::{FromPrimitive, ToPrimitive};
14use serde::{
15 de::{self, MapAccess, Visitor},
16 ser::SerializeMap,
17 Deserialize, Deserializer, Serialize, Serializer,
18};
19
20use casper_types::{
21 bytesrepr::{self, Bytes, FromBytes, ToBytes, U8_SERIALIZED_LENGTH},
22 global_state::Pointer,
23 Digest,
24};
25
26#[cfg(test)]
27pub mod gens;
28
29#[cfg(test)]
30mod tests;
31
32pub(crate) const USIZE_EXCEEDS_U8: &str = "usize exceeds u8";
33pub(crate) const RADIX: usize = 256;
34
35pub type Parents<K, V> = Vec<(u8, Trie<K, V>)>;
37
38pub type PointerBlockValue = Option<Pointer>;
40
41pub type PointerBlockArray = [PointerBlockValue; RADIX];
43
44#[derive(Copy, Clone)]
46pub struct PointerBlock(PointerBlockArray);
47
48impl Serialize for PointerBlock {
49 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
50 where
51 S: Serializer,
52 {
53 let elements_count = self.0.iter().filter(|element| element.is_some()).count();
58 let mut map = serializer.serialize_map(Some(elements_count))?;
59
60 for (index, maybe_pointer_block) in self.0.iter().enumerate() {
62 if let Some(pointer_block_value) = maybe_pointer_block {
63 map.serialize_entry(&(index as u8), pointer_block_value)?;
64 }
65 }
66 map.end()
67 }
68}
69
70impl<'de> Deserialize<'de> for PointerBlock {
71 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
72 where
73 D: Deserializer<'de>,
74 {
75 struct PointerBlockDeserializer;
76
77 impl<'de> Visitor<'de> for PointerBlockDeserializer {
78 type Value = PointerBlock;
79
80 fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
81 formatter.write_str("sparse representation of a PointerBlock")
82 }
83
84 fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
85 where
86 M: MapAccess<'de>,
87 {
88 let mut pointer_block = PointerBlock::new();
89
90 while let Some((index, pointer_block_value)) = access.next_entry::<u8, Pointer>()? {
92 let element = pointer_block.0.get_mut(usize::from(index)).ok_or_else(|| {
93 de::Error::custom(format!("invalid index {} in pointer block value", index))
94 })?;
95 *element = Some(pointer_block_value);
96 }
97
98 Ok(pointer_block)
99 }
100 }
101 deserializer.deserialize_map(PointerBlockDeserializer)
102 }
103}
104
105impl PointerBlock {
106 pub fn new() -> Self {
108 Default::default()
109 }
110
111 pub fn from_indexed_pointers(indexed_pointers: &[(u8, Pointer)]) -> Self {
113 let mut ret = PointerBlock::new();
114 for (idx, ptr) in indexed_pointers.iter() {
115 ret[*idx as usize] = Some(*ptr);
116 }
117 ret
118 }
119
120 pub fn as_indexed_pointers(&self) -> impl Iterator<Item = (u8, Pointer)> + '_ {
122 self.0
123 .iter()
124 .enumerate()
125 .filter_map(|(index, maybe_pointer)| {
126 maybe_pointer
127 .map(|value| (index.try_into().expect(USIZE_EXCEEDS_U8), value.to_owned()))
128 })
129 }
130
131 pub fn child_count(&self) -> usize {
133 self.as_indexed_pointers().count()
134 }
135}
136
137impl From<PointerBlockArray> for PointerBlock {
138 fn from(src: PointerBlockArray) -> Self {
139 PointerBlock(src)
140 }
141}
142
143impl PartialEq for PointerBlock {
144 #[inline]
145 fn eq(&self, other: &PointerBlock) -> bool {
146 self.0[..] == other.0[..]
147 }
148}
149
150impl Eq for PointerBlock {}
151
152impl Default for PointerBlock {
153 fn default() -> Self {
154 PointerBlock([Default::default(); RADIX])
155 }
156}
157
158impl ToBytes for PointerBlock {
159 fn to_bytes(&self) -> Result<Vec<u8>, bytesrepr::Error> {
160 let mut result = bytesrepr::allocate_buffer(self)?;
161 for pointer in self.0.iter() {
162 result.append(&mut pointer.to_bytes()?);
163 }
164 Ok(result)
165 }
166
167 fn serialized_length(&self) -> usize {
168 self.0.iter().map(ToBytes::serialized_length).sum()
169 }
170
171 fn write_bytes(&self, writer: &mut Vec<u8>) -> Result<(), bytesrepr::Error> {
172 for pointer in self.0.iter() {
173 pointer.write_bytes(writer)?;
174 }
175 Ok(())
176 }
177}
178
179impl FromBytes for PointerBlock {
180 fn from_bytes(mut bytes: &[u8]) -> Result<(Self, &[u8]), bytesrepr::Error> {
181 let pointer_block_array = {
182 let mut result: MaybeUninit<PointerBlockArray> = MaybeUninit::uninit();
184 let result_ptr = result.as_mut_ptr() as *mut PointerBlockValue;
185 for i in 0..RADIX {
186 let (t, remainder) = match FromBytes::from_bytes(bytes) {
187 Ok(success) => success,
188 Err(error) => {
189 for j in 0..i {
190 unsafe { result_ptr.add(j).drop_in_place() }
191 }
192 return Err(error);
193 }
194 };
195 unsafe { result_ptr.add(i).write(t) };
196 bytes = remainder;
197 }
198 unsafe { result.assume_init() }
199 };
200 Ok((PointerBlock(pointer_block_array), bytes))
201 }
202}
203
204impl core::ops::Index<usize> for PointerBlock {
205 type Output = PointerBlockValue;
206
207 #[inline]
208 fn index(&self, index: usize) -> &Self::Output {
209 let PointerBlock(dat) = self;
210 &dat[index]
211 }
212}
213
214impl core::ops::IndexMut<usize> for PointerBlock {
215 #[inline]
216 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
217 let PointerBlock(dat) = self;
218 &mut dat[index]
219 }
220}
221
222impl core::ops::Index<core::ops::Range<usize>> for PointerBlock {
223 type Output = [PointerBlockValue];
224
225 #[inline]
226 fn index(&self, index: core::ops::Range<usize>) -> &[PointerBlockValue] {
227 let PointerBlock(dat) = self;
228 &dat[index]
229 }
230}
231
232impl core::ops::Index<core::ops::RangeTo<usize>> for PointerBlock {
233 type Output = [PointerBlockValue];
234
235 #[inline]
236 fn index(&self, index: core::ops::RangeTo<usize>) -> &[PointerBlockValue] {
237 let PointerBlock(dat) = self;
238 &dat[index]
239 }
240}
241
242impl core::ops::Index<core::ops::RangeFrom<usize>> for PointerBlock {
243 type Output = [PointerBlockValue];
244
245 #[inline]
246 fn index(&self, index: core::ops::RangeFrom<usize>) -> &[PointerBlockValue] {
247 let PointerBlock(dat) = self;
248 &dat[index]
249 }
250}
251
252impl core::ops::Index<core::ops::RangeFull> for PointerBlock {
253 type Output = [PointerBlockValue];
254
255 #[inline]
256 fn index(&self, index: core::ops::RangeFull) -> &[PointerBlockValue] {
257 let PointerBlock(dat) = self;
258 &dat[index]
259 }
260}
261
262impl ::std::fmt::Debug for PointerBlock {
263 #[allow(clippy::assertions_on_constants)]
264 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
265 assert!(RADIX > 1, "RADIX must be > 1");
266 write!(f, "{}([", stringify!(PointerBlock))?;
267 write!(f, "{:?}", self.0[0])?;
268 for item in self.0[1..].iter() {
269 write!(f, ", {:?}", item)?;
270 }
271 write!(f, "])")
272 }
273}
274
275#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize, DataSize)]
277pub struct TrieRaw(Bytes);
278
279impl TrieRaw {
280 pub fn new(bytes: Bytes) -> Self {
282 TrieRaw(bytes)
283 }
284
285 pub fn into_inner(self) -> Bytes {
287 self.0
288 }
289
290 pub fn inner(&self) -> &Bytes {
292 &self.0
293 }
294
295 pub fn hash(&self) -> Digest {
297 Digest::hash_into_chunks_if_necessary(self.inner())
298 }
299}
300
301impl ToBytes for TrieRaw {
302 fn to_bytes(&self) -> Result<Vec<u8>, bytesrepr::Error> {
303 self.0.to_bytes()
304 }
305
306 fn serialized_length(&self) -> usize {
307 self.0.serialized_length()
308 }
309}
310
311impl FromBytes for TrieRaw {
312 fn from_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), bytesrepr::Error> {
313 let (bytes, rem) = Bytes::from_bytes(bytes)?;
314 Ok((TrieRaw(bytes), rem))
315 }
316}
317
318#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive, ToPrimitive)]
320#[repr(u8)]
321pub(crate) enum TrieTag {
322 Leaf = 0,
324 Node = 1,
326 Extension = 2,
328}
329
330impl From<TrieTag> for u8 {
331 fn from(value: TrieTag) -> Self {
332 TrieTag::to_u8(&value).unwrap() }
334}
335
336#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
338pub enum Trie<K, V> {
339 Leaf {
341 key: K,
343 value: V,
345 },
346 Node {
348 pointer_block: Box<PointerBlock>,
350 },
351 Extension {
353 affix: Bytes,
355 pointer: Pointer,
357 },
358}
359
360impl<K, V> Display for Trie<K, V>
361where
362 K: Debug,
363 V: Debug,
364{
365 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
366 write!(f, "{:?}", self)
367 }
368}
369
370impl<K, V> Trie<K, V> {
371 fn tag(&self) -> TrieTag {
372 match self {
373 Trie::Leaf { .. } => TrieTag::Leaf,
374 Trie::Node { .. } => TrieTag::Node,
375 Trie::Extension { .. } => TrieTag::Extension,
376 }
377 }
378
379 pub fn tag_type(&self) -> String {
381 match self {
382 Trie::Leaf { .. } => "Leaf".to_string(),
383 Trie::Node { .. } => "Node".to_string(),
384 Trie::Extension { .. } => "Extension".to_string(),
385 }
386 }
387
388 pub fn leaf(key: K, value: V) -> Self {
390 Trie::Leaf { key, value }
391 }
392
393 pub fn node(indexed_pointers: &[(u8, Pointer)]) -> Self {
395 let pointer_block = PointerBlock::from_indexed_pointers(indexed_pointers);
396 let pointer_block = Box::new(pointer_block);
397 Trie::Node { pointer_block }
398 }
399
400 pub fn extension(affix: Vec<u8>, pointer: Pointer) -> Self {
402 Trie::Extension {
403 affix: affix.into(),
404 pointer,
405 }
406 }
407
408 pub fn key(&self) -> Option<&K> {
410 match self {
411 Trie::Leaf { key, .. } => Some(key),
412 _ => None,
413 }
414 }
415
416 pub fn trie_hash(&self) -> Result<Digest, bytesrepr::Error>
418 where
419 Self: ToBytes,
420 {
421 self.to_bytes()
422 .map(|bytes| Digest::hash_into_chunks_if_necessary(&bytes))
423 }
424
425 pub fn trie_hash_and_bytes(&self) -> Result<(Digest, Vec<u8>), bytesrepr::Error>
427 where
428 Self: ToBytes,
429 {
430 self.to_bytes()
431 .map(|bytes| (Digest::hash_into_chunks_if_necessary(&bytes), bytes))
432 }
433
434 pub fn as_pointer_block(&self) -> Option<&PointerBlock> {
436 if let Self::Node { pointer_block } = self {
437 Some(pointer_block.as_ref())
438 } else {
439 None
440 }
441 }
442
443 pub fn iter_children(&self) -> DescendantsIterator {
445 match self {
446 Trie::<K, V>::Leaf { .. } => DescendantsIterator::ZeroOrOne(None),
447 Trie::Node { pointer_block } => DescendantsIterator::PointerBlock {
448 iter: pointer_block.0.iter().flatten(),
449 },
450 Trie::Extension { pointer, .. } => {
451 DescendantsIterator::ZeroOrOne(Some(pointer.into_hash()))
452 }
453 }
454 }
455}
456
457#[derive(Debug, Clone, PartialEq)]
460pub(crate) struct TrieLeafBytes(Bytes);
461
462impl TrieLeafBytes {
463 pub(crate) fn bytes(&self) -> &Bytes {
464 &self.0
465 }
466
467 pub(crate) fn try_deserialize_leaf_key<K: FromBytes>(
468 &self,
469 ) -> Result<(K, &[u8]), bytesrepr::Error> {
470 let (tag_byte, rem) = u8::from_bytes(&self.0)?;
471 let tag = TrieTag::from_u8(tag_byte).ok_or(bytesrepr::Error::Formatting)?;
472 assert_eq!(
473 tag,
474 TrieTag::Leaf,
475 "Unexpected layout for trie leaf bytes. Expected `TrieTag::Leaf` but got {:?}",
476 tag
477 );
478 K::from_bytes(rem)
479 }
480}
481
482impl From<&[u8]> for TrieLeafBytes {
483 fn from(value: &[u8]) -> Self {
484 Self(value.into())
485 }
486}
487
488impl From<Vec<u8>> for TrieLeafBytes {
489 fn from(value: Vec<u8>) -> Self {
490 Self(value.into())
491 }
492}
493
494#[derive(Debug, Clone, PartialEq)]
496pub(crate) enum LazilyDeserializedTrie {
497 Leaf(TrieLeafBytes),
499 Node { pointer_block: Box<PointerBlock> },
501 Extension { affix: Bytes, pointer: Pointer },
503}
504
505impl LazilyDeserializedTrie {
506 pub(crate) fn iter_children(&self) -> DescendantsIterator {
507 match self {
508 LazilyDeserializedTrie::Leaf(_) => {
509 DescendantsIterator::ZeroOrOne(None)
511 }
512 LazilyDeserializedTrie::Node { pointer_block } => DescendantsIterator::PointerBlock {
513 iter: pointer_block.0.iter().flatten(),
514 },
515 LazilyDeserializedTrie::Extension { pointer, .. } => {
516 DescendantsIterator::ZeroOrOne(Some(pointer.into_hash()))
517 }
518 }
519 }
520}
521
522impl FromBytes for LazilyDeserializedTrie {
523 fn from_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), bytesrepr::Error> {
524 let (tag_byte, rem) = u8::from_bytes(bytes)?;
525 let tag = TrieTag::from_u8(tag_byte).ok_or(bytesrepr::Error::Formatting)?;
526 match tag {
527 TrieTag::Leaf => Ok((LazilyDeserializedTrie::Leaf(bytes.into()), &[])),
528 TrieTag::Node => {
529 let (pointer_block, rem) = PointerBlock::from_bytes(rem)?;
530 Ok((
531 LazilyDeserializedTrie::Node {
532 pointer_block: Box::new(pointer_block),
533 },
534 rem,
535 ))
536 }
537 TrieTag::Extension => {
538 let (affix, rem) = FromBytes::from_bytes(rem)?;
539 let (pointer, rem) = Pointer::from_bytes(rem)?;
540 Ok((LazilyDeserializedTrie::Extension { affix, pointer }, rem))
541 }
542 }
543 }
544}
545
546impl<K, V> TryFrom<Trie<K, V>> for LazilyDeserializedTrie
547where
548 K: ToBytes,
549 V: ToBytes,
550{
551 type Error = bytesrepr::Error;
552
553 fn try_from(value: Trie<K, V>) -> Result<Self, Self::Error> {
554 match value {
555 Trie::Leaf { .. } => {
556 let serialized_bytes = ToBytes::to_bytes(&value)?;
557 Ok(LazilyDeserializedTrie::Leaf(serialized_bytes.into()))
558 }
559 Trie::Node { pointer_block } => Ok(LazilyDeserializedTrie::Node { pointer_block }),
560 Trie::Extension { affix, pointer } => {
561 Ok(LazilyDeserializedTrie::Extension { affix, pointer })
562 }
563 }
564 }
565}
566
567pub enum DescendantsIterator<'a> {
569 ZeroOrOne(Option<Digest>),
571 PointerBlock {
573 iter: Flatten<slice::Iter<'a, Option<Pointer>>>,
575 },
576}
577
578impl Iterator for DescendantsIterator<'_> {
579 type Item = Digest;
580
581 fn next(&mut self) -> Option<Self::Item> {
582 match *self {
583 DescendantsIterator::ZeroOrOne(ref mut maybe_digest) => maybe_digest.take(),
584 DescendantsIterator::PointerBlock { ref mut iter } => {
585 iter.next().map(|pointer| *pointer.hash())
586 }
587 }
588 }
589}
590
591impl<K, V> ToBytes for Trie<K, V>
592where
593 K: ToBytes,
594 V: ToBytes,
595{
596 fn to_bytes(&self) -> Result<Vec<u8>, bytesrepr::Error> {
597 let mut ret = bytesrepr::allocate_buffer(self)?;
598 self.write_bytes(&mut ret)?;
599 Ok(ret)
600 }
601
602 fn serialized_length(&self) -> usize {
603 U8_SERIALIZED_LENGTH
604 + match self {
605 Trie::Leaf { key, value } => key.serialized_length() + value.serialized_length(),
606 Trie::Node { pointer_block } => pointer_block.serialized_length(),
607 Trie::Extension { affix, pointer } => {
608 affix.serialized_length() + pointer.serialized_length()
609 }
610 }
611 }
612
613 fn write_bytes(&self, writer: &mut Vec<u8>) -> Result<(), bytesrepr::Error> {
614 writer.push(u8::from(self.tag()));
617 match self {
618 Trie::Leaf { key, value } => {
619 key.write_bytes(writer)?;
620 value.write_bytes(writer)?;
621 }
622 Trie::Node { pointer_block } => pointer_block.write_bytes(writer)?,
623 Trie::Extension { affix, pointer } => {
624 affix.write_bytes(writer)?;
625 pointer.write_bytes(writer)?;
626 }
627 }
628 Ok(())
629 }
630}
631
632impl<K: FromBytes, V: FromBytes> FromBytes for Trie<K, V> {
633 fn from_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), bytesrepr::Error> {
634 let (tag_byte, rem) = u8::from_bytes(bytes)?;
635 let tag = TrieTag::from_u8(tag_byte).ok_or(bytesrepr::Error::Formatting)?;
636 match tag {
637 TrieTag::Leaf => {
638 let (key, rem) = K::from_bytes(rem)?;
639 let (value, rem) = V::from_bytes(rem)?;
640 Ok((Trie::Leaf { key, value }, rem))
641 }
642 TrieTag::Node => {
643 let (pointer_block, rem) = PointerBlock::from_bytes(rem)?;
644 Ok((
645 Trie::Node {
646 pointer_block: Box::new(pointer_block),
647 },
648 rem,
649 ))
650 }
651 TrieTag::Extension => {
652 let (affix, rem) = FromBytes::from_bytes(rem)?;
653 let (pointer, rem) = Pointer::from_bytes(rem)?;
654 Ok((Trie::Extension { affix, pointer }, rem))
655 }
656 }
657 }
658}
659
660impl<K: FromBytes, V: FromBytes> TryFrom<LazilyDeserializedTrie> for Trie<K, V> {
661 type Error = bytesrepr::Error;
662
663 fn try_from(value: LazilyDeserializedTrie) -> Result<Self, Self::Error> {
664 match value {
665 LazilyDeserializedTrie::Leaf(leaf_bytes) => {
666 let (key, value_bytes) = leaf_bytes.try_deserialize_leaf_key()?;
667 let value = bytesrepr::deserialize_from_slice(value_bytes)?;
668 Ok(Self::Leaf { key, value })
669 }
670 LazilyDeserializedTrie::Node { pointer_block } => Ok(Self::Node { pointer_block }),
671 LazilyDeserializedTrie::Extension { affix, pointer } => {
672 Ok(Self::Extension { affix, pointer })
673 }
674 }
675 }
676}
677
678pub(crate) mod operations {
679 use casper_types::{
680 bytesrepr::{self, ToBytes},
681 Digest,
682 };
683
684 use crate::global_state::trie::Trie;
685
686 pub fn create_hashed_empty_trie<K: ToBytes, V: ToBytes>(
689 ) -> Result<(Digest, Trie<K, V>), bytesrepr::Error> {
690 let root: Trie<K, V> = Trie::Node {
691 pointer_block: Default::default(),
692 };
693 let root_bytes: Vec<u8> = root.to_bytes()?;
694 Ok((Digest::hash(root_bytes), root))
695 }
696}