ad_editor/
trie.rs

1//! A trie data structure with modifications for supporting key bindings and autocompletions in a
2//! composible way with the rest of the ad internal APIs.
3use std::{cmp, fmt};
4
5/// A singly initialised Trie mapping sequences to a value
6///
7/// NOTE: It is not permitted for values to be mapped to a key that is a prefix of another key also
8/// existing in the same Try.
9///
10/// There are convenience methods provided for `Trie<char, V>` for when &str values are used as keys.
11#[derive(Clone, PartialEq, Eq)]
12pub struct Trie<K, V>
13where
14    K: Clone + PartialEq,
15    V: Clone,
16{
17    parent_key: Option<Vec<K>>,
18    roots: Vec<Node<K, V>>,
19    default: Option<DefaultMapping<K, V>>,
20}
21
22impl<K, V> fmt::Debug for Trie<K, V>
23where
24    K: Clone + PartialEq + fmt::Debug,
25    V: Clone + fmt::Debug,
26{
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        f.debug_struct("Trie")
29            .field("parent_key", &self.parent_key)
30            .field("roots", &self.roots)
31            .field("default", &stringify!(self.default))
32            .finish()
33    }
34}
35
36impl<K, V> Trie<K, V>
37where
38    K: Clone + PartialEq,
39    V: Clone,
40{
41    /// Will panic if there are any key collisions or if there are any sequences that nest
42    /// under a prefix that is already required to hold a value
43    pub fn from_pairs(pairs: Vec<(Vec<K>, V)>) -> Self {
44        let mut roots = Vec::new();
45
46        for (k, v) in pairs.into_iter() {
47            insert(k, v, &mut roots)
48        }
49
50        Self {
51            parent_key: None,
52            roots,
53            default: None,
54        }
55    }
56
57    /// Set the default handler for unmatched single element keys
58    pub fn set_default(&mut self, default: DefaultMapping<K, V>) {
59        self.default = Some(default);
60    }
61
62    /// Query this Try for a given key or key prefix.
63    ///
64    /// If the key maps to a leaf then the value is returned, if it maps to a sub-trie then
65    /// `Partial` is returned to denote that the given key is a parent of one or more values. If
66    /// the key is not found within the `Try` then `Missing` is returned.
67    pub fn get(&self, key: &[K]) -> QueryResult<V> {
68        match get_node(key, &self.roots) {
69            Some(Node {
70                d: Data::Val(v), ..
71            }) => QueryResult::Val(v.clone()),
72
73            Some(_) => QueryResult::Partial,
74
75            None => match self.default {
76                Some(f) if key.len() == 1 => f(&key[0]).into(),
77                _ => QueryResult::Missing,
78            },
79        }
80    }
81
82    /// Query this Try for a given key or key prefix requiring the key to match exactly.
83    ///
84    /// If the key maps to a leaf then the `Some(value)` is returned, otherwise `None`.
85    pub fn get_exact(&self, key: &[K]) -> Option<V> {
86        match get_node(key, &self.roots) {
87            Some(Node {
88                d: Data::Val(v), ..
89            }) => Some(v.clone()),
90
91            Some(_) => None,
92
93            None => match self.default {
94                Some(f) if key.len() == 1 => f(&key[0]),
95                _ => None,
96            },
97        }
98    }
99
100    /// Find all candidate keys with the given key as a prefix (up to and including the given
101    /// prefix itself).
102    pub fn candidates(&self, key: &[K]) -> Vec<Vec<K>> {
103        match get_node(key, &self.roots) {
104            None => vec![],
105            Some(n) => n.resolved_keys(key),
106        }
107    }
108
109    /// The number of leaf nodes in this Try
110    pub fn len(&self) -> usize {
111        self.roots.iter().map(|r| r.len()).sum()
112    }
113
114    /// The number of leaf nodes in this Try
115    pub fn is_empty(&self) -> bool {
116        self.roots.is_empty()
117    }
118
119    /// Wether the given key is present in this [Trie] either as a full key or a partial prefix to
120    /// multiple keys.
121    pub fn contains_key_or_prefix(&self, key: &[K]) -> bool {
122        !matches!(self.get(key), QueryResult::Missing)
123    }
124}
125
126impl<V> Trie<char, V>
127where
128    V: Clone,
129{
130    /// Construct a new [Trie] with char internal keys from the given string keys.
131    pub fn from_str_keys(pairs: Vec<(&str, V)>) -> Self {
132        let mut roots = Vec::new();
133
134        for (k, v) in pairs.into_iter() {
135            insert(k.chars().collect(), v, &mut roots)
136        }
137
138        Self {
139            parent_key: None,
140            roots,
141            default: None,
142        }
143    }
144
145    /// Query this [Trie] using a string key.
146    ///
147    /// Both full and partial matches are possible.
148    pub fn get_str(&self, key: &str) -> QueryResult<V> {
149        self.get(&key.chars().collect::<Vec<_>>())
150    }
151
152    /// Query this [Trie] using a string key.
153    ///
154    /// Only fll matches will be returned.
155    pub fn get_str_exact(&self, key: &str) -> Option<V> {
156        self.get_exact(&key.chars().collect::<Vec<_>>())
157    }
158
159    /// Show all partial and full matches for the given key.
160    pub fn candidate_strings(&self, key: &str) -> Vec<String> {
161        let raw = self.candidates(&key.chars().collect::<Vec<_>>());
162        let mut strings: Vec<String> = raw.into_iter().map(|v| v.into_iter().collect()).collect();
163        strings.sort();
164
165        strings
166    }
167}
168
169/// A default handler for mapping a single length key to an `Option<V>`.
170///
171/// This is used to avoid having to specify large numbers of single length keys that should all be
172/// handled in a similar way.
173pub type DefaultMapping<K, V> = fn(&K) -> Option<V>;
174
175/// The result of querying a [Try] for a particular Key.
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum QueryResult<V> {
178    /// A leaf value associated with the key used in the query
179    Val(V),
180    /// The key used to query is a prefix to multiple targets
181    Partial,
182    /// The key does not exist within the [Trie]
183    Missing,
184}
185
186impl<V> From<Option<V>> for QueryResult<V> {
187    fn from(opt: Option<V>) -> Self {
188        match opt {
189            Some(v) => QueryResult::Val(v),
190            None => QueryResult::Missing,
191        }
192    }
193}
194
195impl<V> From<QueryResult<V>> for Option<V> {
196    fn from(q: QueryResult<V>) -> Self {
197        match q {
198            QueryResult::Val(v) => Some(v),
199            _ => None,
200        }
201    }
202}
203
204impl<V> QueryResult<V> {
205    pub fn map<F, U>(self, f: F) -> QueryResult<U>
206    where
207        F: Fn(V) -> U,
208    {
209        match self {
210            Self::Val(v) => QueryResult::Val(f(v)),
211            Self::Partial => QueryResult::Partial,
212            Self::Missing => QueryResult::Missing,
213        }
214    }
215}
216
217#[derive(Debug, Clone, PartialEq, Eq)]
218enum Data<K, V>
219where
220    K: Clone + PartialEq,
221{
222    Val(V),
223    Children(Vec<Node<K, V>>),
224}
225
226#[derive(Debug, Clone, PartialEq, Eq)]
227struct Node<K, V>
228where
229    K: Clone + PartialEq,
230{
231    k: K,
232    d: Data<K, V>,
233}
234
235impl<K, V> Node<K, V>
236where
237    K: Clone + PartialEq,
238{
239    fn len(&self) -> usize {
240        match &self.d {
241            Data::Children(nodes) => nodes.iter().map(|n| n.len()).sum(),
242            Data::Val(_) => 1,
243        }
244    }
245
246    // Panics if this node already holds a value
247    fn insert(&mut self, k: Vec<K>, v: V) {
248        match &mut self.d {
249            Data::Children(nodes) => insert(k, v, nodes),
250            Data::Val(_) => panic!("attempt to insert into value node"),
251        }
252    }
253
254    fn get_child<'s>(&'s self, k: &[K]) -> Option<&'s Node<K, V>> {
255        match &self.d {
256            Data::Children(nodes) => get_node(k, nodes),
257            Data::Val(_) => None,
258        }
259    }
260
261    fn resolved_keys(&self, prefix: &[K]) -> Vec<Vec<K>> {
262        match &self.d {
263            Data::Val(_) => vec![prefix.to_vec()],
264            Data::Children(nodes) => nodes
265                .iter()
266                .flat_map(|n| {
267                    let mut so_far = prefix.to_vec();
268                    so_far.push(n.k.clone());
269                    n.resolved_keys(&so_far)
270                })
271                .collect(),
272        }
273    }
274}
275
276fn insert<K, V>(mut key: Vec<K>, v: V, current: &mut Vec<Node<K, V>>)
277where
278    K: Clone + PartialEq,
279{
280    for n in current.iter_mut() {
281        if key[0] == n.k {
282            if key.len() > 1 {
283                key.remove(0);
284                n.insert(key, v);
285                return;
286            }
287            panic!("duplicate entry for key")
288        }
289    }
290
291    let k = key.remove(0);
292
293    // No matching root so create a new one
294    if key.is_empty() {
295        current.push(Node { k, d: Data::Val(v) });
296    } else {
297        let mut children = vec![];
298        insert(key, v, &mut children);
299
300        let d = Data::Children(children);
301        current.push(Node { k, d });
302    }
303}
304
305fn get_node<'n, K, V>(key: &[K], nodes: &'n [Node<K, V>]) -> Option<&'n Node<K, V>>
306where
307    K: Clone + PartialEq,
308{
309    if key.is_empty() {
310        return None;
311    }
312
313    for n in nodes.iter() {
314        if key[0] == n.k {
315            return if key.len() == 1 {
316                Some(n)
317            } else {
318                n.get_child(&key[1..])
319            };
320        }
321    }
322
323    None
324}
325
326/// A match function to use as part of a wildcard match
327pub type WildcardFn<K> = fn(&K) -> bool;
328
329/// A wildcard node in a key sequence that can conditionally match a single key element.
330#[derive(Debug)]
331pub enum WildCard<K> {
332    /// A literal key
333    Lit(K),
334    /// A predicate function for checking whether a key should be considered a match
335    Wild(WildcardFn<K>),
336}
337
338impl<K> cmp::PartialEq<K> for WildCard<K>
339where
340    K: PartialEq,
341{
342    fn eq(&self, other: &K) -> bool {
343        match self {
344            WildCard::Lit(k) => k == other,
345            WildCard::Wild(f) => f(other),
346        }
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use simple_test_case::test_case;
354
355    #[test]
356    #[should_panic(expected = "duplicate entry for key")]
357    fn duplicate_keys_panic() {
358        Trie::from_pairs(vec![(vec![42], 1), (vec![42], 2)]);
359    }
360
361    #[test]
362    #[should_panic(expected = "attempt to insert into value node")]
363    fn children_under_a_value_node_panics() {
364        Trie::from_pairs(vec![(vec![42], 1), (vec![42, 69], 2)]);
365    }
366
367    #[test_case(&[42], None; "partial should be None")]
368    #[test_case(&[144], None; "missing should be None")]
369    #[test_case(&[42, 69, 144], None; "overshoot should be None")]
370    #[test_case(&[42, 69], Some(1); "exact should be Some")]
371    #[test]
372    fn get_exact_works(k: &[usize], expected: Option<usize>) {
373        let t = Trie::from_pairs(vec![(vec![42, 69], 1)]);
374
375        assert_eq!(t.get_exact(k), expected);
376    }
377
378    #[test_case("fo", None; "partial should be None")]
379    #[test_case("bar", None; "missing should be None")]
380    #[test_case("fooo", None; "overshoot should be None")]
381    #[test_case("foo", Some(1); "exact should be Some")]
382    #[test]
383    fn get_str_exact_works(k: &str, expected: Option<usize>) {
384        let t = Trie::from_str_keys(vec![("foo", 1)]);
385
386        assert_eq!(t.get_str_exact(k), expected);
387    }
388
389    #[test_case("ba", QueryResult::Partial; "partial match")]
390    #[test_case("bar", QueryResult::Val(2); "exact match")]
391    #[test_case("baz", QueryResult::Val(3); "exact match with shared prefix")]
392    #[test_case("barf", QueryResult::Missing; "overshot known key")]
393    #[test_case("have you any wool?", QueryResult::Missing; "completely missing")]
394    #[test]
395    fn get_works(k: &str, expected: QueryResult<usize>) {
396        let t = Trie::from_str_keys(vec![("foo", 1), ("bar", 2), ("baz", 3)]);
397
398        assert_eq!(t.get_str(k), expected);
399    }
400
401    #[test_case("f", &["fold", "food", "fool"]; "first char")]
402    #[test_case("fo", &["fold", "food", "fool"]; "shared prefix")]
403    #[test_case("foo", &["food", "fool"]; "shared prefix not all match")]
404    #[test_case("food", &["food"]; "exact match")]
405    #[test_case("foods", &[]; "overshot")]
406    #[test_case("q", &[]; "unknown first char")]
407    #[test_case("quux", &[]; "unknown full key")]
408    #[test_case("", &[]; "empty string")]
409    #[test]
410    fn candidate_strings_works(k: &str, expected: &[&str]) {
411        let expected: Vec<String> = expected.iter().map(|s| s.to_string()).collect();
412        let t = Trie::from_str_keys(
413            ["fool", "fold", "food"]
414                .into_iter()
415                .enumerate()
416                .map(|(i, s)| (s, i))
417                .collect(),
418        );
419
420        assert_eq!(t.candidate_strings(k), expected);
421    }
422
423    fn usize_default_handler(n: &usize) -> Option<usize> {
424        Some(n + 1)
425    }
426
427    #[test_case(&[42], QueryResult::Val(1); "exact single should match from the Try")]
428    #[test_case(&[12, 13], QueryResult::Val(2); "exact multi should match from the Try")]
429    #[test_case(&[69], QueryResult::Val(70); "missing single should be defaulted")]
430    #[test_case(&[69, 420], QueryResult::Missing; "missing multi should always be missing")]
431    #[test_case(&[12], QueryResult::Partial; "partial should remain partial")]
432    #[test]
433    fn default_handlers_work(k: &[usize], expected: QueryResult<usize>) {
434        let mut t = Trie::from_pairs(vec![(vec![42], 1), (vec![12, 13], 2)]);
435        t.set_default(usize_default_handler);
436
437        assert_eq!(t.get(k), expected);
438
439        let expected_opt: Option<usize> = expected.into();
440        assert_eq!(t.get_exact(k), expected_opt);
441    }
442
443    #[test]
444    fn wildcards_match_correctly() {
445        fn is_valid(c: &char) -> bool {
446            *c == 'i' || *c == 'a'
447        }
448
449        let w = WildCard::Wild(is_valid);
450
451        assert!(w == 'a');
452        assert!(w != 'b');
453    }
454}