ad_editor/
trie.rs

1//! A simple trie data structure for supporting key bindings and autocompletions in a composible
2//! way with the rest of the ad internal APIs.
3use std::{collections::BTreeMap, fmt, ops::Range, sync::Arc};
4
5/// A singly initialised Trie mapping key sequences to a value.
6///
7/// It is not permitted for values to be mapped to a key that is a prefix of another key also
8/// existing in the same Trie.
9///
10/// There are convenience methods provided for `Trie<char, V>` for when &str values are used as keys.
11#[derive(Clone, PartialEq, Eq)]
12pub struct Trie<K, V>
13where
14    K: Clone + PartialEq + Ord,
15    V: Clone,
16{
17    nodes: Arc<[Node<K>]>,
18    values: Arc<[V]>,
19    n_roots: usize,
20}
21
22impl<K, V> fmt::Debug for Trie<K, V>
23where
24    K: Clone + PartialEq + Ord + fmt::Debug,
25    V: Clone + fmt::Debug,
26{
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        f.debug_struct("Trie")
29            .field("nodes", &self.nodes)
30            .field("values", &self.values)
31            .field("n_roots", &self.n_roots)
32            .finish()
33    }
34}
35
36impl<K, V> Default for Trie<K, V>
37where
38    K: Clone + PartialEq + Ord,
39    V: Clone,
40{
41    fn default() -> Self {
42        Self {
43            nodes: Arc::from(Vec::new()),
44            values: Arc::from(Vec::new()),
45            n_roots: 0,
46        }
47    }
48}
49
50impl<K, V> Trie<K, V>
51where
52    K: Clone + PartialEq + Ord,
53    V: Clone,
54{
55    /// Construct a new Trie from key-value pairs.
56    ///
57    /// Will panic if there are any key collisions or if there are any sequences that nest
58    /// under a prefix that is already required to hold a value
59    pub fn try_from_iter(it: impl IntoIterator<Item = (Vec<K>, V)>) -> Result<Self, &'static str> {
60        let mut roots = Vec::new();
61        for (key, value) in it.into_iter() {
62            insert(key, value, &mut roots)?;
63        }
64
65        if roots.is_empty() {
66            return Ok(Self::default());
67        }
68
69        let mut nodes = Vec::new();
70        let mut values = Vec::new();
71        let (_, n_roots) = flatten(roots, &mut nodes, &mut values);
72
73        Ok(Trie {
74            nodes: Arc::from(nodes),
75            values: Arc::from(values),
76            n_roots,
77        })
78    }
79
80    /// Merge two Tries.
81    ///
82    /// If the resulting Trie would be invalid to construct directly, an error is returned.
83    ///
84    /// Consumes both inputs.
85    pub fn merge(self, other: Self) -> Result<Self, &'static str> {
86        if self.is_empty() {
87            return Ok(other);
88        } else if other.is_empty() {
89            return Ok(self);
90        }
91
92        let mut pairs = Vec::with_capacity(self.len() + other.len());
93        self.extract_pairs(&mut pairs, Vec::new(), 0..self.n_roots);
94        other.extract_pairs(&mut pairs, Vec::new(), 0..other.n_roots);
95
96        Self::try_from_iter(pairs)
97    }
98
99    /// Merge two Tries preferring keys from `other` in the case of collisions.
100    ///
101    /// If the resulting Trie would be invalid to construct directly, an error is returned.
102    ///
103    /// Consumes both inputs.
104    pub fn merge_overriding(self, other: Self) -> Result<Self, &'static str> {
105        if self.is_empty() {
106            return Ok(other);
107        } else if other.is_empty() {
108            return Ok(self);
109        }
110
111        let mut pairs = Vec::with_capacity(self.len() + other.len());
112        self.extract_pairs(&mut pairs, Vec::new(), 0..self.n_roots);
113
114        let mut m = BTreeMap::from_iter(pairs.drain(..));
115        other.extract_pairs(&mut pairs, Vec::new(), 0..other.n_roots);
116        m.extend(pairs.drain(..));
117
118        Self::try_from_iter(m)
119    }
120
121    fn extract_pairs(&self, pairs: &mut Vec<(Vec<K>, V)>, key: Vec<K>, indices: Range<usize>) {
122        for i in indices {
123            let node = &self.nodes[i];
124            let mut child_key = key.clone();
125            child_key.push(node.key.clone());
126
127            match node.data {
128                Data::Leaf { i } => pairs.push((child_key, self.values[i].clone())),
129
130                Data::Internal {
131                    child_start,
132                    n_children,
133                } => self.extract_pairs(pairs, child_key, child_start..child_start + n_children),
134            }
135        }
136    }
137
138    /// Query this [Trie] for a given key or key prefix
139    ///
140    /// If the key maps to a leaf then the value is returned, if it maps to a sub-trie then
141    /// `Partial` is returned to denote that the given key is a parent of one or more values. If
142    /// the key is not found within the `Trie` then `Missing` is returned.
143    pub fn get<'a>(&'a self, key: &[K]) -> QueryResult<'a, V> {
144        if key.is_empty() {
145            return QueryResult::Missing;
146        }
147
148        let mut indices = 0..self.n_roots;
149        let mut key_index = 0;
150
151        'outer: while key_index < key.len() {
152            let target = &key[key_index];
153
154            // Binary search within the current level (assumes sorted children)
155            for i in indices {
156                let node = &self.nodes[i];
157                if &node.key == target {
158                    key_index += 1;
159
160                    match node.data {
161                        Data::Leaf { i } => {
162                            return if key_index == key.len() {
163                                QueryResult::Val(&self.values[i])
164                            } else {
165                                QueryResult::Missing
166                            };
167                        }
168
169                        Data::Internal {
170                            child_start,
171                            n_children,
172                        } => {
173                            indices = child_start..child_start + n_children;
174                            continue 'outer;
175                        }
176                    }
177                } else if node.key > *target {
178                    // We've moved past where the node would be in sorted order
179                    return QueryResult::Missing;
180                }
181            }
182
183            return QueryResult::Missing;
184        }
185
186        QueryResult::Partial
187    }
188
189    /// Query this [Trie] for a given key or key prefix requiring the key to match exactly.
190    ///
191    /// If the key maps to a leaf then the `Some(value)` is returned, otherwise `None`.
192    pub fn get_exact<'a>(&'a self, key: &[K]) -> Option<&'a V> {
193        self.get(key).into()
194    }
195
196    /// The number of leaf values in this Trie
197    pub fn len(&self) -> usize {
198        self.nodes.iter().filter(|n| n.is_leaf()).count()
199    }
200
201    /// Whether this Trie is empty
202    pub fn is_empty(&self) -> bool {
203        self.nodes.is_empty()
204    }
205}
206
207// Implementation for char-based convenience methods
208impl<V> Trie<char, V>
209where
210    V: Clone,
211{
212    /// Construct a new [Trie] with [char] internal keys from string keys.
213    pub fn from_str_keys(pairs: Vec<(&str, V)>) -> Result<Self, &'static str> {
214        let char_pairs: Vec<(Vec<char>, V)> = pairs
215            .into_iter()
216            .map(|(k, v)| (k.chars().collect(), v))
217            .collect();
218
219        Self::try_from_iter(char_pairs)
220    }
221
222    /// Query this [Trie] using a string key.
223    ///
224    /// Both full and partial matches are possible.
225    pub fn get_str<'a>(&'a self, key: &str) -> QueryResult<'a, V> {
226        self.get(&key.chars().collect::<Vec<_>>())
227    }
228
229    /// Query this [Trie] using a string key.
230    ///
231    /// Only fll matches will be returned.
232    pub fn get_str_exact<'a>(&'a self, key: &str) -> Option<&'a V> {
233        self.get_exact(&key.chars().collect::<Vec<_>>())
234    }
235}
236
237/// A single node within a [Trie].
238///
239/// Contains the last element of the key that traverses down to this node alongside [Data] that
240/// identifies this node as being internal or a leaf.
241#[derive(Debug, Clone, PartialEq, Eq)]
242struct Node<K>
243where
244    K: Clone + PartialEq + Ord,
245{
246    key: K,
247    data: Data,
248}
249
250impl<K> Node<K>
251where
252    K: Clone + PartialEq + PartialOrd + Ord,
253{
254    fn new_internal(key: K, child_start: usize, n_children: usize) -> Self {
255        Self {
256            key,
257            data: Data::Internal {
258                child_start,
259                n_children,
260            },
261        }
262    }
263
264    fn new_leaf(key: K, i: usize) -> Self {
265        Self {
266            key,
267            data: Data::Leaf { i },
268        }
269    }
270
271    fn is_leaf(&self) -> bool {
272        matches!(self.data, Data::Leaf { .. })
273    }
274}
275
276/// The internal data held at each node in a Trie.
277///
278/// Internal nodes are "pointers" to their children while leaves hold the value associated with the
279/// full key used to traverse down to them.
280#[derive(Debug, Copy, Clone, PartialEq, Eq)]
281enum Data {
282    Internal {
283        /// Index of the first child node
284        child_start: usize,
285        /// Number of child nodes
286        n_children: usize,
287    },
288    Leaf {
289        // Index of the value associated with the full key-path down to this node
290        i: usize,
291    },
292}
293
294#[derive(Debug)]
295struct BuildNode<K, V>
296where
297    K: PartialEq + Ord,
298{
299    k: K,
300    data: BuildNodeData<K, V>,
301}
302
303#[derive(Debug)]
304enum BuildNodeData<K, V>
305where
306    K: PartialEq + Ord,
307{
308    Internal(Vec<BuildNode<K, V>>),
309    Leaf(V),
310}
311
312fn insert<K, V>(
313    mut key: Vec<K>,
314    v: V,
315    current: &mut Vec<BuildNode<K, V>>,
316) -> Result<(), &'static str>
317where
318    K: PartialEq + Ord,
319{
320    for n in current.iter_mut() {
321        if key[0] == n.k {
322            if key.len() <= 1 {
323                return Err("duplicate entry for key");
324            }
325
326            key.remove(0);
327            return match &mut n.data {
328                BuildNodeData::Internal(nodes) => insert(key, v, nodes),
329                BuildNodeData::Leaf(_) => Err("attempt to insert into value node"),
330            };
331        }
332    }
333
334    let k = key.remove(0);
335
336    if key.is_empty() {
337        current.push(BuildNode {
338            k,
339            data: BuildNodeData::Leaf(v),
340        });
341    } else {
342        let mut children = vec![];
343        insert(key, v, &mut children)?;
344        current.push(BuildNode {
345            k,
346            data: BuildNodeData::Internal(children),
347        });
348    }
349
350    Ok(())
351}
352
353fn flatten<K, V>(
354    mut roots: Vec<BuildNode<K, V>>,
355    nodes: &mut Vec<Node<K>>,
356    values: &mut Vec<V>,
357) -> (usize, usize)
358where
359    K: Clone + PartialEq + Ord,
360    V: Clone,
361{
362    roots.sort_by(|l, r| l.k.cmp(&r.k));
363
364    let child_start = nodes.len();
365    let n_children = roots.len();
366
367    let mut child_stack = Vec::new();
368
369    // Insert roots first, storing any child nodes that need to be inserted later.
370    for BuildNode { k, data } in roots.into_iter() {
371        match data {
372            BuildNodeData::Internal(children) => {
373                let i = nodes.len();
374                nodes.push(Node::new_internal(k, 0, children.len()));
375                child_stack.push((i, children));
376            }
377
378            BuildNodeData::Leaf(v) => {
379                let i = values.len();
380                values.push(v);
381                nodes.push(Node::new_leaf(k, i))
382            }
383        }
384    }
385
386    // Insert the child nodes for each root node, updating their state now that we know the offsets
387    // of their children.
388    for (i, children) in child_stack.into_iter() {
389        let (start, _) = flatten(children, nodes, values);
390        match &mut nodes[i] {
391            Node {
392                data: Data::Internal { child_start, .. },
393                ..
394            } => {
395                *child_start = start;
396            }
397
398            _ => unreachable!(),
399        }
400    }
401
402    (child_start, n_children)
403}
404
405/// A default handler for mapping a single length key to an `Option<V>`.
406///
407/// This is used to avoid having to specify large numbers of single length keys that should all be
408/// handled in a similar way.
409pub type DefaultMapping<K, V> = fn(&K) -> Option<V>;
410
411/// The result of querying a [Trie] for a particular Key.
412#[derive(Debug, Clone, PartialEq, Eq)]
413pub enum QueryResult<'a, V> {
414    /// A leaf value associated with the key used in the query
415    Val(&'a V),
416    /// The key used to query is a prefix to multiple values
417    Partial,
418    /// The key does not exist within the [Trie]
419    Missing,
420}
421
422impl<'a, V> From<Option<&'a V>> for QueryResult<'a, V> {
423    fn from(opt: Option<&'a V>) -> Self {
424        match opt {
425            Some(v) => QueryResult::Val(v),
426            None => QueryResult::Missing,
427        }
428    }
429}
430
431impl<'a, V> From<QueryResult<'a, V>> for Option<&'a V> {
432    fn from(q: QueryResult<'a, V>) -> Self {
433        match q {
434            QueryResult::Val(v) => Some(v),
435            _ => None,
436        }
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use simple_test_case::test_case;
444
445    #[test]
446    fn duplicate_keys_errors() {
447        assert!(Trie::try_from_iter(vec![(vec![42], 1), (vec![42], 2)]).is_err());
448    }
449
450    #[test]
451    fn children_under_a_value_node_errors() {
452        assert!(Trie::try_from_iter(vec![(vec![42], 1), (vec![42, 69], 2)]).is_err());
453    }
454
455    #[test_case("foo", QueryResult::Val(&1); "val 1")]
456    #[test_case("bar", QueryResult::Val(&2); "val 2")]
457    #[test_case("baz", QueryResult::Val(&3); "val 3")]
458    #[test_case("ba", QueryResult::Partial; "partial 1")] // typos:ignore
459    #[test_case("fo", QueryResult::Partial; "partial 2")] // typos:ignore
460    #[test_case("barf", QueryResult::Missing; "overshoot")]
461    #[test_case("have you any wool?", QueryResult::Missing; "fully missing")]
462    #[test]
463    fn get_works(key: &str, expected: QueryResult<'_, usize>) {
464        let t = Trie::from_str_keys(vec![("foo", 1), ("bar", 2), ("baz", 3)]).unwrap();
465        assert_eq!(t.get_str(key), expected);
466    }
467
468    #[test_case(&[42], None; "partial should be None")]
469    #[test_case(&[144], None; "missing should be None")]
470    #[test_case(&[42, 69, 144], None; "overshoot should be None")]
471    #[test_case(&[42, 69], Some(1); "exact should be Some")]
472    #[test]
473    fn get_exact_works(key: &[usize], expected: Option<usize>) {
474        let t = Trie::try_from_iter(vec![(vec![42, 69], 1)]).unwrap();
475        assert_eq!(t.get_exact(key), expected.as_ref());
476    }
477
478    #[test_case("fo", None; "partial")] // typos:ignore
479    #[test_case("bar", None; "missing")]
480    #[test_case("fool", None; "overshoot")]
481    #[test_case("foo", Some(1); "found")]
482    #[test]
483    fn get_str_exact_works(key: &str, expected: Option<usize>) {
484        let t = Trie::from_str_keys(vec![("foo", 1)]).unwrap();
485        assert_eq!(t.get_str_exact(key), expected.as_ref());
486    }
487
488    #[test]
489    fn merge_works() {
490        let t1 = Trie::from_str_keys(vec![("foo", 1), ("bar", 2)]).unwrap();
491        let t2 = Trie::from_str_keys(vec![("baz", 3), ("qux", 4)]).unwrap();
492
493        let merged = t1.merge(t2).unwrap();
494
495        assert_eq!(merged.get_str_exact("foo"), Some(&1));
496        assert_eq!(merged.get_str_exact("bar"), Some(&2));
497        assert_eq!(merged.get_str_exact("baz"), Some(&3));
498        assert_eq!(merged.get_str_exact("qux"), Some(&4));
499        assert_eq!(merged.len(), 4);
500    }
501
502    #[test]
503    fn merge_conflicts_error() {
504        let t1 = Trie::from_str_keys(vec![("foo", 1)]).unwrap();
505        let t2 = Trie::from_str_keys(vec![("foo", 2)]).unwrap();
506
507        assert!(t1.merge(t2).is_err());
508    }
509
510    #[test]
511    fn merge_overriding_works() {
512        let t1 = Trie::from_str_keys(vec![("foo", 1), ("bar", 2)]).unwrap();
513        let t2 = Trie::from_str_keys(vec![("baz", 3), ("foo", 4)]).unwrap();
514
515        let merged = t1.merge_overriding(t2).unwrap();
516
517        assert_eq!(merged.get_str_exact("foo"), Some(&4));
518        assert_eq!(merged.get_str_exact("bar"), Some(&2));
519        assert_eq!(merged.get_str_exact("baz"), Some(&3));
520        assert_eq!(merged.len(), 3);
521    }
522
523    #[test]
524    fn merge_overriding_conflicts_are_ok() {
525        let t1 = Trie::from_str_keys(vec![("foo", 1)]).unwrap();
526        let t2 = Trie::from_str_keys(vec![("foo", 2)]).unwrap();
527
528        assert!(t1.merge_overriding(t2).is_ok());
529    }
530}