certified_vars/collections/
map.rs

1use crate::collections::seq::Seq;
2use crate::label::{Label, Prefix};
3use crate::rbtree::entry::Entry;
4use crate::rbtree::iterator::RbTreeIterator;
5use crate::rbtree::RbTree;
6use crate::{AsHashTree, Hash, HashTree};
7use candid::types::{Compound, Field, Label as CLabel, Type};
8use candid::CandidType;
9use serde::de::{MapAccess, Visitor};
10use serde::ser::SerializeMap;
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12use std::borrow::Borrow;
13use std::fmt::{self, Debug, Formatter};
14use std::iter::FromIterator;
15use std::marker::PhantomData;
16
17#[derive(Default)]
18pub struct Map<K: 'static + Label, V: AsHashTree + 'static> {
19    pub(crate) inner: RbTree<K, V>,
20}
21
22impl<K: 'static + Label, V: AsHashTree + 'static> Map<K, V> {
23    #[inline]
24    pub fn new() -> Self {
25        Self {
26            inner: RbTree::new(),
27        }
28    }
29
30    /// Returns `true` if the map does not contain any values.
31    #[inline]
32    pub fn is_empty(&self) -> bool {
33        self.inner.is_empty()
34    }
35
36    /// Returns the number of elements in the map.
37    #[inline]
38    pub fn len(&self) -> usize {
39        self.inner.len()
40    }
41
42    /// Clear the map.
43    #[inline]
44    pub fn clear(&mut self) {
45        self.inner = RbTree::new();
46    }
47
48    /// Insert a key-value pair into the map. Returns [`None`] if the key did not
49    /// exists in the map, otherwise the previous value associated with the provided
50    /// key will be returned.
51    #[inline]
52    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
53        self.inner.insert(key, value).0
54    }
55
56    /// Remove the value associated with the given key from the map, returns the
57    /// previous value associated with the key.
58    #[inline]
59    pub fn remove<Q: ?Sized>(&mut self, key: &Q) -> Option<V>
60    where
61        K: Borrow<Q>,
62        Q: Ord,
63    {
64        self.inner.delete(key).map(|(_, v)| v)
65    }
66
67    /// Remove an entry from the map and return the key and value.
68    #[inline]
69    pub fn remove_entry<Q: ?Sized>(&mut self, key: &Q) -> Option<(K, V)>
70    where
71        K: Borrow<Q>,
72        Q: Ord,
73    {
74        self.inner.delete(key)
75    }
76
77    #[inline]
78    pub fn entry(&mut self, key: K) -> Entry<K, V> {
79        self.inner.entry(key)
80    }
81
82    /// Returns a mutable reference to the value corresponding to the key.
83    #[inline]
84    pub fn get_mut<Q: ?Sized>(&mut self, key: &Q) -> Option<&mut V>
85    where
86        K: Borrow<Q>,
87        Q: Ord,
88    {
89        self.inner.modify(key, |v| v)
90    }
91
92    /// Return the value associated with the given key.
93    #[inline]
94    pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V>
95    where
96        K: Borrow<Q>,
97        Q: Ord,
98    {
99        self.inner.get(key)
100    }
101
102    /// Return an iterator over the key-values in the map.
103    #[inline]
104    pub fn iter(&self) -> RbTreeIterator<K, V> {
105        RbTreeIterator::new(&self.inner)
106    }
107
108    /// Create a HashTree witness for the value associated with given key.
109    #[inline]
110    pub fn witness<Q: ?Sized>(&self, key: &Q) -> HashTree
111    where
112        K: Borrow<Q>,
113        Q: Ord,
114    {
115        self.inner.witness(key)
116    }
117
118    /// Returns a witness enumerating all the keys in this map.  The
119    /// resulting tree doesn't include values, they are replaced with
120    /// "Pruned" nodes.
121    pub fn witness_keys(&self) -> HashTree {
122        self.inner.keys()
123    }
124
125    /// Returns a witness for the key-value pairs in the specified range.
126    /// The resulting tree contains both keys and values.
127    #[inline]
128    pub fn witness_value_range<Q1: ?Sized, Q2: ?Sized>(&self, first: &K, last: &K) -> HashTree<'_>
129    where
130        K: Borrow<Q1> + Borrow<Q2>,
131        Q1: Ord,
132        Q2: Ord,
133    {
134        self.inner.value_range(first, last)
135    }
136
137    /// Returns a witness for the keys in the specified range.
138    /// The resulting tree only contains the keys, and the values are replaced with
139    /// "Pruned" nodes.
140    #[inline]
141    pub fn witness_key_range<Q1: ?Sized, Q2: ?Sized>(&self, first: &K, last: &K) -> HashTree<'_>
142    where
143        K: Borrow<Q1> + Borrow<Q2>,
144        Q1: Ord,
145        Q2: Ord,
146    {
147        self.inner.key_range(first, last)
148    }
149
150    /// Returns a witness for the keys with the given prefix, this replaces the values with
151    /// "Pruned" nodes.
152    #[inline]
153    pub fn witness_keys_with_prefix<P: ?Sized>(&self, prefix: &P) -> HashTree<'_>
154    where
155        K: Prefix<P>,
156        P: Ord,
157    {
158        self.inner.keys_with_prefix(prefix)
159    }
160
161    /// Return the underlying [`RbTree`] for this map.
162    #[inline]
163    pub fn as_tree(&self) -> &RbTree<K, V> {
164        &self.inner
165    }
166}
167
168impl<K: 'static + Label, V: AsHashTree> Map<K, Seq<V>> {
169    /// Perform a [`Seq::append`] on the seq associated with the give value, if
170    /// the seq does not exists, creates an empty one and inserts it to the map.
171    pub fn append_deep(&mut self, key: K, value: V) {
172        let mut value = Some(value);
173
174        self.inner.modify(&key, |seq| {
175            seq.append(value.take().unwrap());
176        });
177
178        if let Some(value) = value.take() {
179            let mut seq = Seq::new();
180            seq.append(value);
181            self.inner.insert(key, seq);
182        }
183    }
184
185    #[inline]
186    pub fn len_deep<Q: ?Sized>(&mut self, key: &Q) -> usize
187    where
188        K: Borrow<Q>,
189        Q: Ord,
190    {
191        self.inner.get(key).map(|seq| seq.len()).unwrap_or(0)
192    }
193}
194
195impl<K: 'static + Label, V: 'static + AsHashTree> AsRef<RbTree<K, V>> for Map<K, V> {
196    #[inline]
197    fn as_ref(&self) -> &RbTree<K, V> {
198        &self.inner
199    }
200}
201
202impl<K: 'static + Label, V: AsHashTree + 'static> Serialize for Map<K, V>
203where
204    K: Serialize,
205    V: Serialize,
206{
207    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
208    where
209        S: Serializer,
210    {
211        let mut s = serializer.serialize_map(Some(self.len()))?;
212
213        for (key, value) in self.iter() {
214            s.serialize_entry(key, value)?;
215        }
216
217        s.end()
218    }
219}
220
221impl<'de, K: 'static + Label, V: AsHashTree + 'static> Deserialize<'de> for Map<K, V>
222where
223    K: Deserialize<'de>,
224    V: Deserialize<'de>,
225{
226    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
227    where
228        D: Deserializer<'de>,
229    {
230        deserializer.deserialize_map(MapVisitor(PhantomData::default()))
231    }
232}
233
234struct MapVisitor<K, V>(PhantomData<(K, V)>);
235
236impl<'de, K: 'static + Label, V: AsHashTree + 'static> Visitor<'de> for MapVisitor<K, V>
237where
238    K: Deserialize<'de>,
239    V: Deserialize<'de>,
240{
241    type Value = Map<K, V>;
242
243    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
244        write!(formatter, "expected a map")
245    }
246
247    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
248    where
249        A: MapAccess<'de>,
250    {
251        let mut result = Map::new();
252
253        loop {
254            if let Some((key, value)) = map.next_entry::<K, V>()? {
255                result.insert(key, value);
256                continue;
257            }
258
259            break;
260        }
261
262        Ok(result)
263    }
264}
265
266impl<K: 'static + Label, V: AsHashTree + 'static> CandidType for Map<K, V>
267where
268    K: CandidType,
269    V: CandidType,
270{
271    fn _ty() -> Type {
272        let tuple = Type::Record(vec![
273            Field {
274                id: CLabel::Id(0),
275                ty: K::ty(),
276            },
277            Field {
278                id: CLabel::Id(1),
279                ty: V::ty(),
280            },
281        ]);
282        Type::Vec(Box::new(tuple))
283    }
284
285    fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
286    where
287        S: candid::types::Serializer,
288    {
289        let mut ser = serializer.serialize_vec(self.len())?;
290        for e in self.iter() {
291            Compound::serialize_element(&mut ser, &e)?;
292        }
293        Ok(())
294    }
295}
296
297impl<K: 'static + Label, V: AsHashTree + 'static> FromIterator<(K, V)> for Map<K, V> {
298    fn from_iter<I: IntoIterator<Item = (K, V)>>(iter: I) -> Self {
299        let mut result = Map::new();
300
301        for (key, value) in iter {
302            result.insert(key, value);
303        }
304
305        result
306    }
307}
308
309impl<K: 'static + Label, V: AsHashTree + 'static> Debug for Map<K, V>
310where
311    K: Debug,
312    V: Debug,
313{
314    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
315        f.debug_map().entries(self.iter()).finish()
316    }
317}
318
319impl<K: 'static + Label, V: AsHashTree + 'static> AsHashTree for Map<K, V> {
320    #[inline]
321    fn root_hash(&self) -> Hash {
322        self.inner.root_hash()
323    }
324
325    #[inline]
326    fn as_hash_tree(&self) -> HashTree<'_> {
327        self.inner.as_hash_tree()
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn insert() {
337        let mut map = Map::<String, u32>::new();
338        assert_eq!(map.insert("A".into(), 0), None);
339        assert_eq!(map.insert("A".into(), 1), Some(0));
340        assert_eq!(map.insert("B".into(), 2), None);
341        assert_eq!(map.insert("C".into(), 3), None);
342        assert_eq!(map.insert("B".into(), 4), Some(2));
343        assert_eq!(map.insert("C".into(), 5), Some(3));
344        assert_eq!(map.insert("B".into(), 6), Some(4));
345        assert_eq!(map.insert("C".into(), 7), Some(5));
346        assert_eq!(map.insert("A".into(), 8), Some(1));
347
348        assert_eq!(map.get("A"), Some(&8));
349        assert_eq!(map.get("B"), Some(&6));
350        assert_eq!(map.get("C"), Some(&7));
351        assert_eq!(map.get("D"), None);
352    }
353
354    #[test]
355    fn remove() {
356        let mut map = Map::<String, u32>::new();
357
358        for i in 0..200u32 {
359            map.insert(hex::encode(&i.to_be_bytes()), i);
360        }
361
362        for i in 0..200u32 {
363            assert_eq!(map.remove(&hex::encode(&i.to_be_bytes())), Some(i));
364        }
365
366        for i in 0..200u32 {
367            assert_eq!(map.get(&hex::encode(&i.to_be_bytes())), None);
368        }
369    }
370
371    #[test]
372    fn remove_rev() {
373        let mut map = Map::<String, u32>::new();
374
375        for i in 0..200u32 {
376            map.insert(hex::encode(&i.to_be_bytes()), i);
377        }
378
379        for i in (0..200u32).rev() {
380            assert_eq!(map.remove(&hex::encode(&i.to_be_bytes())), Some(i));
381        }
382
383        for i in 0..200u32 {
384            assert_eq!(map.get(&hex::encode(&i.to_be_bytes())), None);
385        }
386    }
387}