1use crate::{
2 algorithms::{SequenceMatcher, Similarity, SimilarityMetric},
3 processors::{NullStringProcessor, StringProcessor},
4};
5use std::cmp::Reverse;
6use std::collections::BinaryHeap;
7
8pub fn get_top_n<'a>(
42 query: &str,
43 choices: &[&'a str],
44 cutoff: Option<f64>,
45 n: Option<usize>,
46 processor: Option<&dyn StringProcessor>,
47 scorer: Option<&dyn SimilarityMetric>,
48) -> Vec<&'a str> {
49 let mut matches = BinaryHeap::new();
50 let n = n.unwrap_or(3);
51 let cutoff = cutoff.unwrap_or(0.7);
52 let scorer = match scorer {
53 Some(scorer_trait) => scorer_trait,
54 None => &SequenceMatcher,
55 };
56 let processor = match processor {
57 Some(some_processor) => some_processor,
58 None => &NullStringProcessor,
59 };
60 let processed_query = processor.process(query);
61
62 for &choice in choices {
63 let processed_choice = processor.process(choice);
64 let raw_ratio = scorer.compute_metric(processed_query.as_str(), processed_choice.as_str());
65 let ratio = match raw_ratio {
66 Similarity::Usize(r) => r as f64,
67 Similarity::Float(r) => r,
68 };
69 if ratio >= cutoff {
70 let int_ratio = match raw_ratio {
71 Similarity::Usize(r) => r as i64,
72 Similarity::Float(r) => (r * std::u32::MAX as f64) as i64,
73 };
74 matches.push((int_ratio, Reverse(choice)));
77 }
78 }
79 let mut rv = vec![];
80 for _ in 0..n {
81 if let Some((_, elt)) = matches.pop() {
82 rv.push(elt.0);
83 } else {
84 break;
85 }
86 }
87 rv
88}
89
90#[cfg(test)]
91mod tests {
92 use super::get_top_n;
93 use crate::algorithms::jaro::JaroWinkler;
94 use crate::algorithms::SimilarityMetric;
95 use crate::processors::{LowerAlphaNumStringProcessor, StringProcessor};
96 use rstest::rstest;
97
98 #[rstest]
99 #[case(Some(0.7), Some(3), None, None, &["brazil", "braziu", "trazil"])]
100 #[case(Some(0.9), Some(5), None, None, &["brazil"])]
101 #[case(Some(0.7), Some(2), None, Some(&JaroWinkler as &dyn SimilarityMetric), &["brazil", "braziu"])]
102 #[case(Some(0.7), Some(2), Some(&LowerAlphaNumStringProcessor as &dyn StringProcessor), None, &["brazil", "BRA ZIL"])]
103 fn test_get_top_n<'a>(
104 #[case] cutoff: Option<f64>,
105 #[case] n: Option<usize>,
106 #[case] processor: Option<&dyn StringProcessor>,
107 #[case] scorer: Option<&dyn SimilarityMetric>,
108 #[case] expected: &[&'a str],
109 ) {
110 let choices = &["trazil", "BRA ZIL", "brazil", "spain", "braziu"][..];
111 let query = "brazil";
112 let matches = get_top_n(query, choices, cutoff, n, processor, scorer);
113 assert_eq!(matches, expected);
114 }
115}