lean_imt/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2
3#[cfg(feature = "std")]
4use std::collections::HashMap;
5
6#[cfg(not(feature = "std"))]
7use hashbrown::HashMap;
8
9#[cfg(not(feature = "std"))]
10#[macro_use]
11extern crate alloc;
12
13#[cfg(not(feature = "std"))]
14use alloc::{
15    string::{String, ToString},
16    vec::Vec,
17};
18
19pub type IMTNode = String;
20pub type IMTHashFunction = fn(Vec<IMTNode>) -> IMTNode;
21
22#[derive(Debug)]
23pub struct LeanIMT {
24    size: usize,
25    depth: usize,
26    side_nodes: HashMap<usize, IMTNode>,
27    leaves: HashMap<IMTNode, usize>,
28    hash: IMTHashFunction,
29}
30
31impl LeanIMT {
32    pub fn new(hash: IMTHashFunction) -> Self {
33        LeanIMT {
34            size: 0,
35            depth: 0,
36            side_nodes: HashMap::new(),
37            leaves: HashMap::new(),
38            hash,
39        }
40    }
41
42    /// Inserts a new leaf into the tree.
43    pub fn insert(&mut self, leaf: IMTNode) -> Result<IMTNode, &'static str> {
44        if self.leaves.contains_key(&leaf) {
45            return Err("Leaf already exists");
46        }
47        if leaf == "0" {
48            return Err("Leaf cannot be zero");
49        }
50
51        let mut index = self.size;
52        let mut tree_depth = self.depth;
53
54        // Increase tree depth if necessary
55        if (1 << tree_depth) < index + 1 {
56            tree_depth += 1;
57            self.depth = tree_depth;
58        }
59
60        let mut node = leaf.clone();
61
62        for level in 0..tree_depth {
63            if ((index >> level) & 1) == 1 {
64                // If the bit at position `level` is 1, hash with the side node
65                let side_node = self
66                    .side_nodes
67                    .get(&level)
68                    .cloned()
69                    .expect("No side node at this level");
70                node = (self.hash)(vec![side_node, node]);
71            } else {
72                // Else, store the node as side node
73                self.side_nodes.insert(level, node.clone());
74                break;
75            }
76        }
77
78        index += 1;
79        self.size = index;
80
81        // Update the root node
82        self.side_nodes.insert(tree_depth, node.clone());
83        self.leaves.insert(leaf, index);
84
85        Ok(node)
86    }
87
88    /// Inserts multiple leaves into the tree.
89    pub fn insert_many(&mut self, leaves: Vec<IMTNode>) -> Result<IMTNode, &'static str> {
90        // Validate leaves
91        for leaf in &leaves {
92            if self.leaves.contains_key(leaf) {
93                return Err("Leaf already exists");
94            }
95            if leaf == "0" {
96                return Err("Leaf cannot be zero");
97            }
98        }
99
100        let mut current_level_new_nodes = leaves.clone();
101
102        let tree_size = self.size;
103        let mut tree_depth = self.depth;
104
105        // Calculate new tree depth
106        while (1 << tree_depth) < tree_size + leaves.len() {
107            tree_depth += 1;
108        }
109        self.depth = tree_depth;
110
111        let mut current_level_start_index = tree_size;
112        let mut current_level_size = tree_size + leaves.len();
113        let mut next_level_start_index = current_level_start_index >> 1;
114        let mut next_level_size = ((current_level_size - 1) >> 1) + 1;
115
116        for level in 0..tree_depth {
117            let number_of_new_nodes = next_level_size - next_level_start_index;
118            let mut next_level_new_nodes = Vec::with_capacity(number_of_new_nodes);
119
120            for i in 0..number_of_new_nodes {
121                let left_index = (i + next_level_start_index) * 2 - current_level_start_index;
122                let right_index = left_index + 1;
123
124                let left_node = if left_index < current_level_new_nodes.len() {
125                    current_level_new_nodes[left_index].clone()
126                } else {
127                    self.side_nodes.get(&level).cloned().unwrap_or("0".to_string())
128                };
129
130                let right_node = if right_index < current_level_new_nodes.len() {
131                    current_level_new_nodes[right_index].clone()
132                } else {
133                    "0".to_string()
134                };
135
136                let parent_node = if right_node != "0" {
137                    (self.hash)(vec![left_node.clone(), right_node])
138                } else {
139                    left_node.clone()
140                };
141
142                next_level_new_nodes.push(parent_node);
143            }
144
145            // Update side nodes
146            if current_level_size & 1 == 1 {
147                self.side_nodes
148                    .insert(level, current_level_new_nodes.last().cloned().unwrap());
149            } else if current_level_new_nodes.len() > 1 {
150                self.side_nodes.insert(
151                    level,
152                    current_level_new_nodes
153                        .get(current_level_new_nodes.len() - 2)
154                        .cloned()
155                        .unwrap(),
156                );
157            }
158
159            current_level_start_index = next_level_start_index;
160            next_level_start_index >>= 1;
161
162            current_level_new_nodes = next_level_new_nodes;
163            current_level_size = next_level_size;
164            next_level_size = ((next_level_size - 1) >> 1) + 1;
165        }
166
167        // Update tree size and root
168        self.size = tree_size + leaves.len();
169        self.side_nodes
170            .insert(tree_depth, current_level_new_nodes[0].clone());
171
172        // Update leaves mapping
173        for (i, leaf) in leaves.iter().enumerate() {
174            self.leaves.insert(leaf.clone(), tree_size + i + 1);
175        }
176
177        Ok(current_level_new_nodes[0].clone())
178    }
179
180    /// Updates an existing leaf in the tree.
181    pub fn update(
182        &mut self,
183        old_leaf: &IMTNode,
184        new_leaf: IMTNode,
185        sibling_nodes: &[IMTNode],
186    ) -> Result<IMTNode, &'static str> {
187        if !self.leaves.contains_key(old_leaf) {
188            return Err("Leaf does not exist");
189        }
190        if self.leaves.contains_key(&new_leaf) && new_leaf != "0" {
191            return Err("New leaf already exists");
192        }
193
194        let index = self.index_of(old_leaf)?;
195        let mut node = new_leaf.clone();
196        let mut old_root = old_leaf.clone();
197
198        let last_index = self.size - 1;
199        let mut i = 0;
200
201        let tree_depth = self.depth;
202
203        for level in 0..tree_depth {
204            if ((index >> level) & 1) == 1 {
205                let sibling_node = sibling_nodes
206                    .get(i)
207                    .cloned()
208                    .ok_or("Not enough sibling nodes")?;
209                node = (self.hash)(vec![sibling_node.clone(), node]);
210                old_root = (self.hash)(vec![sibling_node, old_root]);
211                i += 1;
212            } else {
213                if (index >> level) != (last_index >> level) {
214                    let sibling_node = sibling_nodes
215                        .get(i)
216                        .cloned()
217                        .ok_or("Not enough sibling nodes")?;
218                    node = (self.hash)(vec![node, sibling_node.clone()]);
219                    old_root = (self.hash)(vec![old_root, sibling_node]);
220                    i += 1;
221                } else {
222                    self.side_nodes.insert(level, node.clone());
223                }
224            }
225        }
226
227        if Some(old_root) != self.root() {
228            return Err("Wrong sibling nodes");
229        }
230
231        self.side_nodes.insert(tree_depth, node.clone());
232
233        if new_leaf != "0" {
234            let leaf_index = *self.leaves.get(old_leaf).unwrap();
235            self.leaves.insert(new_leaf.clone(), leaf_index);
236        }
237
238        self.leaves.remove(old_leaf);
239
240        Ok(node)
241    }
242
243    /// Removes a leaf from the tree.
244    pub fn remove(&mut self, old_leaf: &IMTNode, sibling_nodes: &[IMTNode]) -> Result<IMTNode, &'static str> {
245        self.update(old_leaf, "0".to_string(), sibling_nodes)
246    }
247
248    /// Checks if a leaf exists in the tree.
249    pub fn has(&self, leaf: &IMTNode) -> bool {
250        self.leaves.contains_key(leaf)
251    }
252
253    /// Returns the index of a leaf in the tree.
254    pub fn index_of(&self, leaf: &IMTNode) -> Result<usize, &'static str> {
255        self.leaves
256            .get(leaf)
257            .map(|&index| index - 1)
258            .ok_or("Leaf does not exist")
259    }
260
261    /// Returns the root of the tree.
262    pub fn root(&self) -> Option<IMTNode> {
263        self.side_nodes.get(&self.depth).cloned()
264    }
265
266    /// Getter Functions for Debugging
267    pub fn get_size(&self) -> usize {
268        self.size
269    }
270
271    pub fn get_depth(&self) -> usize {
272        self.depth
273    }
274
275    pub fn get_side_nodes(&self) -> HashMap<usize, IMTNode> {
276        self.side_nodes.clone()
277    }
278
279    pub fn get_leaves(&self) -> HashMap<IMTNode, usize> {
280        self.leaves.clone()
281    }
282}
283
284#[cfg(all(feature = "std", test))]
285mod tests {
286    use super::*;
287
288    fn simple_hash_function(nodes: Vec<String>) -> String {
289        nodes.join(",")
290    }
291
292    #[test]
293    fn test_new_lean_imt() {
294        let hash: IMTHashFunction = simple_hash_function;
295        let imt = LeanIMT::new(hash);
296
297        assert_eq!(imt.size, 0);
298        assert_eq!(imt.depth, 0);
299        assert!(imt.root().is_none());
300    }
301
302    #[test]
303    fn test_insert() {
304        let hash: IMTHashFunction = simple_hash_function;
305        let mut imt = LeanIMT::new(hash);
306
307        assert!(imt.insert("leaf1".to_string()).is_ok());
308        assert_eq!(imt.size, 1);
309        assert_eq!(imt.depth, 0);
310        assert!(imt.has(&"leaf1".to_string()));
311        assert_eq!(imt.root().unwrap(), "leaf1".to_string());
312    }
313
314    #[test]
315    fn test_insert_many() {
316        let hash: IMTHashFunction = simple_hash_function;
317        let mut imt = LeanIMT::new(hash);
318
319        let leaves = vec!["leaf1".to_string(), "leaf2".to_string(), "leaf3".to_string()];
320        assert!(imt.insert_many(leaves.clone()).is_ok());
321        assert_eq!(imt.size, 3);
322        assert_eq!(imt.depth, 2);
323        for leaf in &leaves {
324            assert!(imt.has(leaf));
325        }
326        // Expected root calculation
327        let expected_root = simple_hash_function(vec![
328            simple_hash_function(vec![
329                leaves[0].clone(),
330                leaves[1].clone(),
331            ]),
332            leaves[2].clone(),
333        ]);
334        assert_eq!(imt.root().unwrap(), expected_root);
335    }
336
337    #[test]
338    fn test_insert_duplicate_leaf() {
339        let hash: IMTHashFunction = simple_hash_function;
340        let mut imt = LeanIMT::new(hash);
341
342        imt.insert("leaf1".to_string()).unwrap();
343        let result = imt.insert("leaf1".to_string());
344        assert!(result.is_err());
345        assert_eq!(result.unwrap_err(), "Leaf already exists");
346    }
347
348    #[test]
349    fn test_insert_many_with_duplicate_leaf() {
350        let hash: IMTHashFunction = simple_hash_function;
351        let mut imt = LeanIMT::new(hash);
352
353        imt.insert("leaf1".to_string()).unwrap();
354        let leaves = vec!["leaf2".to_string(), "leaf1".to_string()];
355        let result = imt.insert_many(leaves);
356        assert!(result.is_err());
357        assert_eq!(result.unwrap_err(), "Leaf already exists");
358    }
359
360    #[test]
361    fn test_update() {
362        let hash: IMTHashFunction = simple_hash_function;
363        let mut imt = LeanIMT::new(hash);
364
365        imt.insert("leaf1".to_string()).unwrap();
366        let sibling_nodes = vec![];
367        assert!(imt
368            .update(
369                &"leaf1".to_string(),
370                "new_leaf1".to_string(),
371                &sibling_nodes
372            )
373            .is_ok());
374        assert!(imt.has(&"new_leaf1".to_string()));
375        assert!(!imt.has(&"leaf1".to_string()));
376        assert_eq!(imt.root().unwrap(), "new_leaf1".to_string());
377    }
378
379    #[test]
380    fn test_update_nonexistent_leaf() {
381        let hash: IMTHashFunction = simple_hash_function;
382        let mut imt = LeanIMT::new(hash);
383
384        let sibling_nodes = vec![];
385        let result = imt.update(
386            &"nonexistent_leaf".to_string(),
387            "new_leaf".to_string(),
388            &sibling_nodes,
389        );
390        assert!(result.is_err());
391        assert_eq!(result.unwrap_err(), "Leaf does not exist");
392    }
393
394    #[test]
395    fn test_remove() {
396        let hash: IMTHashFunction = simple_hash_function;
397        let mut imt = LeanIMT::new(hash);
398
399        imt.insert("leaf1".to_string()).unwrap();
400        let sibling_nodes = vec![];
401        assert!(imt.remove(&"leaf1".to_string(), &sibling_nodes).is_ok());
402        assert!(!imt.has(&"leaf1".to_string()));
403        assert_eq!(imt.root().unwrap(), "0".to_string());
404    }
405
406    #[test]
407    fn test_remove_nonexistent_leaf() {
408        let hash: IMTHashFunction = simple_hash_function;
409        let mut imt = LeanIMT::new(hash);
410
411        let sibling_nodes = vec![];
412        let result = imt.remove(&"nonexistent_leaf".to_string(), &sibling_nodes);
413        assert!(result.is_err());
414        assert_eq!(result.unwrap_err(), "Leaf does not exist");
415    }
416
417    #[test]
418    fn test_has_and_index_of() {
419        let hash: IMTHashFunction = simple_hash_function;
420        let mut imt = LeanIMT::new(hash);
421
422        assert!(!imt.has(&"leaf1".to_string()));
423        assert!(imt.index_of(&"leaf1".to_string()).is_err());
424
425        imt.insert("leaf1".to_string()).unwrap();
426        assert!(imt.has(&"leaf1".to_string()));
427        assert_eq!(imt.index_of(&"leaf1".to_string()).unwrap(), 0);
428    }
429
430    #[test]
431    fn test_root_after_operations() {
432        let hash: IMTHashFunction = simple_hash_function;
433        let mut imt = LeanIMT::new(hash);
434
435        // Initially empty tree
436        assert!(imt.root().is_none());
437
438        // Insert leaf1
439        imt.insert("leaf1".to_string()).unwrap();
440        let root_after_leaf1 = imt.root().unwrap();
441
442        // Insert leaf2
443        imt.insert("leaf2".to_string()).unwrap();
444        let root_after_leaf2 = imt.root().unwrap();
445        assert_ne!(root_after_leaf1, root_after_leaf2);
446
447        // Remove leaf1
448        let sibling_nodes = vec!["leaf2".to_string()];
449        imt.remove(&"leaf1".to_string(), &sibling_nodes).unwrap();
450        let root_after_removal = imt.root().unwrap();
451        assert_eq!(root_after_removal, "0,leaf2".to_string());
452
453        // Update leaf2
454        let sibling_nodes = vec!["0".to_string()];
455        imt.update(
456            &"leaf2".to_string(),
457            "leaf3".to_string(),
458            &sibling_nodes,
459        )
460        .unwrap();
461        let root_after_update = imt.root().unwrap();
462        assert_eq!(root_after_update, "0,leaf3".to_string());
463    }
464
465    #[test]
466    fn test_tree_consistency() {
467        let hash: IMTHashFunction = simple_hash_function;
468        let mut imt = LeanIMT::new(hash);
469
470        // Insert leaves
471        imt.insert("leaf1".to_string()).unwrap();
472        imt.insert("leaf2".to_string()).unwrap();
473        imt.insert("leaf3".to_string()).unwrap();
474        imt.insert("leaf4".to_string()).unwrap();
475
476        // Current root
477        let root_before = imt.root().unwrap();
478
479        // Update leaf2
480        let sibling_nodes = vec!["leaf1".to_string(), simple_hash_function(vec![
481            "leaf3".to_string(),
482            "leaf4".to_string(),
483        ])];
484        imt.update(
485            &"leaf2".to_string(),
486            "leaf2_updated".to_string(),
487            &sibling_nodes,
488        )
489        .unwrap();
490
491        // New root should be different
492        let root_after = imt.root().unwrap();
493        assert_ne!(root_before, root_after);
494
495        // Remove leaf3
496        let sibling_nodes = vec!["leaf4".to_string(), simple_hash_function(vec![
497            "leaf1".to_string(),
498            "leaf2_updated".to_string(),
499        ])];
500        imt.remove(&"leaf3".to_string(), &sibling_nodes).unwrap();
501
502        // Root should change again
503        let root_after_removal = imt.root().unwrap();
504        assert_ne!(root_after, root_after_removal);
505
506        // Check that leaves are correctly updated
507        assert!(imt.has(&"leaf1".to_string()));
508        assert!(imt.has(&"leaf2_updated".to_string()));
509        assert!(!imt.has(&"leaf2".to_string()));
510        assert!(!imt.has(&"leaf3".to_string()));
511        assert!(imt.has(&"leaf4".to_string()));
512    }
513
514    #[test]
515    fn test_large_number_of_leaves() {
516        let hash: IMTHashFunction = |nodes: Vec<String>| {
517            // Simple hash function that simulates combining nodes
518            format!("H({})", nodes.join("+"))
519        };
520        let mut imt = LeanIMT::new(hash);
521
522        // Insert 100 leaves
523        let leaves: Vec<_> = (1..=100).map(|i| format!("leaf{}", i)).collect();
524        assert!(imt.insert_many(leaves.clone()).is_ok());
525        assert_eq!(imt.size, 100);
526
527        // Check that all leaves are present
528        for leaf in &leaves {
529            assert!(imt.has(leaf));
530        }
531
532        // Check that the tree depth is correct
533        let expected_depth = (100 as f64).log2().ceil() as usize;
534        assert_eq!(imt.depth, expected_depth);
535    }
536
537    #[test]
538    fn test_insertion_after_removal() {
539        let hash: IMTHashFunction = simple_hash_function;
540        let mut imt = LeanIMT::new(hash);
541
542        // Insert leaves
543        imt.insert("leaf1".to_string()).unwrap();
544        imt.insert("leaf2".to_string()).unwrap();
545
546        // Remove leaf1
547        let sibling_nodes = vec!["leaf2".to_string()];
548        imt.remove(&"leaf1".to_string(), &sibling_nodes).unwrap();
549
550        // Insert new leaf
551        assert!(imt.insert("leaf3".to_string()).is_ok());
552
553        // Check that leaves are correctly updated
554        assert!(!imt.has(&"leaf1".to_string()));
555        assert!(imt.has(&"leaf2".to_string()));
556        assert!(imt.has(&"leaf3".to_string()));
557    }
558
559    #[test]
560    fn test_tree_after_all_leaves_removed() {
561        let hash: IMTHashFunction = simple_hash_function;
562        let mut imt = LeanIMT::new(hash);
563
564        // Insert leaves
565        imt.insert("leaf1".to_string()).unwrap();
566        imt.insert("leaf2".to_string()).unwrap();
567
568        // Remove all leaves
569        let sibling_nodes = vec!["leaf2".to_string()];
570        imt.remove(&"leaf1".to_string(), &sibling_nodes).unwrap();
571
572        let sibling_nodes = vec!["0".to_string()];
573        imt.remove(&"leaf2".to_string(), &sibling_nodes).unwrap();
574
575        // Tree should be empty
576        assert_eq!(imt.size, 2);
577        assert_eq!(imt.depth, 1);
578        assert_eq!(imt.root().unwrap(), "0,0".to_string());
579        assert!(!imt.has(&"leaf1".to_string()));
580        assert!(!imt.has(&"leaf2".to_string()));
581    }
582
583    #[test]
584    fn test_insert_after_tree_becomes_empty() {
585        let hash: IMTHashFunction = simple_hash_function;
586        let mut imt = LeanIMT::new(hash);
587
588        // Insert and remove leaves
589        imt.insert("leaf1".to_string()).unwrap();
590        let sibling_nodes = vec![];
591        imt.remove(&"leaf1".to_string(), &sibling_nodes).unwrap();
592
593        // Insert new leaf
594        assert!(imt.insert("leaf2".to_string()).is_ok());
595        assert!(imt.has(&"leaf2".to_string()));
596        assert_eq!(imt.root().unwrap(), "0,leaf2".to_string());
597    }
598
599    #[test]
600    fn test_insertion_causes_depth_increase() {
601        let hash: IMTHashFunction = simple_hash_function;
602        let mut imt = LeanIMT::new(hash);
603
604        // Insert leaves to fill tree of depth 0
605        imt.insert("leaf1".to_string()).unwrap();
606        assert_eq!(imt.depth, 0);
607
608        // Insert leaves to fill tree of depth 1
609        imt.insert("leaf2".to_string()).unwrap();
610        assert_eq!(imt.depth, 1);
611
612        // Insert another leaf, depth should increase
613        imt.insert("leaf3".to_string()).unwrap();
614        assert_eq!(imt.depth, 2);
615
616        // Insert leaves to fill tree of depth 2
617        imt.insert("leaf4".to_string()).unwrap();
618        assert_eq!(imt.depth, 2);
619
620        // Insert another leaf, depth should increase
621        imt.insert("leaf5".to_string()).unwrap();
622        assert_eq!(imt.depth, 3);
623    }
624
625    #[test]
626    fn test_invalid_sibling_nodes_on_update() {
627        let hash: IMTHashFunction = simple_hash_function;
628        let mut imt = LeanIMT::new(hash);
629
630        // Insert leaves
631        imt.insert("leaf1".to_string()).unwrap();
632        imt.insert("leaf2".to_string()).unwrap();
633
634        // Try to update with incorrect sibling nodes
635        let sibling_nodes = vec!["wrong_sibling".to_string()];
636        let result = imt.update(
637            &"leaf1".to_string(),
638            "leaf1_updated".to_string(),
639            &sibling_nodes,
640        );
641        assert!(result.is_err());
642        assert_eq!(result.unwrap_err(), "Wrong sibling nodes");
643    }
644
645    #[test]
646    fn test_invalid_sibling_nodes_on_remove() {
647        let hash: IMTHashFunction = simple_hash_function;
648        let mut imt = LeanIMT::new(hash);
649
650        // Insert leaves
651        imt.insert("leaf1".to_string()).unwrap();
652        imt.insert("leaf2".to_string()).unwrap();
653
654        // Try to remove with incorrect sibling nodes
655        let sibling_nodes = vec!["wrong_sibling".to_string()];
656        let result = imt.remove(&"leaf1".to_string(), &sibling_nodes);
657        assert!(result.is_err());
658        assert_eq!(result.unwrap_err(), "Wrong sibling nodes");
659    }
660}