artsy/
lib.rs

1#![deny(warnings)]
2
3use std::mem;
4
5#[cfg(feature = "node4")]
6mod node4;
7
8#[cfg(feature = "node4")]
9use self::node4::Node4 as DefaultNode;
10
11#[cfg(feature = "node16")]
12mod node16;
13
14#[cfg(all(not(feature = "node4"), feature = "node16"))]
15use self::node16::Node16 as DefaultNode;
16
17#[cfg(feature = "node48")]
18mod node48;
19
20#[cfg(all(not(feature = "node4"), not(feature = "node16"), feature = "node48"))]
21use self::node48::Node48 as DefaultNode;
22
23// always included
24mod node256;
25
26#[cfg(all(not(feature = "node4"), not(feature = "node16"), not(feature = "node48")))]
27use self::node256::Node256 as DefaultNode;
28
29pub struct Trie<'a, T> {
30    root: Option<Child<'a, T>>,
31    term: u8,
32}
33
34#[derive(Debug)]
35pub struct KeyContainsTerminator;
36
37impl<'a, T> Trie<'a, T> {
38    pub fn with_terminator(term: u8) -> Trie<'a, T> {
39        Trie {
40            root: None,
41            term: term,
42        }
43    }
44
45    pub fn for_ascii() -> Trie<'a, T> {
46        Self::with_terminator(0)
47    }
48
49    pub fn for_utf8() -> Trie<'a, T> {
50        Self::with_terminator(0xff)
51    }
52
53    pub fn insert(&mut self, key: &[u8], value: T) -> Result<Option<T>, KeyContainsTerminator> {
54        if !key.contains(&self.term) {
55            Ok(self.insert_impl(key, value))
56        } else {
57            Err(KeyContainsTerminator)
58        }
59    }
60
61    pub unsafe fn insert_unchecked(&mut self, key: &[u8], value: T) -> Option<T> {
62        self.insert_impl(key, value)
63    }
64
65    fn insert_impl(&mut self, key: &[u8], value: T) -> Option<T> {
66        match self.root {
67            None => {
68                let mut node = Node::new();
69                let inserted = node.insert(key, value, self.term);
70                self.root = Some(Child::Node(node));
71                inserted
72            }
73            Some(Child::Node(ref mut node)) => node.insert(key, value, self.term),
74            Some(Child::Leaf(_))            => unreachable!(),
75        }
76    }
77
78    pub fn contains(&self, key: &[u8]) -> Result<bool, KeyContainsTerminator> {
79        if !key.contains(&self.term) {
80            Ok(self.contains_impl(key))
81        } else {
82            Err(KeyContainsTerminator)
83        }
84    }
85
86    pub unsafe fn contains_unchecked(&self, key: &[u8]) -> bool {
87        self.contains_impl(key)
88    }
89
90    fn contains_impl(&self, key: &[u8]) -> bool {
91        match self.root {
92            None                        => false,
93            Some(Child::Node(ref node)) => node.contains(key, self.term),
94            Some(Child::Leaf(_))        => unreachable!(),
95        }
96    }
97
98    pub fn get(&self, key: &[u8]) -> Result<Option<&T>, KeyContainsTerminator> {
99        if !key.contains(&self.term) {
100            Ok(self.get_impl(key))
101        } else {
102            Err(KeyContainsTerminator)
103        }
104    }
105
106    pub unsafe fn get_unchecked(&self, key: &[u8]) -> Option<&T> {
107        self.get_impl(key)
108    }
109
110    fn get_impl(&self, key: &[u8]) -> Option<&T> {
111        match self.root {
112            None                        => None,
113            Some(Child::Node(ref node)) => node.get(key, self.term),
114            Some(Child::Leaf(_))        => unreachable!(),
115        }
116    }
117
118    pub fn is_empty(&self) -> bool {
119        self.root.is_none()
120    }
121}
122
123struct Node<'a, T: 'a>(Box<dyn NodeImpl<'a, T> + 'a>);
124
125trait NodeImpl<'a, T> {
126    fn insert_child(&mut self, key: u8, child: Child<'a, T>) -> Result<Option<Child<'a, T>>, Child<'a, T>>;
127
128    fn update_child(&mut self, key: u8, child: Child<'a, T>) -> Result<(), Child<'a, T>>;
129
130    fn find_child(&self, key: u8) -> Option<&Child<'a, T>>;
131
132    fn upgrade(self: Box<Self>) -> Box<dyn NodeImpl<'a, T> + 'a>;
133}
134
135impl<'a, T> Node<'a, T> {
136    fn new() -> Self {
137        Node(Box::new(DefaultNode::default()))
138    }
139
140    fn insert(&mut self, key: &[u8], value: T, term: u8) -> Option<T> {
141        if key.is_empty() {
142            self.insert_child(term, Child::Leaf(value))
143                .map(|n| n.to_leaf().unwrap())
144        } else {
145            self.update_child(key[0], Child::Node(Node::new()));
146            let child = self.find_child_mut(key[0]).unwrap().as_node_mut().unwrap();
147            child.insert(&key[1..], value, term)
148        }
149    }
150
151    fn contains(&self, key: &[u8], term: u8) -> bool {
152        self.get(key, term).is_some()
153    }
154
155    fn get(&self, key: &[u8], term: u8) -> Option<&T> {
156        if key.is_empty() {
157            self.find_child(term)
158                .map(|n| n.as_leaf().unwrap())
159        } else {
160            self.find_child(key[0])
161                .and_then(|n| n.as_node())
162                .and_then(|node| node.get(&key[1..], term))
163        }
164    }
165
166    fn insert_child(&mut self, key: u8, child: Child<'a, T>) -> Option<Child<'a, T>> {
167        let result = self.0.insert_child(key, child);
168        match result {
169            Ok(replaced_child) => replaced_child,
170            Err(child)         => {
171                self.upgrade();
172                self.insert_child(key, child)
173            }
174        }
175    }
176
177    fn update_child(&mut self, key: u8, child: Child<'a, T>) {
178        let result = self.0.update_child(key, child);
179        if let Err(child) = result {
180            self.upgrade();
181            self.update_child(key, child)
182        }
183    }
184
185    fn find_child(&self, key: u8) -> Option<&Child<'a, T>> {
186        self.0.find_child(key)
187    }
188
189    fn upgrade(&mut self) {
190        take_mut::take(&mut self.0, NodeImpl::upgrade);
191    }
192
193    fn find_child_mut(&mut self, key: u8) -> Option<&mut Child<'_, T>> {
194        unsafe { mem::transmute(self.find_child(key)) }
195    }
196}
197
198enum Child<'a, T: 'a> {
199    Node(Node<'a, T>),
200    Leaf(T),
201}
202
203impl<'a, T> Child<'a, T> {
204    fn as_node(&self) -> Option<&Node<'a, T>> {
205        if let Child::Node(ref node) = self {
206            Some(node)
207        } else {
208            None
209        }
210    }
211
212    fn as_node_mut(&mut self) -> Option<&mut Node<'a, T>> {
213        if let Child::Node(ref mut node) = self {
214            Some(node)
215        } else {
216            None
217        }
218    }
219
220    fn as_leaf(&self) -> Option<&T> {
221        if let Child::Leaf(ref value) = self {
222            Some(value)
223        } else {
224            None
225        }
226    }
227
228    fn to_leaf(self) -> Option<T> {
229        if let Child::Leaf(value) = self {
230            Some(value)
231        } else {
232            None
233        }
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use std::fmt::Debug;
241
242    #[test]
243    fn test_readme_insert_lookup_example() {
244        let mut map = Trie::for_utf8();
245        map.insert(b"a", 0).unwrap();
246        map.insert(b"ac", 1).unwrap();
247
248        assert_eq!(map.get(b"a").unwrap(), Some(&0));
249        assert_eq!(map.get(b"ac").unwrap(), Some(&1));
250        assert_eq!(map.get(b"ab").unwrap(), None);
251    }
252
253    trait TrieTestExtensions<T: Clone + PartialEq + Debug> {
254        fn check_insertion(&mut self, key: &[u8], value: T);
255
256        fn check_existence(&mut self, key: &[u8], value: T);
257    }
258
259    impl<'a, T: 'a + Clone + PartialEq + Debug> TrieTestExtensions<T> for Trie<'a, T> {
260        fn check_insertion(&mut self, key: &[u8], value: T) {
261            self.insert(key, value.clone()).unwrap();
262            self.check_existence(key, value);
263        }
264
265        fn check_existence(&mut self, key: &[u8], value: T) {
266            assert_eq!(self.get(key).unwrap(), Some(&value));
267        }
268    }
269
270    #[test]
271    fn it_works() {
272        let mut trie = Trie::for_utf8();
273        trie.check_insertion(b"the answer", 42);
274    }
275
276    #[test]
277    fn it_works_for_empty_strings() {
278        let mut trie = Trie::for_utf8();
279        trie.check_insertion(b"", 1);
280    }
281
282    #[test]
283    fn it_is_empty_by_default() {
284        let trie = Trie::<()>::for_utf8();
285        assert!(trie.is_empty());
286    }
287
288    #[test]
289    fn it_doesnt_overwrite_entries_with_a_common_prefix() {
290        let mut trie = Trie::for_utf8();
291        trie.insert(b"a", 1).unwrap();
292        trie.insert(b"ab", 2).unwrap();
293        assert_eq!(trie.get(b"a").unwrap(), Some(&1));
294        assert_eq!(trie.get(b"ab").unwrap(), Some(&2));
295    }
296
297    #[test]
298    fn it_can_store_more_than_4_parallel_entries() {
299        let mut trie = Trie::for_utf8();
300        // 1) insert
301        trie.check_insertion(b"a", 1);
302        trie.check_insertion(b"b", 2);
303        trie.check_insertion(b"c", 3);
304        trie.check_insertion(b"d", 4);
305        trie.check_insertion(b"e", 5);
306        // 2) verify
307        trie.check_existence(b"a", 1);
308        trie.check_existence(b"b", 2);
309        trie.check_existence(b"c", 3);
310        trie.check_existence(b"d", 4);
311        trie.check_existence(b"e", 5);
312    }
313
314    #[test]
315    fn it_can_store_more_than_16_parallel_entries() {
316        let mut trie = Trie::for_utf8();
317        // 1) insert
318        trie.check_insertion(b"a", 1);
319        trie.check_insertion(b"c", 2);
320        trie.check_insertion(b"d", 3);
321        trie.check_insertion(b"e", 4);
322        trie.check_insertion(b"f", 5);
323        trie.check_insertion(b"g", 6);
324        trie.check_insertion(b"h", 7);
325        trie.check_insertion(b"i", 8);
326        trie.check_insertion(b"j", 9);
327        trie.check_insertion(b"k", 10);
328        trie.check_insertion(b"l", 11);
329        trie.check_insertion(b"m", 12);
330        trie.check_insertion(b"n", 13);
331        trie.check_insertion(b"o", 14);
332        trie.check_insertion(b"p", 15);
333        trie.check_insertion(b"q", 16);
334        trie.check_insertion(b"r", 17);
335        // 2) verify
336        trie.check_existence(b"a", 1);
337        trie.check_existence(b"c", 2);
338        trie.check_existence(b"d", 3);
339        trie.check_existence(b"e", 4);
340        trie.check_existence(b"f", 5);
341        trie.check_existence(b"g", 6);
342        trie.check_existence(b"h", 7);
343        trie.check_existence(b"i", 8);
344        trie.check_existence(b"j", 9);
345        trie.check_existence(b"k", 10);
346        trie.check_existence(b"l", 11);
347        trie.check_existence(b"m", 12);
348        trie.check_existence(b"n", 13);
349        trie.check_existence(b"o", 14);
350        trie.check_existence(b"p", 15);
351        trie.check_existence(b"q", 16);
352        trie.check_existence(b"r", 17);
353    }
354}