1use std::cell::RefCell;
8
9use nucleo_matcher::{Config, Matcher};
10
11thread_local! {
12 static MATCHER: RefCell<Matcher> = RefCell::new(Matcher::new(Config::DEFAULT));
13}
14
15pub fn rank<T>(query: &str, candidates: impl IntoIterator<Item = (T, Vec<String>)>) -> Vec<T> {
21 let candidates = candidates.into_iter();
22 if query.is_empty() {
23 return candidates.map(|(item, _)| item).collect();
24 }
25 use nucleo_matcher::Utf32Str;
26 use nucleo_matcher::pattern::{CaseMatching, Normalization, Pattern};
27
28 let pattern = Pattern::parse(query, CaseMatching::Smart, Normalization::Smart);
29 MATCHER.with(|m| {
30 let mut matcher = m.borrow_mut();
31 let mut buf: Vec<char> = Vec::new();
32 let mut scored: Vec<(T, u32)> = Vec::new();
33 for (item, haystacks) in candidates {
34 let mut best = 0u32;
35 for h in &haystacks {
36 buf.clear();
37 if let Some(score) = pattern.score(Utf32Str::new(h, &mut buf), &mut matcher) {
38 best = best.max(score);
39 }
40 }
41 if best > 0 {
42 scored.push((item, best));
43 }
44 }
45 scored.sort_by_key(|(_, score)| std::cmp::Reverse(*score));
47 scored.into_iter().map(|(item, _)| item).collect()
48 })
49}
50
51pub fn rank_host_indices(
57 hosts: &[crate::ssh_config::model::HostEntry],
58 candidates: &[usize],
59 query: &str,
60) -> Vec<usize> {
61 let scored: Vec<(usize, Vec<String>)> = candidates
62 .iter()
63 .filter_map(|&i| hosts.get(i).map(|h| (i, h)))
64 .map(|(i, h)| {
65 let mut haystacks = vec![h.alias.clone(), h.hostname.clone(), h.user.clone()];
66 if let Some(p) = &h.provider {
67 haystacks.push(p.clone());
68 }
69 haystacks.extend(h.tags.iter().cloned());
70 haystacks.extend(h.provider_tags.iter().cloned());
71 (i, haystacks)
72 })
73 .collect();
74 rank(query, scored)
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80
81 fn cand(items: &[(usize, &[&str])]) -> Vec<(usize, Vec<String>)> {
82 items
83 .iter()
84 .map(|(i, hs)| (*i, hs.iter().map(|s| s.to_string()).collect()))
85 .collect()
86 }
87
88 #[test]
89 fn empty_query_returns_all_in_order() {
90 let c = cand(&[(0, &["alpha"]), (1, &["beta"])]);
91 assert_eq!(rank("", c), vec![0, 1]);
92 }
93
94 #[test]
95 fn filters_out_non_matches() {
96 let c = cand(&[
97 (0, &["aws-api-eu"]),
98 (1, &["db-primary"]),
99 (2, &["aws-worker"]),
100 ]);
101 let got = rank("aws", c);
102 assert!(got.contains(&0));
103 assert!(got.contains(&2));
104 assert!(!got.contains(&1), "db-primary must not match 'aws'");
105 }
106
107 #[test]
108 fn subsequence_matches_fuzzily() {
109 let c = cand(&[(0, &["aws-api-eu"]), (1, &["db-primary"])]);
111 let got = rank("awseu", c);
112 assert_eq!(got, vec![0]);
113 }
114
115 #[test]
116 fn matches_any_haystack_field() {
117 let c = cand(&[(0, &["bastion", "140.82.121.3"]), (1, &["db", "10.0.0.1"])]);
119 let got = rank("140.82", c);
120 assert_eq!(got, vec![0]);
121 }
122
123 #[test]
124 fn equal_scores_keep_input_order() {
125 let c = cand(&[(0, &["aws"]), (1, &["aws"])]);
128 assert_eq!(rank("aws", c), vec![0, 1]);
129 }
130
131 #[test]
132 fn rank_host_indices_searches_all_fields_and_respects_candidates() {
133 use crate::ssh_config::model::HostEntry;
134 let hosts = vec![
135 HostEntry {
136 alias: "web1".into(),
137 hostname: "10.0.0.1".into(),
138 provider: Some("aws".into()),
139 ..Default::default()
140 },
141 HostEntry {
142 alias: "db".into(),
143 hostname: "10.0.0.2".into(),
144 tags: vec!["prod".into()],
145 ..Default::default()
146 },
147 HostEntry {
148 alias: "cache".into(),
149 hostname: "10.0.0.3".into(),
150 ..Default::default()
151 },
152 ];
153 let all = [0usize, 1, 2];
154 assert_eq!(rank_host_indices(&hosts, &all, ""), vec![0, 1, 2]);
156 assert_eq!(rank_host_indices(&hosts, &all, "aws"), vec![0]);
158 assert_eq!(rank_host_indices(&hosts, &all, "prod"), vec![1]);
160 assert_eq!(rank_host_indices(&hosts, &[1, 2], ""), vec![1, 2]);
162 }
163}