diem_types/nibble/nibble_path/
mod.rs

1// Copyright (c) The Diem Core Contributors
2// SPDX-License-Identifier: Apache-2.0
3
4//! NibblePath library simplify operations with nibbles in a compact format for modified sparse
5//! Merkle tree by providing powerful iterators advancing by either bit or nibble.
6
7#[cfg(test)]
8mod nibble_path_test;
9
10use crate::nibble::{Nibble, ROOT_NIBBLE_HEIGHT};
11use mirai_annotations::*;
12#[cfg(any(test, feature = "fuzzing"))]
13use proptest::{collection::vec, prelude::*};
14use serde::{Deserialize, Serialize};
15use std::{fmt, iter::FromIterator};
16
17/// NibblePath defines a path in Merkle tree in the unit of nibble (4 bits).
18#[derive(Clone, Hash, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
19pub struct NibblePath {
20    /// Indicates the total number of nibbles in bytes. Either `bytes.len() * 2 - 1` or
21    /// `bytes.len() * 2`.
22    // Guarantees intended ordering based on the top-to-bottom declaration order of the struct's
23    // members.
24    num_nibbles: usize,
25    /// The underlying bytes that stores the path, 2 nibbles per byte. If the number of nibbles is
26    /// odd, the second half of the last byte must be 0.
27    bytes: Vec<u8>,
28    // invariant num_nibbles <= ROOT_NIBBLE_HEIGHT
29}
30
31/// Supports debug format by concatenating nibbles literally. For example, [0x12, 0xa0] with 3
32/// nibbles will be printed as "12a".
33impl fmt::Debug for NibblePath {
34    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35        self.nibbles().try_for_each(|x| write!(f, "{:x}", x))
36    }
37}
38
39/// Convert a vector of bytes into `NibblePath` using the lower 4 bits of each byte as nibble.
40impl FromIterator<Nibble> for NibblePath {
41    fn from_iter<I: IntoIterator<Item = Nibble>>(iter: I) -> Self {
42        let mut nibble_path = NibblePath::new(vec![]);
43        for nibble in iter {
44            nibble_path.push(nibble);
45        }
46        nibble_path
47    }
48}
49
50#[cfg(any(test, feature = "fuzzing"))]
51impl Arbitrary for NibblePath {
52    type Parameters = ();
53    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
54        arb_nibble_path().boxed()
55    }
56    type Strategy = BoxedStrategy<Self>;
57}
58
59#[cfg(any(test, feature = "fuzzing"))]
60prop_compose! {
61    fn arb_nibble_path()(
62        mut bytes in vec(any::<u8>(), 0..=ROOT_NIBBLE_HEIGHT/2),
63        is_odd in any::<bool>()
64    ) -> NibblePath {
65        if let Some(last_byte) = bytes.last_mut() {
66            if is_odd {
67                *last_byte &= 0xf0;
68                return NibblePath::new_odd(bytes);
69            }
70        }
71        NibblePath::new(bytes)
72    }
73}
74
75#[cfg(any(test, feature = "fuzzing"))]
76prop_compose! {
77    fn arb_internal_nibble_path()(
78        nibble_path in arb_nibble_path().prop_filter(
79            "Filter out leaf paths.",
80            |p| p.num_nibbles() < ROOT_NIBBLE_HEIGHT,
81        )
82    ) -> NibblePath {
83        nibble_path
84    }
85}
86
87impl NibblePath {
88    /// Creates a new `NibblePath` from a vector of bytes assuming each byte has 2 nibbles.
89    pub fn new(bytes: Vec<u8>) -> Self {
90        checked_precondition!(bytes.len() <= ROOT_NIBBLE_HEIGHT / 2);
91        let num_nibbles = bytes.len() * 2;
92        NibblePath { num_nibbles, bytes }
93    }
94
95    /// Similar to `new()` but assumes that the bytes have one less nibble.
96    pub fn new_odd(bytes: Vec<u8>) -> Self {
97        checked_precondition!(bytes.len() <= ROOT_NIBBLE_HEIGHT / 2);
98        assert_eq!(
99            bytes.last().expect("Should have odd number of nibbles.") & 0x0f,
100            0,
101            "Last nibble must be 0."
102        );
103        let num_nibbles = bytes.len() * 2 - 1;
104        NibblePath { num_nibbles, bytes }
105    }
106
107    /// Adds a nibble to the end of the nibble path.
108    pub fn push(&mut self, nibble: Nibble) {
109        assert!(ROOT_NIBBLE_HEIGHT > self.num_nibbles);
110        if self.num_nibbles % 2 == 0 {
111            self.bytes.push(u8::from(nibble) << 4);
112        } else {
113            self.bytes[self.num_nibbles / 2] |= u8::from(nibble);
114        }
115        self.num_nibbles += 1;
116    }
117
118    /// Pops a nibble from the end of the nibble path.
119    pub fn pop(&mut self) -> Option<Nibble> {
120        let poped_nibble = if self.num_nibbles % 2 == 0 {
121            self.bytes.last_mut().map(|last_byte| {
122                let nibble = *last_byte & 0x0f;
123                *last_byte &= 0xf0;
124                Nibble::from(nibble)
125            })
126        } else {
127            self.bytes.pop().map(|byte| Nibble::from(byte >> 4))
128        };
129        if poped_nibble.is_some() {
130            self.num_nibbles -= 1;
131        }
132        poped_nibble
133    }
134
135    /// Returns the last nibble.
136    pub fn last(&self) -> Option<Nibble> {
137        let last_byte_option = self.bytes.last();
138        if self.num_nibbles % 2 == 0 {
139            last_byte_option.map(|last_byte| Nibble::from(*last_byte & 0x0f))
140        } else {
141            let last_byte = last_byte_option.expect("Last byte must exist if num_nibbles is odd.");
142            Some(Nibble::from(*last_byte >> 4))
143        }
144    }
145
146    /// Get the i-th bit.
147    fn get_bit(&self, i: usize) -> bool {
148        assert!(i < self.num_nibbles * 4);
149        let pos = i / 8;
150        let bit = 7 - i % 8;
151        ((self.bytes[pos] >> bit) & 1) != 0
152    }
153
154    /// Get the i-th nibble.
155    pub fn get_nibble(&self, i: usize) -> Nibble {
156        assert!(i < self.num_nibbles);
157        Nibble::from((self.bytes[i / 2] >> (if i % 2 == 1 { 0 } else { 4 })) & 0xf)
158    }
159
160    /// Get a bit iterator iterates over the whole nibble path.
161    pub fn bits(&self) -> BitIterator {
162        assume!(self.num_nibbles <= ROOT_NIBBLE_HEIGHT); // invariant
163        BitIterator {
164            nibble_path: self,
165            pos: (0..self.num_nibbles * 4),
166        }
167    }
168
169    /// Get a nibble iterator iterates over the whole nibble path.
170    pub fn nibbles(&self) -> NibbleIterator {
171        assume!(self.num_nibbles <= ROOT_NIBBLE_HEIGHT); // invariant
172        NibbleIterator::new(self, 0, self.num_nibbles)
173    }
174
175    /// Get the total number of nibbles stored.
176    pub fn num_nibbles(&self) -> usize {
177        self.num_nibbles
178    }
179
180    ///  Returns `true` if the nibbles contains no elements.
181    pub fn is_empty(&self) -> bool {
182        self.num_nibbles() == 0
183    }
184
185    /// Get the underlying bytes storing nibbles.
186    pub fn bytes(&self) -> &[u8] {
187        &self.bytes
188    }
189}
190
191pub trait Peekable: Iterator {
192    /// Returns the `next()` value without advancing the iterator.
193    fn peek(&self) -> Option<Self::Item>;
194}
195
196/// BitIterator iterates a nibble path by bit.
197pub struct BitIterator<'a> {
198    nibble_path: &'a NibblePath,
199    pos: std::ops::Range<usize>,
200}
201
202impl<'a> Peekable for BitIterator<'a> {
203    /// Returns the `next()` value without advancing the iterator.
204    fn peek(&self) -> Option<Self::Item> {
205        if self.pos.start < self.pos.end {
206            Some(self.nibble_path.get_bit(self.pos.start))
207        } else {
208            None
209        }
210    }
211}
212
213/// BitIterator spits out a boolean each time. True/false denotes 1/0.
214impl<'a> Iterator for BitIterator<'a> {
215    type Item = bool;
216    fn next(&mut self) -> Option<Self::Item> {
217        self.pos.next().map(|i| self.nibble_path.get_bit(i))
218    }
219}
220
221/// Support iterating bits in reversed order.
222impl<'a> DoubleEndedIterator for BitIterator<'a> {
223    fn next_back(&mut self) -> Option<Self::Item> {
224        self.pos.next_back().map(|i| self.nibble_path.get_bit(i))
225    }
226}
227
228/// NibbleIterator iterates a nibble path by nibble.
229#[derive(Debug)]
230pub struct NibbleIterator<'a> {
231    /// The underlying nibble path that stores the nibbles
232    nibble_path: &'a NibblePath,
233
234    /// The current index, `pos.start`, will bump by 1 after calling `next()` until `pos.start ==
235    /// pos.end`.
236    pos: std::ops::Range<usize>,
237
238    /// The start index of the iterator. At the beginning, `pos.start == start`. [start, pos.end)
239    /// defines the range of `nibble_path` this iterator iterates over. `nibble_path` refers to
240    /// the entire underlying buffer but the range may only be partial.
241    start: usize,
242    // invariant self.start <= self.pos.start;
243    // invariant self.pos.start <= self.pos.end;
244    // invariant self.pos.end <= ROOT_NIBBLE_HEIGHT;
245}
246
247/// NibbleIterator spits out a byte each time. Each byte must be in range [0, 16).
248impl<'a> Iterator for NibbleIterator<'a> {
249    type Item = Nibble;
250    fn next(&mut self) -> Option<Self::Item> {
251        self.pos.next().map(|i| self.nibble_path.get_nibble(i))
252    }
253}
254
255impl<'a> Peekable for NibbleIterator<'a> {
256    /// Returns the `next()` value without advancing the iterator.
257    fn peek(&self) -> Option<Self::Item> {
258        if self.pos.start < self.pos.end {
259            Some(self.nibble_path.get_nibble(self.pos.start))
260        } else {
261            None
262        }
263    }
264}
265
266impl<'a> NibbleIterator<'a> {
267    fn new(nibble_path: &'a NibblePath, start: usize, end: usize) -> Self {
268        precondition!(start <= end);
269        precondition!(start <= ROOT_NIBBLE_HEIGHT);
270        precondition!(end <= ROOT_NIBBLE_HEIGHT);
271        Self {
272            nibble_path,
273            pos: (start..end),
274            start,
275        }
276    }
277
278    /// Returns a nibble iterator that iterates all visited nibbles.
279    pub fn visited_nibbles(&self) -> NibbleIterator<'a> {
280        assume!(self.start <= self.pos.start); // invariant
281        assume!(self.pos.start <= ROOT_NIBBLE_HEIGHT); // invariant
282        Self::new(self.nibble_path, self.start, self.pos.start)
283    }
284
285    /// Returns a nibble iterator that iterates all remaining nibbles.
286    pub fn remaining_nibbles(&self) -> NibbleIterator<'a> {
287        assume!(self.pos.start <= self.pos.end); // invariant
288        assume!(self.pos.end <= ROOT_NIBBLE_HEIGHT); // invariant
289        Self::new(self.nibble_path, self.pos.start, self.pos.end)
290    }
291
292    /// Turn it into a `BitIterator`.
293    pub fn bits(&self) -> BitIterator<'a> {
294        assume!(self.pos.start <= self.pos.end); // invariant
295        assume!(self.pos.end <= ROOT_NIBBLE_HEIGHT); // invariant
296        BitIterator {
297            nibble_path: self.nibble_path,
298            pos: (self.pos.start * 4..self.pos.end * 4),
299        }
300    }
301
302    /// Cut and return the range of the underlying `nibble_path` that this iterator is iterating
303    /// over as a new `NibblePath`
304    pub fn get_nibble_path(&self) -> NibblePath {
305        self.visited_nibbles()
306            .chain(self.remaining_nibbles())
307            .collect()
308    }
309
310    /// Get the number of nibbles that this iterator covers.
311    pub fn num_nibbles(&self) -> usize {
312        assume!(self.start <= self.pos.end); // invariant
313        self.pos.end - self.start
314    }
315
316    /// Return `true` if the iteration is over.
317    pub fn is_finished(&self) -> bool {
318        self.peek().is_none()
319    }
320}
321
322/// Advance both iterators if their next nibbles are the same until either reaches the end or
323/// the find a mismatch. Return the number of matched nibbles.
324pub fn skip_common_prefix<I1, I2>(x: &mut I1, y: &mut I2) -> usize
325where
326    I1: Iterator + Peekable,
327    I2: Iterator + Peekable,
328    <I1 as Iterator>::Item: std::cmp::PartialEq<<I2 as Iterator>::Item>,
329{
330    let mut count = 0;
331    loop {
332        let x_peek = x.peek();
333        let y_peek = y.peek();
334        if x_peek.is_none()
335            || y_peek.is_none()
336            || x_peek.expect("cannot be none") != y_peek.expect("cannot be none")
337        {
338            break;
339        }
340        count += 1;
341        x.next();
342        y.next();
343    }
344    count
345}