casper_storage/global_state/trie/
mod.rs

1//! Core types for a Merkle Trie
2
3use std::{
4    convert::{TryFrom, TryInto},
5    fmt::{self, Debug, Display, Formatter},
6    iter::Flatten,
7    mem::MaybeUninit,
8    slice,
9};
10
11use datasize::DataSize;
12use num_derive::{FromPrimitive, ToPrimitive};
13use num_traits::{FromPrimitive, ToPrimitive};
14use serde::{
15    de::{self, MapAccess, Visitor},
16    ser::SerializeMap,
17    Deserialize, Deserializer, Serialize, Serializer,
18};
19
20use casper_types::{
21    bytesrepr::{self, Bytes, FromBytes, ToBytes, U8_SERIALIZED_LENGTH},
22    global_state::Pointer,
23    Digest,
24};
25
26#[cfg(test)]
27pub mod gens;
28
29#[cfg(test)]
30mod tests;
31
32pub(crate) const USIZE_EXCEEDS_U8: &str = "usize exceeds u8";
33pub(crate) const RADIX: usize = 256;
34
35/// A parent is represented as a pair of a child index and a node or extension.
36pub type Parents<K, V> = Vec<(u8, Trie<K, V>)>;
37
38/// Type alias for values under pointer blocks.
39pub type PointerBlockValue = Option<Pointer>;
40
41/// Type alias for arrays of pointer block values.
42pub type PointerBlockArray = [PointerBlockValue; RADIX];
43
44/// Represents the underlying structure of a node in a Merkle Trie
45#[derive(Copy, Clone)]
46pub struct PointerBlock(PointerBlockArray);
47
48impl Serialize for PointerBlock {
49    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
50    where
51        S: Serializer,
52    {
53        // We are going to use the sparse representation of pointer blocks
54        // non-None entries and their indices will be output
55
56        // Create the sequence serializer, reserving the necessary number of slots
57        let elements_count = self.0.iter().filter(|element| element.is_some()).count();
58        let mut map = serializer.serialize_map(Some(elements_count))?;
59
60        // Store the non-None entries with their indices
61        for (index, maybe_pointer_block) in self.0.iter().enumerate() {
62            if let Some(pointer_block_value) = maybe_pointer_block {
63                map.serialize_entry(&(index as u8), pointer_block_value)?;
64            }
65        }
66        map.end()
67    }
68}
69
70impl<'de> Deserialize<'de> for PointerBlock {
71    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
72    where
73        D: Deserializer<'de>,
74    {
75        struct PointerBlockDeserializer;
76
77        impl<'de> Visitor<'de> for PointerBlockDeserializer {
78            type Value = PointerBlock;
79
80            fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
81                formatter.write_str("sparse representation of a PointerBlock")
82            }
83
84            fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
85            where
86                M: MapAccess<'de>,
87            {
88                let mut pointer_block = PointerBlock::new();
89
90                // Unpack the sparse representation
91                while let Some((index, pointer_block_value)) = access.next_entry::<u8, Pointer>()? {
92                    let element = pointer_block.0.get_mut(usize::from(index)).ok_or_else(|| {
93                        de::Error::custom(format!("invalid index {} in pointer block value", index))
94                    })?;
95                    *element = Some(pointer_block_value);
96                }
97
98                Ok(pointer_block)
99            }
100        }
101        deserializer.deserialize_map(PointerBlockDeserializer)
102    }
103}
104
105impl PointerBlock {
106    /// No-arg constructor for `PointerBlock`. Delegates to `Default::default()`.
107    pub fn new() -> Self {
108        Default::default()
109    }
110
111    /// Constructs a `PointerBlock` from a slice of indexed `Pointer`s.
112    pub fn from_indexed_pointers(indexed_pointers: &[(u8, Pointer)]) -> Self {
113        let mut ret = PointerBlock::new();
114        for (idx, ptr) in indexed_pointers.iter() {
115            ret[*idx as usize] = Some(*ptr);
116        }
117        ret
118    }
119
120    /// Deconstructs a `PointerBlock` into an iterator of indexed `Pointer`s.
121    pub fn as_indexed_pointers(&self) -> impl Iterator<Item = (u8, Pointer)> + '_ {
122        self.0
123            .iter()
124            .enumerate()
125            .filter_map(|(index, maybe_pointer)| {
126                maybe_pointer
127                    .map(|value| (index.try_into().expect(USIZE_EXCEEDS_U8), value.to_owned()))
128            })
129    }
130
131    /// Gets the count of children for this `PointerBlock`.
132    pub fn child_count(&self) -> usize {
133        self.as_indexed_pointers().count()
134    }
135}
136
137impl From<PointerBlockArray> for PointerBlock {
138    fn from(src: PointerBlockArray) -> Self {
139        PointerBlock(src)
140    }
141}
142
143impl PartialEq for PointerBlock {
144    #[inline]
145    fn eq(&self, other: &PointerBlock) -> bool {
146        self.0[..] == other.0[..]
147    }
148}
149
150impl Eq for PointerBlock {}
151
152impl Default for PointerBlock {
153    fn default() -> Self {
154        PointerBlock([Default::default(); RADIX])
155    }
156}
157
158impl ToBytes for PointerBlock {
159    fn to_bytes(&self) -> Result<Vec<u8>, bytesrepr::Error> {
160        let mut result = bytesrepr::allocate_buffer(self)?;
161        for pointer in self.0.iter() {
162            result.append(&mut pointer.to_bytes()?);
163        }
164        Ok(result)
165    }
166
167    fn serialized_length(&self) -> usize {
168        self.0.iter().map(ToBytes::serialized_length).sum()
169    }
170
171    fn write_bytes(&self, writer: &mut Vec<u8>) -> Result<(), bytesrepr::Error> {
172        for pointer in self.0.iter() {
173            pointer.write_bytes(writer)?;
174        }
175        Ok(())
176    }
177}
178
179impl FromBytes for PointerBlock {
180    fn from_bytes(mut bytes: &[u8]) -> Result<(Self, &[u8]), bytesrepr::Error> {
181        let pointer_block_array = {
182            // With MaybeUninit here we can avoid default initialization of result array below.
183            let mut result: MaybeUninit<PointerBlockArray> = MaybeUninit::uninit();
184            let result_ptr = result.as_mut_ptr() as *mut PointerBlockValue;
185            for i in 0..RADIX {
186                let (t, remainder) = match FromBytes::from_bytes(bytes) {
187                    Ok(success) => success,
188                    Err(error) => {
189                        for j in 0..i {
190                            unsafe { result_ptr.add(j).drop_in_place() }
191                        }
192                        return Err(error);
193                    }
194                };
195                unsafe { result_ptr.add(i).write(t) };
196                bytes = remainder;
197            }
198            unsafe { result.assume_init() }
199        };
200        Ok((PointerBlock(pointer_block_array), bytes))
201    }
202}
203
204impl core::ops::Index<usize> for PointerBlock {
205    type Output = PointerBlockValue;
206
207    #[inline]
208    fn index(&self, index: usize) -> &Self::Output {
209        let PointerBlock(dat) = self;
210        &dat[index]
211    }
212}
213
214impl core::ops::IndexMut<usize> for PointerBlock {
215    #[inline]
216    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
217        let PointerBlock(dat) = self;
218        &mut dat[index]
219    }
220}
221
222impl core::ops::Index<core::ops::Range<usize>> for PointerBlock {
223    type Output = [PointerBlockValue];
224
225    #[inline]
226    fn index(&self, index: core::ops::Range<usize>) -> &[PointerBlockValue] {
227        let PointerBlock(dat) = self;
228        &dat[index]
229    }
230}
231
232impl core::ops::Index<core::ops::RangeTo<usize>> for PointerBlock {
233    type Output = [PointerBlockValue];
234
235    #[inline]
236    fn index(&self, index: core::ops::RangeTo<usize>) -> &[PointerBlockValue] {
237        let PointerBlock(dat) = self;
238        &dat[index]
239    }
240}
241
242impl core::ops::Index<core::ops::RangeFrom<usize>> for PointerBlock {
243    type Output = [PointerBlockValue];
244
245    #[inline]
246    fn index(&self, index: core::ops::RangeFrom<usize>) -> &[PointerBlockValue] {
247        let PointerBlock(dat) = self;
248        &dat[index]
249    }
250}
251
252impl core::ops::Index<core::ops::RangeFull> for PointerBlock {
253    type Output = [PointerBlockValue];
254
255    #[inline]
256    fn index(&self, index: core::ops::RangeFull) -> &[PointerBlockValue] {
257        let PointerBlock(dat) = self;
258        &dat[index]
259    }
260}
261
262impl ::std::fmt::Debug for PointerBlock {
263    #[allow(clippy::assertions_on_constants)]
264    fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
265        assert!(RADIX > 1, "RADIX must be > 1");
266        write!(f, "{}([", stringify!(PointerBlock))?;
267        write!(f, "{:?}", self.0[0])?;
268        for item in self.0[1..].iter() {
269            write!(f, ", {:?}", item)?;
270        }
271        write!(f, "])")
272    }
273}
274
275/// Newtype representing a trie node in its raw form without deserializing into `Trie`.
276#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize, DataSize)]
277pub struct TrieRaw(Bytes);
278
279impl TrieRaw {
280    /// Constructs an instance of [`TrieRaw`].
281    pub fn new(bytes: Bytes) -> Self {
282        TrieRaw(bytes)
283    }
284
285    /// Consumes self and returns inner bytes.
286    pub fn into_inner(self) -> Bytes {
287        self.0
288    }
289
290    /// Returns a reference inner bytes.
291    pub fn inner(&self) -> &Bytes {
292        &self.0
293    }
294
295    /// Returns a hash of the inner bytes.
296    pub fn hash(&self) -> Digest {
297        Digest::hash_into_chunks_if_necessary(self.inner())
298    }
299}
300
301impl ToBytes for TrieRaw {
302    fn to_bytes(&self) -> Result<Vec<u8>, bytesrepr::Error> {
303        self.0.to_bytes()
304    }
305
306    fn serialized_length(&self) -> usize {
307        self.0.serialized_length()
308    }
309}
310
311impl FromBytes for TrieRaw {
312    fn from_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), bytesrepr::Error> {
313        let (bytes, rem) = Bytes::from_bytes(bytes)?;
314        Ok((TrieRaw(bytes), rem))
315    }
316}
317
318/// Represents all possible serialization tags for a [`Trie`] enum.
319#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive, ToPrimitive)]
320#[repr(u8)]
321pub(crate) enum TrieTag {
322    /// Represents a tag for a [`Trie::Leaf`] variant.
323    Leaf = 0,
324    /// Represents a tag for a [`Trie::Node`] variant.
325    Node = 1,
326    /// Represents a tag for a [`Trie::Extension`] variant.
327    Extension = 2,
328}
329
330impl From<TrieTag> for u8 {
331    fn from(value: TrieTag) -> Self {
332        TrieTag::to_u8(&value).unwrap() // SAFETY: TrieTag is represented as u8.
333    }
334}
335
336/// Represents a Merkle Trie.
337#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
338pub enum Trie<K, V> {
339    /// Trie leaf.
340    Leaf {
341        /// Leaf key.
342        key: K,
343        /// Leaf value.
344        value: V,
345    },
346    /// Trie node.
347    Node {
348        /// Node pointer block.
349        pointer_block: Box<PointerBlock>,
350    },
351    /// Trie extension node.
352    Extension {
353        /// Extension node affix bytes.
354        affix: Bytes,
355        /// Extension node pointer.
356        pointer: Pointer,
357    },
358}
359
360impl<K, V> Display for Trie<K, V>
361where
362    K: Debug,
363    V: Debug,
364{
365    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
366        write!(f, "{:?}", self)
367    }
368}
369
370impl<K, V> Trie<K, V> {
371    fn tag(&self) -> TrieTag {
372        match self {
373            Trie::Leaf { .. } => TrieTag::Leaf,
374            Trie::Node { .. } => TrieTag::Node,
375            Trie::Extension { .. } => TrieTag::Extension,
376        }
377    }
378
379    /// Tag type for current trie element.
380    pub fn tag_type(&self) -> String {
381        match self {
382            Trie::Leaf { .. } => "Leaf".to_string(),
383            Trie::Node { .. } => "Node".to_string(),
384            Trie::Extension { .. } => "Extension".to_string(),
385        }
386    }
387
388    /// Constructs a [`Trie::Leaf`] from a given key and value.
389    pub fn leaf(key: K, value: V) -> Self {
390        Trie::Leaf { key, value }
391    }
392
393    /// Constructs a [`Trie::Node`] from a given slice of indexed pointers.
394    pub fn node(indexed_pointers: &[(u8, Pointer)]) -> Self {
395        let pointer_block = PointerBlock::from_indexed_pointers(indexed_pointers);
396        let pointer_block = Box::new(pointer_block);
397        Trie::Node { pointer_block }
398    }
399
400    /// Constructs a [`Trie::Extension`] from a given affix and pointer.
401    pub fn extension(affix: Vec<u8>, pointer: Pointer) -> Self {
402        Trie::Extension {
403            affix: affix.into(),
404            pointer,
405        }
406    }
407
408    /// Gets a reference to the root key of this Trie.
409    pub fn key(&self) -> Option<&K> {
410        match self {
411            Trie::Leaf { key, .. } => Some(key),
412            _ => None,
413        }
414    }
415
416    /// Returns the hash of this Trie.
417    pub fn trie_hash(&self) -> Result<Digest, bytesrepr::Error>
418    where
419        Self: ToBytes,
420    {
421        self.to_bytes()
422            .map(|bytes| Digest::hash_into_chunks_if_necessary(&bytes))
423    }
424
425    /// Returns bytes representation of this Trie and the hash over those bytes.
426    pub fn trie_hash_and_bytes(&self) -> Result<(Digest, Vec<u8>), bytesrepr::Error>
427    where
428        Self: ToBytes,
429    {
430        self.to_bytes()
431            .map(|bytes| (Digest::hash_into_chunks_if_necessary(&bytes), bytes))
432    }
433
434    /// Returns a pointer block, if possible.
435    pub fn as_pointer_block(&self) -> Option<&PointerBlock> {
436        if let Self::Node { pointer_block } = self {
437            Some(pointer_block.as_ref())
438        } else {
439            None
440        }
441    }
442
443    /// Returns an iterator over descendants of the trie.
444    pub fn iter_children(&self) -> DescendantsIterator {
445        match self {
446            Trie::<K, V>::Leaf { .. } => DescendantsIterator::ZeroOrOne(None),
447            Trie::Node { pointer_block } => DescendantsIterator::PointerBlock {
448                iter: pointer_block.0.iter().flatten(),
449            },
450            Trie::Extension { pointer, .. } => {
451                DescendantsIterator::ZeroOrOne(Some(pointer.into_hash()))
452            }
453        }
454    }
455}
456
457/// Bytes representation of a `Trie` that is a `Trie::Leaf` variant.
458/// The bytes for this trie leaf also include the `Trie::Tag`.
459#[derive(Debug, Clone, PartialEq)]
460pub(crate) struct TrieLeafBytes(Bytes);
461
462impl TrieLeafBytes {
463    pub(crate) fn bytes(&self) -> &Bytes {
464        &self.0
465    }
466
467    pub(crate) fn try_deserialize_leaf_key<K: FromBytes>(
468        &self,
469    ) -> Result<(K, &[u8]), bytesrepr::Error> {
470        let (tag_byte, rem) = u8::from_bytes(&self.0)?;
471        let tag = TrieTag::from_u8(tag_byte).ok_or(bytesrepr::Error::Formatting)?;
472        assert_eq!(
473            tag,
474            TrieTag::Leaf,
475            "Unexpected layout for trie leaf bytes. Expected `TrieTag::Leaf` but got {:?}",
476            tag
477        );
478        K::from_bytes(rem)
479    }
480}
481
482impl From<&[u8]> for TrieLeafBytes {
483    fn from(value: &[u8]) -> Self {
484        Self(value.into())
485    }
486}
487
488impl From<Vec<u8>> for TrieLeafBytes {
489    fn from(value: Vec<u8>) -> Self {
490        Self(value.into())
491    }
492}
493
494/// Like `Trie` but does not deserialize the leaf when constructed.
495#[derive(Debug, Clone, PartialEq)]
496pub(crate) enum LazilyDeserializedTrie {
497    /// Serialized trie leaf bytes
498    Leaf(TrieLeafBytes),
499    /// Trie node.
500    Node { pointer_block: Box<PointerBlock> },
501    /// Trie extension node.
502    Extension { affix: Bytes, pointer: Pointer },
503}
504
505impl LazilyDeserializedTrie {
506    pub(crate) fn iter_children(&self) -> DescendantsIterator {
507        match self {
508            LazilyDeserializedTrie::Leaf(_) => {
509                // Leaf bytes does not have any children
510                DescendantsIterator::ZeroOrOne(None)
511            }
512            LazilyDeserializedTrie::Node { pointer_block } => DescendantsIterator::PointerBlock {
513                iter: pointer_block.0.iter().flatten(),
514            },
515            LazilyDeserializedTrie::Extension { pointer, .. } => {
516                DescendantsIterator::ZeroOrOne(Some(pointer.into_hash()))
517            }
518        }
519    }
520}
521
522impl FromBytes for LazilyDeserializedTrie {
523    fn from_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), bytesrepr::Error> {
524        let (tag_byte, rem) = u8::from_bytes(bytes)?;
525        let tag = TrieTag::from_u8(tag_byte).ok_or(bytesrepr::Error::Formatting)?;
526        match tag {
527            TrieTag::Leaf => Ok((LazilyDeserializedTrie::Leaf(bytes.into()), &[])),
528            TrieTag::Node => {
529                let (pointer_block, rem) = PointerBlock::from_bytes(rem)?;
530                Ok((
531                    LazilyDeserializedTrie::Node {
532                        pointer_block: Box::new(pointer_block),
533                    },
534                    rem,
535                ))
536            }
537            TrieTag::Extension => {
538                let (affix, rem) = FromBytes::from_bytes(rem)?;
539                let (pointer, rem) = Pointer::from_bytes(rem)?;
540                Ok((LazilyDeserializedTrie::Extension { affix, pointer }, rem))
541            }
542        }
543    }
544}
545
546impl<K, V> TryFrom<Trie<K, V>> for LazilyDeserializedTrie
547where
548    K: ToBytes,
549    V: ToBytes,
550{
551    type Error = bytesrepr::Error;
552
553    fn try_from(value: Trie<K, V>) -> Result<Self, Self::Error> {
554        match value {
555            Trie::Leaf { .. } => {
556                let serialized_bytes = ToBytes::to_bytes(&value)?;
557                Ok(LazilyDeserializedTrie::Leaf(serialized_bytes.into()))
558            }
559            Trie::Node { pointer_block } => Ok(LazilyDeserializedTrie::Node { pointer_block }),
560            Trie::Extension { affix, pointer } => {
561                Ok(LazilyDeserializedTrie::Extension { affix, pointer })
562            }
563        }
564    }
565}
566
567/// An iterator over the descendants of a trie node.
568pub enum DescendantsIterator<'a> {
569    /// A leaf (zero descendants) or extension (one descendant) being iterated.
570    ZeroOrOne(Option<Digest>),
571    /// A pointer block being iterated.
572    PointerBlock {
573        /// An iterator over the non-None entries of the `PointerBlock`.
574        iter: Flatten<slice::Iter<'a, Option<Pointer>>>,
575    },
576}
577
578impl Iterator for DescendantsIterator<'_> {
579    type Item = Digest;
580
581    fn next(&mut self) -> Option<Self::Item> {
582        match *self {
583            DescendantsIterator::ZeroOrOne(ref mut maybe_digest) => maybe_digest.take(),
584            DescendantsIterator::PointerBlock { ref mut iter } => {
585                iter.next().map(|pointer| *pointer.hash())
586            }
587        }
588    }
589}
590
591impl<K, V> ToBytes for Trie<K, V>
592where
593    K: ToBytes,
594    V: ToBytes,
595{
596    fn to_bytes(&self) -> Result<Vec<u8>, bytesrepr::Error> {
597        let mut ret = bytesrepr::allocate_buffer(self)?;
598        self.write_bytes(&mut ret)?;
599        Ok(ret)
600    }
601
602    fn serialized_length(&self) -> usize {
603        U8_SERIALIZED_LENGTH
604            + match self {
605                Trie::Leaf { key, value } => key.serialized_length() + value.serialized_length(),
606                Trie::Node { pointer_block } => pointer_block.serialized_length(),
607                Trie::Extension { affix, pointer } => {
608                    affix.serialized_length() + pointer.serialized_length()
609                }
610            }
611    }
612
613    fn write_bytes(&self, writer: &mut Vec<u8>) -> Result<(), bytesrepr::Error> {
614        // NOTE: When changing this make sure all partial deserializers that are referencing
615        // `LazyTrieLeaf` are also updated.
616        writer.push(u8::from(self.tag()));
617        match self {
618            Trie::Leaf { key, value } => {
619                key.write_bytes(writer)?;
620                value.write_bytes(writer)?;
621            }
622            Trie::Node { pointer_block } => pointer_block.write_bytes(writer)?,
623            Trie::Extension { affix, pointer } => {
624                affix.write_bytes(writer)?;
625                pointer.write_bytes(writer)?;
626            }
627        }
628        Ok(())
629    }
630}
631
632impl<K: FromBytes, V: FromBytes> FromBytes for Trie<K, V> {
633    fn from_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), bytesrepr::Error> {
634        let (tag_byte, rem) = u8::from_bytes(bytes)?;
635        let tag = TrieTag::from_u8(tag_byte).ok_or(bytesrepr::Error::Formatting)?;
636        match tag {
637            TrieTag::Leaf => {
638                let (key, rem) = K::from_bytes(rem)?;
639                let (value, rem) = V::from_bytes(rem)?;
640                Ok((Trie::Leaf { key, value }, rem))
641            }
642            TrieTag::Node => {
643                let (pointer_block, rem) = PointerBlock::from_bytes(rem)?;
644                Ok((
645                    Trie::Node {
646                        pointer_block: Box::new(pointer_block),
647                    },
648                    rem,
649                ))
650            }
651            TrieTag::Extension => {
652                let (affix, rem) = FromBytes::from_bytes(rem)?;
653                let (pointer, rem) = Pointer::from_bytes(rem)?;
654                Ok((Trie::Extension { affix, pointer }, rem))
655            }
656        }
657    }
658}
659
660impl<K: FromBytes, V: FromBytes> TryFrom<LazilyDeserializedTrie> for Trie<K, V> {
661    type Error = bytesrepr::Error;
662
663    fn try_from(value: LazilyDeserializedTrie) -> Result<Self, Self::Error> {
664        match value {
665            LazilyDeserializedTrie::Leaf(leaf_bytes) => {
666                let (key, value_bytes) = leaf_bytes.try_deserialize_leaf_key()?;
667                let value = bytesrepr::deserialize_from_slice(value_bytes)?;
668                Ok(Self::Leaf { key, value })
669            }
670            LazilyDeserializedTrie::Node { pointer_block } => Ok(Self::Node { pointer_block }),
671            LazilyDeserializedTrie::Extension { affix, pointer } => {
672                Ok(Self::Extension { affix, pointer })
673            }
674        }
675    }
676}
677
678pub(crate) mod operations {
679    use casper_types::{
680        bytesrepr::{self, ToBytes},
681        Digest,
682    };
683
684    use crate::global_state::trie::Trie;
685
686    /// Creates a tuple containing an empty root hash and an empty root (a node
687    /// with an empty pointer block)
688    pub fn create_hashed_empty_trie<K: ToBytes, V: ToBytes>(
689    ) -> Result<(Digest, Trie<K, V>), bytesrepr::Error> {
690        let root: Trie<K, V> = Trie::Node {
691            pointer_block: Default::default(),
692        };
693        let root_bytes: Vec<u8> = root.to_bytes()?;
694        Ok((Digest::hash(root_bytes), root))
695    }
696}