lean_imt/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2
3#[cfg(feature = "std")]
4use std::collections::HashMap;
5
6#[cfg(not(feature = "std"))]
7use hashbrown::HashMap;
8
9#[cfg(not(feature = "std"))]
10#[macro_use]
11extern crate alloc;
12
13#[cfg(not(feature = "std"))]
14use alloc::{
15    string::{String, ToString},
16    vec::Vec,
17};
18
19pub type IMTNode = String;
20pub type IMTHashFunction = fn(Vec<IMTNode>) -> IMTNode;
21
22#[derive(Debug)]
23pub struct LeanIMT {
24    size: usize,
25    depth: usize,
26    side_nodes: HashMap<usize, IMTNode>,
27    leaves: HashMap<IMTNode, usize>,
28    hash: IMTHashFunction,
29}
30
31impl LeanIMT {
32    pub fn new(hash: IMTHashFunction) -> Self {
33        LeanIMT {
34            size: 0,
35            depth: 0,
36            side_nodes: HashMap::new(),
37            leaves: HashMap::new(),
38            hash,
39        }
40    }
41
42    /// Inserts a new leaf into the tree.
43    pub fn insert(&mut self, leaf: IMTNode) -> Result<IMTNode, &'static str> {
44        if self.leaves.contains_key(&leaf) {
45            return Err("Leaf already exists");
46        }
47        if leaf == "0" {
48            return Err("Leaf cannot be zero");
49        }
50
51        let mut index = self.size;
52        let mut tree_depth = self.depth;
53
54        // Increase tree depth if necessary
55        if (1 << tree_depth) < index + 1 {
56            tree_depth += 1;
57            self.depth = tree_depth;
58        }
59
60        let mut node = leaf.clone();
61
62        for level in 0..tree_depth {
63            if ((index >> level) & 1) == 1 {
64                // If the bit at position `level` is 1, hash with the side node
65                let side_node = self
66                    .side_nodes
67                    .get(&level)
68                    .cloned()
69                    .expect("No side node at this level");
70                node = (self.hash)(vec![side_node, node]);
71            } else {
72                // Else, store the node as side node
73                self.side_nodes.insert(level, node.clone());
74            }
75        }
76
77        index += 1;
78        self.size = index;
79
80        // Update the root node
81        self.side_nodes.insert(tree_depth, node.clone());
82        self.leaves.insert(leaf, index);
83
84        Ok(node)
85    }
86
87    /// Inserts multiple leaves into the tree.
88    pub fn insert_many(&mut self, leaves: Vec<IMTNode>) -> Result<IMTNode, &'static str> {
89        // Validate leaves
90        for leaf in &leaves {
91            if self.leaves.contains_key(leaf) {
92                return Err("Leaf already exists");
93            }
94            if leaf == "0" {
95                return Err("Leaf cannot be zero");
96            }
97        }
98
99        let mut current_level_new_nodes = leaves.clone();
100
101        let tree_size = self.size;
102        let mut tree_depth = self.depth;
103
104        // Calculate new tree depth
105        while (1 << tree_depth) < tree_size + leaves.len() {
106            tree_depth += 1;
107        }
108        self.depth = tree_depth;
109
110        let mut current_level_start_index = tree_size;
111        let mut current_level_size = tree_size + leaves.len();
112        let mut next_level_start_index = current_level_start_index >> 1;
113        let mut next_level_size = ((current_level_size - 1) >> 1) + 1;
114
115        for level in 0..tree_depth {
116            let number_of_new_nodes = next_level_size - next_level_start_index;
117            let mut next_level_new_nodes = Vec::with_capacity(number_of_new_nodes);
118
119            for i in 0..number_of_new_nodes {
120                let left_index = (i + next_level_start_index) * 2 - current_level_start_index;
121                let right_index = left_index + 1;
122
123                let left_node = if left_index < current_level_new_nodes.len() {
124                    current_level_new_nodes[left_index].clone()
125                } else {
126                    self.side_nodes
127                        .get(&level)
128                        .cloned()
129                        .unwrap_or("0".to_string())
130                };
131
132                let right_node = if right_index < current_level_new_nodes.len() {
133                    current_level_new_nodes[right_index].clone()
134                } else {
135                    "0".to_string()
136                };
137
138                let parent_node = if right_node != "0" {
139                    (self.hash)(vec![left_node.clone(), right_node])
140                } else {
141                    left_node.clone()
142                };
143
144                next_level_new_nodes.push(parent_node);
145            }
146
147            // Update side nodes
148            if current_level_size & 1 == 1 {
149                self.side_nodes
150                    .insert(level, current_level_new_nodes.last().cloned().unwrap());
151            } else if current_level_new_nodes.len() > 1 {
152                self.side_nodes.insert(
153                    level,
154                    current_level_new_nodes
155                        .get(current_level_new_nodes.len() - 2)
156                        .cloned()
157                        .unwrap(),
158                );
159            }
160
161            current_level_start_index = next_level_start_index;
162            next_level_start_index >>= 1;
163
164            current_level_new_nodes = next_level_new_nodes;
165            current_level_size = next_level_size;
166            next_level_size = ((next_level_size - 1) >> 1) + 1;
167        }
168
169        // Update tree size and root
170        self.size = tree_size + leaves.len();
171        self.side_nodes
172            .insert(tree_depth, current_level_new_nodes[0].clone());
173
174        // Update leaves mapping
175        for (i, leaf) in leaves.iter().enumerate() {
176            self.leaves.insert(leaf.clone(), tree_size + i + 1);
177        }
178
179        Ok(current_level_new_nodes[0].clone())
180    }
181
182    /// Updates an existing leaf in the tree.
183    pub fn update(
184        &mut self,
185        old_leaf: &IMTNode,
186        new_leaf: IMTNode,
187        sibling_nodes: &[IMTNode],
188    ) -> Result<IMTNode, &'static str> {
189        if !self.leaves.contains_key(old_leaf) {
190            return Err("Leaf does not exist");
191        }
192        if self.leaves.contains_key(&new_leaf) && new_leaf != "0" {
193            return Err("New leaf already exists");
194        }
195
196        let index = self.index_of(old_leaf)?;
197        let mut node = new_leaf.clone();
198        let mut old_root = old_leaf.clone();
199
200        let last_index = self.size - 1;
201        let mut i = 0;
202
203        let tree_depth = self.depth;
204
205        for level in 0..tree_depth {
206            if ((index >> level) & 1) == 1 {
207                let sibling_node = sibling_nodes
208                    .get(i)
209                    .cloned()
210                    .ok_or("Not enough sibling nodes")?;
211                node = (self.hash)(vec![sibling_node.clone(), node]);
212                old_root = (self.hash)(vec![sibling_node, old_root]);
213                i += 1;
214            } else {
215                if (index >> level) != (last_index >> level) {
216                    let sibling_node = sibling_nodes
217                        .get(i)
218                        .cloned()
219                        .ok_or("Not enough sibling nodes")?;
220                    node = (self.hash)(vec![node, sibling_node.clone()]);
221                    old_root = (self.hash)(vec![old_root, sibling_node]);
222                    i += 1;
223                } else {
224                    self.side_nodes.insert(level, node.clone());
225                }
226            }
227        }
228
229        if Some(old_root) != self.root() {
230            return Err("Wrong sibling nodes");
231        }
232
233        self.side_nodes.insert(tree_depth, node.clone());
234
235        if new_leaf != "0" {
236            let leaf_index = *self.leaves.get(old_leaf).unwrap();
237            self.leaves.insert(new_leaf.clone(), leaf_index);
238        }
239
240        self.leaves.remove(old_leaf);
241
242        Ok(node)
243    }
244
245    /// Removes a leaf from the tree.
246    pub fn remove(
247        &mut self,
248        old_leaf: &IMTNode,
249        sibling_nodes: &[IMTNode],
250    ) -> Result<IMTNode, &'static str> {
251        self.update(old_leaf, "0".to_string(), sibling_nodes)
252    }
253
254    /// Checks if a leaf exists in the tree.
255    pub fn has(&self, leaf: &IMTNode) -> bool {
256        self.leaves.contains_key(leaf)
257    }
258
259    /// Returns the index of a leaf in the tree.
260    pub fn index_of(&self, leaf: &IMTNode) -> Result<usize, &'static str> {
261        self.leaves
262            .get(leaf)
263            .map(|&index| index - 1)
264            .ok_or("Leaf does not exist")
265    }
266
267    /// Returns the root of the tree.
268    pub fn root(&self) -> Option<IMTNode> {
269        self.side_nodes.get(&self.depth).cloned()
270    }
271
272    /// Getter Functions for Debugging
273    pub fn get_size(&self) -> usize {
274        self.size
275    }
276
277    pub fn get_depth(&self) -> usize {
278        self.depth
279    }
280
281    pub fn get_side_nodes(&self) -> HashMap<usize, IMTNode> {
282        self.side_nodes.clone()
283    }
284
285    pub fn get_leaves(&self) -> HashMap<IMTNode, usize> {
286        self.leaves.clone()
287    }
288}
289
290#[cfg(all(feature = "std", test))]
291mod tests {
292    use super::*;
293
294    fn simple_hash_function(nodes: Vec<String>) -> String {
295        nodes.join(",")
296    }
297
298    #[test]
299    fn test_new_lean_imt() {
300        let hash: IMTHashFunction = simple_hash_function;
301        let imt = LeanIMT::new(hash);
302
303        assert_eq!(imt.size, 0);
304        assert_eq!(imt.depth, 0);
305        assert!(imt.root().is_none());
306    }
307
308    #[test]
309    fn test_insert() {
310        let hash: IMTHashFunction = simple_hash_function;
311        let mut imt = LeanIMT::new(hash);
312
313        assert!(imt.insert("leaf1".to_string()).is_ok());
314        assert_eq!(imt.size, 1);
315        assert_eq!(imt.depth, 0);
316        assert!(imt.has(&"leaf1".to_string()));
317        assert_eq!(imt.root().unwrap(), "leaf1".to_string());
318    }
319
320    #[test]
321    fn test_insert_() {
322        let hash: IMTHashFunction = simple_hash_function;
323        let mut imt = LeanIMT::new(hash);
324
325        assert!(imt.insert("leaf1".to_string()).is_ok());
326        assert_eq!(imt.size, 1);
327        assert_eq!(imt.depth, 0);
328        assert!(imt.has(&"leaf1".to_string()));
329        assert_eq!(imt.root().unwrap(), "leaf1".to_string());
330
331        assert!(imt.insert("leaf2".to_string()).is_ok());
332        assert_eq!(imt.size, 2);
333        assert_eq!(imt.depth, 1);
334        assert!(imt.has(&"leaf2".to_string()));
335        assert_eq!(imt.root().unwrap(), "leaf1,leaf2".to_string());
336
337        assert!(imt.insert("leaf3".to_string()).is_ok());
338        assert_eq!(imt.size, 3);
339        assert_eq!(imt.depth, 2);
340        assert!(imt.has(&"leaf3".to_string()));
341        assert_eq!(imt.root().unwrap(), "leaf1,leaf2,leaf3".to_string());
342
343        assert!(imt.insert("leaf4".to_string()).is_ok());
344        assert_eq!(imt.size, 4);
345        assert_eq!(imt.depth, 2);
346        assert!(imt.has(&"leaf4".to_string()));
347        assert_eq!(imt.root().unwrap(), "leaf1,leaf2,leaf3,leaf4".to_string());
348
349        assert!(imt.insert("leaf5".to_string()).is_ok());
350        assert_eq!(imt.size, 5);
351        assert_eq!(imt.depth, 3);
352        assert!(imt.has(&"leaf5".to_string()));
353        assert_eq!(
354            imt.root().unwrap(),
355            "leaf1,leaf2,leaf3,leaf4,leaf5".to_string()
356        );
357    }
358
359    #[test]
360    fn test_insert_many() {
361        let hash: IMTHashFunction = simple_hash_function;
362        let mut imt = LeanIMT::new(hash);
363
364        let leaves = vec![
365            "leaf1".to_string(),
366            "leaf2".to_string(),
367            "leaf3".to_string(),
368        ];
369        assert!(imt.insert_many(leaves.clone()).is_ok());
370        assert_eq!(imt.size, 3);
371        assert_eq!(imt.depth, 2);
372        for leaf in &leaves {
373            assert!(imt.has(leaf));
374        }
375        // Expected root calculation
376        let expected_root = simple_hash_function(vec![
377            simple_hash_function(vec![leaves[0].clone(), leaves[1].clone()]),
378            leaves[2].clone(),
379        ]);
380        assert_eq!(imt.root().unwrap(), expected_root);
381    }
382
383    #[test]
384    fn test_insert_duplicate_leaf() {
385        let hash: IMTHashFunction = simple_hash_function;
386        let mut imt = LeanIMT::new(hash);
387
388        imt.insert("leaf1".to_string()).unwrap();
389        let result = imt.insert("leaf1".to_string());
390        assert!(result.is_err());
391        assert_eq!(result.unwrap_err(), "Leaf already exists");
392    }
393
394    #[test]
395    fn test_insert_many_with_duplicate_leaf() {
396        let hash: IMTHashFunction = simple_hash_function;
397        let mut imt = LeanIMT::new(hash);
398
399        imt.insert("leaf1".to_string()).unwrap();
400        let leaves = vec!["leaf2".to_string(), "leaf1".to_string()];
401        let result = imt.insert_many(leaves);
402        assert!(result.is_err());
403        assert_eq!(result.unwrap_err(), "Leaf already exists");
404    }
405
406    #[test]
407    fn test_update() {
408        let hash: IMTHashFunction = simple_hash_function;
409        let mut imt = LeanIMT::new(hash);
410
411        imt.insert("leaf1".to_string()).unwrap();
412        let sibling_nodes = vec![];
413        assert!(imt
414            .update(
415                &"leaf1".to_string(),
416                "new_leaf1".to_string(),
417                &sibling_nodes
418            )
419            .is_ok());
420        assert!(imt.has(&"new_leaf1".to_string()));
421        assert!(!imt.has(&"leaf1".to_string()));
422        assert_eq!(imt.root().unwrap(), "new_leaf1".to_string());
423    }
424
425    #[test]
426    fn test_update_nonexistent_leaf() {
427        let hash: IMTHashFunction = simple_hash_function;
428        let mut imt = LeanIMT::new(hash);
429
430        let sibling_nodes = vec![];
431        let result = imt.update(
432            &"nonexistent_leaf".to_string(),
433            "new_leaf".to_string(),
434            &sibling_nodes,
435        );
436        assert!(result.is_err());
437        assert_eq!(result.unwrap_err(), "Leaf does not exist");
438    }
439
440    #[test]
441    fn test_remove() {
442        let hash: IMTHashFunction = simple_hash_function;
443        let mut imt = LeanIMT::new(hash);
444
445        imt.insert("leaf1".to_string()).unwrap();
446        let sibling_nodes = vec![];
447        assert!(imt.remove(&"leaf1".to_string(), &sibling_nodes).is_ok());
448        assert!(!imt.has(&"leaf1".to_string()));
449        assert_eq!(imt.root().unwrap(), "0".to_string());
450    }
451
452    #[test]
453    fn test_remove_nonexistent_leaf() {
454        let hash: IMTHashFunction = simple_hash_function;
455        let mut imt = LeanIMT::new(hash);
456
457        let sibling_nodes = vec![];
458        let result = imt.remove(&"nonexistent_leaf".to_string(), &sibling_nodes);
459        assert!(result.is_err());
460        assert_eq!(result.unwrap_err(), "Leaf does not exist");
461    }
462
463    #[test]
464    fn test_has_and_index_of() {
465        let hash: IMTHashFunction = simple_hash_function;
466        let mut imt = LeanIMT::new(hash);
467
468        assert!(!imt.has(&"leaf1".to_string()));
469        assert!(imt.index_of(&"leaf1".to_string()).is_err());
470
471        imt.insert("leaf1".to_string()).unwrap();
472        assert!(imt.has(&"leaf1".to_string()));
473        assert_eq!(imt.index_of(&"leaf1".to_string()).unwrap(), 0);
474    }
475
476    #[test]
477    fn test_root_after_operations() {
478        let hash: IMTHashFunction = simple_hash_function;
479        let mut imt = LeanIMT::new(hash);
480
481        // Initially empty tree
482        assert!(imt.root().is_none());
483
484        // Insert leaf1
485        imt.insert("leaf1".to_string()).unwrap();
486        let root_after_leaf1 = imt.root().unwrap();
487
488        // Insert leaf2
489        imt.insert("leaf2".to_string()).unwrap();
490        let root_after_leaf2 = imt.root().unwrap();
491        assert_ne!(root_after_leaf1, root_after_leaf2);
492
493        // Remove leaf1
494        let sibling_nodes = vec!["leaf2".to_string()];
495        imt.remove(&"leaf1".to_string(), &sibling_nodes).unwrap();
496        let root_after_removal = imt.root().unwrap();
497        assert_eq!(root_after_removal, "0,leaf2".to_string());
498
499        // Update leaf2
500        let sibling_nodes = vec!["0".to_string()];
501        imt.update(&"leaf2".to_string(), "leaf3".to_string(), &sibling_nodes)
502            .unwrap();
503        let root_after_update = imt.root().unwrap();
504        assert_eq!(root_after_update, "0,leaf3".to_string());
505    }
506
507    #[test]
508    fn test_tree_consistency() {
509        let hash: IMTHashFunction = simple_hash_function;
510        let mut imt = LeanIMT::new(hash);
511
512        // Insert leaves
513        imt.insert("leaf1".to_string()).unwrap();
514        imt.insert("leaf2".to_string()).unwrap();
515        imt.insert("leaf3".to_string()).unwrap();
516        imt.insert("leaf4".to_string()).unwrap();
517
518        // Current root
519        let root_before = imt.root().unwrap();
520
521        // Update leaf2
522        let sibling_nodes = vec![
523            "leaf1".to_string(),
524            simple_hash_function(vec!["leaf3".to_string(), "leaf4".to_string()]),
525        ];
526        imt.update(
527            &"leaf2".to_string(),
528            "leaf2_updated".to_string(),
529            &sibling_nodes,
530        )
531        .unwrap();
532
533        // New root should be different
534        let root_after = imt.root().unwrap();
535        assert_ne!(root_before, root_after);
536
537        // Remove leaf3
538        let sibling_nodes = vec![
539            "leaf4".to_string(),
540            simple_hash_function(vec!["leaf1".to_string(), "leaf2_updated".to_string()]),
541        ];
542        imt.remove(&"leaf3".to_string(), &sibling_nodes).unwrap();
543
544        // Root should change again
545        let root_after_removal = imt.root().unwrap();
546        assert_ne!(root_after, root_after_removal);
547
548        // Check that leaves are correctly updated
549        assert!(imt.has(&"leaf1".to_string()));
550        assert!(imt.has(&"leaf2_updated".to_string()));
551        assert!(!imt.has(&"leaf2".to_string()));
552        assert!(!imt.has(&"leaf3".to_string()));
553        assert!(imt.has(&"leaf4".to_string()));
554    }
555
556    #[test]
557    fn test_large_number_of_leaves() {
558        let hash: IMTHashFunction = |nodes: Vec<String>| {
559            // Simple hash function that simulates combining nodes
560            format!("H({})", nodes.join("+"))
561        };
562        let mut imt = LeanIMT::new(hash);
563
564        // Insert 100 leaves
565        let leaves: Vec<_> = (1..=100).map(|i| format!("leaf{}", i)).collect();
566        assert!(imt.insert_many(leaves.clone()).is_ok());
567        assert_eq!(imt.size, 100);
568
569        // Check that all leaves are present
570        for leaf in &leaves {
571            assert!(imt.has(leaf));
572        }
573
574        // Check that the tree depth is correct
575        let expected_depth = (100 as f64).log2().ceil() as usize;
576        assert_eq!(imt.depth, expected_depth);
577    }
578
579    #[test]
580    fn test_insertion_after_removal() {
581        let hash: IMTHashFunction = simple_hash_function;
582        let mut imt = LeanIMT::new(hash);
583
584        // Insert leaves
585        imt.insert("leaf1".to_string()).unwrap();
586        imt.insert("leaf2".to_string()).unwrap();
587
588        // Remove leaf1
589        let sibling_nodes = vec!["leaf2".to_string()];
590        imt.remove(&"leaf1".to_string(), &sibling_nodes).unwrap();
591
592        // Insert new leaf
593        assert!(imt.insert("leaf3".to_string()).is_ok());
594
595        // Check that leaves are correctly updated
596        assert!(!imt.has(&"leaf1".to_string()));
597        assert!(imt.has(&"leaf2".to_string()));
598        assert!(imt.has(&"leaf3".to_string()));
599    }
600
601    #[test]
602    fn test_tree_after_all_leaves_removed() {
603        let hash: IMTHashFunction = simple_hash_function;
604        let mut imt = LeanIMT::new(hash);
605
606        // Insert leaves
607        imt.insert("leaf1".to_string()).unwrap();
608        imt.insert("leaf2".to_string()).unwrap();
609
610        // Remove all leaves
611        let sibling_nodes = vec!["leaf2".to_string()];
612        imt.remove(&"leaf1".to_string(), &sibling_nodes).unwrap();
613
614        let sibling_nodes = vec!["0".to_string()];
615        imt.remove(&"leaf2".to_string(), &sibling_nodes).unwrap();
616
617        // Tree should be empty
618        assert_eq!(imt.size, 2);
619        assert_eq!(imt.depth, 1);
620        assert_eq!(imt.root().unwrap(), "0,0".to_string());
621        assert!(!imt.has(&"leaf1".to_string()));
622        assert!(!imt.has(&"leaf2".to_string()));
623    }
624
625    #[test]
626    fn test_insert_after_tree_becomes_empty() {
627        let hash: IMTHashFunction = simple_hash_function;
628        let mut imt = LeanIMT::new(hash);
629
630        // Insert and remove leaves
631        imt.insert("leaf1".to_string()).unwrap();
632        let sibling_nodes = vec![];
633        imt.remove(&"leaf1".to_string(), &sibling_nodes).unwrap();
634
635        // Insert new leaf
636        assert!(imt.insert("leaf2".to_string()).is_ok());
637        assert!(imt.has(&"leaf2".to_string()));
638        assert_eq!(imt.root().unwrap(), "0,leaf2".to_string());
639    }
640
641    #[test]
642    fn test_insertion_causes_depth_increase() {
643        let hash: IMTHashFunction = simple_hash_function;
644        let mut imt = LeanIMT::new(hash);
645
646        // Insert leaves to fill tree of depth 0
647        imt.insert("leaf1".to_string()).unwrap();
648        assert_eq!(imt.depth, 0);
649
650        // Insert leaves to fill tree of depth 1
651        imt.insert("leaf2".to_string()).unwrap();
652        assert_eq!(imt.depth, 1);
653
654        // Insert another leaf, depth should increase
655        imt.insert("leaf3".to_string()).unwrap();
656        assert_eq!(imt.depth, 2);
657
658        // Insert leaves to fill tree of depth 2
659        imt.insert("leaf4".to_string()).unwrap();
660        assert_eq!(imt.depth, 2);
661
662        // Insert another leaf, depth should increase
663        imt.insert("leaf5".to_string()).unwrap();
664        assert_eq!(imt.depth, 3);
665    }
666
667    #[test]
668    fn test_invalid_sibling_nodes_on_update() {
669        let hash: IMTHashFunction = simple_hash_function;
670        let mut imt = LeanIMT::new(hash);
671
672        // Insert leaves
673        imt.insert("leaf1".to_string()).unwrap();
674        imt.insert("leaf2".to_string()).unwrap();
675
676        // Try to update with incorrect sibling nodes
677        let sibling_nodes = vec!["wrong_sibling".to_string()];
678        let result = imt.update(
679            &"leaf1".to_string(),
680            "leaf1_updated".to_string(),
681            &sibling_nodes,
682        );
683        assert!(result.is_err());
684        assert_eq!(result.unwrap_err(), "Wrong sibling nodes");
685    }
686
687    #[test]
688    fn test_invalid_sibling_nodes_on_remove() {
689        let hash: IMTHashFunction = simple_hash_function;
690        let mut imt = LeanIMT::new(hash);
691
692        // Insert leaves
693        imt.insert("leaf1".to_string()).unwrap();
694        imt.insert("leaf2".to_string()).unwrap();
695
696        // Try to remove with incorrect sibling nodes
697        let sibling_nodes = vec!["wrong_sibling".to_string()];
698        let result = imt.remove(&"leaf1".to_string(), &sibling_nodes);
699        assert!(result.is_err());
700        assert_eq!(result.unwrap_err(), "Wrong sibling nodes");
701    }
702}