1#[cfg(feature = "serde")]
2extern crate serde;
3pub mod metrics;
4
5use std::{
6 borrow::Borrow,
7 collections::VecDeque,
8 fmt::{Debug, Formatter, Result as FmtResult},
9 iter::Extend,
10};
11
12#[cfg(feature = "enable-fnv")]
13extern crate fnv;
14#[cfg(feature = "enable-fnv")]
15use fnv::FnvHashMap;
16
17#[cfg(not(feature = "enable-fnv"))]
18use std::collections::HashMap;
19
20pub trait Metric<K: ?Sized> {
31 fn distance(&self, a: &K, b: &K) -> u32;
32 fn threshold_distance(&self, a: &K, b: &K, threshold: u32) -> Option<u32>;
33}
34
35#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37struct BKNode<K> {
38 key: K,
40 #[cfg(feature = "enable-fnv")]
43 children: FnvHashMap<u32, BKNode<K>>,
44 #[cfg(not(feature = "enable-fnv"))]
45 children: HashMap<u32, BKNode<K>>,
46 max_child_distance: Option<u32>,
47}
48
49impl<K> BKNode<K> {
50 pub fn new(key: K) -> BKNode<K> {
52 BKNode {
53 key,
54 #[cfg(feature = "enable-fnv")]
55 children: fnv::FnvHashMap::default(),
56 #[cfg(not(feature = "enable-fnv"))]
57 children: HashMap::default(),
58 max_child_distance: None,
59 }
60 }
61
62 pub fn add_child(&mut self, distance: u32, key: K) {
78 self.children.insert(distance, BKNode::new(key));
79 self.max_child_distance = self.max_child_distance.max(Some(distance));
80 }
81}
82
83impl<K> Debug for BKNode<K>
84where
85 K: Debug,
86{
87 fn fmt(&self, f: &mut Formatter) -> FmtResult {
88 f.debug_map().entry(&self.key, &self.children).finish()
89 }
90}
91
92#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
94pub struct BKTree<K, M = metrics::Levenshtein> {
95 root: Option<BKNode<K>>,
97 metric: M,
100}
101
102impl<K, M> BKTree<K, M>
103where
104 M: Metric<K>,
105{
106 pub fn new(metric: M) -> BKTree<K, M> {
123 BKTree { root: None, metric }
124 }
125
126 pub fn add(&mut self, key: K) {
147 match self.root {
148 Some(ref mut root) => {
149 let mut cur_node = root;
150 let mut cur_dist = self.metric.distance(&cur_node.key, &key);
151 while cur_node.children.contains_key(&cur_dist) && cur_dist > 0 {
152 let current = cur_node;
156 let next_node = current.children.get_mut(&cur_dist).unwrap();
157
158 cur_node = next_node;
159 cur_dist = self.metric.distance(&cur_node.key, &key);
160 }
161 if cur_dist > 0 {
163 cur_node.add_child(cur_dist, key);
164 }
165 }
166 None => {
167 self.root = Some(BKNode::new(key));
168 }
169 }
170 }
171
172 pub fn find<'a, 'q, Q: ?Sized>(&'a self, key: &'q Q, tolerance: u32) -> Find<'a, 'q, K, Q, M>
198 where
199 K: Borrow<Q>,
200 M: Metric<Q>,
201 {
202 let candidates = if let Some(root) = &self.root {
203 VecDeque::from(vec![root])
204 } else {
205 VecDeque::new()
206 };
207 Find {
208 candidates,
209 tolerance,
210 metric: &self.metric,
211 key,
212 }
213 }
214
215 pub fn find_exact<Q: ?Sized>(&self, key: &Q) -> Option<&K>
234 where
235 K: Borrow<Q>,
236 M: Metric<Q>,
237 {
238 self.find(key, 0).next().map(|(_, found_key)| found_key)
239 }
240}
241
242impl<K, M: Metric<K>> Extend<K> for BKTree<K, M> {
243 fn extend<I: IntoIterator<Item = K>>(&mut self, keys: I) {
258 for key in keys {
259 self.add(key);
260 }
261 }
262}
263
264impl<K: AsRef<str>> Default for BKTree<K> {
265 fn default() -> BKTree<K> {
266 BKTree::new(metrics::Levenshtein)
267 }
268}
269
270pub struct Find<'a, 'q, K: 'a, Q: 'q + ?Sized, M: 'a> {
272 candidates: VecDeque<&'a BKNode<K>>,
275 tolerance: u32,
276 metric: &'a M,
277 key: &'q Q,
278}
279
280impl<'a, 'q, K, Q: ?Sized, M> Iterator for Find<'a, 'q, K, Q, M>
281where
282 K: Borrow<Q>,
283 M: Metric<Q>,
284{
285 type Item = (u32, &'a K);
286
287 fn next(&mut self) -> Option<(u32, &'a K)> {
288 while let Some(current) = self.candidates.pop_front() {
289 let BKNode {
290 key,
291 children,
292 max_child_distance,
293 } = current;
294 let distance_cutoff = max_child_distance.unwrap_or(0) + self.tolerance;
295 let cur_dist = self.metric.threshold_distance(
296 self.key,
297 current.key.borrow() as &Q,
298 distance_cutoff,
299 );
300 if let Some(dist) = cur_dist {
301 let min_dist = dist.saturating_sub(self.tolerance);
303 let max_dist = dist.saturating_add(self.tolerance);
304 for (dist, child_node) in &mut children.iter() {
305 if min_dist <= *dist && *dist <= max_dist {
306 self.candidates.push_back(child_node);
307 }
308 }
309 if dist <= self.tolerance {
311 return Some((dist, &key));
312 }
313 }
314 }
315 None
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 extern crate bincode;
322
323 use std::fmt::Debug;
324 use {BKNode, BKTree};
325
326 fn assert_eq_sorted<'t, T: 't, I>(left: I, right: &[(u32, T)])
327 where
328 T: Ord + Debug,
329 I: Iterator<Item = (u32, &'t T)>,
330 {
331 let mut left_mut: Vec<_> = left.collect();
332 let mut right_mut: Vec<_> = right.iter().map(|&(dist, ref key)| (dist, key)).collect();
333
334 left_mut.sort();
335 right_mut.sort();
336
337 assert_eq!(left_mut, right_mut);
338 }
339
340 #[test]
341 fn node_construct() {
342 let node: BKNode<&str> = BKNode::new("foo");
343 assert_eq!(node.key, "foo");
344 assert!(node.children.is_empty());
345 }
346
347 #[test]
348 fn tree_construct() {
349 let tree: BKTree<&str> = Default::default();
350 assert!(tree.root.is_none());
351 }
352
353 #[test]
354 fn tree_add() {
355 let mut tree: BKTree<&str> = Default::default();
356 tree.add("foo");
357 match tree.root {
358 Some(ref root) => {
359 assert_eq!(root.key, "foo");
360 }
361 None => {
362 assert!(false);
363 }
364 }
365 tree.add("fop");
366 tree.add("f\u{e9}\u{e9}");
367 match tree.root {
368 Some(ref root) => {
369 assert_eq!(root.children.get(&1).unwrap().key, "fop");
370 assert_eq!(root.children.get(&4).unwrap().key, "f\u{e9}\u{e9}");
371 }
372 None => {
373 assert!(false);
374 }
375 }
376 }
377
378 #[test]
379 fn tree_extend() {
380 let mut tree: BKTree<&str> = Default::default();
381 tree.extend(vec!["foo", "fop"]);
382 match tree.root {
383 Some(ref root) => {
384 assert_eq!(root.key, "foo");
385 }
386 None => {
387 assert!(false);
388 }
389 }
390 assert_eq!(tree.root.unwrap().children.get(&1).unwrap().key, "fop");
391 }
392
393 #[test]
394 fn tree_find() {
395 let mut tree: BKTree<&str> = Default::default();
400 tree.add("book");
401 tree.add("books");
402 tree.add("cake");
403 tree.add("boo");
404 tree.add("cape");
405 tree.add("boon");
406 tree.add("cook");
407 tree.add("cart");
408 assert_eq_sorted(tree.find("caqe", 1), &[(1, "cake"), (1, "cape")]);
409 assert_eq_sorted(tree.find("cape", 1), &[(1, "cake"), (0, "cape")]);
410 assert_eq_sorted(
411 tree.find("book", 1),
412 &[
413 (0, "book"),
414 (1, "books"),
415 (1, "boo"),
416 (1, "boon"),
417 (1, "cook"),
418 ],
419 );
420 assert_eq_sorted(tree.find("book", 0), &[(0, "book")]);
421 assert!(tree.find("foobar", 1).next().is_none());
422 }
423
424 #[test]
425 fn tree_find_exact() {
426 let mut tree: BKTree<&str> = Default::default();
427 tree.add("book");
428 tree.add("books");
429 tree.add("cake");
430 tree.add("boo");
431 tree.add("cape");
432 tree.add("boon");
433 tree.add("cook");
434 tree.add("cart");
435 assert_eq!(tree.find_exact("caqe"), None);
436 assert_eq!(tree.find_exact("cape"), Some(&"cape"));
437 assert_eq!(tree.find_exact("book"), Some(&"book"));
438 }
439
440 #[test]
441 fn one_node_tree() {
442 let mut tree: BKTree<&str> = Default::default();
443 tree.add("book");
444 tree.add("book");
445 assert_eq!(tree.root.unwrap().children.len(), 0);
446 }
447
448 #[cfg(feature = "serde")]
449 #[test]
450 fn test_serialization() {
451 let mut tree: BKTree<&str> = Default::default();
452 tree.add("book");
453 tree.add("books");
454 tree.add("cake");
455 tree.add("boo");
456 tree.add("cape");
457 tree.add("boon");
458 tree.add("cook");
459 tree.add("cart");
460
461 assert_eq_sorted(tree.find("book", 0), &[(0, "book")]);
463 assert_eq_sorted(tree.find("books", 0), &[(0, "books")]);
464 assert_eq_sorted(tree.find("cake", 0), &[(0, "cake")]);
465 assert_eq_sorted(tree.find("boo", 0), &[(0, "boo")]);
466 assert_eq_sorted(tree.find("cape", 0), &[(0, "cape")]);
467 assert_eq_sorted(tree.find("boon", 0), &[(0, "boon")]);
468 assert_eq_sorted(tree.find("cook", 0), &[(0, "cook")]);
469 assert_eq_sorted(tree.find("cart", 0), &[(0, "cart")]);
470
471 assert_eq_sorted(
473 tree.find("book", 1),
474 &[
475 (0, "book"),
476 (1, "books"),
477 (1, "boo"),
478 (1, "boon"),
479 (1, "cook"),
480 ],
481 );
482
483 assert_eq!(None, tree.find_exact("This &str hasn't been added"));
485
486 let encoded_tree: Vec<u8> = bincode::serialize(&tree).unwrap();
487 let decoded_tree: BKTree<&str> = bincode::deserialize(&encoded_tree[..]).unwrap();
488
489 assert_eq_sorted(decoded_tree.find("book", 0), &[(0, "book")]);
491 assert_eq_sorted(decoded_tree.find("books", 0), &[(0, "books")]);
492 assert_eq_sorted(decoded_tree.find("cake", 0), &[(0, "cake")]);
493 assert_eq_sorted(decoded_tree.find("boo", 0), &[(0, "boo")]);
494 assert_eq_sorted(decoded_tree.find("cape", 0), &[(0, "cape")]);
495 assert_eq_sorted(decoded_tree.find("boon", 0), &[(0, "boon")]);
496 assert_eq_sorted(decoded_tree.find("cook", 0), &[(0, "cook")]);
497 assert_eq_sorted(decoded_tree.find("cart", 0), &[(0, "cart")]);
498
499 assert_eq_sorted(
501 decoded_tree.find("book", 1),
502 &[
503 (0, "book"),
504 (1, "books"),
505 (1, "boo"),
506 (1, "boon"),
507 (1, "cook"),
508 ],
509 );
510
511 assert_eq!(None, decoded_tree.find_exact("This &str hasn't been added"));
513 }
514}