bk_tree/
lib.rs

1#[cfg(feature = "serde")]
2extern crate serde;
3pub mod metrics;
4
5use std::{
6    borrow::Borrow,
7    collections::VecDeque,
8    fmt::{Debug, Formatter, Result as FmtResult},
9    iter::Extend,
10};
11
12#[cfg(feature = "enable-fnv")]
13extern crate fnv;
14#[cfg(feature = "enable-fnv")]
15use fnv::FnvHashMap;
16
17#[cfg(not(feature = "enable-fnv"))]
18use std::collections::HashMap;
19
20/// A trait for a *metric* (distance function).
21///
22/// Implementations should follow the metric axioms:
23///
24/// * **Zero**: `distance(a, b) == 0` if and only if `a == b`
25/// * **Symmetry**: `distance(a, b) == distance(b, a)`
26/// * **Triangle inequality**: `distance(a, c) <= distance(a, b) + distance(b, c)`
27///
28/// If any of these rules are broken, then the BK-tree may give unexpected
29/// results.
30pub trait Metric<K: ?Sized> {
31    fn distance(&self, a: &K, b: &K) -> u32;
32    fn threshold_distance(&self, a: &K, b: &K, threshold: u32) -> Option<u32>;
33}
34
35/// A node within the [BK-tree](https://en.wikipedia.org/wiki/BK-tree).
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37struct BKNode<K> {
38    /// The key determining the node.
39    key: K,
40    /// A hash-map of children, indexed by their distance from this node based
41    /// on the metric being used by the tree.
42    #[cfg(feature = "enable-fnv")]
43    children: FnvHashMap<u32, BKNode<K>>,
44    #[cfg(not(feature = "enable-fnv"))]
45    children: HashMap<u32, BKNode<K>>,
46    max_child_distance: Option<u32>,
47}
48
49impl<K> BKNode<K> {
50    /// Constructs a new `BKNode<K>`.
51    pub fn new(key: K) -> BKNode<K> {
52        BKNode {
53            key,
54            #[cfg(feature = "enable-fnv")]
55            children: fnv::FnvHashMap::default(),
56            #[cfg(not(feature = "enable-fnv"))]
57            children: HashMap::default(),
58            max_child_distance: None,
59        }
60    }
61
62    /// Add a child to the node.
63    ///
64    /// Given the distance from this node's key, add the given key as a child
65    /// node. *Warning:* this does not test the invariant that the distance as
66    /// measured by the tree between this node's key and the provided key
67    /// actually matches the distance passed in.
68    ///
69    /// # Examples
70    ///
71    /// ```ignore
72    /// use bk_tree::BKNode;
73    ///
74    /// let mut foo = BKNode::new("foo");
75    /// foo.add_child(1, "fop");
76    /// ```
77    pub fn add_child(&mut self, distance: u32, key: K) {
78        self.children.insert(distance, BKNode::new(key));
79        self.max_child_distance = self.max_child_distance.max(Some(distance));
80    }
81}
82
83impl<K> Debug for BKNode<K>
84where
85    K: Debug,
86{
87    fn fmt(&self, f: &mut Formatter) -> FmtResult {
88        f.debug_map().entry(&self.key, &self.children).finish()
89    }
90}
91
92/// A representation of a [BK-tree](https://en.wikipedia.org/wiki/BK-tree).
93#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
94pub struct BKTree<K, M = metrics::Levenshtein> {
95    /// The root node. May be empty if nothing has been put in the tree yet.
96    root: Option<BKNode<K>>,
97    /// The metric being used to determine the distance between nodes on the
98    /// tree.
99    metric: M,
100}
101
102impl<K, M> BKTree<K, M>
103where
104    M: Metric<K>,
105{
106    /// Constructs a new `BKTree<K>` using the provided metric.
107    ///
108    /// Note that we make no assumptions about the metric function provided.
109    /// *Ideally* it is actually a
110    /// [valid metric](https://en.wikipedia.org/wiki/Metric_(mathematics)),
111    /// but you may choose to use one that is not technically a valid metric.
112    /// If you do not use a valid metric, however, you may find that the tree
113    /// behaves confusingly for some values.
114    ///
115    /// # Examples
116    ///
117    /// ```
118    /// use bk_tree::{BKTree, metrics};
119    ///
120    /// let tree: BKTree<&str> = BKTree::new(metrics::Levenshtein);
121    /// ```
122    pub fn new(metric: M) -> BKTree<K, M> {
123        BKTree { root: None, metric }
124    }
125
126    /// Adds a key to the tree.
127    ///
128    /// If the tree is empty, this simply sets the root to
129    /// `Some(BKNode::new(key))`. Otherwise, we iterate downwards through the
130    /// tree until we see a node that does not have a child with the same
131    /// distance. If we encounter a node that is exactly the same distance from
132    /// the root node, then the new key is the same as that node's key and so we
133    /// do nothing. **Note**: This means that if your metric allows for unequal
134    /// keys to return 0, you will see improper behavior!
135    ///
136    /// # Examples
137    ///
138    /// ```
139    /// use bk_tree::{BKTree, metrics};
140    ///
141    /// let mut tree: BKTree<&str> = BKTree::new(metrics::Levenshtein);
142    ///
143    /// tree.add("foo");
144    /// tree.add("bar");
145    /// ```
146    pub fn add(&mut self, key: K) {
147        match self.root {
148            Some(ref mut root) => {
149                let mut cur_node = root;
150                let mut cur_dist = self.metric.distance(&cur_node.key, &key);
151                while cur_node.children.contains_key(&cur_dist) && cur_dist > 0 {
152                    // We have to do some moving around here to safely get the
153                    // child corresponding to the current distance away without
154                    // accidentally trying to mutate the wrong thing.
155                    let current = cur_node;
156                    let next_node = current.children.get_mut(&cur_dist).unwrap();
157
158                    cur_node = next_node;
159                    cur_dist = self.metric.distance(&cur_node.key, &key);
160                }
161                // If cur_dist == 0, we have landed on a node with the same key.
162                if cur_dist > 0 {
163                    cur_node.add_child(cur_dist, key);
164                }
165            }
166            None => {
167                self.root = Some(BKNode::new(key));
168            }
169        }
170    }
171
172    /// Searches for a key in the BK-tree given a certain tolerance.
173    ///
174    /// This traverses the tree searching for all keys with distance within
175    /// `tolerance` of of the key provided. The tolerance may be zero, in which
176    /// case this searches for exact matches. The results are returned as an
177    /// iterator of `(distance, key)` pairs.
178    ///
179    /// *Note:* There is no guarantee on the order of elements yielded by the
180    /// iterator. The elements returned may or may not be sorted in terms of
181    /// distance from the provided key.
182    ///
183    /// # Examples
184    /// ```
185    /// use bk_tree::{BKTree, metrics};
186    ///
187    /// let mut tree: BKTree<&str> = BKTree::new(metrics::Levenshtein);
188    ///
189    /// tree.add("foo");
190    /// tree.add("fop");
191    /// tree.add("bar");
192    ///
193    /// assert_eq!(tree.find("foo", 0).collect::<Vec<_>>(), vec![(0, &"foo")]);
194    /// assert_eq!(tree.find("foo", 1).collect::<Vec<_>>(), vec![(0, &"foo"), (1, &"fop")]);
195    /// assert!(tree.find("foz", 0).next().is_none());
196    /// ```
197    pub fn find<'a, 'q, Q: ?Sized>(&'a self, key: &'q Q, tolerance: u32) -> Find<'a, 'q, K, Q, M>
198    where
199        K: Borrow<Q>,
200        M: Metric<Q>,
201    {
202        let candidates = if let Some(root) = &self.root {
203            VecDeque::from(vec![root])
204        } else {
205            VecDeque::new()
206        };
207        Find {
208            candidates,
209            tolerance,
210            metric: &self.metric,
211            key,
212        }
213    }
214
215    /// Searches for an exact match in the tree.
216    ///
217    /// This is equivalent to calling `find` with a tolerance of 0, then picking
218    /// out the first result.
219    ///
220    /// # Examples
221    /// ```
222    /// use bk_tree::{BKTree, metrics};
223    ///
224    /// let mut tree: BKTree<&str> = BKTree::new(metrics::Levenshtein);
225    ///
226    /// tree.add("foo");
227    /// tree.add("fop");
228    /// tree.add("bar");
229    ///
230    /// assert_eq!(tree.find_exact("foz"), None);
231    /// assert_eq!(tree.find_exact("foo"), Some(&"foo"));
232    /// ```
233    pub fn find_exact<Q: ?Sized>(&self, key: &Q) -> Option<&K>
234    where
235        K: Borrow<Q>,
236        M: Metric<Q>,
237    {
238        self.find(key, 0).next().map(|(_, found_key)| found_key)
239    }
240}
241
242impl<K, M: Metric<K>> Extend<K> for BKTree<K, M> {
243    /// Adds multiple keys to the tree.
244    ///
245    /// Given an iterator with items of type `K`, this method simply adds every
246    /// item to the tree.
247    ///
248    /// # Examples
249    ///
250    /// ```
251    /// use bk_tree::{BKTree, metrics};
252    ///
253    /// let mut tree: BKTree<&str> = BKTree::new(metrics::Levenshtein);
254    ///
255    /// tree.extend(vec!["foo", "bar"]);
256    /// ```
257    fn extend<I: IntoIterator<Item = K>>(&mut self, keys: I) {
258        for key in keys {
259            self.add(key);
260        }
261    }
262}
263
264impl<K: AsRef<str>> Default for BKTree<K> {
265    fn default() -> BKTree<K> {
266        BKTree::new(metrics::Levenshtein)
267    }
268}
269
270/// Iterator for the results of `BKTree::find`.
271pub struct Find<'a, 'q, K: 'a, Q: 'q + ?Sized, M: 'a> {
272    /// Iterator stack. Because of the inversion of control in play, we must
273    /// implement the traversal using an explicit stack.
274    candidates: VecDeque<&'a BKNode<K>>,
275    tolerance: u32,
276    metric: &'a M,
277    key: &'q Q,
278}
279
280impl<'a, 'q, K, Q: ?Sized, M> Iterator for Find<'a, 'q, K, Q, M>
281where
282    K: Borrow<Q>,
283    M: Metric<Q>,
284{
285    type Item = (u32, &'a K);
286
287    fn next(&mut self) -> Option<(u32, &'a K)> {
288        while let Some(current) = self.candidates.pop_front() {
289            let BKNode {
290                key,
291                children,
292                max_child_distance,
293            } = current;
294            let distance_cutoff = max_child_distance.unwrap_or(0) + self.tolerance;
295            let cur_dist = self.metric.threshold_distance(
296                self.key,
297                current.key.borrow() as &Q,
298                distance_cutoff,
299            );
300            if let Some(dist) = cur_dist {
301                // Find the first child node within an appropriate distance
302                let min_dist = dist.saturating_sub(self.tolerance);
303                let max_dist = dist.saturating_add(self.tolerance);
304                for (dist, child_node) in &mut children.iter() {
305                    if min_dist <= *dist && *dist <= max_dist {
306                        self.candidates.push_back(child_node);
307                    }
308                }
309                // If this node is also close enough to the key, yield it
310                if dist <= self.tolerance {
311                    return Some((dist, &key));
312                }
313            }
314        }
315        None
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    extern crate bincode;
322
323    use std::fmt::Debug;
324    use {BKNode, BKTree};
325
326    fn assert_eq_sorted<'t, T: 't, I>(left: I, right: &[(u32, T)])
327    where
328        T: Ord + Debug,
329        I: Iterator<Item = (u32, &'t T)>,
330    {
331        let mut left_mut: Vec<_> = left.collect();
332        let mut right_mut: Vec<_> = right.iter().map(|&(dist, ref key)| (dist, key)).collect();
333
334        left_mut.sort();
335        right_mut.sort();
336
337        assert_eq!(left_mut, right_mut);
338    }
339
340    #[test]
341    fn node_construct() {
342        let node: BKNode<&str> = BKNode::new("foo");
343        assert_eq!(node.key, "foo");
344        assert!(node.children.is_empty());
345    }
346
347    #[test]
348    fn tree_construct() {
349        let tree: BKTree<&str> = Default::default();
350        assert!(tree.root.is_none());
351    }
352
353    #[test]
354    fn tree_add() {
355        let mut tree: BKTree<&str> = Default::default();
356        tree.add("foo");
357        match tree.root {
358            Some(ref root) => {
359                assert_eq!(root.key, "foo");
360            }
361            None => {
362                assert!(false);
363            }
364        }
365        tree.add("fop");
366        tree.add("f\u{e9}\u{e9}");
367        match tree.root {
368            Some(ref root) => {
369                assert_eq!(root.children.get(&1).unwrap().key, "fop");
370                assert_eq!(root.children.get(&4).unwrap().key, "f\u{e9}\u{e9}");
371            }
372            None => {
373                assert!(false);
374            }
375        }
376    }
377
378    #[test]
379    fn tree_extend() {
380        let mut tree: BKTree<&str> = Default::default();
381        tree.extend(vec!["foo", "fop"]);
382        match tree.root {
383            Some(ref root) => {
384                assert_eq!(root.key, "foo");
385            }
386            None => {
387                assert!(false);
388            }
389        }
390        assert_eq!(tree.root.unwrap().children.get(&1).unwrap().key, "fop");
391    }
392
393    #[test]
394    fn tree_find() {
395        /*
396         * This example tree is from
397         * https://nullwords.wordpress.com/2013/03/13/the-bk-tree-a-data-structure-for-spell-checking/
398         */
399        let mut tree: BKTree<&str> = Default::default();
400        tree.add("book");
401        tree.add("books");
402        tree.add("cake");
403        tree.add("boo");
404        tree.add("cape");
405        tree.add("boon");
406        tree.add("cook");
407        tree.add("cart");
408        assert_eq_sorted(tree.find("caqe", 1), &[(1, "cake"), (1, "cape")]);
409        assert_eq_sorted(tree.find("cape", 1), &[(1, "cake"), (0, "cape")]);
410        assert_eq_sorted(
411            tree.find("book", 1),
412            &[
413                (0, "book"),
414                (1, "books"),
415                (1, "boo"),
416                (1, "boon"),
417                (1, "cook"),
418            ],
419        );
420        assert_eq_sorted(tree.find("book", 0), &[(0, "book")]);
421        assert!(tree.find("foobar", 1).next().is_none());
422    }
423
424    #[test]
425    fn tree_find_exact() {
426        let mut tree: BKTree<&str> = Default::default();
427        tree.add("book");
428        tree.add("books");
429        tree.add("cake");
430        tree.add("boo");
431        tree.add("cape");
432        tree.add("boon");
433        tree.add("cook");
434        tree.add("cart");
435        assert_eq!(tree.find_exact("caqe"), None);
436        assert_eq!(tree.find_exact("cape"), Some(&"cape"));
437        assert_eq!(tree.find_exact("book"), Some(&"book"));
438    }
439
440    #[test]
441    fn one_node_tree() {
442        let mut tree: BKTree<&str> = Default::default();
443        tree.add("book");
444        tree.add("book");
445        assert_eq!(tree.root.unwrap().children.len(), 0);
446    }
447
448    #[cfg(feature = "serde")]
449    #[test]
450    fn test_serialization() {
451        let mut tree: BKTree<&str> = Default::default();
452        tree.add("book");
453        tree.add("books");
454        tree.add("cake");
455        tree.add("boo");
456        tree.add("cape");
457        tree.add("boon");
458        tree.add("cook");
459        tree.add("cart");
460
461        // Test exact search (zero tolerance)
462        assert_eq_sorted(tree.find("book", 0), &[(0, "book")]);
463        assert_eq_sorted(tree.find("books", 0), &[(0, "books")]);
464        assert_eq_sorted(tree.find("cake", 0), &[(0, "cake")]);
465        assert_eq_sorted(tree.find("boo", 0), &[(0, "boo")]);
466        assert_eq_sorted(tree.find("cape", 0), &[(0, "cape")]);
467        assert_eq_sorted(tree.find("boon", 0), &[(0, "boon")]);
468        assert_eq_sorted(tree.find("cook", 0), &[(0, "cook")]);
469        assert_eq_sorted(tree.find("cart", 0), &[(0, "cart")]);
470
471        // Test fuzzy search
472        assert_eq_sorted(
473            tree.find("book", 1),
474            &[
475                (0, "book"),
476                (1, "books"),
477                (1, "boo"),
478                (1, "boon"),
479                (1, "cook"),
480            ],
481        );
482
483        // Test for false positives
484        assert_eq!(None, tree.find_exact("This &str hasn't been added"));
485
486        let encoded_tree: Vec<u8> = bincode::serialize(&tree).unwrap();
487        let decoded_tree: BKTree<&str> = bincode::deserialize(&encoded_tree[..]).unwrap();
488
489        // Test exact search (zero tolerance)
490        assert_eq_sorted(decoded_tree.find("book", 0), &[(0, "book")]);
491        assert_eq_sorted(decoded_tree.find("books", 0), &[(0, "books")]);
492        assert_eq_sorted(decoded_tree.find("cake", 0), &[(0, "cake")]);
493        assert_eq_sorted(decoded_tree.find("boo", 0), &[(0, "boo")]);
494        assert_eq_sorted(decoded_tree.find("cape", 0), &[(0, "cape")]);
495        assert_eq_sorted(decoded_tree.find("boon", 0), &[(0, "boon")]);
496        assert_eq_sorted(decoded_tree.find("cook", 0), &[(0, "cook")]);
497        assert_eq_sorted(decoded_tree.find("cart", 0), &[(0, "cart")]);
498
499        // Test fuzzy search
500        assert_eq_sorted(
501            decoded_tree.find("book", 1),
502            &[
503                (0, "book"),
504                (1, "books"),
505                (1, "boo"),
506                (1, "boon"),
507                (1, "cook"),
508            ],
509        );
510
511        // Test for false positives
512        assert_eq!(None, decoded_tree.find_exact("This &str hasn't been added"));
513    }
514}