certified_vars/collections/
map.rs1use crate::collections::seq::Seq;
2use crate::label::{Label, Prefix};
3use crate::rbtree::entry::Entry;
4use crate::rbtree::iterator::RbTreeIterator;
5use crate::rbtree::RbTree;
6use crate::{AsHashTree, Hash, HashTree};
7use candid::types::{Compound, Field, Label as CLabel, Type};
8use candid::CandidType;
9use serde::de::{MapAccess, Visitor};
10use serde::ser::SerializeMap;
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12use std::borrow::Borrow;
13use std::fmt::{self, Debug, Formatter};
14use std::iter::FromIterator;
15use std::marker::PhantomData;
16
17#[derive(Default)]
18pub struct Map<K: 'static + Label, V: AsHashTree + 'static> {
19 pub(crate) inner: RbTree<K, V>,
20}
21
22impl<K: 'static + Label, V: AsHashTree + 'static> Map<K, V> {
23 #[inline]
24 pub fn new() -> Self {
25 Self {
26 inner: RbTree::new(),
27 }
28 }
29
30 #[inline]
32 pub fn is_empty(&self) -> bool {
33 self.inner.is_empty()
34 }
35
36 #[inline]
38 pub fn len(&self) -> usize {
39 self.inner.len()
40 }
41
42 #[inline]
44 pub fn clear(&mut self) {
45 self.inner = RbTree::new();
46 }
47
48 #[inline]
52 pub fn insert(&mut self, key: K, value: V) -> Option<V> {
53 self.inner.insert(key, value).0
54 }
55
56 #[inline]
59 pub fn remove<Q: ?Sized>(&mut self, key: &Q) -> Option<V>
60 where
61 K: Borrow<Q>,
62 Q: Ord,
63 {
64 self.inner.delete(key).map(|(_, v)| v)
65 }
66
67 #[inline]
69 pub fn remove_entry<Q: ?Sized>(&mut self, key: &Q) -> Option<(K, V)>
70 where
71 K: Borrow<Q>,
72 Q: Ord,
73 {
74 self.inner.delete(key)
75 }
76
77 #[inline]
78 pub fn entry(&mut self, key: K) -> Entry<K, V> {
79 self.inner.entry(key)
80 }
81
82 #[inline]
84 pub fn get_mut<Q: ?Sized>(&mut self, key: &Q) -> Option<&mut V>
85 where
86 K: Borrow<Q>,
87 Q: Ord,
88 {
89 self.inner.modify(key, |v| v)
90 }
91
92 #[inline]
94 pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V>
95 where
96 K: Borrow<Q>,
97 Q: Ord,
98 {
99 self.inner.get(key)
100 }
101
102 #[inline]
104 pub fn iter(&self) -> RbTreeIterator<K, V> {
105 RbTreeIterator::new(&self.inner)
106 }
107
108 #[inline]
110 pub fn witness<Q: ?Sized>(&self, key: &Q) -> HashTree
111 where
112 K: Borrow<Q>,
113 Q: Ord,
114 {
115 self.inner.witness(key)
116 }
117
118 pub fn witness_keys(&self) -> HashTree {
122 self.inner.keys()
123 }
124
125 #[inline]
128 pub fn witness_value_range<Q1: ?Sized, Q2: ?Sized>(&self, first: &K, last: &K) -> HashTree<'_>
129 where
130 K: Borrow<Q1> + Borrow<Q2>,
131 Q1: Ord,
132 Q2: Ord,
133 {
134 self.inner.value_range(first, last)
135 }
136
137 #[inline]
141 pub fn witness_key_range<Q1: ?Sized, Q2: ?Sized>(&self, first: &K, last: &K) -> HashTree<'_>
142 where
143 K: Borrow<Q1> + Borrow<Q2>,
144 Q1: Ord,
145 Q2: Ord,
146 {
147 self.inner.key_range(first, last)
148 }
149
150 #[inline]
153 pub fn witness_keys_with_prefix<P: ?Sized>(&self, prefix: &P) -> HashTree<'_>
154 where
155 K: Prefix<P>,
156 P: Ord,
157 {
158 self.inner.keys_with_prefix(prefix)
159 }
160
161 #[inline]
163 pub fn as_tree(&self) -> &RbTree<K, V> {
164 &self.inner
165 }
166}
167
168impl<K: 'static + Label, V: AsHashTree> Map<K, Seq<V>> {
169 pub fn append_deep(&mut self, key: K, value: V) {
172 let mut value = Some(value);
173
174 self.inner.modify(&key, |seq| {
175 seq.append(value.take().unwrap());
176 });
177
178 if let Some(value) = value.take() {
179 let mut seq = Seq::new();
180 seq.append(value);
181 self.inner.insert(key, seq);
182 }
183 }
184
185 #[inline]
186 pub fn len_deep<Q: ?Sized>(&mut self, key: &Q) -> usize
187 where
188 K: Borrow<Q>,
189 Q: Ord,
190 {
191 self.inner.get(key).map(|seq| seq.len()).unwrap_or(0)
192 }
193}
194
195impl<K: 'static + Label, V: 'static + AsHashTree> AsRef<RbTree<K, V>> for Map<K, V> {
196 #[inline]
197 fn as_ref(&self) -> &RbTree<K, V> {
198 &self.inner
199 }
200}
201
202impl<K: 'static + Label, V: AsHashTree + 'static> Serialize for Map<K, V>
203where
204 K: Serialize,
205 V: Serialize,
206{
207 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
208 where
209 S: Serializer,
210 {
211 let mut s = serializer.serialize_map(Some(self.len()))?;
212
213 for (key, value) in self.iter() {
214 s.serialize_entry(key, value)?;
215 }
216
217 s.end()
218 }
219}
220
221impl<'de, K: 'static + Label, V: AsHashTree + 'static> Deserialize<'de> for Map<K, V>
222where
223 K: Deserialize<'de>,
224 V: Deserialize<'de>,
225{
226 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
227 where
228 D: Deserializer<'de>,
229 {
230 deserializer.deserialize_map(MapVisitor(PhantomData::default()))
231 }
232}
233
234struct MapVisitor<K, V>(PhantomData<(K, V)>);
235
236impl<'de, K: 'static + Label, V: AsHashTree + 'static> Visitor<'de> for MapVisitor<K, V>
237where
238 K: Deserialize<'de>,
239 V: Deserialize<'de>,
240{
241 type Value = Map<K, V>;
242
243 fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
244 write!(formatter, "expected a map")
245 }
246
247 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
248 where
249 A: MapAccess<'de>,
250 {
251 let mut result = Map::new();
252
253 loop {
254 if let Some((key, value)) = map.next_entry::<K, V>()? {
255 result.insert(key, value);
256 continue;
257 }
258
259 break;
260 }
261
262 Ok(result)
263 }
264}
265
266impl<K: 'static + Label, V: AsHashTree + 'static> CandidType for Map<K, V>
267where
268 K: CandidType,
269 V: CandidType,
270{
271 fn _ty() -> Type {
272 let tuple = Type::Record(vec![
273 Field {
274 id: CLabel::Id(0),
275 ty: K::ty(),
276 },
277 Field {
278 id: CLabel::Id(1),
279 ty: V::ty(),
280 },
281 ]);
282 Type::Vec(Box::new(tuple))
283 }
284
285 fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
286 where
287 S: candid::types::Serializer,
288 {
289 let mut ser = serializer.serialize_vec(self.len())?;
290 for e in self.iter() {
291 Compound::serialize_element(&mut ser, &e)?;
292 }
293 Ok(())
294 }
295}
296
297impl<K: 'static + Label, V: AsHashTree + 'static> FromIterator<(K, V)> for Map<K, V> {
298 fn from_iter<I: IntoIterator<Item = (K, V)>>(iter: I) -> Self {
299 let mut result = Map::new();
300
301 for (key, value) in iter {
302 result.insert(key, value);
303 }
304
305 result
306 }
307}
308
309impl<K: 'static + Label, V: AsHashTree + 'static> Debug for Map<K, V>
310where
311 K: Debug,
312 V: Debug,
313{
314 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
315 f.debug_map().entries(self.iter()).finish()
316 }
317}
318
319impl<K: 'static + Label, V: AsHashTree + 'static> AsHashTree for Map<K, V> {
320 #[inline]
321 fn root_hash(&self) -> Hash {
322 self.inner.root_hash()
323 }
324
325 #[inline]
326 fn as_hash_tree(&self) -> HashTree<'_> {
327 self.inner.as_hash_tree()
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn insert() {
337 let mut map = Map::<String, u32>::new();
338 assert_eq!(map.insert("A".into(), 0), None);
339 assert_eq!(map.insert("A".into(), 1), Some(0));
340 assert_eq!(map.insert("B".into(), 2), None);
341 assert_eq!(map.insert("C".into(), 3), None);
342 assert_eq!(map.insert("B".into(), 4), Some(2));
343 assert_eq!(map.insert("C".into(), 5), Some(3));
344 assert_eq!(map.insert("B".into(), 6), Some(4));
345 assert_eq!(map.insert("C".into(), 7), Some(5));
346 assert_eq!(map.insert("A".into(), 8), Some(1));
347
348 assert_eq!(map.get("A"), Some(&8));
349 assert_eq!(map.get("B"), Some(&6));
350 assert_eq!(map.get("C"), Some(&7));
351 assert_eq!(map.get("D"), None);
352 }
353
354 #[test]
355 fn remove() {
356 let mut map = Map::<String, u32>::new();
357
358 for i in 0..200u32 {
359 map.insert(hex::encode(&i.to_be_bytes()), i);
360 }
361
362 for i in 0..200u32 {
363 assert_eq!(map.remove(&hex::encode(&i.to_be_bytes())), Some(i));
364 }
365
366 for i in 0..200u32 {
367 assert_eq!(map.get(&hex::encode(&i.to_be_bytes())), None);
368 }
369 }
370
371 #[test]
372 fn remove_rev() {
373 let mut map = Map::<String, u32>::new();
374
375 for i in 0..200u32 {
376 map.insert(hex::encode(&i.to_be_bytes()), i);
377 }
378
379 for i in (0..200u32).rev() {
380 assert_eq!(map.remove(&hex::encode(&i.to_be_bytes())), Some(i));
381 }
382
383 for i in 0..200u32 {
384 assert_eq!(map.get(&hex::encode(&i.to_be_bytes())), None);
385 }
386 }
387}