1#![deny(warnings)]
2
3use std::mem;
4
5#[cfg(feature = "node4")]
6mod node4;
7
8#[cfg(feature = "node4")]
9use self::node4::Node4 as DefaultNode;
10
11#[cfg(feature = "node16")]
12mod node16;
13
14#[cfg(all(not(feature = "node4"), feature = "node16"))]
15use self::node16::Node16 as DefaultNode;
16
17#[cfg(feature = "node48")]
18mod node48;
19
20#[cfg(all(not(feature = "node4"), not(feature = "node16"), feature = "node48"))]
21use self::node48::Node48 as DefaultNode;
22
23mod node256;
25
26#[cfg(all(not(feature = "node4"), not(feature = "node16"), not(feature = "node48")))]
27use self::node256::Node256 as DefaultNode;
28
29pub struct Trie<'a, T> {
30 root: Option<Child<'a, T>>,
31 term: u8,
32}
33
34#[derive(Debug)]
35pub struct KeyContainsTerminator;
36
37impl<'a, T> Trie<'a, T> {
38 pub fn with_terminator(term: u8) -> Trie<'a, T> {
39 Trie {
40 root: None,
41 term: term,
42 }
43 }
44
45 pub fn for_ascii() -> Trie<'a, T> {
46 Self::with_terminator(0)
47 }
48
49 pub fn for_utf8() -> Trie<'a, T> {
50 Self::with_terminator(0xff)
51 }
52
53 pub fn insert(&mut self, key: &[u8], value: T) -> Result<Option<T>, KeyContainsTerminator> {
54 if !key.contains(&self.term) {
55 Ok(self.insert_impl(key, value))
56 } else {
57 Err(KeyContainsTerminator)
58 }
59 }
60
61 pub unsafe fn insert_unchecked(&mut self, key: &[u8], value: T) -> Option<T> {
62 self.insert_impl(key, value)
63 }
64
65 fn insert_impl(&mut self, key: &[u8], value: T) -> Option<T> {
66 match self.root {
67 None => {
68 let mut node = Node::new();
69 let inserted = node.insert(key, value, self.term);
70 self.root = Some(Child::Node(node));
71 inserted
72 }
73 Some(Child::Node(ref mut node)) => node.insert(key, value, self.term),
74 Some(Child::Leaf(_)) => unreachable!(),
75 }
76 }
77
78 pub fn contains(&self, key: &[u8]) -> Result<bool, KeyContainsTerminator> {
79 if !key.contains(&self.term) {
80 Ok(self.contains_impl(key))
81 } else {
82 Err(KeyContainsTerminator)
83 }
84 }
85
86 pub unsafe fn contains_unchecked(&self, key: &[u8]) -> bool {
87 self.contains_impl(key)
88 }
89
90 fn contains_impl(&self, key: &[u8]) -> bool {
91 match self.root {
92 None => false,
93 Some(Child::Node(ref node)) => node.contains(key, self.term),
94 Some(Child::Leaf(_)) => unreachable!(),
95 }
96 }
97
98 pub fn get(&self, key: &[u8]) -> Result<Option<&T>, KeyContainsTerminator> {
99 if !key.contains(&self.term) {
100 Ok(self.get_impl(key))
101 } else {
102 Err(KeyContainsTerminator)
103 }
104 }
105
106 pub unsafe fn get_unchecked(&self, key: &[u8]) -> Option<&T> {
107 self.get_impl(key)
108 }
109
110 fn get_impl(&self, key: &[u8]) -> Option<&T> {
111 match self.root {
112 None => None,
113 Some(Child::Node(ref node)) => node.get(key, self.term),
114 Some(Child::Leaf(_)) => unreachable!(),
115 }
116 }
117
118 pub fn is_empty(&self) -> bool {
119 self.root.is_none()
120 }
121}
122
123struct Node<'a, T: 'a>(Box<dyn NodeImpl<'a, T> + 'a>);
124
125trait NodeImpl<'a, T> {
126 fn insert_child(&mut self, key: u8, child: Child<'a, T>) -> Result<Option<Child<'a, T>>, Child<'a, T>>;
127
128 fn update_child(&mut self, key: u8, child: Child<'a, T>) -> Result<(), Child<'a, T>>;
129
130 fn find_child(&self, key: u8) -> Option<&Child<'a, T>>;
131
132 fn upgrade(self: Box<Self>) -> Box<dyn NodeImpl<'a, T> + 'a>;
133}
134
135impl<'a, T> Node<'a, T> {
136 fn new() -> Self {
137 Node(Box::new(DefaultNode::default()))
138 }
139
140 fn insert(&mut self, key: &[u8], value: T, term: u8) -> Option<T> {
141 if key.is_empty() {
142 self.insert_child(term, Child::Leaf(value))
143 .map(|n| n.to_leaf().unwrap())
144 } else {
145 self.update_child(key[0], Child::Node(Node::new()));
146 let child = self.find_child_mut(key[0]).unwrap().as_node_mut().unwrap();
147 child.insert(&key[1..], value, term)
148 }
149 }
150
151 fn contains(&self, key: &[u8], term: u8) -> bool {
152 self.get(key, term).is_some()
153 }
154
155 fn get(&self, key: &[u8], term: u8) -> Option<&T> {
156 if key.is_empty() {
157 self.find_child(term)
158 .map(|n| n.as_leaf().unwrap())
159 } else {
160 self.find_child(key[0])
161 .and_then(|n| n.as_node())
162 .and_then(|node| node.get(&key[1..], term))
163 }
164 }
165
166 fn insert_child(&mut self, key: u8, child: Child<'a, T>) -> Option<Child<'a, T>> {
167 let result = self.0.insert_child(key, child);
168 match result {
169 Ok(replaced_child) => replaced_child,
170 Err(child) => {
171 self.upgrade();
172 self.insert_child(key, child)
173 }
174 }
175 }
176
177 fn update_child(&mut self, key: u8, child: Child<'a, T>) {
178 let result = self.0.update_child(key, child);
179 if let Err(child) = result {
180 self.upgrade();
181 self.update_child(key, child)
182 }
183 }
184
185 fn find_child(&self, key: u8) -> Option<&Child<'a, T>> {
186 self.0.find_child(key)
187 }
188
189 fn upgrade(&mut self) {
190 take_mut::take(&mut self.0, NodeImpl::upgrade);
191 }
192
193 fn find_child_mut(&mut self, key: u8) -> Option<&mut Child<'_, T>> {
194 unsafe { mem::transmute(self.find_child(key)) }
195 }
196}
197
198enum Child<'a, T: 'a> {
199 Node(Node<'a, T>),
200 Leaf(T),
201}
202
203impl<'a, T> Child<'a, T> {
204 fn as_node(&self) -> Option<&Node<'a, T>> {
205 if let Child::Node(ref node) = self {
206 Some(node)
207 } else {
208 None
209 }
210 }
211
212 fn as_node_mut(&mut self) -> Option<&mut Node<'a, T>> {
213 if let Child::Node(ref mut node) = self {
214 Some(node)
215 } else {
216 None
217 }
218 }
219
220 fn as_leaf(&self) -> Option<&T> {
221 if let Child::Leaf(ref value) = self {
222 Some(value)
223 } else {
224 None
225 }
226 }
227
228 fn to_leaf(self) -> Option<T> {
229 if let Child::Leaf(value) = self {
230 Some(value)
231 } else {
232 None
233 }
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use std::fmt::Debug;
241
242 #[test]
243 fn test_readme_insert_lookup_example() {
244 let mut map = Trie::for_utf8();
245 map.insert(b"a", 0).unwrap();
246 map.insert(b"ac", 1).unwrap();
247
248 assert_eq!(map.get(b"a").unwrap(), Some(&0));
249 assert_eq!(map.get(b"ac").unwrap(), Some(&1));
250 assert_eq!(map.get(b"ab").unwrap(), None);
251 }
252
253 trait TrieTestExtensions<T: Clone + PartialEq + Debug> {
254 fn check_insertion(&mut self, key: &[u8], value: T);
255
256 fn check_existence(&mut self, key: &[u8], value: T);
257 }
258
259 impl<'a, T: 'a + Clone + PartialEq + Debug> TrieTestExtensions<T> for Trie<'a, T> {
260 fn check_insertion(&mut self, key: &[u8], value: T) {
261 self.insert(key, value.clone()).unwrap();
262 self.check_existence(key, value);
263 }
264
265 fn check_existence(&mut self, key: &[u8], value: T) {
266 assert_eq!(self.get(key).unwrap(), Some(&value));
267 }
268 }
269
270 #[test]
271 fn it_works() {
272 let mut trie = Trie::for_utf8();
273 trie.check_insertion(b"the answer", 42);
274 }
275
276 #[test]
277 fn it_works_for_empty_strings() {
278 let mut trie = Trie::for_utf8();
279 trie.check_insertion(b"", 1);
280 }
281
282 #[test]
283 fn it_is_empty_by_default() {
284 let trie = Trie::<()>::for_utf8();
285 assert!(trie.is_empty());
286 }
287
288 #[test]
289 fn it_doesnt_overwrite_entries_with_a_common_prefix() {
290 let mut trie = Trie::for_utf8();
291 trie.insert(b"a", 1).unwrap();
292 trie.insert(b"ab", 2).unwrap();
293 assert_eq!(trie.get(b"a").unwrap(), Some(&1));
294 assert_eq!(trie.get(b"ab").unwrap(), Some(&2));
295 }
296
297 #[test]
298 fn it_can_store_more_than_4_parallel_entries() {
299 let mut trie = Trie::for_utf8();
300 trie.check_insertion(b"a", 1);
302 trie.check_insertion(b"b", 2);
303 trie.check_insertion(b"c", 3);
304 trie.check_insertion(b"d", 4);
305 trie.check_insertion(b"e", 5);
306 trie.check_existence(b"a", 1);
308 trie.check_existence(b"b", 2);
309 trie.check_existence(b"c", 3);
310 trie.check_existence(b"d", 4);
311 trie.check_existence(b"e", 5);
312 }
313
314 #[test]
315 fn it_can_store_more_than_16_parallel_entries() {
316 let mut trie = Trie::for_utf8();
317 trie.check_insertion(b"a", 1);
319 trie.check_insertion(b"c", 2);
320 trie.check_insertion(b"d", 3);
321 trie.check_insertion(b"e", 4);
322 trie.check_insertion(b"f", 5);
323 trie.check_insertion(b"g", 6);
324 trie.check_insertion(b"h", 7);
325 trie.check_insertion(b"i", 8);
326 trie.check_insertion(b"j", 9);
327 trie.check_insertion(b"k", 10);
328 trie.check_insertion(b"l", 11);
329 trie.check_insertion(b"m", 12);
330 trie.check_insertion(b"n", 13);
331 trie.check_insertion(b"o", 14);
332 trie.check_insertion(b"p", 15);
333 trie.check_insertion(b"q", 16);
334 trie.check_insertion(b"r", 17);
335 trie.check_existence(b"a", 1);
337 trie.check_existence(b"c", 2);
338 trie.check_existence(b"d", 3);
339 trie.check_existence(b"e", 4);
340 trie.check_existence(b"f", 5);
341 trie.check_existence(b"g", 6);
342 trie.check_existence(b"h", 7);
343 trie.check_existence(b"i", 8);
344 trie.check_existence(b"j", 9);
345 trie.check_existence(b"k", 10);
346 trie.check_existence(b"l", 11);
347 trie.check_existence(b"m", 12);
348 trie.check_existence(b"n", 13);
349 trie.check_existence(b"o", 14);
350 trie.check_existence(b"p", 15);
351 trie.check_existence(b"q", 16);
352 trie.check_existence(b"r", 17);
353 }
354}