1use std::path::Path;
5
6use tree_sitter::Language;
7
8pub struct LanguageSpec {
10 pub extensions: &'static [&'static str],
11 pub symbol_query: &'static str,
12 pub import_query: &'static str,
13 pub call_query: &'static str,
14}
15
16const PYTHON: LanguageSpec = LanguageSpec {
19 extensions: &[".py", ".pyi"],
20 symbol_query: r#"
21 (function_definition name: (identifier) @name) @definition.function
22 (class_definition name: (identifier) @name) @definition.class
23 "#,
24 import_query: r#"
25 (import_statement) @import
26 (import_from_statement) @import
27 "#,
28 call_query: r#"
29 (call function: (identifier) @name) @call
30 (call function: (attribute attribute: (identifier) @name)) @call
31 "#,
32};
33
34const JAVASCRIPT: LanguageSpec = LanguageSpec {
35 extensions: &[".js", ".jsx", ".cjs", ".mjs"],
36 symbol_query: r#"
37 (function_declaration name: (identifier) @name) @definition.function
38 (class_declaration name: (identifier) @name) @definition.class
39 (method_definition name: (property_identifier) @name) @definition.method
40 (lexical_declaration (variable_declarator name: (identifier) @name value: (arrow_function))) @definition.function
41 "#,
42 import_query: r#"
43 (import_statement) @import
44 "#,
45 call_query: r#"
46 (call_expression function: (identifier) @name) @call
47 (call_expression function: (member_expression property: (property_identifier) @name)) @call
48 "#,
49};
50
51const TYPESCRIPT: LanguageSpec = LanguageSpec {
52 extensions: &[".ts", ".tsx"],
53 symbol_query: r#"
54 (function_declaration name: (identifier) @name) @definition.function
55 (class_declaration name: (type_identifier) @name) @definition.class
56 (method_definition name: (property_identifier) @name) @definition.method
57 (interface_declaration name: (type_identifier) @name) @definition.type
58 (type_alias_declaration name: (type_identifier) @name) @definition.type
59 (enum_declaration name: (identifier) @name) @definition.type
60 (lexical_declaration (variable_declarator name: (identifier) @name value: (arrow_function))) @definition.function
61 "#,
62 import_query: r#"
63 (import_statement) @import
64 "#,
65 call_query: r#"
66 (call_expression function: (identifier) @name) @call
67 (call_expression function: (member_expression property: (property_identifier) @name)) @call
68 "#,
69};
70
71const GO: LanguageSpec = LanguageSpec {
72 extensions: &[".go"],
73 symbol_query: r#"
74 (function_declaration name: (identifier) @name) @definition.function
75 (method_declaration name: (field_identifier) @name) @definition.method
76 (type_declaration (type_spec name: (type_identifier) @name)) @definition.type
77 "#,
78 import_query: r#"
79 (import_declaration) @import
80 "#,
81 call_query: r#"
82 (call_expression function: (identifier) @name) @call
83 (call_expression function: (selector_expression field: (field_identifier) @name)) @call
84 "#,
85};
86
87const RUST: LanguageSpec = LanguageSpec {
88 extensions: &[".rs"],
89 symbol_query: r#"
90 (function_item name: (identifier) @name) @definition.function
91 (struct_item name: (type_identifier) @name) @definition.class
92 (enum_item name: (type_identifier) @name) @definition.type
93 (trait_item name: (type_identifier) @name) @definition.type
94 (impl_item type: (type_identifier) @name) @definition.class
95 (type_item name: (type_identifier) @name) @definition.type
96 "#,
97 import_query: r#"
98 (use_declaration) @import
99 "#,
100 call_query: r#"
101 (call_expression function: (identifier) @name) @call
102 (call_expression function: (scoped_identifier name: (identifier) @name)) @call
103 (call_expression function: (field_expression field: (field_identifier) @name)) @call
104 "#,
105};
106
107const JAVA: LanguageSpec = LanguageSpec {
108 extensions: &[".java"],
109 symbol_query: r#"
110 (method_declaration name: (identifier) @name) @definition.method
111 (class_declaration name: (identifier) @name) @definition.class
112 (interface_declaration name: (identifier) @name) @definition.type
113 (enum_declaration name: (identifier) @name) @definition.type
114 (constructor_declaration name: (identifier) @name) @definition.method
115 "#,
116 import_query: r#"
117 (import_declaration) @import
118 "#,
119 call_query: r#"
120 (method_invocation name: (identifier) @name) @call
121 (object_creation_expression type: (type_identifier) @name) @call
122 "#,
123};
124
125const PHP: LanguageSpec = LanguageSpec {
126 extensions: &[".php"],
127 symbol_query: r#"
128 (function_definition name: (name) @name) @definition.function
129 (class_declaration name: (name) @name) @definition.class
130 (method_declaration name: (name) @name) @definition.method
131 (interface_declaration name: (name) @name) @definition.type
132 (trait_declaration name: (name) @name) @definition.type
133 "#,
134 import_query: r#"
135 (namespace_use_declaration) @import
136 "#,
137 call_query: r#"
138 (function_call_expression function: (name) @name) @call
139 (function_call_expression function: (qualified_name) @name) @call
140 (scoped_call_expression scope: [(name) (qualified_name)] name: (name) @name) @call
141 (member_call_expression name: (name) @name) @call
142 (object_creation_expression [(name) (qualified_name)] @name) @call
143 "#,
144};
145
146const DART: LanguageSpec = LanguageSpec {
147 extensions: &[".dart"],
148 symbol_query: r#"
149 (function_declaration signature: (function_signature name: (identifier) @name)) @definition.function
150 (external_function_declaration signature: (function_signature name: (identifier) @name)) @definition.function
151 (class_declaration name: (identifier) @name) @definition.class
152 (method_declaration signature: (method_signature (function_signature name: (identifier) @name))) @definition.method
153 (enum_declaration name: (identifier) @name) @definition.type
154 "#,
155 import_query: r#"
156 (import_or_export) @import
157 "#,
158 call_query: "",
161};
162
163const CSHARP: LanguageSpec = LanguageSpec {
164 extensions: &[".cs"],
165 symbol_query: r#"
166 (method_declaration name: (identifier) @name) @definition.method
167 (class_declaration name: (identifier) @name) @definition.class
168 (interface_declaration name: (identifier) @name) @definition.type
169 (struct_declaration name: (identifier) @name) @definition.type
170 (enum_declaration name: (identifier) @name) @definition.type
171 (constructor_declaration name: (identifier) @name) @definition.method
172 "#,
173 import_query: r#"
174 (using_directive) @import
175 "#,
176 call_query: r#"
177 (invocation_expression function: (identifier) @name) @call
178 (invocation_expression function: (member_access_expression name: (identifier) @name)) @call
179 "#,
180};
181
182const C_LANG: LanguageSpec = LanguageSpec {
183 extensions: &[".c", ".h"],
184 symbol_query: r#"
185 (function_definition declarator: (function_declarator declarator: (identifier) @name)) @definition.function
186 (struct_specifier name: (type_identifier) @name) @definition.type
187 (enum_specifier name: (type_identifier) @name) @definition.type
188 (type_definition declarator: (type_identifier) @name) @definition.type
189 "#,
190 import_query: r#"
191 (preproc_include) @import
192 "#,
193 call_query: r#"
194 (call_expression function: (identifier) @name) @call
195 "#,
196};
197
198const CPP: LanguageSpec = LanguageSpec {
199 extensions: &[".cpp", ".cc", ".cxx", ".hpp", ".hxx", ".hh"],
200 symbol_query: r#"
201 (function_definition declarator: (function_declarator declarator: (identifier) @name)) @definition.function
202 (function_definition declarator: (function_declarator declarator: (qualified_identifier name: (identifier) @name))) @definition.function
203 (class_specifier name: (type_identifier) @name) @definition.class
204 (struct_specifier name: (type_identifier) @name) @definition.type
205 "#,
206 import_query: r#"
207 (preproc_include) @import
208 "#,
209 call_query: r#"
210 (call_expression function: (identifier) @name) @call
211 (call_expression function: (field_expression field: (field_identifier) @name)) @call
212 "#,
213};
214
215const OBJC: LanguageSpec = LanguageSpec {
216 extensions: &[".m", ".mm"],
219 symbol_query: r#"
220 (class_interface "@interface" . (identifier) @name) @definition.class
221 (class_implementation "@implementation" . (identifier) @name) @definition.class
222 (protocol_declaration "@protocol" . (identifier) @name) @definition.type
223 (method_declaration (identifier) @name) @definition.method
224 (method_declaration (method_identifier (identifier) @name)) @definition.method
225 (method_definition (identifier) @name) @definition.method
226 (method_definition (method_identifier (identifier) @name)) @definition.method
227 (function_definition declarator: (function_declarator declarator: (identifier) @name)) @definition.function
228 (declaration declarator: (function_declarator declarator: (identifier) @name)) @definition.function
229 (struct_specifier name: (type_identifier) @name) @definition.type
230 (enum_specifier name: (type_identifier) @name) @definition.type
231 (type_definition declarator: (type_identifier) @name) @definition.type
232 "#,
233 import_query: r#"
234 (preproc_include) @import
235 "#,
236 call_query: r#"
237 (call_expression function: (identifier) @name) @call
238 (message_expression receiver: (_) @receiver method: (identifier) @name) @call
239 "#,
240};
241
242const ELIXIR: LanguageSpec = LanguageSpec {
243 extensions: &[".ex", ".exs"],
244 symbol_query: r#"
245 (call
246 target: (identifier) @_keyword
247 (#any-of? @_keyword "def" "defp" "defmacro")
248 (arguments [
249 (identifier) @name
250 (call target: (identifier) @name)
251 (binary_operator
252 left: (call target: (identifier) @name)
253 operator: "when")
254 ])) @definition.function
255 (call target: (identifier) @_keyword (#any-of? @_keyword "defmodule") (arguments (alias) @name)) @definition.class
256 "#,
257 import_query: r#"
258 (call target: (identifier) @_keyword (#any-of? @_keyword "import" "alias" "use" "require")) @import
259 "#,
260 call_query: r#"
261 (call target: (identifier) @name) @call
262 (call target: (dot right: (identifier) @name)) @call
263 "#,
264};
265
266const RUBY: LanguageSpec = LanguageSpec {
267 extensions: &[".rb", ".rake", ".gemspec"],
268 symbol_query: r#"
269 (method name: (identifier) @name) @definition.function
270 (singleton_method name: (identifier) @name) @definition.function
271 (class name: (constant) @name) @definition.class
272 (module name: (constant) @name) @definition.class
273 "#,
274 import_query: r#"
275 (call method: (identifier) @_m (#any-of? @_m "require" "require_relative" "load" "include" "extend" "prepend")) @import
276 "#,
277 call_query: r#"
278 (call method: (identifier) @name) @call
279 "#,
280};
281
282const KOTLIN: LanguageSpec = LanguageSpec {
283 extensions: &[".kt", ".kts"],
284 symbol_query: r#"
285 (function_declaration name: (identifier) @name) @definition.function
286 (class_declaration name: (identifier) @name) @definition.class
287 (object_declaration name: (identifier) @name) @definition.class
288 "#,
289 import_query: r#"
290 (import) @import
291 "#,
292 call_query: r#"
293 (call_expression (identifier) @name) @call
294 (call_expression (navigation_expression (identifier) (identifier) @name)) @call
295 "#,
296};
297
298const SCALA: LanguageSpec = LanguageSpec {
299 extensions: &[".scala", ".sc"],
300 symbol_query: r#"
301 (class_definition name: [(identifier) (operator_identifier)] @name) @definition.class
302 (object_definition name: [(identifier) (operator_identifier)] @name) @definition.class
303 (trait_definition name: [(identifier) (operator_identifier)] @name) @definition.type
304 (function_definition name: [(identifier) (operator_identifier)] @name) @definition.function
305 "#,
306 import_query: r#"
307 (import_declaration) @import
308 "#,
309 call_query: r#"
310 (call_expression function: (identifier) @name) @call
311 (call_expression function: (field_expression field: (identifier) @name)) @call
312 (call_expression function: (generic_function function: (identifier) @name)) @call
313 (call_expression function: (generic_function function: (field_expression field: (identifier) @name))) @call
314 (instance_expression (type_identifier) @name) @call
315 (instance_expression (generic_type type: (type_identifier) @name)) @call
316"#,
317};
318
319const LUA: LanguageSpec = LanguageSpec {
320 extensions: &[".lua"],
321 symbol_query: r#"
322(function_declaration
323 name: [
324 (identifier) @name
325 (dot_index_expression field: (identifier) @name)
326 ]) @definition.function
327(function_declaration
328 name: (method_index_expression method: (identifier) @name)) @definition.method
329(assignment_statement
330 (variable_list
331 .
332 name: [
333 (identifier) @name
334 (dot_index_expression field: (identifier) @name)
335 ])
336 (expression_list
337 .
338 value: (function_definition))) @definition.function
339(table_constructor
340 (field
341 name: (identifier) @name
342 value: (function_definition))) @definition.function
343"#,
344 import_query: r#"
345(assignment_statement
346 (expression_list
347 (function_call name: (identifier) @_require (#eq? @_require "require")))) @import
348(assignment_statement
349 (expression_list
350 (dot_index_expression
351 table: (function_call name: (identifier) @_require (#eq? @_require "require"))))) @import
352"#,
353 call_query: r#"
354(function_call
355 name: [
356 (identifier) @name
357 (dot_index_expression field: (identifier) @name)
358 (method_index_expression method: (identifier) @name)
359 ]) @call
360"#,
361};
362
363const YAML: LanguageSpec = LanguageSpec {
364 extensions: &[".yaml", ".yml"],
365 symbol_query: r#"
366 (block_mapping_pair key: (_) @name) @definition.property
367 "#,
368 import_query: "",
369 call_query: "",
370};
371
372const JSON_LANG: LanguageSpec = LanguageSpec {
373 extensions: &[".json", ".jsonc"],
374 symbol_query: r#"
375 (pair key: (string (string_content) @name)) @definition.property
376 "#,
377 import_query: "",
378 call_query: "",
379};
380
381const SWIFT: LanguageSpec = LanguageSpec {
382 extensions: &[".swift"],
383 symbol_query: r#"
384 (function_declaration name: (simple_identifier) @name) @definition.function
385 (class_declaration declaration_kind: "class" name: (type_identifier) @name) @definition.class
386 (class_declaration declaration_kind: "actor" name: (type_identifier) @name) @definition.class
387 (protocol_declaration name: (type_identifier) @name) @definition.type
388 (class_declaration declaration_kind: "struct" name: (type_identifier) @name) @definition.type
389 (class_declaration declaration_kind: "enum" name: (type_identifier) @name) @definition.type
390 "#,
391 import_query: r#"
392 (import_declaration) @import
393 "#,
394 call_query: r#"
395 (call_expression (simple_identifier) @name) @call
396 (call_expression (navigation_expression suffix: (navigation_suffix suffix: (simple_identifier) @name))) @call
397 "#,
398};
399
400const BASH: LanguageSpec = LanguageSpec {
401 extensions: &[".sh", ".bash"],
402 symbol_query: r#"
403 (function_definition name: (word) @name) @definition.function
404 "#,
405 import_query: r#"
406 (command
407 name: (command_name) @_cmd
408 (#any-of? @_cmd "source" ".")) @import
409 "#,
410 call_query: r#"
411 (command name: (command_name) @name) @call
412 "#,
413};
414
415const SPECS: &[(&str, &LanguageSpec)] = &[
419 ("python", &PYTHON),
420 ("javascript", &JAVASCRIPT),
421 ("typescript", &TYPESCRIPT),
422 ("go", &GO),
423 ("rust", &RUST),
424 ("java", &JAVA),
425 ("php", &PHP),
426 ("dart", &DART),
427 ("csharp", &CSHARP),
428 ("objc", &OBJC),
429 ("c", &C_LANG),
430 ("cpp", &CPP),
431 ("elixir", &ELIXIR),
432 ("ruby", &RUBY),
433 ("kotlin", &KOTLIN),
434 ("scala", &SCALA),
435 ("lua", &LUA),
436 ("swift", &SWIFT),
437 ("bash", &BASH),
438 ("yaml", &YAML),
439 ("json", &JSON_LANG),
440];
441
442pub fn detect_language(file_path: &str) -> Option<&'static str> {
444 let path = Path::new(file_path);
445 let ext = path
446 .extension()
447 .map(|e| format!(".{}", e.to_string_lossy().to_lowercase()))?;
448
449 if ext == ".h" {
450 return Some(detect_header_language(path));
451 }
452
453 for (name, spec) in SPECS {
454 if spec.extensions.contains(&ext.as_str()) {
455 return Some(name);
456 }
457 }
458 None
459}
460
461fn detect_header_language(path: &Path) -> &'static str {
462 if objc_header_has_sibling_implementation(path) {
463 return "objc";
464 }
465
466 let Some(source) = std::fs::read(path)
467 .ok()
468 .map(|bytes| String::from_utf8_lossy(&bytes).into_owned())
469 else {
470 return "c";
471 };
472
473 if source_contains_objc_header_signal(&source) {
474 "objc"
475 } else if source_contains_cpp_header_signal(&source) {
476 "cpp"
477 } else {
478 "c"
479 }
480}
481
482fn objc_header_has_sibling_implementation(path: &Path) -> bool {
483 path.with_extension("m").is_file() || path.with_extension("mm").is_file()
484}
485
486fn source_contains_objc_header_signal(source: &str) -> bool {
487 source_contains_header_signal(source, |bytes, idx| {
488 objc_directive_at(bytes, idx, b"@interface")
489 || objc_directive_at(bytes, idx, b"@protocol")
490 || objc_directive_at(bytes, idx, b"@import")
491 })
492}
493
494fn source_contains_cpp_header_signal(source: &str) -> bool {
495 source_contains_header_signal(source, |bytes, idx| {
496 c_like_keyword_at(bytes, idx, b"class")
497 || c_like_keyword_at(bytes, idx, b"namespace")
498 || c_like_keyword_at(bytes, idx, b"template")
499 })
500}
501
502fn source_contains_header_signal<F>(source: &str, mut signal_at: F) -> bool
503where
504 F: FnMut(&[u8], usize) -> bool,
505{
506 let bytes = source.as_bytes();
507 let mut idx = 0;
508 while idx < bytes.len() {
509 match bytes[idx] {
510 b'/' if bytes.get(idx + 1) == Some(&b'/') => {
511 idx += 2;
512 while idx < bytes.len() && bytes[idx] != b'\n' {
513 idx += 1;
514 }
515 }
516 b'/' if bytes.get(idx + 1) == Some(&b'*') => {
517 idx += 2;
518 while idx + 1 < bytes.len() && !(bytes[idx] == b'*' && bytes[idx + 1] == b'/') {
519 idx += 1;
520 }
521 idx = (idx + 2).min(bytes.len());
522 }
523 b'"' | b'\'' => idx = skip_quoted(bytes, idx),
524 _ => {
525 if signal_at(bytes, idx) {
526 return true;
527 }
528 idx += 1;
529 }
530 }
531 }
532 false
533}
534
535fn skip_quoted(bytes: &[u8], start: usize) -> usize {
536 let quote = bytes[start];
537 let mut idx = start + 1;
538 while idx < bytes.len() {
539 if bytes[idx] == b'\\' {
540 idx = (idx + 2).min(bytes.len());
541 } else if bytes[idx] == quote {
542 return idx + 1;
543 } else {
544 idx += 1;
545 }
546 }
547 bytes.len()
548}
549
550fn objc_directive_at(bytes: &[u8], idx: usize, directive: &[u8]) -> bool {
551 literal_at(bytes, idx, directive)
552 && bytes
553 .get(idx + directive.len())
554 .is_none_or(|byte| !is_ascii_identifier_byte(*byte))
555}
556
557fn c_like_keyword_at(bytes: &[u8], idx: usize, keyword: &[u8]) -> bool {
558 literal_at(bytes, idx, keyword)
559 && idx
560 .checked_sub(1)
561 .and_then(|previous| bytes.get(previous))
562 .is_none_or(|byte| !is_ascii_identifier_byte(*byte) && *byte != b'@')
563 && bytes
564 .get(idx + keyword.len())
565 .is_none_or(|byte| !is_ascii_identifier_byte(*byte))
566}
567
568fn literal_at(bytes: &[u8], idx: usize, literal: &[u8]) -> bool {
569 bytes.get(idx..idx + literal.len()) == Some(literal)
570}
571
572fn is_ascii_identifier_byte(byte: u8) -> bool {
573 byte.is_ascii_alphanumeric() || byte == b'_'
574}
575
576pub fn get_spec(lang: &str) -> Option<&'static LanguageSpec> {
578 SPECS
579 .iter()
580 .find(|(name, _)| *name == lang)
581 .map(|(_, s)| *s)
582}
583
584pub fn is_data_language(lang: &str) -> bool {
592 get_spec(lang)
593 .map(|spec| spec.import_query.is_empty() && spec.call_query.is_empty())
594 .unwrap_or(false)
595}
596
597pub fn get_ts_language(lang: &str) -> Option<Language> {
599 let lang_fn = match lang {
600 "python" => tree_sitter_python::LANGUAGE,
601 "javascript" => tree_sitter_javascript::LANGUAGE,
602 "typescript" => tree_sitter_typescript::LANGUAGE_TYPESCRIPT,
603 "go" => tree_sitter_go::LANGUAGE,
604 "rust" => tree_sitter_rust::LANGUAGE,
605 "java" => tree_sitter_java::LANGUAGE,
606 "objc" => tree_sitter_objc::LANGUAGE,
607 "c" => tree_sitter_c::LANGUAGE,
608 "cpp" => tree_sitter_cpp::LANGUAGE,
609 "csharp" => tree_sitter_c_sharp::LANGUAGE,
610 "ruby" => tree_sitter_ruby::LANGUAGE,
611 "php" => tree_sitter_php::LANGUAGE_PHP,
612 "swift" => tree_sitter_swift::LANGUAGE,
613 "kotlin" => tree_sitter_kotlin_ng::LANGUAGE,
614 "scala" => tree_sitter_scala::LANGUAGE,
615 "lua" => tree_sitter_lua::LANGUAGE,
616 "dart" => tree_sitter_dart::LANGUAGE,
617 "elixir" => tree_sitter_elixir::LANGUAGE,
618 "bash" => tree_sitter_bash::LANGUAGE,
619 "json" => tree_sitter_json::LANGUAGE,
620 "yaml" => tree_sitter_yaml::LANGUAGE,
621 _ => return None,
622 };
623 Some(lang_fn.into())
624}
625
626pub fn get_ts_language_for_path(lang: &str, file_path: &str) -> Option<Language> {
628 if lang == "typescript"
629 && std::path::Path::new(file_path)
630 .extension()
631 .map(|ext| ext.to_string_lossy().eq_ignore_ascii_case("tsx"))
632 .unwrap_or(false)
633 {
634 return Some(tree_sitter_typescript::LANGUAGE_TSX.into());
635 }
636
637 get_ts_language(lang)
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643
644 #[test]
645 fn markdown_extensions_are_not_detected() {
646 assert_eq!(detect_language("README.md"), None);
648 assert_eq!(detect_language("docs/guide.markdown"), None);
649 }
650
651 #[test]
652 fn javascript_extensions_still_detect() {
653 assert_eq!(detect_language("src/app.js"), Some("javascript"));
654 assert_eq!(detect_language("src/app.jsx"), Some("javascript"));
655 assert_eq!(detect_language("src/app.cjs"), Some("javascript"));
656 assert_eq!(detect_language("src/generated.mjs"), Some("javascript"));
657 }
658
659 #[test]
660 fn typescript_extensions_still_detect() {
661 assert_eq!(detect_language("src/app.ts"), Some("typescript"));
662 assert_eq!(detect_language("src/app.tsx"), Some("typescript"));
663 }
664
665 #[test]
666 fn bash_extensions_detect() {
667 assert_eq!(detect_language("scripts/deploy.sh"), Some("bash"));
668 assert_eq!(detect_language("scripts/env.bash"), Some("bash"));
669 }
670
671 #[test]
672 fn scala_extensions_detect() {
673 assert_eq!(detect_language("src/main/scala/App.scala"), Some("scala"));
674 assert_eq!(detect_language("scripts/build.sc"), Some("scala"));
675 }
676
677 #[test]
678 fn lua_extensions_detect() {
679 assert_eq!(detect_language("lua/app/init.lua"), Some("lua"));
680 }
681
682 #[test]
683 fn objc_extensions_detect() {
684 assert_eq!(detect_language("Sources/App/Widget.m"), Some("objc"));
685 assert_eq!(detect_language("Sources/App/Widget.mm"), Some("objc"));
686 }
687
688 #[test]
689 fn c_header_detects_without_objc_or_cpp_signal() {
690 assert_eq!(detect_language("Sources/App/Widget.h"), Some("c"));
691 }
692
693 #[test]
694 fn objc_header_detects_declaration_signal() {
695 let tempdir = tempfile::TempDir::new().expect("create tempdir");
696 let header = tempdir.path().join("Widget.h");
697 std::fs::write(
698 &header,
699 r#"
700@interface Widget
701- (void)render;
702@end
703"#,
704 )
705 .expect("write header");
706
707 assert_eq!(detect_language(&header.to_string_lossy()), Some("objc"));
708 }
709
710 #[test]
711 fn objc_header_detects_sibling_implementation_signal() {
712 let tempdir = tempfile::TempDir::new().expect("create tempdir");
713 let header = tempdir.path().join("Widget.h");
714 std::fs::write(&header, "void WidgetRender(void);\n").expect("write header");
715 std::fs::write(
716 tempdir.path().join("Widget.m"),
717 "void WidgetRender(void) {}\n",
718 )
719 .expect("write implementation");
720
721 assert_eq!(detect_language(&header.to_string_lossy()), Some("objc"));
722 }
723
724 #[test]
725 fn cpp_header_detects_cpp_signal() {
726 let tempdir = tempfile::TempDir::new().expect("create tempdir");
727 let header = tempdir.path().join("Widget.h");
728 std::fs::write(
729 &header,
730 r#"
731namespace app {
732template <typename T>
733class Widget {};
734}
735"#,
736 )
737 .expect("write header");
738
739 assert_eq!(detect_language(&header.to_string_lossy()), Some("cpp"));
740 }
741
742 #[test]
743 fn objcxx_paths_use_objc_grammar() {
744 let language = get_ts_language_for_path("objc", "Sources/App/Widget.mm").unwrap();
745 assert!(parses_without_error(
746 language,
747 r#"
748@interface Widget
749- (void)render;
750@end
751
752@implementation Widget
753- (void)render { helper(); }
754@end
755"#,
756 ));
757 }
758
759 #[test]
760 fn tsx_paths_use_tsx_grammar() {
761 let language = get_ts_language_for_path("typescript", "src/app.tsx").unwrap();
762 assert!(parses_without_error(
763 language,
764 "export const View = () => <section data-id=\"x\" />;",
765 ));
766 }
767
768 #[test]
769 fn ts_paths_keep_typescript_grammar() {
770 let language = get_ts_language_for_path("typescript", "src/app.ts").unwrap();
771 assert!(parses_with_error(
772 language,
773 "export const View = () => <section />;"
774 ));
775 }
776
777 fn parses_without_error(language: Language, source: &str) -> bool {
778 let mut parser = tree_sitter::Parser::new();
779 parser.set_language(&language).unwrap();
780 let tree = parser.parse(source, None).unwrap();
781 !tree.root_node().has_error()
782 }
783
784 fn parses_with_error(language: Language, source: &str) -> bool {
785 let mut parser = tree_sitter::Parser::new();
786 parser.set_language(&language).unwrap();
787 let tree = parser.parse(source, None).unwrap();
788 tree.root_node().has_error()
789 }
790
791 #[test]
792 fn is_data_language_matches_only_json_and_yaml() {
793 assert!(is_data_language("json"));
794 assert!(is_data_language("yaml"));
795 assert!(!is_data_language("rust"));
797 assert!(!is_data_language("python"));
798 assert!(!is_data_language("dart"));
800 assert!(!is_data_language("not_a_language"));
802 }
803}