use std::collections::{BTreeSet, HashMap, HashSet};
use std::sync::OnceLock;
use crate::cli::HeaderMode;
#[derive(Debug)]
struct LanguageRegistry {
names: HashMap<String, String>,
extensions: HashMap<String, HashSet<String>>,
ext_to_canonical: HashMap<String, String>,
}
static REGISTRY: OnceLock<LanguageRegistry> = OnceLock::new();
fn registry() -> &'static LanguageRegistry {
REGISTRY.get_or_init(|| {
let mut names = HashMap::new();
let mut extensions = HashMap::new();
let mut ext_to_canonicals: HashMap<String, Vec<String>> = HashMap::new();
for (lang_name, lang) in linguist::definitions::LANGUAGES.iter() {
let canonical = lang_name.to_ascii_lowercase();
names.insert(canonical.clone(), canonical.clone());
if let Some(aliases) = &lang.aliases {
for alias in aliases {
names.insert(alias.to_ascii_lowercase(), canonical.clone());
}
}
let ext_set: HashSet<String> = match &lang.extensions {
Some(v) => v.iter().filter_map(|e| normalize_ext(e)).collect(),
None => HashSet::new(),
};
if !ext_set.is_empty() {
for ext in &ext_set {
ext_to_canonicals
.entry(ext.clone())
.or_default()
.push(canonical.clone());
}
extensions.insert(canonical.clone(), ext_set);
}
}
let ext_to_canonical: HashMap<String, String> = ext_to_canonicals
.into_iter()
.map(|(ext, canonicals)| {
let chosen = if canonicals.len() == 1 {
canonicals.into_iter().next().unwrap()
} else {
const EXT_PREFERENCE: &[(&str, &str)] = &[("rs", "rust")];
let preferred = EXT_PREFERENCE
.iter()
.find(|(e, _)| *e == ext)
.and_then(|(_, lang)| {
let s = (*lang).to_string();
canonicals.contains(&s).then_some(s)
});
preferred.unwrap_or_else(|| {
let path = format!("dummy.{}", ext);
linguist::detect_language_by_extension(&path)
.ok()
.and_then(|v| {
v.iter()
.map(|d| d.name.to_ascii_lowercase())
.find(|name| canonicals.contains(name))
})
.unwrap_or_else(|| {
let mut c = canonicals;
c.sort();
c.into_iter().last().unwrap()
})
})
};
(ext, chosen)
})
.collect();
LanguageRegistry {
names,
extensions,
ext_to_canonical,
}
})
}
pub fn canonical_language_name(raw: &str) -> Option<String> {
registry().names.get(&raw.to_ascii_lowercase()).cloned()
}
pub fn canonical_language_for_extension(ext: &str) -> Option<String> {
let normalized = normalize_ext(ext)?;
registry().ext_to_canonical.get(&normalized).cloned()
}
pub fn normalize_ext(raw: &str) -> Option<String> {
let trimmed = raw.trim();
if trimmed.is_empty() {
return None;
}
Some(trimmed.trim_start_matches('.').to_ascii_lowercase())
}
pub fn apply_header_mode(exts: &mut HashSet<String>, headers: HeaderMode) {
let header_exts: HashSet<String> = ["h", "hh", "hpp", "hxx", "h++"]
.into_iter()
.map(str::to_string)
.collect();
match headers {
HeaderMode::Include => exts.extend(header_exts),
HeaderMode::Exclude => exts.retain(|e| !header_exts.contains(e)),
HeaderMode::Only => exts.retain(|e| header_exts.contains(e)),
}
}
pub fn build_extensions(
langs: &[String],
ext: &[String],
headers: HeaderMode,
) -> Result<HashSet<String>, String> {
let mut selected = HashSet::new();
for e in ext {
if let Some(n) = normalize_ext(e) {
selected.insert(n);
}
}
let requested_langs = if langs.is_empty() {
vec!["all".to_string()]
} else {
langs.to_vec()
};
for raw in &requested_langs {
if raw.eq_ignore_ascii_case("all") {
for exts in registry().extensions.values() {
selected.extend(exts.iter().cloned());
}
continue;
}
let canonical =
canonical_language_name(raw).ok_or_else(|| format!("invalid --lang value: {raw}"))?;
let exts = registry()
.extensions
.get(&canonical)
.ok_or_else(|| format!("language has no extensions: {raw}"))?;
selected.extend(exts.iter().cloned());
}
if requested_langs.iter().any(|l| {
l.eq_ignore_ascii_case("all")
|| canonical_language_name(l).is_some_and(|name| name == "c" || name == "c++")
}) {
apply_header_mode(&mut selected, headers);
}
Ok(selected)
}
pub fn display_langs(langs: &[String]) -> String {
if langs.is_empty() {
return "all".to_string();
}
let mut normalized = BTreeSet::new();
for lang in langs {
let canonical = canonical_language_name(lang).unwrap_or_else(|| lang.to_ascii_lowercase());
normalized.insert(canonical);
}
normalized.into_iter().collect::<Vec<_>>().join(",")
}