1use std::sync::{Arc, OnceLock};
8
9use tree_sitter::{Language, Query};
10
11pub struct LangConfig {
16 pub language: Language,
18 pub query: Query,
20}
21
22#[must_use]
27pub fn config_for_extension(ext: &str) -> Option<Arc<LangConfig>> {
28 static CACHE: OnceLock<std::collections::HashMap<&'static str, Arc<LangConfig>>> =
30 OnceLock::new();
31
32 let cache = CACHE.get_or_init(|| {
33 let mut m = std::collections::HashMap::new();
34 for &ext in &[
36 "rs", "py", "js", "jsx", "ts", "tsx", "go", "java", "c", "h", "cpp", "cc", "cxx", "hpp",
37 ] {
38 if let Some(cfg) = compile_config(ext) {
39 m.insert(ext, Arc::new(cfg));
40 }
41 }
42 m
43 });
44
45 cache.get(ext).cloned()
46}
47
48fn compile_config(ext: &str) -> Option<LangConfig> {
50 let (lang, query_str): (Language, &str) = match ext {
51 "rs" => (
55 tree_sitter_rust::LANGUAGE.into(),
56 concat!(
57 "(function_item name: (identifier) @name) @def\n",
58 "(struct_item name: (type_identifier) @name) @def\n",
59 "(enum_item name: (type_identifier) @name) @def\n",
60 "(type_item name: (type_identifier) @name) @def",
61 ),
62 ),
63 "py" => (
66 tree_sitter_python::LANGUAGE.into(),
67 concat!(
68 "(function_definition name: (identifier) @name) @def\n",
69 "(class_definition name: (identifier) @name body: (block) @def)",
70 ),
71 ),
72 "js" | "jsx" => (
74 tree_sitter_javascript::LANGUAGE.into(),
75 concat!(
76 "(function_declaration name: (identifier) @name) @def\n",
77 "(method_definition name: (property_identifier) @name) @def\n",
78 "(class_declaration name: (identifier) @name) @def",
79 ),
80 ),
81 "ts" => (
82 tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
83 concat!(
84 "(function_declaration name: (identifier) @name) @def\n",
85 "(method_definition name: (property_identifier) @name) @def\n",
86 "(class_declaration name: (type_identifier) @name) @def\n",
87 "(interface_declaration name: (type_identifier) @name) @def",
88 ),
89 ),
90 "tsx" => (
91 tree_sitter_typescript::LANGUAGE_TSX.into(),
92 concat!(
93 "(function_declaration name: (identifier) @name) @def\n",
94 "(method_definition name: (property_identifier) @name) @def\n",
95 "(class_declaration name: (type_identifier) @name) @def\n",
96 "(interface_declaration name: (type_identifier) @name) @def",
97 ),
98 ),
99 "go" => (
100 tree_sitter_go::LANGUAGE.into(),
101 concat!(
102 "(function_declaration name: (identifier) @name) @def\n",
103 "(method_declaration name: (field_identifier) @name) @def",
104 ),
105 ),
106 "java" => (
109 tree_sitter_java::LANGUAGE.into(),
110 concat!(
111 "(method_declaration name: (identifier) @name) @def\n",
112 "(class_declaration name: (identifier) @name) @def\n",
113 "(interface_declaration name: (identifier) @name) @def",
114 ),
115 ),
116 "c" | "h" => (
117 tree_sitter_c::LANGUAGE.into(),
118 "(function_definition declarator: (function_declarator declarator: (identifier) @name)) @def",
119 ),
120 "cpp" | "cc" | "cxx" | "hpp" => (
122 tree_sitter_cpp::LANGUAGE.into(),
123 concat!(
124 "(function_definition declarator: (function_declarator declarator: (identifier) @name)) @def\n",
125 "(class_specifier name: (type_identifier) @name) @def",
126 ),
127 ),
128 _ => return None,
129 };
130 let query = match Query::new(&lang, query_str) {
131 Ok(q) => q,
132 Err(e) => {
133 tracing::warn!(ext, %e, "tree-sitter query compilation failed — language may be ABI-incompatible");
134 return None;
135 }
136 };
137 Some(LangConfig {
138 language: lang,
139 query,
140 })
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn rust_extension_resolves() {
149 assert!(config_for_extension("rs").is_some());
150 }
151
152 #[test]
153 fn python_extension_resolves() {
154 assert!(config_for_extension("py").is_some());
155 }
156
157 #[test]
158 fn unknown_extension_returns_none() {
159 assert!(config_for_extension("xyz").is_none());
160 }
161
162 #[test]
163 fn all_supported_extensions() {
164 let exts = [
165 "rs", "py", "js", "jsx", "ts", "tsx", "go", "java", "c", "h", "cpp", "cc", "cxx", "hpp",
166 ];
167 for ext in &exts {
168 assert!(config_for_extension(ext).is_some(), "failed for {ext}");
169 }
170 }
171}