modppl/
trie.rs

1use std::collections::{HashMap, hash_map};
2use crate::{SplitAddr::{self,Prefix,Term}, AddrMap};
3
4
5/// Weighted Digital Trie
6#[derive(Debug,Clone,PartialEq)]
7pub struct Trie<V> {
8    mapping: HashMap<String,Trie<V>>,
9    value: Option<V>,
10    weight: f64
11}
12
13
14impl<V> Trie<V> {
15
16    /// Initialize an empty Trie.
17    pub fn new() -> Self {
18        Trie {
19            mapping: HashMap::new(),
20            value: None,
21            weight: 0.
22        }
23    }
24
25    /// Initialize a Trie with an inner value and weight.
26    pub fn leaf(value: V, weight: f64) -> Self {
27        Trie {
28            mapping: HashMap::new(),
29            value: Some(value),
30            weight: weight
31        }
32    }
33
34    /// Return `true` if `self` is empty (has no inner value nor descendants), otherwise `false`.
35    pub fn is_empty(&self) -> bool {
36        self.mapping.is_empty() && self.value.is_none()
37    }
38
39    /// Return `true` if `self` is a leaf (has an inner value but no descendants), otherwise `false`.
40    pub fn is_leaf(&self) -> bool {
41        self.mapping.is_empty() && self.value.is_some()
42    }
43
44    /// Return the number of _direct_ descendants of the `Trie`.
45    pub fn len(&self) -> usize {
46        self.mapping.len()
47    }
48
49    /// Return some reference to the inner value if there is one, otherwise none.
50    pub fn ref_inner(&self) -> Option<&V> {
51        self.value.as_ref()
52    }
53
54    /// Return some inner value (setting the inner value to none), otherwise just return none.
55    pub fn take_inner(&mut self) -> Option<V> {
56        self.value.take()
57    }
58
59    /// Return some inner value (setting the inner value to `value`), otherwise just return none.
60    pub fn replace_inner(&mut self, value: V) -> Option<V> {
61        self.value.replace(value)
62    }
63
64    /// Return some inner value if there is one, otherwise panic with `msg`.
65    pub fn expect_inner(self, msg: &str) -> V {
66        self.value.expect(msg)
67    }
68
69    /// Iterate through the _direct_ descendants of `self`.
70    pub fn iter(&self) -> hash_map::Iter<'_, String, Trie<V>> {
71        self.mapping.iter()
72    }
73
74    /// Iterate mutably through the _direct_ descendants of `self`.
75    pub fn iter_mut(&mut self) -> hash_map::IterMut<'_, String, Trie<V>> {
76        self.mapping.iter_mut()
77    }
78
79    /// Move `self` into an iterator over the _direct_ descendants of `self`.
80    pub fn into_iter(self) -> hash_map::IntoIter<String, Trie<V>> {
81        self.mapping.into_iter()
82    }
83
84    /// Return the sum of the weight of all descendants.
85    pub fn weight(&self) -> f64 {
86        self.weight
87    }
88
89    /// Return some reference to a descendant at `addr` if present, otherwise none.
90    pub fn search(&self, addr: &str) -> Option<&Trie<V>> {
91        match SplitAddr::from_addr(addr) {
92            Term(addr) => {
93                self.mapping.get(addr)
94            }
95            Prefix(first, rest) => {
96                self.mapping[first].search(rest)
97            }
98        }
99    }
100
101    /// Observe an unweighted `value` at `addr`. Panic if `addr` is occupied.
102    pub fn observe(&mut self, addr: &str, value: V) {
103        match SplitAddr::from_addr(addr) {
104            Term(addr) => {
105                if self.mapping.contains_key(addr) {
106                    panic!("observe: attempted to put into occupied address \"{addr}\"");
107                } else {
108                    self.mapping.insert(addr.to_string(), Trie::leaf(value, 0.0));
109                }
110            }
111            Prefix(first, rest) => {
112                let submap = self.mapping
113                    .entry(first.to_string())
114                    .or_insert(Trie::new());
115                submap.observe(rest, value)
116            }
117        }
118    }
119
120    /// Observe a weighted `value` at `addr`, summing the weight by `weight`. Panic if `addr` is occupied.
121    pub fn w_observe(&mut self, addr: &str, value: V, weight: f64) { 
122        self.weight += weight;
123        match SplitAddr::from_addr(addr) {
124            Term(addr) => {
125                if self.mapping.contains_key(addr) {
126                    panic!("w_observe: attempted to put into occupied address \"{addr}\"");
127                } else {
128                    self.mapping.insert(addr.to_string(), Trie::leaf(value, weight));
129                }
130            }
131            Prefix(first, rest) => {
132                let submap = self.mapping
133                    .entry(first.to_string())
134                    .or_insert(Trie::new());
135                submap.w_observe(rest, value, weight)
136            }
137        }
138    }
139
140    /// Insert a descendant `sub` at `addr`. Panic if `addr` is occupied.
141    pub fn insert(&mut self, addr: &str, sub: Trie<V>) {
142        self.weight += sub.weight;
143        match SplitAddr::from_addr(addr) {
144            Term(addr) => {
145                if self.mapping.contains_key(addr) {
146                    panic!("insert: attempted to put into occupied address \"{addr}\"");
147                } else {
148                    self.mapping.insert(addr.to_string(), sub);
149                }
150            }
151            Prefix(first, rest) => {
152                let submap = self.mapping
153                    .entry(first.to_string())
154                    .or_insert(Trie::new());
155                submap.insert(rest, sub)
156            }
157        }
158    }
159
160    /// Return a descendant at `addr` if present (removing it), otherwise just return none.
161    pub fn remove(&mut self, addr: &str) -> Option<Trie<V>> {
162        if let Some(sub) = match SplitAddr::from_addr(addr) {
163            Term(addr) => {
164                self.mapping.remove(addr)
165            }
166            Prefix(first, rest) => {
167                match self.mapping.get_mut(first) {
168                    Some(node) => {
169                        let leaf = node.remove(rest);
170                        if node.is_empty() {
171                            self.remove(first);
172                        }
173                        leaf
174                    }
175                    None => { None }
176                }
177            }
178        } {
179            self.weight -= sub.weight;
180            Some(sub)
181        } else {
182            None
183        }
184    }
185
186    /// Merge an `other` Trie into `self`, preferentially using the values of `other` at overlapping addresses.
187    pub fn merge(&mut self, other: Self) {
188        for (addr, othersub) in other.into_iter() {
189            if othersub.is_leaf() {
190                self.w_observe(&addr, othersub.value.unwrap(), othersub.weight);
191            } else {
192                match self.mapping.get_mut(&addr) {
193                    Some(sub) => {
194                        sub.merge(othersub);
195                    }
196                    None => {
197                        self.insert(&addr, othersub);
198                    }
199                }
200            }
201        }
202    }
203
204    /// Return an `AddrMap` representing the address schema of `self`.
205    pub fn schema(&self) -> AddrMap {
206        let mut amap = AddrMap::new();
207        for (addr, subtrie) in self.iter() {
208            if subtrie.is_leaf() {
209                amap.visit(addr);
210            } else {
211                amap.insert(addr, subtrie.schema());
212            }
213        }
214        amap
215    }
216
217    /// Collect the set of values identified by `mask` into a new `Trie`,
218    /// leaving values in `self` that are in the complement of `mask`.
219    /// 
220    /// Return the new `self`, the collected value trie, and the weight of the collected value trie.
221    pub fn collect(
222        mut self: Self,
223        mask: &AddrMap
224    ) -> (Self,Self,f64) {
225        let mut collected = Trie::new();
226        if &self.schema() == mask {
227            let weight = self.weight();
228            return (collected, self, weight);
229        } else if !mask.is_leaf() {
230            for (addr, submask) in mask.iter() {
231                let Some(sub) = self.remove(addr) else { unreachable!() };
232                if submask.is_leaf() {
233                    collected.insert(addr, sub);
234                } else {
235                    let (sub, subcollected, _) = sub.collect(submask);
236                    if !sub.is_empty() {
237                        self.insert(addr, sub);
238                    }
239                    if !subcollected.is_empty() {
240                        collected.insert(addr, subcollected);
241                    }
242                }
243            }
244        }
245        let weight = collected.weight();
246        (self, collected, weight)
247    }
248
249}