miden_crypto/merkle/smt/full/
leaf.rs

1use alloc::{string::ToString, vec::Vec};
2use core::cmp::Ordering;
3
4use super::EMPTY_WORD;
5use crate::{
6    Felt, Word,
7    field::PrimeField64,
8    hash::rpo::Rpo256,
9    merkle::smt::{LeafIndex, MAX_LEAF_ENTRIES, SMT_DEPTH, SmtLeafError},
10    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
11};
12
13/// The number of field elements in a key-value pair (two Words, 4 Felts each).
14const DOUBLE_WORD_LEN: usize = 8;
15
16/// Represents a leaf node in the Sparse Merkle Tree.
17///
18/// A leaf can be empty, hold a single key-value pair, or multiple key-value pairs.
19#[derive(Clone, Debug, PartialEq, Eq)]
20#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
21pub enum SmtLeaf {
22    /// An empty leaf at the specified index.
23    Empty(LeafIndex<SMT_DEPTH>),
24    /// A leaf containing a single key-value pair.
25    Single((Word, Word)),
26    /// A leaf containing multiple key-value pairs.
27    Multiple(Vec<(Word, Word)>),
28}
29
30impl SmtLeaf {
31    // CONSTRUCTORS
32    // ---------------------------------------------------------------------------------------------
33
34    /// Returns a new leaf with the specified entries
35    ///
36    /// # Errors
37    ///   - Returns an error if 2 keys in `entries` map to a different leaf index
38    ///   - Returns an error if 1 or more keys in `entries` map to a leaf index different from
39    ///     `leaf_index`
40    pub fn new(
41        entries: Vec<(Word, Word)>,
42        leaf_index: LeafIndex<SMT_DEPTH>,
43    ) -> Result<Self, SmtLeafError> {
44        match entries.len() {
45            0 => Ok(Self::new_empty(leaf_index)),
46            1 => {
47                let (key, value) = entries[0];
48
49                let computed_index = LeafIndex::<SMT_DEPTH>::from(key);
50                if computed_index != leaf_index {
51                    return Err(SmtLeafError::InconsistentSingleLeafIndices {
52                        key,
53                        expected_leaf_index: leaf_index,
54                        actual_leaf_index: computed_index,
55                    });
56                }
57
58                Ok(Self::new_single(key, value))
59            },
60            _ => {
61                let leaf = Self::new_multiple(entries)?;
62
63                // `new_multiple()` checked that all keys map to the same leaf index. We still need
64                // to ensure that leaf index is `leaf_index`.
65                if leaf.index() != leaf_index {
66                    Err(SmtLeafError::InconsistentMultipleLeafIndices {
67                        leaf_index_from_keys: leaf.index(),
68                        leaf_index_supplied: leaf_index,
69                    })
70                } else {
71                    Ok(leaf)
72                }
73            },
74        }
75    }
76
77    /// Returns a new empty leaf with the specified leaf index
78    pub fn new_empty(leaf_index: LeafIndex<SMT_DEPTH>) -> Self {
79        Self::Empty(leaf_index)
80    }
81
82    /// Returns a new single leaf with the specified entry. The leaf index is derived from the
83    /// entry's key.
84    pub fn new_single(key: Word, value: Word) -> Self {
85        Self::Single((key, value))
86    }
87
88    /// Returns a new multiple leaf with the specified entries. The leaf index is derived from the
89    /// entries' keys.
90    ///
91    /// # Errors
92    ///   - Returns an error if 2 keys in `entries` map to a different leaf index
93    ///   - Returns an error if the number of entries exceeds [`MAX_LEAF_ENTRIES`]
94    pub fn new_multiple(entries: Vec<(Word, Word)>) -> Result<Self, SmtLeafError> {
95        if entries.len() < 2 {
96            return Err(SmtLeafError::MultipleLeafRequiresTwoEntries(entries.len()));
97        }
98
99        if entries.len() > MAX_LEAF_ENTRIES {
100            return Err(SmtLeafError::TooManyLeafEntries { actual: entries.len() });
101        }
102
103        // Check that all keys map to the same leaf index
104        {
105            let mut keys = entries.iter().map(|(key, _)| key);
106
107            let first_key = *keys.next().expect("ensured at least 2 entries");
108            let first_leaf_index: LeafIndex<SMT_DEPTH> = first_key.into();
109
110            for &next_key in keys {
111                let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into();
112
113                if next_leaf_index != first_leaf_index {
114                    return Err(SmtLeafError::InconsistentMultipleLeafKeys {
115                        key_1: first_key,
116                        key_2: next_key,
117                    });
118                }
119            }
120        }
121
122        Ok(Self::Multiple(entries))
123    }
124
125    // PUBLIC ACCESSORS
126    // ---------------------------------------------------------------------------------------------
127
128    /// Returns true if the leaf is empty
129    pub fn is_empty(&self) -> bool {
130        matches!(self, Self::Empty(_))
131    }
132
133    /// Returns the leaf's index in the [`super::Smt`]
134    pub fn index(&self) -> LeafIndex<SMT_DEPTH> {
135        match self {
136            SmtLeaf::Empty(leaf_index) => *leaf_index,
137            SmtLeaf::Single((key, _)) => (*key).into(),
138            SmtLeaf::Multiple(entries) => {
139                // Note: All keys are guaranteed to have the same leaf index
140                let (first_key, _) = entries[0];
141                first_key.into()
142            },
143        }
144    }
145
146    /// Returns the number of entries stored in the leaf
147    pub fn num_entries(&self) -> usize {
148        match self {
149            SmtLeaf::Empty(_) => 0,
150            SmtLeaf::Single(_) => 1,
151            SmtLeaf::Multiple(entries) => entries.len(),
152        }
153    }
154
155    /// Computes the hash of the leaf
156    pub fn hash(&self) -> Word {
157        match self {
158            SmtLeaf::Empty(_) => EMPTY_WORD,
159            SmtLeaf::Single((key, value)) => Rpo256::merge(&[*key, *value]),
160            SmtLeaf::Multiple(kvs) => {
161                let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect();
162                Rpo256::hash_elements(&elements)
163            },
164        }
165    }
166
167    // ITERATORS
168    // ---------------------------------------------------------------------------------------------
169
170    /// Returns a slice with key-value pairs in the leaf.
171    pub fn entries(&self) -> &[(Word, Word)] {
172        match self {
173            SmtLeaf::Empty(_) => &[],
174            SmtLeaf::Single(kv_pair) => core::slice::from_ref(kv_pair),
175            SmtLeaf::Multiple(kv_pairs) => kv_pairs,
176        }
177    }
178
179    // CONVERSIONS
180    // ---------------------------------------------------------------------------------------------
181
182    /// Returns an iterator over the field elements representing this leaf.
183    pub fn to_elements(&self) -> impl Iterator<Item = Felt> + '_ {
184        self.entries().iter().copied().flat_map(kv_to_elements)
185    }
186
187    /// Converts a leaf to a list of field elements.
188    pub fn into_elements(self) -> Vec<Felt> {
189        self.into_entries().into_iter().flat_map(kv_to_elements).collect()
190    }
191
192    /// Converts a leaf the key-value pairs in the leaf
193    pub fn into_entries(self) -> Vec<(Word, Word)> {
194        match self {
195            SmtLeaf::Empty(_) => Vec::new(),
196            SmtLeaf::Single(kv_pair) => vec![kv_pair],
197            SmtLeaf::Multiple(kv_pairs) => kv_pairs,
198        }
199    }
200
201    /// Converts a list of elements into a leaf
202    pub fn try_from_elements(
203        elements: &[Felt],
204        leaf_index: LeafIndex<SMT_DEPTH>,
205    ) -> Result<SmtLeaf, SmtLeafError> {
206        if elements.is_empty() {
207            return Ok(SmtLeaf::new_empty(leaf_index));
208        }
209
210        // Elements should be organized into a contiguous array of K/V Words (4 Felts each).
211        if !elements.len().is_multiple_of(DOUBLE_WORD_LEN) {
212            return Err(SmtLeafError::DecodingError(
213                "elements length is not a multiple of 8".into(),
214            ));
215        }
216
217        let num_entries = elements.len() / DOUBLE_WORD_LEN;
218
219        if num_entries == 1 {
220            // Single entry.
221            let key = Word::new([elements[0], elements[1], elements[2], elements[3]]);
222            let value = Word::new([elements[4], elements[5], elements[6], elements[7]]);
223            Ok(SmtLeaf::new_single(key, value))
224        } else {
225            // Multiple entries.
226            let mut entries = Vec::with_capacity(num_entries);
227            // Read k/v pairs from each entry.
228            for i in 0..num_entries {
229                let base_idx = i * DOUBLE_WORD_LEN;
230                let key = Word::new([
231                    elements[base_idx],
232                    elements[base_idx + 1],
233                    elements[base_idx + 2],
234                    elements[base_idx + 3],
235                ]);
236                let value = Word::new([
237                    elements[base_idx + 4],
238                    elements[base_idx + 5],
239                    elements[base_idx + 6],
240                    elements[base_idx + 7],
241                ]);
242                entries.push((key, value));
243            }
244            let leaf = SmtLeaf::new_multiple(entries)?;
245            Ok(leaf)
246        }
247    }
248
249    // HELPERS
250    // ---------------------------------------------------------------------------------------------
251
252    /// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another
253    /// leaf.
254    pub(in crate::merkle::smt) fn get_value(&self, key: &Word) -> Option<Word> {
255        // Ensure that `key` maps to this leaf
256        if self.index() != (*key).into() {
257            return None;
258        }
259
260        match self {
261            SmtLeaf::Empty(_) => Some(EMPTY_WORD),
262            SmtLeaf::Single((key_in_leaf, value_in_leaf)) => {
263                if key == key_in_leaf {
264                    Some(*value_in_leaf)
265                } else {
266                    Some(EMPTY_WORD)
267                }
268            },
269            SmtLeaf::Multiple(kv_pairs) => {
270                for (key_in_leaf, value_in_leaf) in kv_pairs {
271                    if key == key_in_leaf {
272                        return Some(*value_in_leaf);
273                    }
274                }
275
276                Some(EMPTY_WORD)
277            },
278        }
279    }
280
281    /// Inserts key-value pair into the leaf; returns the previous value associated with `key`, if
282    /// any.
283    ///
284    /// The caller needs to ensure that `key` has the same leaf index as all other keys in the leaf
285    ///
286    /// # Errors
287    /// Returns an error if inserting the key-value pair would exceed [`MAX_LEAF_ENTRIES`] (1024
288    /// entries) in the leaf.
289    pub(in crate::merkle::smt) fn insert(
290        &mut self,
291        key: Word,
292        value: Word,
293    ) -> Result<Option<Word>, SmtLeafError> {
294        match self {
295            SmtLeaf::Empty(_) => {
296                *self = SmtLeaf::new_single(key, value);
297                Ok(None)
298            },
299            SmtLeaf::Single(kv_pair) => {
300                if kv_pair.0 == key {
301                    // the key is already in this leaf. Update the value and return the previous
302                    // value
303                    let old_value = kv_pair.1;
304                    kv_pair.1 = value;
305                    Ok(Some(old_value))
306                } else {
307                    // Another entry is present in this leaf. Transform the entry into a list
308                    // entry, and make sure the key-value pairs are sorted by key
309                    // This stays within MAX_LEAF_ENTRIES limit. We're only adding one entry to a
310                    // single leaf
311                    let mut pairs = vec![*kv_pair, (key, value)];
312                    pairs.sort_by(|(key_1, _), (key_2, _)| cmp_keys(*key_1, *key_2));
313                    *self = SmtLeaf::Multiple(pairs);
314                    Ok(None)
315                }
316            },
317            SmtLeaf::Multiple(kv_pairs) => {
318                match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
319                    Ok(pos) => {
320                        let old_value = kv_pairs[pos].1;
321                        kv_pairs[pos].1 = value;
322                        Ok(Some(old_value))
323                    },
324                    Err(pos) => {
325                        if kv_pairs.len() >= MAX_LEAF_ENTRIES {
326                            return Err(SmtLeafError::TooManyLeafEntries {
327                                actual: kv_pairs.len() + 1,
328                            });
329                        }
330                        kv_pairs.insert(pos, (key, value));
331                        Ok(None)
332                    },
333                }
334            },
335        }
336    }
337
338    /// Removes key-value pair from the leaf stored at key; returns the previous value associated
339    /// with `key`, if any. Also returns an `is_empty` flag, indicating whether the leaf became
340    /// empty, and must be removed from the data structure it is contained in.
341    pub(in crate::merkle::smt) fn remove(&mut self, key: Word) -> (Option<Word>, bool) {
342        match self {
343            SmtLeaf::Empty(_) => (None, false),
344            SmtLeaf::Single((key_at_leaf, value_at_leaf)) => {
345                if *key_at_leaf == key {
346                    // our key was indeed stored in the leaf, so we return the value that was stored
347                    // in it, and indicate that the leaf should be removed
348                    let old_value = *value_at_leaf;
349
350                    // Note: this is not strictly needed, since the caller is expected to drop this
351                    // `SmtLeaf` object.
352                    *self = SmtLeaf::new_empty(key.into());
353
354                    (Some(old_value), true)
355                } else {
356                    // another key is stored at leaf; nothing to update
357                    (None, false)
358                }
359            },
360            SmtLeaf::Multiple(kv_pairs) => {
361                match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
362                    Ok(pos) => {
363                        let old_value = kv_pairs[pos].1;
364
365                        let _ = kv_pairs.remove(pos);
366                        debug_assert!(!kv_pairs.is_empty());
367
368                        if kv_pairs.len() == 1 {
369                            // convert the leaf into `Single`
370                            *self = SmtLeaf::Single(kv_pairs[0]);
371                        }
372
373                        (Some(old_value), false)
374                    },
375                    Err(_) => {
376                        // other keys are stored at leaf; nothing to update
377                        (None, false)
378                    },
379                }
380            },
381        }
382    }
383}
384
385impl Serializable for SmtLeaf {
386    fn write_into<W: ByteWriter>(&self, target: &mut W) {
387        // Write: num entries
388        self.num_entries().write_into(target);
389
390        // Write: leaf index
391        let leaf_index: u64 = self.index().value();
392        leaf_index.write_into(target);
393
394        // Write: entries
395        for (key, value) in self.entries() {
396            key.write_into(target);
397            value.write_into(target);
398        }
399    }
400}
401
402impl Deserializable for SmtLeaf {
403    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
404        // Read: num entries
405        let num_entries = source.read_usize()?;
406
407        // Read: leaf index
408        let leaf_index: LeafIndex<SMT_DEPTH> = {
409            let value = source.read_u64()?;
410            LeafIndex::new_max_depth(value)
411        };
412
413        // Read: entries
414        let mut entries: Vec<(Word, Word)> = Vec::new();
415        for _ in 0..num_entries {
416            let key: Word = source.read()?;
417            let value: Word = source.read()?;
418
419            entries.push((key, value));
420        }
421
422        Self::new(entries, leaf_index)
423            .map_err(|err| DeserializationError::InvalidValue(err.to_string()))
424    }
425}
426
427// HELPER FUNCTIONS
428// ================================================================================================
429
430/// Converts a key-value tuple to an iterator of `Felt`s
431pub(crate) fn kv_to_elements((key, value): (Word, Word)) -> impl Iterator<Item = Felt> {
432    let key_elements = key.into_iter();
433    let value_elements = value.into_iter();
434
435    key_elements.chain(value_elements)
436}
437
438/// Compares two keys, compared element-by-element using their integer representations starting with
439/// the most significant element.
440pub(crate) fn cmp_keys(key_1: Word, key_2: Word) -> Ordering {
441    for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
442        let v1 = (*v1).as_canonical_u64();
443        let v2 = (*v2).as_canonical_u64();
444        if v1 != v2 {
445            return v1.cmp(&v2);
446        }
447    }
448
449    Ordering::Equal
450}