trie_of_lists/
lib.rs

1use std::{collections::HashMap, hash::Hash};
2
3#[derive(Debug, Default)]
4struct TrieNode<K, V>
5where
6    K: Hash + Eq + Default,
7    V: Hash + Eq + Clone + Default,
8{
9    children: HashMap<K, TrieNode<K, V>>,
10    value: Option<V>,
11}
12
13pub struct Trie<K, V>
14where
15    K: Hash + Eq + Default,
16    V: Hash + Eq + Clone + Default,
17{
18    root: TrieNode<K, V>,
19}
20
21impl<K, V> Trie<K, V>
22where
23    K: Hash + Eq + Default,
24    V: Hash + Eq + Clone + Default,
25{
26    /// Create empty Trie
27    pub fn new() -> Self {
28        Trie {
29            root: TrieNode::default(),
30        }
31    }
32
33    /// Get a copy of the value associated with the key in O(len(key)) time
34    pub fn get<I>(&self, key: I) -> Option<V>
35    where
36        I: IntoIterator<Item = K>,
37    {
38        self.traverse(key).and_then(|node| node.value.clone())
39    }
40
41    /// Check if Trie contains the key in O(len(key)) time
42    pub fn contains<I>(&self, key: I) -> bool
43    where
44        I: IntoIterator<Item = K>,
45    {
46        self.traverse(key)
47            .map_or(false, |node| node.value.is_some())
48    }
49
50    /// ### About
51    /// Finds the key and value of the longest entry with prefix key
52    ///
53    /// ### Example
54    /// Assume trie contains the following keys and values of the form (key) -> (matching_key, matching_value)
55    /// - (four, score, and) -> seven
56    /// - (four, score, and, seven) -> years
57    ///
58    /// The query `best_match (four, score, and, seven, years, ago)` will return ((four, score, and, seven), `years`)
59    pub fn best_match<I>(&self, key: I) -> Option<(Vec<K>, V)>
60    where
61        I: IntoIterator<Item = K>,
62    {
63        let mut cur: &TrieNode<K, V> = &self.root;
64        let mut cur_key = Vec::new();
65        let mut cur_value = None;
66        for part in key {
67            if let Some(v) = cur.children.get(&part) {
68                cur_key.push(part);
69                cur = v;
70                if let Some(new_match) = cur.value.as_ref() {
71                    cur_value.replace(new_match.clone());
72                }
73            } else {
74                break;
75            }
76        }
77        match cur_value {
78            Some(value) => Some((cur_key, value)),
79            None => None,
80        }
81    }
82
83    /// Inserts key and value into Trie, overriding any previous value
84    pub fn insert<I>(&mut self, key: I, value: V)
85    where
86        I: IntoIterator<Item = K>,
87    {
88        let mut cur = &mut self.root;
89        for part in key {
90            cur = cur.children.entry(part).or_insert(TrieNode::default());
91        }
92        cur.value = Some(value);
93    }
94
95    /// Helper function to traverse the Trie
96    fn traverse<I>(&self, key: I) -> Option<&TrieNode<K, V>>
97    where
98        I: IntoIterator<Item = K>,
99    {
100        let mut cur: &TrieNode<K, V> = &self.root;
101        for part in key {
102            match cur.children.get(&part) {
103                Some(v) => cur = v,
104                None => return None,
105            }
106        }
107        Some(cur)
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use std::path::{Path, PathBuf};
114
115    use super::*;
116
117    fn path_to_vec(path: &Path) -> Vec<String> {
118        path.components()
119            .map(|c| c.as_os_str().to_string_lossy().into_owned())
120            .collect()
121    }
122
123    #[test]
124    fn test_all() {
125        let mut trie: Trie<String, PathBuf> = Trie::new();
126        let src_path1 = Path::new("/etc/bin/echos");
127        let src_path2 = Path::new("/etc/bin/echo");
128        let src_path3 = Path::new("/etc/bin/echo/hello.txt");
129
130        let dest_path1 = PathBuf::from("usr/cat");
131        let dest_path2 = PathBuf::from("usr/tar");
132
133        let longer_path1 = Path::new("/etc/bin/echo/hello.txt/jello");
134        trie.insert(path_to_vec(src_path1), dest_path1.clone());
135        trie.insert(path_to_vec(src_path3), dest_path2.clone());
136
137        assert!(trie.contains(path_to_vec(src_path1)));
138        assert!(trie.contains(path_to_vec(src_path3)));
139        assert_eq!(trie.get(path_to_vec(src_path1)), Some(dest_path1.clone()));
140        assert_eq!(trie.get(path_to_vec(src_path3)), Some(dest_path2.clone()));
141        assert_eq!(
142            trie.best_match(path_to_vec(src_path1)),
143            Some((path_to_vec(src_path1), dest_path1.clone()))
144        );
145        assert_eq!(
146            trie.best_match(path_to_vec(longer_path1)),
147            Some((path_to_vec(src_path3), dest_path2.clone()))
148        );
149        assert!(!trie.contains(path_to_vec(src_path2)));
150    }
151}