1use std::sync::{Arc, OnceLock};
8
9use tree_sitter::{Language, Query};
10
11pub struct CallConfig {
16 pub language: Language,
18 pub query: Query,
20}
21
22pub struct LangConfig {
27 pub language: Language,
29 pub query: Query,
31}
32
33#[must_use]
38pub fn config_for_extension(ext: &str) -> Option<Arc<LangConfig>> {
39 static CACHE: OnceLock<std::collections::HashMap<&'static str, Arc<LangConfig>>> =
41 OnceLock::new();
42
43 let cache = CACHE.get_or_init(|| {
44 let mut m = std::collections::HashMap::new();
45 for &ext in &[
47 "rs", "py", "js", "jsx", "ts", "tsx", "go", "java", "c", "h", "cpp", "cc", "cxx",
48 "hpp", "sh", "bash", "bats", "rb", "tf", "tfvars", "hcl", "kt", "kts", "swift",
49 "scala", "toml",
50 ] {
51 if let Some(cfg) = compile_config(ext) {
52 m.insert(ext, Arc::new(cfg));
53 }
54 }
55 m
56 });
57
58 cache.get(ext).cloned()
59}
60
61#[expect(
63 clippy::too_many_lines,
64 reason = "one match arm per language — flat by design"
65)]
66fn compile_config(ext: &str) -> Option<LangConfig> {
67 let (lang, query_str): (Language, &str) = match ext {
68 "rs" => (
72 tree_sitter_rust::LANGUAGE.into(),
73 concat!(
74 "(function_item name: (identifier) @name) @def\n",
75 "(struct_item name: (type_identifier) @name) @def\n",
76 "(enum_item name: (type_identifier) @name) @def\n",
77 "(type_item name: (type_identifier) @name) @def",
78 ),
79 ),
80 "py" => (
83 tree_sitter_python::LANGUAGE.into(),
84 concat!(
85 "(function_definition name: (identifier) @name) @def\n",
86 "(class_definition name: (identifier) @name body: (block) @def)",
87 ),
88 ),
89 "js" | "jsx" => (
91 tree_sitter_javascript::LANGUAGE.into(),
92 concat!(
93 "(function_declaration name: (identifier) @name) @def\n",
94 "(method_definition name: (property_identifier) @name) @def\n",
95 "(class_declaration name: (identifier) @name) @def",
96 ),
97 ),
98 "ts" => (
99 tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
100 concat!(
101 "(function_declaration name: (identifier) @name) @def\n",
102 "(method_definition name: (property_identifier) @name) @def\n",
103 "(class_declaration name: (type_identifier) @name) @def\n",
104 "(interface_declaration name: (type_identifier) @name) @def",
105 ),
106 ),
107 "tsx" => (
108 tree_sitter_typescript::LANGUAGE_TSX.into(),
109 concat!(
110 "(function_declaration name: (identifier) @name) @def\n",
111 "(method_definition name: (property_identifier) @name) @def\n",
112 "(class_declaration name: (type_identifier) @name) @def\n",
113 "(interface_declaration name: (type_identifier) @name) @def",
114 ),
115 ),
116 "go" => (
117 tree_sitter_go::LANGUAGE.into(),
118 concat!(
119 "(function_declaration name: (identifier) @name) @def\n",
120 "(method_declaration name: (field_identifier) @name) @def",
121 ),
122 ),
123 "java" => (
126 tree_sitter_java::LANGUAGE.into(),
127 concat!(
128 "(method_declaration name: (identifier) @name) @def\n",
129 "(class_declaration name: (identifier) @name) @def\n",
130 "(interface_declaration name: (identifier) @name) @def",
131 ),
132 ),
133 "c" | "h" => (
134 tree_sitter_c::LANGUAGE.into(),
135 "(function_definition declarator: (function_declarator declarator: (identifier) @name)) @def",
136 ),
137 "cpp" | "cc" | "cxx" | "hpp" => (
139 tree_sitter_cpp::LANGUAGE.into(),
140 concat!(
141 "(function_definition declarator: (function_declarator declarator: (identifier) @name)) @def\n",
142 "(class_specifier name: (type_identifier) @name) @def",
143 ),
144 ),
145 "sh" | "bash" | "bats" => (
147 tree_sitter_bash::LANGUAGE.into(),
148 "(function_definition name: (word) @name) @def",
149 ),
150 "rb" => (
152 tree_sitter_ruby::LANGUAGE.into(),
153 concat!(
154 "(method name: (identifier) @name) @def\n",
155 "(class name: (constant) @name) @def\n",
156 "(module name: (constant) @name) @def",
157 ),
158 ),
159 "tf" | "tfvars" | "hcl" => (
161 tree_sitter_hcl::LANGUAGE.into(),
162 "(block (identifier) @name) @def",
163 ),
164 "kt" | "kts" => (
166 tree_sitter_kotlin_ng::LANGUAGE.into(),
167 concat!(
168 "(function_declaration name: (identifier) @name) @def\n",
169 "(class_declaration name: (identifier) @name) @def\n",
170 "(object_declaration name: (identifier) @name) @def",
171 ),
172 ),
173 "swift" => (
175 tree_sitter_swift::LANGUAGE.into(),
176 concat!(
177 "(function_declaration name: (simple_identifier) @name) @def\n",
178 "(class_declaration name: (type_identifier) @name) @def\n",
179 "(protocol_declaration name: (type_identifier) @name) @def",
180 ),
181 ),
182 "scala" => (
184 tree_sitter_scala::LANGUAGE.into(),
185 concat!(
186 "(function_definition name: (identifier) @name) @def\n",
187 "(class_definition name: (identifier) @name) @def\n",
188 "(trait_definition name: (identifier) @name) @def\n",
189 "(object_definition name: (identifier) @name) @def",
190 ),
191 ),
192 "toml" => (
194 tree_sitter_toml_ng::LANGUAGE.into(),
195 "(table (bare_key) @name) @def",
196 ),
197 _ => return None,
198 };
199 let query = match Query::new(&lang, query_str) {
200 Ok(q) => q,
201 Err(e) => {
202 tracing::warn!(ext, %e, "tree-sitter query compilation failed — language may be ABI-incompatible");
203 return None;
204 }
205 };
206 Some(LangConfig {
207 language: lang,
208 query,
209 })
210}
211
212#[must_use]
218pub fn call_query_for_extension(ext: &str) -> Option<Arc<CallConfig>> {
219 static CACHE: OnceLock<std::collections::HashMap<&'static str, Arc<CallConfig>>> =
220 OnceLock::new();
221
222 let cache = CACHE.get_or_init(|| {
223 let mut m = std::collections::HashMap::new();
224 for &ext in &[
227 "rs", "py", "js", "jsx", "ts", "tsx", "go", "java", "c", "h", "cpp", "cc", "cxx",
228 "hpp", "sh", "bash", "bats", "rb", "tf", "tfvars", "hcl", "kt", "kts", "swift",
229 "scala",
230 ] {
231 if let Some(cfg) = compile_call_config(ext) {
232 m.insert(ext, Arc::new(cfg));
233 }
234 }
235 m
236 });
237
238 cache.get(ext).cloned()
239}
240
241#[expect(
246 clippy::too_many_lines,
247 reason = "one match arm per language — flat by design"
248)]
249fn compile_call_config(ext: &str) -> Option<CallConfig> {
250 let (lang, query_str): (Language, &str) = match ext {
251 "rs" => (
253 tree_sitter_rust::LANGUAGE.into(),
254 concat!(
255 "(call_expression function: (identifier) @callee) @call\n",
256 "(call_expression function: (field_expression field: (field_identifier) @callee)) @call\n",
257 "(call_expression function: (scoped_identifier name: (identifier) @callee)) @call",
258 ),
259 ),
260 "py" => (
262 tree_sitter_python::LANGUAGE.into(),
263 concat!(
264 "(call function: (identifier) @callee) @call\n",
265 "(call function: (attribute attribute: (identifier) @callee)) @call",
266 ),
267 ),
268 "js" | "jsx" => (
270 tree_sitter_javascript::LANGUAGE.into(),
271 concat!(
272 "(call_expression function: (identifier) @callee) @call\n",
273 "(call_expression function: (member_expression property: (property_identifier) @callee)) @call",
274 ),
275 ),
276 "ts" => (
278 tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
279 concat!(
280 "(call_expression function: (identifier) @callee) @call\n",
281 "(call_expression function: (member_expression property: (property_identifier) @callee)) @call",
282 ),
283 ),
284 "tsx" => (
286 tree_sitter_typescript::LANGUAGE_TSX.into(),
287 concat!(
288 "(call_expression function: (identifier) @callee) @call\n",
289 "(call_expression function: (member_expression property: (property_identifier) @callee)) @call",
290 ),
291 ),
292 "go" => (
294 tree_sitter_go::LANGUAGE.into(),
295 concat!(
296 "(call_expression function: (identifier) @callee) @call\n",
297 "(call_expression function: (selector_expression field: (field_identifier) @callee)) @call",
298 ),
299 ),
300 "java" => (
302 tree_sitter_java::LANGUAGE.into(),
303 "(method_invocation name: (identifier) @callee) @call",
304 ),
305 "c" | "h" => (
307 tree_sitter_c::LANGUAGE.into(),
308 concat!(
309 "(call_expression function: (identifier) @callee) @call\n",
310 "(call_expression function: (field_expression field: (field_identifier) @callee)) @call",
311 ),
312 ),
313 "cpp" | "cc" | "cxx" | "hpp" => (
315 tree_sitter_cpp::LANGUAGE.into(),
316 concat!(
317 "(call_expression function: (identifier) @callee) @call\n",
318 "(call_expression function: (field_expression field: (field_identifier) @callee)) @call",
319 ),
320 ),
321 "sh" | "bash" | "bats" => (
323 tree_sitter_bash::LANGUAGE.into(),
324 "(command name: (command_name (word) @callee)) @call",
325 ),
326 "rb" => (
328 tree_sitter_ruby::LANGUAGE.into(),
329 "(call method: (identifier) @callee) @call",
330 ),
331 "tf" | "tfvars" | "hcl" => (
333 tree_sitter_hcl::LANGUAGE.into(),
334 "(function_call (identifier) @callee) @call",
335 ),
336 "kt" | "kts" => (
339 tree_sitter_kotlin_ng::LANGUAGE.into(),
340 "(call_expression (identifier) @callee) @call",
341 ),
342 "swift" => (
344 tree_sitter_swift::LANGUAGE.into(),
345 "(call_expression (simple_identifier) @callee) @call",
346 ),
347 "scala" => (
349 tree_sitter_scala::LANGUAGE.into(),
350 concat!(
351 "(call_expression function: (identifier) @callee) @call\n",
352 "(call_expression function: (field_expression field: (identifier) @callee)) @call",
353 ),
354 ),
355 _ => return None,
356 };
357 let query = match Query::new(&lang, query_str) {
358 Ok(q) => q,
359 Err(e) => {
360 tracing::warn!(ext, %e, "tree-sitter call query compilation failed");
361 return None;
362 }
363 };
364 Some(CallConfig {
365 language: lang,
366 query,
367 })
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[test]
375 fn rust_extension_resolves() {
376 assert!(config_for_extension("rs").is_some());
377 }
378
379 #[test]
380 fn python_extension_resolves() {
381 assert!(config_for_extension("py").is_some());
382 }
383
384 #[test]
385 fn unknown_extension_returns_none() {
386 assert!(config_for_extension("xyz").is_none());
387 }
388
389 #[test]
390 fn all_supported_extensions() {
391 let exts = [
392 "rs", "py", "js", "jsx", "ts", "tsx", "go", "java", "c", "h", "cpp", "cc", "cxx",
393 "hpp", "sh", "bash", "bats", "rb", "tf", "tfvars", "hcl", "kt", "kts", "swift",
394 "scala", "toml",
395 ];
396 for ext in &exts {
397 assert!(config_for_extension(ext).is_some(), "failed for {ext}");
398 }
399 }
400
401 #[test]
402 fn all_call_query_extensions() {
403 let exts = [
404 "rs", "py", "js", "jsx", "ts", "tsx", "go", "java", "c", "h", "cpp", "cc", "cxx",
405 "hpp", "sh", "bash", "bats", "rb", "tf", "tfvars", "hcl", "kt", "kts", "swift",
406 "scala",
407 ];
408 for ext in &exts {
409 assert!(
410 call_query_for_extension(ext).is_some(),
411 "call query failed for {ext}"
412 );
413 }
414 }
415
416 #[test]
417 fn toml_has_no_call_query() {
418 assert!(call_query_for_extension("toml").is_none());
419 }
420}