use memchr::memchr;
#[inline]
pub fn detect_language_from_extension(ext: &str) -> Option<&'static str> {
include!(concat!(env!("OUT_DIR"), "/extensions_generated.rs"))
}
pub fn detect_language_from_path(path: &str) -> Option<&'static str> {
let ext = std::path::Path::new(path).extension()?.to_str()?;
detect_language_from_extension(ext)
}
pub fn extension_ambiguity(ext: &str) -> Option<(&'static str, &'static [&'static str])> {
let mut buf = [0u8; 32];
let ext_lower = if ext.len() <= buf.len() && ext.is_ascii() {
for (i, b) in ext.bytes().enumerate() {
buf[i] = b.to_ascii_lowercase();
}
std::str::from_utf8(&buf[..ext.len()]).ok()?
} else {
return None;
};
include!(concat!(env!("OUT_DIR"), "/ambiguities_generated.rs"))
}
#[cfg(feature = "serde")]
pub fn extension_ambiguity_json(ext: &str) -> Option<String> {
extension_ambiguity(ext).map(|(assigned, alts)| {
serde_json::json!({
"assigned": assigned,
"alternatives": alts,
})
.to_string()
})
}
pub fn detect_language_from_content(content: &str) -> Option<&'static str> {
if !content.starts_with("#!") {
return None;
}
let bytes = content.as_bytes();
let line_end = memchr(b'\n', bytes).unwrap_or(bytes.len());
let shebang_line = &content[2..line_end].trim_end();
let mut tokens = shebang_line.split_ascii_whitespace();
let interpreter_path = tokens.next()?;
let program: &str = if interpreter_path.ends_with("/env") || interpreter_path == "env" {
loop {
let token = tokens.next()?;
if !token.starts_with('-') {
break token;
}
}
} else {
interpreter_path.rsplit('/').next()?
};
let base = strip_version_suffix(program);
map_interpreter_to_language(base)
}
fn strip_version_suffix(name: &str) -> &str {
let cut = name.find(|c: char| c.is_ascii_digit()).unwrap_or(name.len());
let cut = if cut > 0 && name.as_bytes()[cut - 1] == b'.' {
cut - 1
} else {
cut
};
&name[..cut]
}
fn map_interpreter_to_language(interpreter: &str) -> Option<&'static str> {
match interpreter {
"python" | "python3" | "python2" => Some("python"),
"bash" | "sh" | "dash" | "ash" => Some("bash"),
"zsh" => Some("bash"),
"node" | "nodejs" => Some("javascript"),
"ruby" | "jruby" => Some("ruby"),
"perl" | "perl5" | "perl6" => Some("perl"),
"lua" => Some("lua"),
"php" => Some("php"),
"elixir" => Some("elixir"),
"julia" => Some("julia"),
"Rscript" | "r" | "R" => Some("r"),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_common_extensions() {
assert_eq!(detect_language_from_extension("py"), Some("python"));
assert_eq!(detect_language_from_extension("pyi"), Some("python"));
assert_eq!(detect_language_from_extension("rs"), Some("rust"));
assert_eq!(detect_language_from_extension("js"), Some("javascript"));
assert_eq!(detect_language_from_extension("ts"), Some("typescript"));
assert_eq!(detect_language_from_extension("c"), Some("c"));
assert_eq!(detect_language_from_extension("h"), Some("c"));
assert_eq!(detect_language_from_extension("cpp"), Some("cpp"));
assert_eq!(detect_language_from_extension("go"), Some("go"));
assert_eq!(detect_language_from_extension("rb"), Some("ruby"));
assert_eq!(detect_language_from_extension("java"), Some("java"));
assert_eq!(detect_language_from_extension("cs"), Some("csharp"));
assert_eq!(detect_language_from_extension("tsx"), Some("tsx"));
assert_eq!(detect_language_from_extension("html"), Some("html"));
assert_eq!(detect_language_from_extension("css"), Some("css"));
assert_eq!(detect_language_from_extension("json"), Some("json"));
assert_eq!(detect_language_from_extension("yaml"), Some("yaml"));
assert_eq!(detect_language_from_extension("toml"), Some("toml"));
assert_eq!(detect_language_from_extension("sql"), Some("sql"));
assert_eq!(detect_language_from_extension("md"), Some("markdown"));
}
#[test]
fn test_case_insensitive() {
assert_eq!(detect_language_from_extension("PY"), Some("python"));
assert_eq!(detect_language_from_extension("Rs"), Some("rust"));
assert_eq!(detect_language_from_extension("JS"), Some("javascript"));
assert_eq!(detect_language_from_extension("CPP"), Some("cpp"));
assert_eq!(detect_language_from_extension("Tsx"), Some("tsx"));
}
#[test]
fn test_unknown() {
assert_eq!(detect_language_from_extension("xyz"), None);
assert_eq!(detect_language_from_extension(""), None);
assert_eq!(detect_language_from_extension("abcdef"), None);
}
#[test]
fn test_path_detection() {
assert_eq!(detect_language_from_path("src/main.rs"), Some("rust"));
assert_eq!(detect_language_from_path("/path/to/file.py"), Some("python"));
assert_eq!(detect_language_from_path("README.md"), Some("markdown"));
assert_eq!(detect_language_from_path("app.test.tsx"), Some("tsx"));
assert_eq!(detect_language_from_path("Cargo.toml"), Some("toml"));
}
#[test]
fn test_path_no_extension() {
assert_eq!(detect_language_from_path("Makefile"), None);
assert_eq!(detect_language_from_path(""), None);
assert_eq!(detect_language_from_path("/usr/bin/env"), None);
}
#[test]
fn test_long_extension_rejected() {
let long = "a".repeat(33);
assert_eq!(detect_language_from_extension(&long), None);
}
#[test]
fn test_ambiguity_known() {
let result = extension_ambiguity("m");
assert!(result.is_some(), ".m should be flagged as ambiguous");
let (assigned, alternatives) = result.unwrap();
assert_eq!(assigned, "objc");
assert!(alternatives.contains(&"matlab"));
let result = extension_ambiguity("h");
assert!(result.is_some(), ".h should be flagged as ambiguous");
let (assigned, alternatives) = result.unwrap();
assert_eq!(assigned, "c");
assert!(alternatives.contains(&"cpp"));
let result = extension_ambiguity("v");
assert!(result.is_some(), ".v should be flagged as ambiguous");
let (assigned, alternatives) = result.unwrap();
assert_eq!(assigned, "v");
assert!(alternatives.contains(&"verilog"));
}
#[test]
fn test_ambiguity_unambiguous() {
assert!(extension_ambiguity("py").is_none());
assert!(extension_ambiguity("rs").is_none());
assert!(extension_ambiguity("xyz").is_none());
}
#[test]
fn test_ambiguity_case_insensitive() {
assert!(extension_ambiguity("M").is_some());
assert!(extension_ambiguity("H").is_some());
}
#[test]
fn test_shebang_env_python3() {
assert_eq!(
detect_language_from_content("#!/usr/bin/env python3\npass\n"),
Some("python")
);
}
#[test]
fn test_shebang_env_python() {
assert_eq!(
detect_language_from_content("#!/usr/bin/env python\npass"),
Some("python")
);
}
#[test]
fn test_shebang_direct_python() {
assert_eq!(detect_language_from_content("#!/usr/bin/python\npass"), Some("python"));
}
#[test]
fn test_shebang_bash() {
assert_eq!(detect_language_from_content("#!/bin/bash\necho hi"), Some("bash"));
}
#[test]
fn test_shebang_sh() {
assert_eq!(detect_language_from_content("#!/bin/sh\necho hi"), Some("bash"));
}
#[test]
fn test_shebang_env_node() {
assert_eq!(
detect_language_from_content("#!/usr/bin/env node\nconsole.log(1)"),
Some("javascript")
);
}
#[test]
fn test_shebang_env_ruby() {
assert_eq!(
detect_language_from_content("#!/usr/bin/env ruby\nputs 'hi'"),
Some("ruby")
);
}
#[test]
fn test_shebang_direct_perl() {
assert_eq!(detect_language_from_content("#!/usr/bin/perl\nprint 1"), Some("perl"));
}
#[test]
fn test_shebang_env_lua() {
assert_eq!(
detect_language_from_content("#!/usr/bin/env lua\nprint(1)"),
Some("lua")
);
}
#[test]
fn test_shebang_env_php() {
assert_eq!(detect_language_from_content("#!/usr/bin/env php\n<?php"), Some("php"));
}
#[test]
fn test_shebang_env_elixir() {
assert_eq!(
detect_language_from_content("#!/usr/bin/env elixir\nIO.puts(1)"),
Some("elixir")
);
}
#[test]
fn test_shebang_env_julia() {
assert_eq!(
detect_language_from_content("#!/usr/bin/env julia\nprintln(1)"),
Some("julia")
);
}
#[test]
fn test_shebang_env_rscript() {
assert_eq!(
detect_language_from_content("#!/usr/bin/env Rscript\nprint(1)"),
Some("r")
);
}
#[test]
fn test_shebang_env_s_flag() {
assert_eq!(
detect_language_from_content("#!/usr/bin/env -S python3\npass"),
Some("python")
);
}
#[test]
fn test_shebang_version_suffix() {
assert_eq!(
detect_language_from_content("#!/usr/bin/env python3.11\npass"),
Some("python")
);
assert_eq!(
detect_language_from_content("#!/usr/bin/env ruby3.2\nputs 1"),
Some("ruby")
);
}
#[test]
fn test_no_shebang() {
assert_eq!(detect_language_from_content("def foo(): pass"), None);
assert_eq!(detect_language_from_content("# not a shebang"), None);
}
#[test]
fn test_empty_content() {
assert_eq!(detect_language_from_content(""), None);
}
#[test]
fn test_shebang_unknown_interpreter() {
assert_eq!(detect_language_from_content("#!/usr/bin/env unknownlang\ncode"), None);
assert_eq!(detect_language_from_content("#!/usr/bin/fantasy\ncode"), None);
}
#[test]
fn test_roundtrip_json_to_generated() {
let json_path = concat!(env!("CARGO_MANIFEST_DIR"), "/../../sources/language_definitions.json");
let json_str = match std::fs::read_to_string(json_path) {
Ok(s) => s,
Err(_) => return, };
let defs: std::collections::BTreeMap<String, serde_json::Value> =
serde_json::from_str(&json_str).expect("Failed to parse language_definitions.json");
for (lang_name, def) in &defs {
if let Some(extensions) = def.get("extensions").and_then(|v| v.as_array()) {
for ext_val in extensions {
let ext = ext_val.as_str().expect("extension must be a string");
let result = detect_language_from_extension(ext);
assert_eq!(
result,
Some(lang_name.as_str()),
"Extension '{ext}' should map to '{lang_name}' but got {result:?}"
);
}
}
}
}
}