1use std::collections::{HashMap, BinaryHeap};
8use std::cmp::Ordering;
9use std::iter::FromIterator;
10
11use unicode_normalization::UnicodeNormalization;
12
13use constants::*;
14
15#[derive(Debug)]
17pub struct SearchBase {
18 lines: Vec<LineInfo>,
19}
20
21#[derive(Debug)]
23pub struct LineInfo {
24 line: String,
25 char_map: HashMap<char, Vec<usize>>,
26 heat_map: Vec<f32>,
27 factor: f32,
28}
29
30#[derive(PartialEq, Eq)]
31enum CharClass {
32 Whitespace,
33 Numeric,
34 Alphabetic,
35 First,
36 Other,
37}
38
39#[derive(Debug)]
40struct LineMatch<'a> {
41 score: f32,
42 factor: f32,
43 line: &'a str,
44}
45
46impl<'a> Ord for LineMatch<'a> {
47 fn cmp(&self, other: &LineMatch) -> Ordering {
48 match self.score.partial_cmp(&other.score) {
49 Some(Ordering::Equal) | None => {
50 self.factor
51 .partial_cmp(&other.factor)
52 .unwrap_or(Ordering::Equal)
53 }
54 Some(order) => order,
55 }
56 }
57}
58
59impl<'a> PartialOrd for LineMatch<'a> {
60 fn partial_cmp(&self, other: &LineMatch) -> Option<Ordering> {
61 Some(self.cmp(other))
62 }
63}
64
65impl<'a> PartialEq for LineMatch<'a> {
66 fn eq(&self, other: &LineMatch) -> bool {
67 self.cmp(other) == Ordering::Equal
68 }
69}
70
71impl<'a> Eq for LineMatch<'a> {}
72
73impl<T: Into<String>> From<T> for LineInfo {
75 fn from(item: T) -> LineInfo {
76 LineInfo::new(item, 0.0)
77 }
78}
79
80impl<V: Into<LineInfo>> FromIterator<V> for SearchBase {
81 fn from_iter<T: IntoIterator<Item = V>>(iterator: T) -> SearchBase {
82 SearchBase::new(iterator.into_iter().map(|item| item.into()).collect())
83 }
84}
85
86impl SearchBase {
87 pub fn new(lines: Vec<LineInfo>) -> SearchBase {
89 SearchBase { lines: lines }
90 }
91
92 pub fn query<'a, T: AsRef<str>>(&'a self, query: T, number: usize) -> Vec<&'a str> {
99 let query = query.as_ref();
100 if query.is_empty() {
101 return vec![];
103 }
104
105 let mut matches: BinaryHeap<LineMatch> = BinaryHeap::with_capacity(number);
106
107 let composed: Vec<char> = query.nfkc().filter(|ch| !ch.is_whitespace()).collect();
108
109 for item in self.lines.iter() {
110 let score = match item.score(&composed) {
111 None => {
112 continue;
114 }
115 Some(score) => score,
116 };
117
118 let match_item = LineMatch {
119 score: -score,
120 factor: -item.factor,
121 line: &item.line,
122 };
123
124 if matches.len() < number {
125 matches.push(match_item);
126 } else if let Some(mut other_item) = matches.peek_mut() {
127 if &match_item < &*other_item {
128 *other_item = match_item;
130 }
131 } else {
132 unreachable!("No item to peek at, but number of items greater than zero");
133 }
134 }
135
136 matches.into_sorted_vec().into_iter().map(|x| x.line).collect()
137 }
138}
139
140impl LineInfo {
141 pub fn new<T: Into<String>>(item: T, factor: f32) -> LineInfo {
147 let mut map: HashMap<char, Vec<usize>> = HashMap::new();
148 let mut heat = vec![];
149 let line = item.into();
150
151 let mut ws_score: f32 = 0.0;
152 let mut cs_score: f32 = 0.0;
153 let mut cur_class = CharClass::First;
154 let mut cs_change = false;
155
156 for (idx, c) in line.nfkc().enumerate() {
157 if idx > MAX_LEN {
158 break;
159 }
160
161 if !c.is_whitespace() {
162 if cur_class == CharClass::First {
163 cs_score += FIRST_FACTOR;
164 }
165 }
166
167 if c.is_whitespace() {
168 cur_class = CharClass::Whitespace;
169 ws_score = WHITESPACE_FACTOR;
170 } else if c.is_numeric() {
171 if cur_class != CharClass::Numeric {
172 cur_class = CharClass::Numeric;
173 if !cs_change {
174 cs_score += CLASS_FACTOR;
175 cs_change = true;
176 }
177 } else {
178 cs_change = false;
179 }
180 } else if c.is_alphabetic() {
181 if cur_class != CharClass::Alphabetic {
182 cur_class = CharClass::Alphabetic;
183 if !cs_change {
184 cs_score += CLASS_FACTOR;
185 cs_change = true;
186 }
187 } else {
188 cs_change = false;
189 }
190 } else {
191 if cur_class != CharClass::Other {
192 cur_class = CharClass::Other;
193 if !cs_change {
194 cs_score += CLASS_FACTOR;
195 cs_change = true;
196 }
197 } else {
198 cs_change = false;
199 }
200 }
201
202 if cur_class != CharClass::Whitespace {
203 map.entry(c).or_insert(Vec::default()).push(idx);
204 if c.is_uppercase() {
205 for lc in c.to_lowercase() {
206 map.entry(lc).or_insert(Vec::default()).push(idx);
207 }
208 }
209 }
210
211 heat.push(ws_score + cs_score);
212
213 ws_score *= WHITESPACE_REDUCE;
214 if !cs_change {
215 cs_score *= CLASS_REDUCE;
216 }
217 }
218
219 LineInfo {
220 line: line,
221 char_map: map,
222 heat_map: heat,
223 factor: factor,
224 }
225 }
226
227 pub fn set_factor(&mut self, factor: f32) {
231 self.factor = factor;
232 }
233
234 pub fn get_factor(&self) -> f32 {
238 self.factor
239 }
240
241 fn score_position(&self, position: &[usize]) -> f32 {
242 let avg_dist: f32;
243
244 if position.len() < 2 {
245 avg_dist = 0.0;
246 } else {
247 avg_dist = position.windows(2)
248 .map(|pair| pair[1] as f32 - pair[0] as f32)
249 .sum::<f32>() / position.len() as f32;
250 }
251
252 let heat_sum: f32 = position.iter()
253 .map(|idx| self.heat_map[*idx])
254 .sum();
255
256 avg_dist * DIST_WEIGHT + heat_sum * HEAT_WEIGHT + self.factor * FACTOR_REDUCE
257 }
258
259 fn score<'a>(&self, query: &'a [char]) -> Option<f32> {
260 let mut position = vec![0; query.len()];
261
262 let mut lists: Vec<&[usize]> = Vec::with_capacity(query.len());
263
264 if query.iter().any(|ch| {
265 if let Some(list) = self.char_map.get(ch) {
266 lists.push(list);
268 false
269 } else {
270 true
271 }
272 }) {
273 return None;
274 }
275
276 self.score_inner(query, &mut position, 0, &lists)
277 }
278
279 fn score_inner<'a>(&self, query: &'a [char], position: &mut [usize], idx: usize, lists: &[&[usize]]) -> Option<f32> {
280 if idx == query.len() {
281 Some(self.score_position(position))
282 } else {
283 let mut best = None;
284
285 for sub_position in lists[idx].iter() {
286 if idx > 0 && *sub_position <= position[idx - 1] {
287 continue;
289 }
290
291 position[idx] = *sub_position;
292
293 if let Some(score) = self.score_inner(query, position, idx + 1, lists) {
294 if score > best.unwrap_or(::std::f32::NEG_INFINITY) {
295 best = Some(score);
296 }
297 }
298 }
299
300 best
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use std::iter::FromIterator;
308
309 use rand::Rng;
310
311 use rand;
312 use test;
313
314 use super::*;
315
316 #[test]
317 fn test_matches() {
318 let test_strings = vec!["test1", "test2", "test3"];
320 let base = SearchBase::from_iter(test_strings);
321
322 let result = base.query("abc", 1);
324
325 assert!(result.is_empty());
326 }
327
328 #[test]
329 fn test_one_long() {
330 let test_strings = vec!["a", "b", "ab"];
331 let base = SearchBase::from_iter(test_strings);
332
333 let result = base.query("a", 1);
334
335 assert!(result.contains(&"a"));
336 }
337
338 #[test]
339 fn test_simple_matches() {
340 let test_strings = vec!["test", "hello", "hello2"];
342 let base = SearchBase::from_iter(test_strings);
343
344 let result = base.query("hello", 3);
346
347 assert!(result.contains(&"hello"));
348 assert!(result.contains(&"hello2"));
349 assert!(!result.contains(&"test"));
350 }
351
352 #[test]
353 fn test_truncate() {
354 let test_strings = vec!["test", "toast"];
355 let base = SearchBase::from_iter(test_strings);
356
357 let result = base.query("tt", 1);
359
360 assert_eq!(result.len(), 1);
361 assert!(result.contains(&"test"));
362 }
363
364 #[test]
365 fn test_order() {
366 let test_strings = vec!["abc", "def"];
367 let base = SearchBase::from_iter(test_strings);
368
369 let result = base.query("cb", 1);
370
371 assert_eq!(result.len(), 0);
372 }
373
374 #[bench]
375 fn bench_search(b: &mut test::Bencher) {
376 let mut rng = rand::thread_rng();
377
378 let test_strings = vec!["touaoeuaoeeaoeuaoeuaoeusaoeuaoeuaoeuoeautaoeuaoeuaoeu",
379 "aoeuaoeuhaoeuaoeuaoeueaoeuaoeuaoeulaoeuaoeuaoeuloaeuoaeuoeauooea\
380 ua",
381 "aoeuaoeuahoeuaouaoeuoaeeuaoeuoaeuaoeulaoeuoaeuaoeulaoeuaoeuaoeuo\
382 aoeuoaeuaoeu2aoeuoae"];
383 let mut test_set = Vec::with_capacity(1000);
384
385 for _ in 0..1000 {
386 let num = rng.gen::<usize>() % test_strings.len();
387 test_set.push(test_strings[num].clone());
388 }
389
390 let base = SearchBase::from_iter(test_set);
391
392 b.iter(|| base.query("hello", 10));
393 }
394}