miden_crypto/merkle/smt/full/
leaf.rs

1use alloc::{string::ToString, vec::Vec};
2use core::cmp::Ordering;
3
4use super::{Felt, LeafIndex, Rpo256, RpoDigest, SmtLeafError, Word, EMPTY_WORD, SMT_DEPTH};
5use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
6
7#[derive(Clone, Debug, PartialEq, Eq)]
8#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
9pub enum SmtLeaf {
10    Empty(LeafIndex<SMT_DEPTH>),
11    Single((RpoDigest, Word)),
12    Multiple(Vec<(RpoDigest, Word)>),
13}
14
15impl SmtLeaf {
16    // CONSTRUCTORS
17    // ---------------------------------------------------------------------------------------------
18
19    /// Returns a new leaf with the specified entries
20    ///
21    /// # Errors
22    ///   - Returns an error if 2 keys in `entries` map to a different leaf index
23    ///   - Returns an error if 1 or more keys in `entries` map to a leaf index different from
24    ///     `leaf_index`
25    pub fn new(
26        entries: Vec<(RpoDigest, Word)>,
27        leaf_index: LeafIndex<SMT_DEPTH>,
28    ) -> Result<Self, SmtLeafError> {
29        match entries.len() {
30            0 => Ok(Self::new_empty(leaf_index)),
31            1 => {
32                let (key, value) = entries[0];
33
34                let computed_index = LeafIndex::<SMT_DEPTH>::from(key);
35                if computed_index != leaf_index {
36                    return Err(SmtLeafError::InconsistentSingleLeafIndices {
37                        key,
38                        expected_leaf_index: leaf_index,
39                        actual_leaf_index: computed_index,
40                    });
41                }
42
43                Ok(Self::new_single(key, value))
44            },
45            _ => {
46                let leaf = Self::new_multiple(entries)?;
47
48                // `new_multiple()` checked that all keys map to the same leaf index. We still need
49                // to ensure that that leaf index is `leaf_index`.
50                if leaf.index() != leaf_index {
51                    Err(SmtLeafError::InconsistentMultipleLeafIndices {
52                        leaf_index_from_keys: leaf.index(),
53                        leaf_index_supplied: leaf_index,
54                    })
55                } else {
56                    Ok(leaf)
57                }
58            },
59        }
60    }
61
62    /// Returns a new empty leaf with the specified leaf index
63    pub fn new_empty(leaf_index: LeafIndex<SMT_DEPTH>) -> Self {
64        Self::Empty(leaf_index)
65    }
66
67    /// Returns a new single leaf with the specified entry. The leaf index is derived from the
68    /// entry's key.
69    pub fn new_single(key: RpoDigest, value: Word) -> Self {
70        Self::Single((key, value))
71    }
72
73    /// Returns a new multiple leaf with the specified entries. The leaf index is derived from the
74    /// entries' keys.
75    ///
76    /// # Errors
77    ///   - Returns an error if 2 keys in `entries` map to a different leaf index
78    pub fn new_multiple(entries: Vec<(RpoDigest, Word)>) -> Result<Self, SmtLeafError> {
79        if entries.len() < 2 {
80            return Err(SmtLeafError::MultipleLeafRequiresTwoEntries(entries.len()));
81        }
82
83        // Check that all keys map to the same leaf index
84        {
85            let mut keys = entries.iter().map(|(key, _)| key);
86
87            let first_key = *keys.next().expect("ensured at least 2 entries");
88            let first_leaf_index: LeafIndex<SMT_DEPTH> = first_key.into();
89
90            for &next_key in keys {
91                let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into();
92
93                if next_leaf_index != first_leaf_index {
94                    return Err(SmtLeafError::InconsistentMultipleLeafKeys {
95                        key_1: first_key,
96                        key_2: next_key,
97                    });
98                }
99            }
100        }
101
102        Ok(Self::Multiple(entries))
103    }
104
105    // PUBLIC ACCESSORS
106    // ---------------------------------------------------------------------------------------------
107
108    /// Returns true if the leaf is empty
109    pub fn is_empty(&self) -> bool {
110        matches!(self, Self::Empty(_))
111    }
112
113    /// Returns the leaf's index in the [`super::Smt`]
114    pub fn index(&self) -> LeafIndex<SMT_DEPTH> {
115        match self {
116            SmtLeaf::Empty(leaf_index) => *leaf_index,
117            SmtLeaf::Single((key, _)) => key.into(),
118            SmtLeaf::Multiple(entries) => {
119                // Note: All keys are guaranteed to have the same leaf index
120                let (first_key, _) = entries[0];
121                first_key.into()
122            },
123        }
124    }
125
126    /// Returns the number of entries stored in the leaf
127    pub fn num_entries(&self) -> u64 {
128        match self {
129            SmtLeaf::Empty(_) => 0,
130            SmtLeaf::Single(_) => 1,
131            SmtLeaf::Multiple(entries) => {
132                entries.len().try_into().expect("shouldn't have more than 2^64 entries")
133            },
134        }
135    }
136
137    /// Computes the hash of the leaf
138    pub fn hash(&self) -> RpoDigest {
139        match self {
140            SmtLeaf::Empty(_) => EMPTY_WORD.into(),
141            SmtLeaf::Single((key, value)) => Rpo256::merge(&[*key, value.into()]),
142            SmtLeaf::Multiple(kvs) => {
143                let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect();
144                Rpo256::hash_elements(&elements)
145            },
146        }
147    }
148
149    // ITERATORS
150    // ---------------------------------------------------------------------------------------------
151
152    /// Returns the key-value pairs in the leaf
153    pub fn entries(&self) -> Vec<&(RpoDigest, Word)> {
154        match self {
155            SmtLeaf::Empty(_) => Vec::new(),
156            SmtLeaf::Single(kv_pair) => vec![kv_pair],
157            SmtLeaf::Multiple(kv_pairs) => kv_pairs.iter().collect(),
158        }
159    }
160
161    // CONVERSIONS
162    // ---------------------------------------------------------------------------------------------
163
164    /// Converts a leaf to a list of field elements
165    pub fn to_elements(&self) -> Vec<Felt> {
166        self.clone().into_elements()
167    }
168
169    /// Converts a leaf to a list of field elements
170    pub fn into_elements(self) -> Vec<Felt> {
171        self.into_entries().into_iter().flat_map(kv_to_elements).collect()
172    }
173
174    /// Converts a leaf the key-value pairs in the leaf
175    pub fn into_entries(self) -> Vec<(RpoDigest, Word)> {
176        match self {
177            SmtLeaf::Empty(_) => Vec::new(),
178            SmtLeaf::Single(kv_pair) => vec![kv_pair],
179            SmtLeaf::Multiple(kv_pairs) => kv_pairs,
180        }
181    }
182
183    // HELPERS
184    // ---------------------------------------------------------------------------------------------
185
186    /// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another
187    /// leaf.
188    pub(super) fn get_value(&self, key: &RpoDigest) -> Option<Word> {
189        // Ensure that `key` maps to this leaf
190        if self.index() != key.into() {
191            return None;
192        }
193
194        match self {
195            SmtLeaf::Empty(_) => Some(EMPTY_WORD),
196            SmtLeaf::Single((key_in_leaf, value_in_leaf)) => {
197                if key == key_in_leaf {
198                    Some(*value_in_leaf)
199                } else {
200                    Some(EMPTY_WORD)
201                }
202            },
203            SmtLeaf::Multiple(kv_pairs) => {
204                for (key_in_leaf, value_in_leaf) in kv_pairs {
205                    if key == key_in_leaf {
206                        return Some(*value_in_leaf);
207                    }
208                }
209
210                Some(EMPTY_WORD)
211            },
212        }
213    }
214
215    /// Inserts key-value pair into the leaf; returns the previous value associated with `key`, if
216    /// any.
217    ///
218    /// The caller needs to ensure that `key` has the same leaf index as all other keys in the leaf
219    pub(super) fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
220        match self {
221            SmtLeaf::Empty(_) => {
222                *self = SmtLeaf::new_single(key, value);
223                None
224            },
225            SmtLeaf::Single(kv_pair) => {
226                if kv_pair.0 == key {
227                    // the key is already in this leaf. Update the value and return the previous
228                    // value
229                    let old_value = kv_pair.1;
230                    kv_pair.1 = value;
231                    Some(old_value)
232                } else {
233                    // Another entry is present in this leaf. Transform the entry into a list
234                    // entry, and make sure the key-value pairs are sorted by key
235                    let mut pairs = vec![*kv_pair, (key, value)];
236                    pairs.sort_by(|(key_1, _), (key_2, _)| cmp_keys(*key_1, *key_2));
237
238                    *self = SmtLeaf::Multiple(pairs);
239
240                    None
241                }
242            },
243            SmtLeaf::Multiple(kv_pairs) => {
244                match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
245                    Ok(pos) => {
246                        let old_value = kv_pairs[pos].1;
247                        kv_pairs[pos].1 = value;
248
249                        Some(old_value)
250                    },
251                    Err(pos) => {
252                        kv_pairs.insert(pos, (key, value));
253
254                        None
255                    },
256                }
257            },
258        }
259    }
260
261    /// Removes key-value pair from the leaf stored at key; returns the previous value associated
262    /// with `key`, if any. Also returns an `is_empty` flag, indicating whether the leaf became
263    /// empty, and must be removed from the data structure it is contained in.
264    pub(super) fn remove(&mut self, key: RpoDigest) -> (Option<Word>, bool) {
265        match self {
266            SmtLeaf::Empty(_) => (None, false),
267            SmtLeaf::Single((key_at_leaf, value_at_leaf)) => {
268                if *key_at_leaf == key {
269                    // our key was indeed stored in the leaf, so we return the value that was stored
270                    // in it, and indicate that the leaf should be removed
271                    let old_value = *value_at_leaf;
272
273                    // Note: this is not strictly needed, since the caller is expected to drop this
274                    // `SmtLeaf` object.
275                    *self = SmtLeaf::new_empty(key.into());
276
277                    (Some(old_value), true)
278                } else {
279                    // another key is stored at leaf; nothing to update
280                    (None, false)
281                }
282            },
283            SmtLeaf::Multiple(kv_pairs) => {
284                match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
285                    Ok(pos) => {
286                        let old_value = kv_pairs[pos].1;
287
288                        kv_pairs.remove(pos);
289                        debug_assert!(!kv_pairs.is_empty());
290
291                        if kv_pairs.len() == 1 {
292                            // convert the leaf into `Single`
293                            *self = SmtLeaf::Single(kv_pairs[0]);
294                        }
295
296                        (Some(old_value), false)
297                    },
298                    Err(_) => {
299                        // other keys are stored at leaf; nothing to update
300                        (None, false)
301                    },
302                }
303            },
304        }
305    }
306}
307
308impl Serializable for SmtLeaf {
309    fn write_into<W: ByteWriter>(&self, target: &mut W) {
310        // Write: num entries
311        self.num_entries().write_into(target);
312
313        // Write: leaf index
314        let leaf_index: u64 = self.index().value();
315        leaf_index.write_into(target);
316
317        // Write: entries
318        for (key, value) in self.entries() {
319            key.write_into(target);
320            value.write_into(target);
321        }
322    }
323}
324
325impl Deserializable for SmtLeaf {
326    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
327        // Read: num entries
328        let num_entries = source.read_u64()?;
329
330        // Read: leaf index
331        let leaf_index: LeafIndex<SMT_DEPTH> = {
332            let value = source.read_u64()?;
333            LeafIndex::new_max_depth(value)
334        };
335
336        // Read: entries
337        let mut entries: Vec<(RpoDigest, Word)> = Vec::new();
338        for _ in 0..num_entries {
339            let key: RpoDigest = source.read()?;
340            let value: Word = source.read()?;
341
342            entries.push((key, value));
343        }
344
345        Self::new(entries, leaf_index)
346            .map_err(|err| DeserializationError::InvalidValue(err.to_string()))
347    }
348}
349
350// HELPER FUNCTIONS
351// ================================================================================================
352
353/// Converts a key-value tuple to an iterator of `Felt`s
354pub(crate) fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
355    let key_elements = key.into_iter();
356    let value_elements = value.into_iter();
357
358    key_elements.chain(value_elements)
359}
360
361/// Compares two keys, compared element-by-element using their integer representations starting with
362/// the most significant element.
363pub(crate) fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
364    for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
365        let v1 = v1.as_int();
366        let v2 = v2.as_int();
367        if v1 != v2 {
368            return v1.cmp(&v2);
369        }
370    }
371
372    Ordering::Equal
373}