Skip to main content

purple_ssh/
fuzzy.rs

1//! Small reusable fuzzy ranking using the same nucleo configuration as the jump
2//! palette (smart case and normalization). Lets pickers offer type-to-filter
3//! without duplicating the full jump scoring pipeline. The matcher is cached in
4//! a thread-local so its scratch buffers are reused across keystrokes and frames
5//! instead of being reallocated on every call.
6
7use std::cell::RefCell;
8
9use nucleo_matcher::{Config, Matcher};
10
11thread_local! {
12    static MATCHER: RefCell<Matcher> = RefCell::new(Matcher::new(Config::DEFAULT));
13}
14
15/// Fuzzy-rank `candidates` against `query`, keeping only matches, best score
16/// first. Each candidate carries its searchable haystacks (e.g. alias,
17/// hostname, provider, tags); the best-scoring haystack decides the rank.
18/// Ties keep input order. An empty query returns every candidate in input
19/// order (no filtering).
20pub fn rank<T>(query: &str, candidates: impl IntoIterator<Item = (T, Vec<String>)>) -> Vec<T> {
21    let candidates = candidates.into_iter();
22    if query.is_empty() {
23        return candidates.map(|(item, _)| item).collect();
24    }
25    use nucleo_matcher::Utf32Str;
26    use nucleo_matcher::pattern::{CaseMatching, Normalization, Pattern};
27
28    let pattern = Pattern::parse(query, CaseMatching::Smart, Normalization::Smart);
29    MATCHER.with(|m| {
30        let mut matcher = m.borrow_mut();
31        let mut buf: Vec<char> = Vec::new();
32        let mut scored: Vec<(T, u32)> = Vec::new();
33        for (item, haystacks) in candidates {
34            let mut best = 0u32;
35            for h in &haystacks {
36                buf.clear();
37                if let Some(score) = pattern.score(Utf32Str::new(h, &mut buf), &mut matcher) {
38                    best = best.max(score);
39                }
40            }
41            if best > 0 {
42                scored.push((item, best));
43            }
44        }
45        // Stable sort by score descending; ties keep input (config) order.
46        scored.sort_by_key(|(_, score)| std::cmp::Reverse(*score));
47        scored.into_iter().map(|(item, _)| item).collect()
48    })
49}
50
51/// Fuzzy-rank the candidate host indices by the standard host search fields
52/// (alias, hostname, user, provider, and provider/user tags), best match first.
53/// An empty query keeps the candidate order. Shared by every host picker
54/// (tunnel, container, snippet) so type-to-filter behaves identically across the
55/// app and searches the same fields everywhere.
56pub fn rank_host_indices(
57    hosts: &[crate::ssh_config::model::HostEntry],
58    candidates: &[usize],
59    query: &str,
60) -> Vec<usize> {
61    let scored: Vec<(usize, Vec<String>)> = candidates
62        .iter()
63        .filter_map(|&i| hosts.get(i).map(|h| (i, h)))
64        .map(|(i, h)| {
65            let mut haystacks = vec![h.alias.clone(), h.hostname.clone(), h.user.clone()];
66            if let Some(p) = &h.provider {
67                haystacks.push(p.clone());
68            }
69            haystacks.extend(h.tags.iter().cloned());
70            haystacks.extend(h.provider_tags.iter().cloned());
71            (i, haystacks)
72        })
73        .collect();
74    rank(query, scored)
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    fn cand(items: &[(usize, &[&str])]) -> Vec<(usize, Vec<String>)> {
82        items
83            .iter()
84            .map(|(i, hs)| (*i, hs.iter().map(|s| s.to_string()).collect()))
85            .collect()
86    }
87
88    #[test]
89    fn empty_query_returns_all_in_order() {
90        let c = cand(&[(0, &["alpha"]), (1, &["beta"])]);
91        assert_eq!(rank("", c), vec![0, 1]);
92    }
93
94    #[test]
95    fn filters_out_non_matches() {
96        let c = cand(&[
97            (0, &["aws-api-eu"]),
98            (1, &["db-primary"]),
99            (2, &["aws-worker"]),
100        ]);
101        let got = rank("aws", c);
102        assert!(got.contains(&0));
103        assert!(got.contains(&2));
104        assert!(!got.contains(&1), "db-primary must not match 'aws'");
105    }
106
107    #[test]
108    fn subsequence_matches_fuzzily() {
109        // 'awseu' should match 'aws-api-eu' as a subsequence.
110        let c = cand(&[(0, &["aws-api-eu"]), (1, &["db-primary"])]);
111        let got = rank("awseu", c);
112        assert_eq!(got, vec![0]);
113    }
114
115    #[test]
116    fn matches_any_haystack_field() {
117        // Query hits the hostname haystack, not the alias.
118        let c = cand(&[(0, &["bastion", "140.82.121.3"]), (1, &["db", "10.0.0.1"])]);
119        let got = rank("140.82", c);
120        assert_eq!(got, vec![0]);
121    }
122
123    #[test]
124    fn equal_scores_keep_input_order() {
125        // Two candidates with identical haystacks score equally; the stable
126        // sort must preserve their input order.
127        let c = cand(&[(0, &["aws"]), (1, &["aws"])]);
128        assert_eq!(rank("aws", c), vec![0, 1]);
129    }
130
131    #[test]
132    fn rank_host_indices_searches_all_fields_and_respects_candidates() {
133        use crate::ssh_config::model::HostEntry;
134        let hosts = vec![
135            HostEntry {
136                alias: "web1".into(),
137                hostname: "10.0.0.1".into(),
138                provider: Some("aws".into()),
139                ..Default::default()
140            },
141            HostEntry {
142                alias: "db".into(),
143                hostname: "10.0.0.2".into(),
144                tags: vec!["prod".into()],
145                ..Default::default()
146            },
147            HostEntry {
148                alias: "cache".into(),
149                hostname: "10.0.0.3".into(),
150                ..Default::default()
151            },
152        ];
153        let all = [0usize, 1, 2];
154        // Empty query keeps the candidate order.
155        assert_eq!(rank_host_indices(&hosts, &all, ""), vec![0, 1, 2]);
156        // Matches the provider field, not just alias/hostname.
157        assert_eq!(rank_host_indices(&hosts, &all, "aws"), vec![0]);
158        // Matches a tag.
159        assert_eq!(rank_host_indices(&hosts, &all, "prod"), vec![1]);
160        // Only ranks within the candidate set (index 0 excluded).
161        assert_eq!(rank_host_indices(&hosts, &[1, 2], ""), vec![1, 2]);
162    }
163}