1#![deny(missing_docs)]
23
24#![feature(box_patterns)]
25#![feature(test)]
26
27#[cfg(test)]
28#[macro_use]
29extern crate quickcheck;
30
31#[cfg(test)]
32extern crate test;
33
34#[cfg(test)]
35extern crate tempfile;
36
37extern crate serde;
38#[macro_use]
39extern crate serde_derive;
40extern crate bincode;
41
42use std::fs::OpenOptions;
43use std::io::{self, BufReader, BufWriter};
44use std::collections::HashMap;
45
46use bincode::{serialize_into, deserialize_from, Infinite};
47
48pub use bincode::Error as BincodeError;
49
50type Result<T> = std::result::Result<T, BincodeError>;
51
52#[derive(Default, Debug, PartialEq, Serialize, Deserialize)]
54pub struct Trie<V> {
55 key: Option<char>,
56 children: HashMap<Option<char>, Trie<V>>,
57 contents: Option<V>,
58}
59
60impl<V> Trie<V> {
61 pub fn load_from_file(path: &str) -> Result<Self>
64 where
65 for<'de> V: serde::Serialize + serde::Deserialize<'de>,
66 {
67 let f = OpenOptions::new()
68 .read(true)
69 .write(true)
70 .create(true)
71 .open(path)
72 .expect("Couldn't open trie file");
73 let mut br = BufReader::new(f);
74 match deserialize_from(&mut br, Infinite) {
75 Ok(x) => Ok(x),
76 Err(box bincode::ErrorKind::IoError(e)) => {
77 if e.kind() == io::ErrorKind::UnexpectedEof {
78 return Ok(Trie {
79 key: None,
80 children: HashMap::new(),
81 contents: None,
82 });
83 }
84 Err(Box::new(bincode::ErrorKind::IoError(e)))
85 }
86 Err(e) => Err(e),
87 }
88 }
89
90 pub fn insert(&mut self, key: &str, contents: V) -> Option<V> {
93 let mut chars = key.chars();
94 let mut key_i_need = chars.next();
95 if self.key == key_i_need {
96 if chars.size_hint().0 == 0 {
97 let ret = self.contents.take();
98 self.contents = Some(contents);
99 return ret;
100 }
101 key_i_need = chars.next();
102 }
103 if let Some(c) = self.children.get_mut(&key_i_need) {
104 return c.insert(chars.as_str(), contents);
105 }
106 let mut trie = Trie {
107 key: key_i_need,
108 children: HashMap::new(),
109 contents: None,
110 };
111 trie.insert(chars.as_str(), contents);
112 self.children.insert(key_i_need, trie);
113 None
114 }
115
116 pub fn get(&self, key: &str) -> Option<&V> {
118 let mut chars = key.chars();
119 let mut key_i_need = chars.next();
120 if self.key == key_i_need {
121 if chars.size_hint().0 == 0 {
122 return self.contents.as_ref();
123 }
124 key_i_need = chars.next();
125 }
126 if let Some(c) = self.children.get(&key_i_need) {
127 return c.get(chars.as_str());
128 }
129 None
130 }
131
132 pub fn get_mut(&mut self, key: &str) -> Option<&mut V> {
134 let mut chars = key.chars();
135 let mut key_i_need = chars.next();
136 if self.key == key_i_need {
137 if chars.size_hint().0 == 0 {
138 return self.contents.as_mut();
139 }
140 key_i_need = chars.next();
141 }
142 if let Some(c) = self.children.get_mut(&key_i_need) {
143 return c.get_mut(chars.as_str());
144 }
145 None
146 }
147
148 pub fn save_to_file(&mut self, path: &str) -> Result<()>
150 where
151 for<'de> V: serde::Serialize + serde::Deserialize<'de>,
152 {
153 let f = OpenOptions::new()
154 .read(true)
155 .write(true)
156 .create(true)
157 .open(path)
158 .expect("Couldn't open trie file");
159 let mut bw = BufWriter::new(f);
160 serialize_into(&mut bw, self, Infinite)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use tempfile::NamedTempFile;
168 use std::collections::BTreeMap;
169 use quickcheck::TestResult;
170 use test::Bencher;
171
172 fn insertion_test_helper(mut v: Vec<(String, String)>, replace: bool) -> TestResult {
173 let v = v.iter_mut()
174 .map(|&mut (ref i, ref j)| {
175 let v = if !replace { j } else { "this_will_be_replaced" };
176 (i, v)
177 })
178 .collect::<Vec<_>>();
179 let mut t = Trie::default();
180 let mut bt = BTreeMap::new();
181 for &(i, j) in v.iter() {
182 assert_eq!(t.insert(i, j), bt.insert(i, j));
183 }
184 for &(i, _) in v.iter() {
185 assert_eq!(t.get(i), bt.get(i));
186 }
187 TestResult::from_bool(true)
188 }
189
190 #[test]
191 fn basic_insertion() -> () {
192 let testcases = vec![
193 (String::from("def"), String::from("contents1")),
194 (String::from("abc"), String::from("contents2")),
195 (String::from("abf"), String::from("contents3")),
196 ];
197 insertion_test_helper(testcases, false);
198 }
199
200 #[test]
201 fn test_get_mut() -> () {
202 let mut trie = Trie::default();
203 trie.insert("abc", vec!["test1"]);
204
205 {
206 let thing_to_modify = trie.get_mut("abc").unwrap();
207 thing_to_modify.push("test2");
208 }
209
210 assert_eq!(*trie.get("abc").unwrap(), vec!["test1", "test2"]);
211 }
212
213 quickcheck! {
214 fn random_insertion(v: Vec<(String, String)>) -> TestResult {
215 insertion_test_helper(v, false)
216 }
217 fn replace_insertion(v: Vec<(String, String)>) -> TestResult {
218 insertion_test_helper(v, true)
219 }
220 }
221
222 #[test]
223 fn save_to_file_roundtrip() -> () {
224 let trie_file = NamedTempFile::new().expect("failed to create temporary file");
225 let trie_file_name = trie_file.path().to_str().unwrap();
226
227 let mut trie = Trie::default();
228 trie.insert("abc", String::from("contents1"));
229 trie.insert("abd", String::from("contents2"));
230 trie.insert("hello", String::from("world"));
231
232 trie.save_to_file(trie_file_name).expect(
233 "Couldn't save trie to file",
234 );
235 let trie2 = Trie::load_from_file(trie_file_name).expect("Couldn't load trie from file");
236 assert_eq!(trie, trie2);
237 }
238
239 #[test]
240 fn load_from_empty_file() -> () {
241 let trie_file = NamedTempFile::new().expect("failed to create temporary file");
242 let trie_file_name = trie_file.path().to_str().unwrap();
243
244 let trie = Trie::<Trie<String>>::load_from_file(trie_file_name)
245 .expect("Couldn't load trie from file");
246 assert_eq!(trie.key, None);
247 assert_eq!(trie.children.len(), 0);
248 assert_eq!(trie.contents, None);
249 }
250
251 #[bench]
252 fn bench_many_children(b: &mut Bencher) {
253 let mut trie = Trie::default();
254 let utf8_max_char = 128;
255 let last_child = &String::from_utf8(vec![utf8_max_char - 1]).unwrap();
256 for i in 0..utf8_max_char {
257 trie.insert(&String::from_utf8(vec![i]).unwrap(), "");
258 }
259 assert_eq!((utf8_max_char) as usize, trie.children.len());
260 b.iter(|| trie.get(&last_child));
261 }
262}