use std::collections::BTreeMap;
use crate::config::SourceDef;
use crate::expand::expand_and_normalize;
use crate::os_detect::Os;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Match {
pub name: String,
pub needle_len: usize,
}
pub fn find(haystack: &str, sources: &BTreeMap<String, SourceDef>, os: Os) -> Vec<Match> {
let mut hits: Vec<Match> = Vec::new();
for (name, def) in sources {
let Some(raw) = def.path_for(os) else {
continue;
};
let needle = expand_and_normalize(raw);
if needle.is_empty() {
continue;
}
if needle_aligned_to_boundary(haystack, &needle) {
hits.push(Match {
name: name.clone(),
needle_len: needle.len(),
});
}
}
hits.sort_by_key(|h| std::cmp::Reverse(h.needle_len));
hits
}
fn needle_aligned_to_boundary(haystack: &str, needle: &str) -> bool {
if needle.ends_with('/') {
return haystack.contains(needle);
}
haystack.match_indices(needle).any(|(start, _)| {
let end = start + needle.len();
let after = &haystack[end..];
after.is_empty() || after.starts_with('/')
})
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SourceWarning {
pub name: String,
pub needle: String,
pub reason: SourceWarningReason,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SourceWarningReason {
RootPath,
NeedleTooShort,
}
pub fn validate_sources(sources: &BTreeMap<String, SourceDef>, os: Os) -> Vec<SourceWarning> {
let mut warnings = Vec::new();
for (name, def) in sources {
let Some(raw) = def.path_for(os) else {
continue;
};
let needle = expand_and_normalize(raw);
if needle.is_empty() {
continue;
}
let reason = classify_needle(&needle);
if let Some(reason) = reason {
warnings.push(SourceWarning {
name: name.clone(),
needle,
reason,
});
}
}
warnings
}
fn classify_needle(needle: &str) -> Option<SourceWarningReason> {
if needle == "/" || needle == "\\" {
return Some(SourceWarningReason::RootPath);
}
if is_windows_drive_root(needle) {
return Some(SourceWarningReason::RootPath);
}
if needle.len() < 3 {
return Some(SourceWarningReason::NeedleTooShort);
}
None
}
fn is_windows_drive_root(needle: &str) -> bool {
let bytes = needle.as_bytes();
let drive_letter = bytes
.first()
.map(|b| b.is_ascii_alphabetic())
.unwrap_or(false);
if !drive_letter || bytes.get(1) != Some(&b':') {
return false;
}
match &bytes[2..] {
[] => true, [b'/' | b'\\'] => true, _ => false,
}
}
pub fn names_only(haystack: &str, sources: &BTreeMap<String, SourceDef>, os: Os) -> Vec<String> {
find(haystack, sources, os)
.into_iter()
.map(|m| m.name)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn src(unix: &str) -> SourceDef {
SourceDef {
unix: Some(unix.into()),
..Default::default()
}
}
fn cat(entries: &[(&str, SourceDef)]) -> BTreeMap<String, SourceDef> {
entries
.iter()
.map(|(n, d)| (n.to_string(), d.clone()))
.collect()
}
#[test]
fn find_returns_empty_when_no_source_matches() {
let sources = cat(&[("cargo", src("/home/u/.cargo/bin"))]);
let out = find("/usr/local/bin/rg", &sources, Os::Linux);
assert!(out.is_empty());
}
#[test]
fn find_skips_sources_with_no_path_for_current_os() {
let def = SourceDef {
windows: Some("WinGet".into()),
..Default::default()
};
let sources = cat(&[("winget", def)]);
let out = find("/home/u/.cargo/bin/rg", &sources, Os::Linux);
assert!(out.is_empty());
}
#[test]
fn find_ranks_longer_needle_first() {
let sources = cat(&[
("mise", src("/home/u/.local/share/mise")),
("mise_installs", src("/home/u/.local/share/mise/installs")),
]);
let out = find(
"/home/u/.local/share/mise/installs/python/3.14/bin/python",
&sources,
Os::Linux,
);
assert_eq!(out.len(), 2);
assert_eq!(out[0].name, "mise_installs", "longer needle should lead");
assert_eq!(out[1].name, "mise");
assert!(out[0].needle_len > out[1].needle_len);
}
#[test]
fn find_skips_empty_needles() {
let sources = cat(&[("empty", src(""))]);
let out = find("/anywhere/at/all", &sources, Os::Linux);
assert!(out.is_empty(), "empty needle must not match");
}
#[test]
fn names_only_strips_specificity_but_keeps_order() {
let sources = cat(&[
("mise", src("/home/u/.local/share/mise")),
("mise_installs", src("/home/u/.local/share/mise/installs")),
]);
let out = names_only(
"/home/u/.local/share/mise/installs/python/3.14/bin/python",
&sources,
Os::Linux,
);
assert_eq!(out, vec!["mise_installs".to_string(), "mise".to_string()]);
}
#[test]
fn find_does_not_match_partial_segment() {
let sources = cat(&[("cargo", src("/home/u/.cargo/bin"))]);
let out = find("/home/u/.cargo/binx/rg", &sources, Os::Linux);
assert!(
out.is_empty(),
"needle ending mid-segment must not match: {out:?}"
);
}
#[test]
fn find_matches_when_needle_ends_haystack_exactly() {
let sources = cat(&[("cargo", src("/home/u/.cargo/bin"))]);
let out = find("/home/u/.cargo/bin", &sources, Os::Linux);
assert_eq!(out.len(), 1);
assert_eq!(out[0].name, "cargo");
}
#[test]
fn find_matches_when_needle_is_followed_by_separator() {
let sources = cat(&[("cargo", src("/home/u/.cargo/bin"))]);
let out = find("/home/u/.cargo/bin/rg", &sources, Os::Linux);
assert_eq!(out.len(), 1);
assert_eq!(out[0].name, "cargo");
}
#[test]
fn validate_sources_rejects_root_path() {
let sources = cat(&[("evil", src("/"))]);
let warnings = validate_sources(&sources, Os::Linux);
assert_eq!(warnings.len(), 1);
assert_eq!(warnings[0].name, "evil");
}
#[test]
fn validate_sources_rejects_windows_drive_root() {
let def = SourceDef {
windows: Some("C:\\".into()),
..Default::default()
};
let sources = cat(&[("evil_drive", def)]);
let warnings = validate_sources(&sources, Os::Windows);
assert_eq!(warnings.len(), 1);
assert_eq!(warnings[0].reason, SourceWarningReason::RootPath);
}
#[test]
fn validate_sources_rejects_bare_drive_letter() {
let def = SourceDef {
windows: Some("d:".into()),
..Default::default()
};
let sources = cat(&[("evil_d", def)]);
let warnings = validate_sources(&sources, Os::Windows);
assert_eq!(warnings.len(), 1);
assert_eq!(warnings[0].reason, SourceWarningReason::RootPath);
}
#[test]
fn validate_sources_rejects_too_short_needle() {
let sources = cat(&[("ev", src(".x"))]);
let warnings = validate_sources(&sources, Os::Linux);
assert_eq!(warnings.len(), 1);
}
#[test]
fn validate_sources_accepts_normal_paths() {
let sources = cat(&[
("cargo", src("/home/u/.cargo/bin")),
("apt", src("/usr/bin")),
]);
let warnings = validate_sources(&sources, Os::Linux);
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
}
#[test]
fn validate_sources_skips_sources_without_path_for_os() {
let def = SourceDef {
windows: Some("WinGet".into()),
..Default::default()
};
let sources = cat(&[("winget", def)]);
let warnings = validate_sources(&sources, Os::Linux);
assert!(warnings.is_empty());
}
}