certified_vars/collections/
group.rs

1use crate::hashtree::HashTree::Pruned;
2use crate::hashtree::{fork_hash, labeled_hash, ForkInner};
3use crate::{AsHashTree, Hash, HashTree};
4use std::any::{Any, TypeId};
5use std::borrow::Cow;
6use std::collections::{HashMap, HashSet};
7use std::fmt::Debug;
8
9pub mod builder;
10
11type NodeId = u64;
12
13/// Group is a utility structure to make it easier to deal with multiple nested
14/// certified data in one canister.
15pub struct Group {
16    /// The root node of the group is a shadow of the shape of the group's tree.
17    root: GroupNode,
18    /// The data in this group.
19    data: HashMap<TypeId, Box<dyn GroupLeaf>>,
20    /// Map each typeId used in a Leaf node to all of its ancestors.
21    dependencies: HashMap<TypeId, Vec<NodeId>>,
22}
23
24pub struct Ray<'a> {
25    /// The group this ray belongs to.
26    group: &'a Group,
27    /// The union of all the ancestors of nodes that we're interested in.
28    to_visit: HashSet<NodeId>,
29    /// The [`HashTree`] that should be used for each leaf that we're interested
30    /// in.
31    leaves: HashMap<TypeId, HashTree<'a>>,
32}
33
34#[derive(Debug)]
35struct GroupNode {
36    id: NodeId,
37    data: GroupNodeInner,
38}
39
40#[derive(Debug)]
41enum GroupNodeInner {
42    Fork(Box<GroupNode>, Box<GroupNode>),
43    Labeled(String, Box<GroupNode>),
44    Leaf(TypeId),
45}
46
47impl Group {
48    /// Visit all the nodes recursively and assign the ID and extract the dependencies.
49    fn init(&mut self) {
50        self.dependencies.clear();
51        let mut path = Vec::with_capacity(16);
52        self.root.visit_node(0, &mut self.dependencies, &mut path);
53    }
54
55    /// Create a new witness builder that can be used to generate a [`HashTree`] for
56    /// the entire group.
57    #[must_use = "This method does not have any effects on the group."]
58    pub fn witness(&self) -> Ray {
59        Ray::new(self)
60    }
61
62    /// Returns a mutable reference to the leaf node with the given type.
63    ///
64    /// # Panics
65    ///
66    /// This method panics if the group does not contain any leaf nodes with the given
67    /// type.
68    pub fn get_mut<T: GroupLeaf>(&mut self) -> &mut T {
69        let tid = TypeId::of::<T>();
70        self.data
71            .get_mut(&tid)
72            .expect("Group does not contain the type")
73            .downcast_mut()
74            .unwrap()
75    }
76
77    /// Returns a reference to the leaf node with the given type.
78    ///
79    /// # Panics
80    ///
81    /// This method panics if the group does not contain any leaf nodes with the given
82    /// type.
83    pub fn get<T: GroupLeaf>(&self) -> &T {
84        let tid = TypeId::of::<T>();
85        self.data
86            .get(&tid)
87            .expect("Group does not contain the type")
88            .downcast_ref()
89            .unwrap()
90    }
91}
92
93impl GroupNode {
94    /// Assign the ID of this node, this will recursively update the ID of all the child nodes.
95    #[inline]
96    fn visit_node(
97        &mut self,
98        id: NodeId,
99        dependencies: &mut HashMap<TypeId, Vec<NodeId>>,
100        path: &mut Vec<NodeId>,
101    ) -> NodeId {
102        match &mut self.data {
103            GroupNodeInner::Fork(left, right) => {
104                self.id = id;
105                path.push(self.id);
106                let next_id = left.visit_node(id + 1, dependencies, path);
107                let next_id = right.visit_node(next_id, dependencies, path);
108                path.pop();
109                next_id
110            }
111            GroupNodeInner::Leaf(tid) => {
112                path.push(id);
113                dependencies.insert(*tid, path.clone());
114                path.pop();
115                self.id = id;
116                id + 1
117            }
118            GroupNodeInner::Labeled(_, node) => {
119                path.push(id);
120                let next_id = node.visit_node(id + 1, dependencies, path);
121                path.pop();
122                self.id = id;
123                next_id
124            }
125        }
126    }
127
128    fn witness<'r>(&'r self, ray: &mut Ray<'r>) -> HashTree<'r> {
129        if !ray.to_visit.contains(&self.id) {
130            return Pruned(self.root_hash(ray.group));
131        }
132
133        match &self.data {
134            GroupNodeInner::Fork(left, right) => {
135                let l_tree = left.witness(ray);
136                let r_tree = right.witness(ray);
137                HashTree::Fork(Box::new(ForkInner(l_tree, r_tree)))
138            }
139            GroupNodeInner::Labeled(label, n) => {
140                let tree = n.witness(ray);
141                HashTree::Labeled(Cow::Borrowed(label.as_bytes()), Box::new(tree))
142            }
143            GroupNodeInner::Leaf(tid) => ray.leaves.remove(tid).unwrap(),
144        }
145    }
146
147    fn witness_all<'a>(&'a self, group: &'a Group) -> HashTree<'a> {
148        match &self.data {
149            GroupNodeInner::Fork(left, right) => {
150                let l_tree = left.witness_all(group);
151                let r_tree = right.witness_all(group);
152                HashTree::Fork(Box::new(ForkInner(l_tree, r_tree)))
153            }
154            GroupNodeInner::Labeled(label, n) => {
155                let tree = n.witness_all(group);
156                HashTree::Labeled(Cow::Borrowed(label.as_bytes()), Box::new(tree))
157            }
158            GroupNodeInner::Leaf(tid) => group.data.get(tid).unwrap().as_hash_tree(),
159        }
160    }
161
162    fn root_hash(&self, group: &Group) -> Hash {
163        match &self.data {
164            GroupNodeInner::Fork(left, right) => {
165                fork_hash(&left.root_hash(group), &right.root_hash(group))
166            }
167            GroupNodeInner::Labeled(label, node) => {
168                let hash = node.root_hash(group);
169                labeled_hash(label.as_bytes(), &hash)
170            }
171            GroupNodeInner::Leaf(id) => group.data.get(id).unwrap().root_hash(),
172        }
173    }
174}
175
176impl<'a> Ray<'a> {
177    fn new(group: &'a Group) -> Self {
178        Self {
179            group,
180            to_visit: HashSet::with_capacity(16),
181            leaves: HashMap::with_capacity(8),
182        }
183    }
184
185    #[must_use = "Computing a HashTree is a compute heavy operation, with zero effects on the Group."]
186    pub fn build(mut self) -> HashTree<'a> {
187        self.group.root.witness(&mut self)
188    }
189
190    #[must_use]
191    pub fn full<T: GroupLeaf + 'static>(mut self) -> Self {
192        let tid = TypeId::of::<T>();
193
194        for dep in self.group.dependencies.get(&tid).unwrap() {
195            self.to_visit.insert(*dep);
196        }
197
198        let tree = self.group.data.get(&tid).unwrap().as_hash_tree();
199        self.leaves.insert(tid, tree);
200
201        self
202    }
203
204    #[must_use]
205    pub fn partial<T: GroupLeaf + 'static, F: FnOnce(&T) -> HashTree>(mut self, f: F) -> Self {
206        let tid = TypeId::of::<T>();
207
208        for dep in self.group.dependencies.get(&tid).unwrap() {
209            self.to_visit.insert(*dep);
210        }
211
212        let data = self.group.data.get(&tid).unwrap();
213        let tree = f(data.downcast_ref().unwrap());
214        self.leaves.insert(tid, tree);
215
216        self
217    }
218}
219
220pub trait GroupLeaf: Any + AsHashTree {}
221impl<T: Any + AsHashTree> GroupLeaf for T {}
222
223impl dyn GroupLeaf {
224    pub fn is<T: GroupLeaf>(&self) -> bool {
225        let t = TypeId::of::<T>();
226        let concrete = self.type_id();
227        t == concrete
228    }
229
230    pub fn downcast_ref<T: GroupLeaf>(&self) -> Option<&T> {
231        if self.is::<T>() {
232            unsafe { Some(&*(self as *const dyn GroupLeaf as *const T)) }
233        } else {
234            None
235        }
236    }
237
238    pub fn downcast_mut<T: GroupLeaf>(&mut self) -> Option<&mut T> {
239        if self.is::<T>() {
240            unsafe { Some(&mut *(self as *mut dyn GroupLeaf as *mut T)) }
241        } else {
242            None
243        }
244    }
245}
246
247impl AsHashTree for Group {
248    fn root_hash(&self) -> Hash {
249        self.root.root_hash(self)
250    }
251
252    fn as_hash_tree(&self) -> HashTree<'_> {
253        self.root.witness_all(self)
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::builder::GroupBuilder;
260    use super::*;
261    use crate::Map;
262    use candid::Principal;
263
264    #[test]
265    fn dynamic_box() {
266        let mut map = Map::<String, i8>::new();
267        map.insert("X".to_string(), 17);
268        let hash = map.root_hash();
269        // Now put it in a dynamic box.
270        let data: Box<dyn GroupLeaf> = Box::new(map);
271        let as_map = data.downcast_ref::<Map<String, i8>>().unwrap();
272        assert_eq!(hash, as_map.root_hash());
273    }
274
275    #[test]
276    fn ray() {
277        type S2S = Map<String, String>;
278        let mut map = S2S::new();
279        map.insert("X".to_string(), "x".to_string());
280        map.insert("Y".to_string(), "y".to_string());
281
282        let mut group = Group {
283            root: GroupNode {
284                id: 0,
285                data: GroupNodeInner::Fork(
286                    Box::new(GroupNode {
287                        id: 0,
288                        data: GroupNodeInner::Labeled(
289                            "A".into(),
290                            Box::new(GroupNode {
291                                id: 0,
292                                data: GroupNodeInner::Leaf(TypeId::of::<S2S>()),
293                            }),
294                        ),
295                    }),
296                    Box::new(GroupNode {
297                        id: 0,
298                        data: GroupNodeInner::Leaf(TypeId::of::<String>()),
299                    }),
300                ),
301            },
302            data: Default::default(),
303            dependencies: Default::default(),
304        };
305
306        group.data.insert(TypeId::of::<String>(), Box::new("Cap"));
307        group.data.insert(TypeId::of::<S2S>(), Box::new(map));
308        group.init();
309
310        let t1 = group.witness().build();
311        let t2 = group.witness().full::<String>().build();
312        let t3 = group.witness().full::<S2S>().build();
313        let t4 = group
314            .witness()
315            .partial(|map: &S2S| map.witness("X"))
316            .build();
317
318        assert_eq!(t1.reconstruct(), t2.reconstruct());
319        assert_eq!(t1.reconstruct(), t3.reconstruct());
320        assert_eq!(t1.reconstruct(), t4.reconstruct());
321
322        assert_eq!(t1.get_labels(), Vec::<&[u8]>::new());
323        assert_eq!(t2.get_labels(), Vec::<&[u8]>::new());
324        assert_eq!(t3.get_labels(), vec![b"A", b"X", b"Y"]);
325        assert_eq!(t4.get_labels(), vec![b"A", b"X"]);
326
327        assert_eq!(t1.get_leaf_values(), Vec::<&[u8]>::new());
328        assert_eq!(t2.get_leaf_values(), vec![b"Cap"]);
329        assert_eq!(t3.get_leaf_values(), vec![b"x", b"y"]);
330        assert_eq!(t4.get_leaf_values(), vec![b"x"]);
331    }
332
333    #[test]
334    fn builder() {
335        type Ledger = Map<Principal, u64>;
336        struct Name(String);
337        struct Owner(String);
338        struct Url(String);
339
340        impl AsHashTree for Name {
341            fn as_hash_tree(&self) -> HashTree<'_> {
342                self.0.as_hash_tree()
343            }
344        }
345
346        impl AsHashTree for Owner {
347            fn as_hash_tree(&self) -> HashTree<'_> {
348                self.0.as_hash_tree()
349            }
350        }
351
352        impl AsHashTree for Url {
353            fn as_hash_tree(&self) -> HashTree<'_> {
354                self.0.as_hash_tree()
355            }
356        }
357
358        let mut group = GroupBuilder::new()
359            .insert(["ledger"], Ledger::new())
360            .insert(["meta", "name"], Name("XTC".to_string()))
361            .insert(["meta", "owner"], Owner("Psychedelic".to_string()))
362            .insert(["canister", "url"], Url("https://github.com/x".to_string()))
363            .build();
364
365        {
366            let ledger = group.get_mut::<Ledger>();
367            ledger.insert(Principal::from_slice(&[65]), 100);
368        }
369
370        let t1 = group.witness().full::<Ledger>().build();
371        let t2 = group.witness().full::<Name>().build();
372        let t3 = group.witness().full::<Owner>().build();
373        let t4 = group.witness().full::<Name>().full::<Owner>().build();
374        let t5 = group.witness().full::<Name>().full::<Url>().build();
375
376        assert_eq!(group.root_hash(), t1.reconstruct());
377        assert_eq!(t1.reconstruct(), t2.reconstruct());
378        assert_eq!(t2.reconstruct(), t3.reconstruct());
379        assert_eq!(t3.reconstruct(), t4.reconstruct());
380        assert_eq!(t4.reconstruct(), t5.reconstruct());
381
382        assert_eq!(t1.get_labels(), vec![b"ledger" as &[u8], b"A"]);
383        assert_eq!(t2.get_labels(), vec![b"meta" as &[u8], b"name"]);
384        assert_eq!(t3.get_labels(), vec![b"meta" as &[u8], b"owner"]);
385        assert_eq!(t4.get_labels(), vec![b"meta" as &[u8], b"name", b"owner"]);
386        assert_eq!(
387            t5.get_labels(),
388            vec![b"canister" as &[u8], b"url", b"meta", b"name"]
389        );
390    }
391}