gen_rs/
trie.rs

1use std::ops::Index;
2use std::collections::{HashMap, hash_map};
3use crate::SplitAddr::{self,Prefix,Term};
4
5
6/// Hierarchical prefix tree
7#[derive(Debug,Clone,PartialEq)]
8pub struct Trie<V> {
9    leaf_nodes: HashMap<String,V>,
10    internal_nodes: HashMap<String,Trie<V>>
11}
12
13impl<V> Trie<V> {
14    /// Construct an empty Trie.
15    pub fn new() -> Self {
16        Trie {
17            leaf_nodes: HashMap::new(),
18            internal_nodes: HashMap::new()
19        }
20    }
21
22    /// Return `true` if a Trie is empty (has no leaf or internal nodes), otherwise `false`.
23    pub fn is_empty(&self) -> bool {
24        self.leaf_nodes.is_empty() && self.internal_nodes.is_empty()
25    }
26
27    /// Return `true` if a Trie has a leaf node at `addr`, otherwise `false`.
28    pub fn has_leaf_node(&self, addr: &str) -> bool {
29        match SplitAddr::from_addr(addr) {
30            Term(addr) => {
31                self.leaf_nodes.contains_key(addr)
32            }
33            Prefix(first, rest) => {
34                if self.internal_nodes.contains_key(first) {
35                    self.internal_nodes[first].has_leaf_node(rest)
36                } else {
37                    false
38                }
39            }
40        }
41    }
42
43    /// Return `Some(&value)` if `self` contains a `value` located at `addr`, otherwise `None`.
44    pub fn get_leaf_node(&self, addr: &str) -> Option<&V> {
45        match SplitAddr::from_addr(addr) {
46            Term(addr) => {
47                self.leaf_nodes.get(addr)
48            }
49            Prefix(first, rest) => {
50                self.internal_nodes[first].get_leaf_node(rest)
51            }
52        }
53    }
54
55    /// Insert `value` as a leaf node located at `addr`.
56    /// 
57    /// If there was a value `prev` located at `addr`, return `Some(prev)`, otherwise `None`.
58    pub fn insert_leaf_node(&mut self, addr: &str, value: V) -> Option<V> {
59        match SplitAddr::from_addr(addr) {
60            Term(addr) => {
61                self.leaf_nodes.insert(addr.to_string(), value)
62            }
63            Prefix(first, rest) => {
64                let node = self.internal_nodes
65                    .entry(first.to_string())
66                    .or_insert(Trie::new());
67                node.insert_leaf_node(rest, value)
68            }
69        }
70    }
71
72    /// Return `Some(value)` if `self` contains a `value` located at `addr` and remove `value` from the leaf nodes, otherwise return `None`.
73    pub fn remove_leaf_node(&mut self, addr: &str) -> Option<V> {
74        match SplitAddr::from_addr(addr) {
75            Term(addr) => {
76                self.leaf_nodes.remove(addr)
77            }
78            Prefix(first, rest) => {
79                let node = self.internal_nodes.get_mut(first).unwrap();
80                let leaf = node.remove_leaf_node(rest);
81                if node.is_empty() {
82                    self.remove_internal_node(first);
83                }
84                leaf
85            }
86        }
87    }
88
89    /// Return an iterator over a Trie's leaf nodes.
90    pub fn leaf_iter(&self) -> hash_map::Iter<'_, String, V> {
91        self.leaf_nodes.iter()
92    }
93
94    /// Return `true` if a Trie has an internal node at `addr`, otherwise `false`.
95    pub fn has_internal_node(&self, addr: &str) -> bool {
96        match SplitAddr::from_addr(addr) {
97            Term(addr) => {
98                self.internal_nodes.contains_key(addr)
99            }
100            Prefix(first, rest) => {
101                if self.internal_nodes.contains_key(first) {
102                    self.internal_nodes[first].has_internal_node(rest)
103                } else {
104                    false
105                }
106            }
107        }
108    }
109
110    /// Return `Some(&subtrie)` if `self` contains a `subtrie` located at `addr`, otherwise `None`.
111    pub fn get_internal_node(&self, addr: &str) -> Option<&Self> {
112        match SplitAddr::from_addr(addr) {
113            Term(addr) => {
114                self.internal_nodes.get(addr)
115            }
116            Prefix(first, rest) => {
117                self.internal_nodes[first].get_internal_node(rest)
118            }
119        }
120    }
121
122    /// Insert `subtrie` as an internal node located at `addr`.
123    /// 
124    /// If there was a value `prev_subtrie` located at `addr`, return `Some(prev_subtrie)`, otherwise `None`.
125    /// Panics if `subtrie.is_empty()`.
126    pub fn insert_internal_node(&mut self, addr: &str, new_node: Self) -> Option<Trie<V>> {
127        match SplitAddr::from_addr(addr) {
128            Term(addr) => {
129                if !new_node.is_empty() {
130                    self.internal_nodes.insert(addr.to_string(), new_node)
131                } else {
132                    panic!("attempted to insert empty inode")
133                }
134            }
135            Prefix(first, rest) => {
136                let node = self.internal_nodes
137                    .entry(first.to_string())
138                    .or_insert(Trie::new());
139                node.insert_internal_node(rest, new_node)
140            }
141        }
142    }
143
144    /// Return `Some(subtrie)` if `self` contains a `subtrie` located at `addr` and remove `subtrie` from the internal nodes, otherwise return `None`.
145    pub fn remove_internal_node(&mut self, addr: &str) -> Option<Trie<V>> {
146        match SplitAddr::from_addr(addr) {
147            Term(addr) => {
148                self.internal_nodes.remove(addr)
149            }
150            Prefix(first, rest) => {
151                let node = self.internal_nodes.get_mut(first).unwrap();
152                node.remove_internal_node(rest)
153            }
154        }
155    }
156
157    /// Return an iterator over a Trie's internal nodes.
158    pub fn internal_iter(&self) -> hash_map::Iter<'_, String, Trie<V>> {
159        self.internal_nodes.iter()
160    }
161
162    /// Merge `other` into `self`, freeing previous values and subtries at each `addr` in `self` if `other` also has an entry at `addr`.
163    /// 
164    /// Returns the mutated `self`.
165    pub fn merge(mut self, other: Self) -> Self {
166        for (addr, value) in other.leaf_nodes.into_iter() {
167            self.insert_leaf_node(&addr, value);
168        }
169        for (addr, subtrie) in other.internal_nodes.into_iter() {
170            self.insert_internal_node(&addr, subtrie);
171        }
172        self
173    }
174}
175
176
177// specializations
178
179impl<V> Trie<(V,f64)> {
180    /// Return the sum of all the weights of the leaf nodes and the recursive sum of all internal nodes.
181    pub fn sum(&self) -> f64 {
182        self.internal_nodes.values().fold(0., |acc, t| acc + t.sum()) +
183        self.leaf_nodes.values().fold(0., |acc, v| acc + v.1)
184    }
185
186    /// Convert a weighted `Trie` into the equivalent unweighted version by discarding all the weights.
187    pub fn into_unweighted(self) -> Trie<V> {
188        Trie {
189            internal_nodes: self.internal_nodes.into_iter().map(|(addr, t)| (addr, t.into_unweighted())).collect::<_>(),
190            leaf_nodes: self.leaf_nodes.into_iter().map(|(addr, v)| (addr, v.0)).collect::<_>()
191        }
192    }
193
194    /// Convert an unweighted `Trie` into the equivalent weighted version by adding a weight of `0.` to all leaf nodes.
195    pub fn from_unweighted(trie: Trie<V>) -> Self {
196        Trie {
197            internal_nodes: trie.internal_nodes.into_iter().map(|(addr, t)| (addr, Self::from_unweighted(t))).collect::<_>(),
198            leaf_nodes: trie.leaf_nodes.into_iter().map(|(addr, v)| (addr, (v, 0.))).collect::<_>()
199        }
200    }
201}
202
203use std::{rc::Rc,any::Any};
204
205impl Trie<Rc<dyn Any>> {
206    /// Optimistically casts the reference-counted `dyn Any` at `addr` into type `V`, and returns a cloned value.
207    pub fn read<V: 'static + Clone>(&self, addr: &str) -> V {
208        self.get_leaf_node(addr)
209            .unwrap()
210            .clone()
211            .downcast::<V>()
212            .ok()
213            .unwrap()
214            .as_ref()
215            .clone()
216    }
217}
218
219impl Trie<(Rc<dyn Any>,f64)> {
220    /// Optimistically casts the reference-counted `dyn Any` at `addr` into type `V`, and returns a cloned value.
221    pub fn read<V: 'static + Clone>(&self, addr: &str) -> V {
222        self.get_leaf_node(addr)
223            .unwrap().0
224            .clone()
225            .downcast::<V>()
226            .ok()
227            .unwrap()
228            .as_ref()
229            .clone()
230    }
231}
232
233impl<V> Index<&str> for Trie<V> {
234    type Output = V;
235
236    fn index(&self, index: &str) -> &Self::Output {
237        self.get_leaf_node(index).unwrap()
238    }
239}