1use std::path::Path;
2
3use thiserror::Error;
4
5use code_moniker_core::lang::Lang;
6
7#[derive(Debug, Error)]
8pub enum LangError {
9 #[error(
10 "unsupported file extension `.{0}` (known: ts/tsx/js/jsx/mjs/cjs, rs, java, py/pyi, go, cs, sql/plpgsql)"
11 )]
12 UnknownExtension(String),
13 #[error("file has no extension; cannot infer language")]
14 NoExtension,
15}
16
17pub fn path_to_lang(path: &Path) -> Result<Lang, LangError> {
18 let ext = path
19 .extension()
20 .and_then(|s| s.to_str())
21 .map(|s| s.to_ascii_lowercase());
22 let ext = match ext.as_deref() {
23 Some("") | None => return Err(LangError::NoExtension),
24 Some(e) => e,
25 };
26 match ext {
27 "ts" | "tsx" | "js" | "jsx" | "mjs" | "cjs" => Ok(Lang::Ts),
28 "rs" => Ok(Lang::Rs),
29 "java" => Ok(Lang::Java),
30 "py" | "pyi" => Ok(Lang::Python),
31 "go" => Ok(Lang::Go),
32 "cs" => Ok(Lang::Cs),
33 "sql" | "plpgsql" => Ok(Lang::Sql),
34 other => Err(LangError::UnknownExtension(other.to_string())),
35 }
36}
37
38#[cfg(test)]
39mod tests {
40 use super::*;
41 use std::path::PathBuf;
42
43 fn dispatch(s: &str) -> Result<Lang, LangError> {
44 path_to_lang(&PathBuf::from(s))
45 }
46
47 #[test]
48 fn ts_family_resolves_to_ts() {
49 for p in &[
50 "x.ts",
51 "x.tsx",
52 "x.js",
53 "x.jsx",
54 "x.mjs",
55 "x.cjs",
56 "a/b/c/x.TS",
57 ] {
58 assert_eq!(dispatch(p).unwrap(), Lang::Ts, "{p}");
59 }
60 }
61
62 #[test]
63 fn each_supported_extension_resolves() {
64 assert_eq!(dispatch("a.rs").unwrap(), Lang::Rs);
65 assert_eq!(dispatch("a.java").unwrap(), Lang::Java);
66 assert_eq!(dispatch("a.py").unwrap(), Lang::Python);
67 assert_eq!(dispatch("a.pyi").unwrap(), Lang::Python);
68 assert_eq!(dispatch("a.go").unwrap(), Lang::Go);
69 assert_eq!(dispatch("a.cs").unwrap(), Lang::Cs);
70 }
71
72 #[test]
73 fn unknown_extension_errors() {
74 match dispatch("a.txt") {
75 Err(LangError::UnknownExtension(s)) => assert_eq!(s, "txt"),
76 other => panic!("unexpected: {other:?}"),
77 }
78 }
79
80 #[test]
81 fn missing_extension_errors() {
82 match dispatch("Makefile") {
83 Err(LangError::NoExtension) => {}
84 other => panic!("unexpected: {other:?}"),
85 }
86 }
87
88 #[test]
89 fn case_is_insensitive() {
90 assert_eq!(dispatch("X.JAVA").unwrap(), Lang::Java);
91 assert_eq!(dispatch("X.RS").unwrap(), Lang::Rs);
92 }
93
94 #[test]
95 fn sql_extension_resolves() {
96 assert_eq!(dispatch("a.sql").unwrap(), Lang::Sql);
97 assert_eq!(dispatch("a.plpgsql").unwrap(), Lang::Sql);
98 }
99}