flx/
search.rs

1// Copyright 2015 Jerome Rasky <jerome@rasky.co>
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT
5// or http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7use std::collections::{HashMap, BinaryHeap};
8use std::cmp::Ordering;
9use std::iter::FromIterator;
10
11use unicode_normalization::UnicodeNormalization;
12
13use constants::*;
14
15/// Contains the searchable database
16#[derive(Debug)]
17pub struct SearchBase {
18    lines: Vec<LineInfo>,
19}
20
21/// Parsed information about a line, ready to be searched by a SearchBase.
22#[derive(Debug)]
23pub struct LineInfo {
24    line: String,
25    char_map: HashMap<char, Vec<usize>>,
26    heat_map: Vec<f32>,
27    factor: f32,
28}
29
30#[derive(PartialEq, Eq)]
31enum CharClass {
32    Whitespace,
33    Numeric,
34    Alphabetic,
35    First,
36    Other,
37}
38
39#[derive(Debug)]
40struct LineMatch<'a> {
41    score: f32,
42    factor: f32,
43    line: &'a str,
44}
45
46impl<'a> Ord for LineMatch<'a> {
47    fn cmp(&self, other: &LineMatch) -> Ordering {
48        match self.score.partial_cmp(&other.score) {
49            Some(Ordering::Equal) | None => {
50                self.factor
51                    .partial_cmp(&other.factor)
52                    .unwrap_or(Ordering::Equal)
53            }
54            Some(order) => order,
55        }
56    }
57}
58
59impl<'a> PartialOrd for LineMatch<'a> {
60    fn partial_cmp(&self, other: &LineMatch) -> Option<Ordering> {
61        Some(self.cmp(other))
62    }
63}
64
65impl<'a> PartialEq for LineMatch<'a> {
66    fn eq(&self, other: &LineMatch) -> bool {
67        self.cmp(other) == Ordering::Equal
68    }
69}
70
71impl<'a> Eq for LineMatch<'a> {}
72
73/// Creates a LineInfo object with a factor of zero
74impl<T: Into<String>> From<T> for LineInfo {
75    fn from(item: T) -> LineInfo {
76        LineInfo::new(item, 0.0)
77    }
78}
79
80impl<V: Into<LineInfo>> FromIterator<V> for SearchBase {
81    fn from_iter<T: IntoIterator<Item = V>>(iterator: T) -> SearchBase {
82        SearchBase::new(iterator.into_iter().map(|item| item.into()).collect())
83    }
84}
85
86impl SearchBase {
87    /// Construct a new SearchBase from a Vec of LineInfos.
88    pub fn new(lines: Vec<LineInfo>) -> SearchBase {
89        SearchBase { lines: lines }
90    }
91
92    /// Perform a query of the SearchBase.
93    ///
94    /// number limits the number of matches returned.
95    ///
96    /// Matches any supersequence of the given query, with heuristics to order
97    /// matches based on how close they are to the given query.
98    pub fn query<'a, T: AsRef<str>>(&'a self, query: T, number: usize) -> Vec<&'a str> {
99        let query = query.as_ref();
100        if query.is_empty() {
101            // non-matching query
102            return vec![];
103        }
104
105        let mut matches: BinaryHeap<LineMatch> = BinaryHeap::with_capacity(number);
106
107        let composed: Vec<char> = query.nfkc().filter(|ch| !ch.is_whitespace()).collect();
108
109        for item in self.lines.iter() {
110            let score = match item.score(&composed) {
111                None => {
112                    // non-matching line
113                    continue;
114                }
115                Some(score) => score,
116            };
117
118            let match_item = LineMatch {
119                score: -score,
120                factor: -item.factor,
121                line: &item.line,
122            };
123
124            if matches.len() < number {
125                matches.push(match_item);
126            } else if let Some(mut other_item) = matches.peek_mut() {
127                if &match_item < &*other_item {
128                    // replace the "greatest" item with ours
129                    *other_item = match_item;
130                }
131            } else {
132                unreachable!("No item to peek at, but number of items greater than zero");
133            }
134        }
135
136        matches.into_sorted_vec().into_iter().map(|x| x.line).collect()
137    }
138}
139
140impl LineInfo {
141    /// Constructs a new LineInfo objects from the given item.
142    ///
143    /// Factor is a "tie-breaker," or something to weight the matches in a way
144    /// beyond the matching already done in flx. The greater the factor, the
145    /// more greatly matching favors the item.
146    pub fn new<T: Into<String>>(item: T, factor: f32) -> LineInfo {
147        let mut map: HashMap<char, Vec<usize>> = HashMap::new();
148        let mut heat = vec![];
149        let line = item.into();
150
151        let mut ws_score: f32 = 0.0;
152        let mut cs_score: f32 = 0.0;
153        let mut cur_class = CharClass::First;
154        let mut cs_change = false;
155
156        for (idx, c) in line.nfkc().enumerate() {
157            if idx > MAX_LEN {
158                break;
159            }
160
161            if !c.is_whitespace() {
162                if cur_class == CharClass::First {
163                    cs_score += FIRST_FACTOR;
164                }
165            }
166
167            if c.is_whitespace() {
168                cur_class = CharClass::Whitespace;
169                ws_score = WHITESPACE_FACTOR;
170            } else if c.is_numeric() {
171                if cur_class != CharClass::Numeric {
172                    cur_class = CharClass::Numeric;
173                    if !cs_change {
174                        cs_score += CLASS_FACTOR;
175                        cs_change = true;
176                    }
177                } else {
178                    cs_change = false;
179                }
180            } else if c.is_alphabetic() {
181                if cur_class != CharClass::Alphabetic {
182                    cur_class = CharClass::Alphabetic;
183                    if !cs_change {
184                        cs_score += CLASS_FACTOR;
185                        cs_change = true;
186                    }
187                } else {
188                    cs_change = false;
189                }
190            } else {
191                if cur_class != CharClass::Other {
192                    cur_class = CharClass::Other;
193                    if !cs_change {
194                        cs_score += CLASS_FACTOR;
195                        cs_change = true;
196                    }
197                } else {
198                    cs_change = false;
199                }
200            }
201
202            if cur_class != CharClass::Whitespace {
203                map.entry(c).or_insert(Vec::default()).push(idx);
204                if c.is_uppercase() {
205                    for lc in c.to_lowercase() {
206                        map.entry(lc).or_insert(Vec::default()).push(idx);
207                    }
208                }
209            }
210
211            heat.push(ws_score + cs_score);
212
213            ws_score *= WHITESPACE_REDUCE;
214            if !cs_change {
215                cs_score *= CLASS_REDUCE;
216            }
217        }
218
219        LineInfo {
220            line: line,
221            char_map: map,
222            heat_map: heat,
223            factor: factor,
224        }
225    }
226
227    /// Sets the factor for the line info
228    ///
229    /// Changes the factor after the creation of the line
230    pub fn set_factor(&mut self, factor: f32) {
231        self.factor = factor;
232    }
233
234    /// Gets the factor for the line info
235    ///
236    /// Produces the factor for the line info
237    pub fn get_factor(&self) -> f32 {
238        self.factor
239    }
240
241    fn score_position(&self, position: &[usize]) -> f32 {
242        let avg_dist: f32;
243
244        if position.len() < 2 {
245            avg_dist = 0.0;
246        } else {
247            avg_dist = position.windows(2)
248                               .map(|pair| pair[1] as f32 - pair[0] as f32)
249                               .sum::<f32>() / position.len() as f32;
250        }
251
252        let heat_sum: f32 = position.iter()
253                                    .map(|idx| self.heat_map[*idx])
254                                    .sum();
255
256        avg_dist * DIST_WEIGHT + heat_sum * HEAT_WEIGHT + self.factor * FACTOR_REDUCE
257    }
258
259    fn score<'a>(&self, query: &'a [char]) -> Option<f32> {
260        let mut position = vec![0; query.len()];
261
262        let mut lists: Vec<&[usize]> = Vec::with_capacity(query.len());
263
264        if query.iter().any(|ch| {
265            if let Some(list) = self.char_map.get(ch) {
266                // Use a side effect here to save time
267                lists.push(list);
268                false
269            } else {
270                true
271            }
272        }) {
273            return None;
274        }
275
276        self.score_inner(query, &mut position, 0, &lists)
277    }
278
279    fn score_inner<'a>(&self, query: &'a [char], position: &mut [usize], idx: usize, lists: &[&[usize]]) -> Option<f32> {
280        if idx == query.len() {
281            Some(self.score_position(position))
282        } else {
283            let mut best = None;
284
285            for sub_position in lists[idx].iter() {
286                if idx > 0 && *sub_position <= position[idx - 1] {
287                    // not a valid position
288                    continue;
289                }
290
291                position[idx] = *sub_position;
292
293                if let Some(score) = self.score_inner(query, position, idx + 1, lists) {
294                    if score > best.unwrap_or(::std::f32::NEG_INFINITY) {
295                        best = Some(score);
296                    }
297                }
298            }
299
300            best
301        }
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use std::iter::FromIterator;
308
309    use rand::Rng;
310
311    use rand;
312    use test;
313
314    use super::*;
315
316    #[test]
317    fn test_matches() {
318        // create a simple search set
319        let test_strings = vec!["test1", "test2", "test3"];
320        let base = SearchBase::from_iter(test_strings);
321
322        // search for something deinitely not in it
323        let result = base.query("abc", 1);
324
325        assert!(result.is_empty());
326    }
327
328    #[test]
329    fn test_one_long() {
330        let test_strings = vec!["a", "b", "ab"];
331        let base = SearchBase::from_iter(test_strings);
332
333        let result = base.query("a", 1);
334
335        assert!(result.contains(&"a"));
336    }
337
338    #[test]
339    fn test_simple_matches() {
340        // create a simple search set
341        let test_strings = vec!["test", "hello", "hello2"];
342        let base = SearchBase::from_iter(test_strings);
343
344        // search
345        let result = base.query("hello", 3);
346
347        assert!(result.contains(&"hello"));
348        assert!(result.contains(&"hello2"));
349        assert!(!result.contains(&"test"));
350    }
351
352    #[test]
353    fn test_truncate() {
354        let test_strings = vec!["test", "toast"];
355        let base = SearchBase::from_iter(test_strings);
356
357        // tt matches test more closely than toast
358        let result = base.query("tt", 1);
359
360        assert_eq!(result.len(), 1);
361        assert!(result.contains(&"test"));
362    }
363
364    #[test]
365    fn test_order() {
366        let test_strings = vec!["abc", "def"];
367        let base = SearchBase::from_iter(test_strings);
368
369        let result = base.query("cb", 1);
370
371        assert_eq!(result.len(), 0);
372    }
373
374    #[bench]
375    fn bench_search(b: &mut test::Bencher) {
376        let mut rng = rand::thread_rng();
377
378        let test_strings = vec!["touaoeuaoeeaoeuaoeuaoeusaoeuaoeuaoeuoeautaoeuaoeuaoeu",
379                                "aoeuaoeuhaoeuaoeuaoeueaoeuaoeuaoeulaoeuaoeuaoeuloaeuoaeuoeauooea\
380                                 ua",
381                                "aoeuaoeuahoeuaouaoeuoaeeuaoeuoaeuaoeulaoeuoaeuaoeulaoeuaoeuaoeuo\
382                                 aoeuoaeuaoeu2aoeuoae"];
383        let mut test_set = Vec::with_capacity(1000);
384
385        for _ in 0..1000 {
386            let num = rng.gen::<usize>() % test_strings.len();
387            test_set.push(test_strings[num].clone());
388        }
389
390        let base = SearchBase::from_iter(test_set);
391
392        b.iter(|| base.query("hello", 10));
393    }
394}