Skip to main content

graphify_extract/
lang_config.rs

1//! Per-language configuration for tree-sitter–based extraction.
2//!
3//! Each [`LanguageConfig`] describes the AST node types that correspond to
4//! classes, functions, imports, and calls for a given language. The actual
5//! tree-sitter grammar crates are not yet wired in; this module is
6//! **configuration only**.
7
8use std::collections::HashSet;
9
10/// Metadata describing how to extract entities from a language's AST.
11#[derive(Debug, Clone)]
12pub struct LanguageConfig {
13    /// Human-readable language name.
14    pub name: &'static str,
15    /// Crate that provides the tree-sitter grammar (e.g. `"tree_sitter_python"`).
16    pub ts_crate: &'static str,
17    /// AST node types that represent class-like definitions.
18    pub class_types: HashSet<&'static str>,
19    /// AST node types that represent function/method definitions.
20    pub function_types: HashSet<&'static str>,
21    /// AST node types that represent import statements.
22    pub import_types: HashSet<&'static str>,
23    /// AST node types that represent function calls.
24    pub call_types: HashSet<&'static str>,
25    /// Field name for the entity name within the AST node.
26    pub name_field: &'static str,
27    /// Field name for the body of a class/function.
28    pub body_field: &'static str,
29    /// Field name for the function being called in a call expression.
30    pub call_function_field: &'static str,
31    /// AST node types that delimit function boundaries (for scope analysis).
32    pub function_boundary_types: HashSet<&'static str>,
33    /// Whether the language uses parenthesised labels (e.g. Go `func (r *Recv) Name()`).
34    pub function_label_parens: bool,
35}
36
37fn hs(items: &[&'static str]) -> HashSet<&'static str> {
38    items.iter().copied().collect()
39}
40
41// Language configs (matching the Python extract.py)
42
43pub fn python_config() -> LanguageConfig {
44    LanguageConfig {
45        name: "python",
46        ts_crate: "tree_sitter_python",
47        class_types: hs(&["class_definition"]),
48        function_types: hs(&["function_definition"]),
49        import_types: hs(&["import_statement", "import_from_statement"]),
50        call_types: hs(&["call"]),
51        name_field: "name",
52        body_field: "body",
53        call_function_field: "function",
54        function_boundary_types: hs(&["function_definition", "class_definition"]),
55        function_label_parens: false,
56    }
57}
58
59pub fn javascript_config() -> LanguageConfig {
60    LanguageConfig {
61        name: "javascript",
62        ts_crate: "tree_sitter_javascript",
63        class_types: hs(&["class_declaration", "class"]),
64        function_types: hs(&[
65            "function_declaration",
66            "method_definition",
67            "arrow_function",
68            "function",
69        ]),
70        import_types: hs(&["import_statement"]),
71        call_types: hs(&["call_expression"]),
72        name_field: "name",
73        body_field: "body",
74        call_function_field: "function",
75        function_boundary_types: hs(&[
76            "function_declaration",
77            "method_definition",
78            "arrow_function",
79            "class_declaration",
80        ]),
81        function_label_parens: false,
82    }
83}
84
85pub fn typescript_config() -> LanguageConfig {
86    let mut cfg = javascript_config();
87    cfg.name = "typescript";
88    cfg.ts_crate = "tree_sitter_typescript";
89    cfg.class_types.insert("abstract_class_declaration");
90    cfg.function_types.insert("function_signature");
91    cfg.import_types.insert("import_statement");
92    cfg
93}
94
95pub fn go_config() -> LanguageConfig {
96    LanguageConfig {
97        name: "go",
98        ts_crate: "tree_sitter_go",
99        class_types: hs(&["type_declaration", "type_spec"]),
100        function_types: hs(&["function_declaration", "method_declaration"]),
101        import_types: hs(&["import_declaration", "import_spec"]),
102        call_types: hs(&["call_expression"]),
103        name_field: "name",
104        body_field: "body",
105        call_function_field: "function",
106        function_boundary_types: hs(&["function_declaration", "method_declaration"]),
107        function_label_parens: true,
108    }
109}
110
111pub fn rust_config() -> LanguageConfig {
112    LanguageConfig {
113        name: "rust",
114        ts_crate: "tree_sitter_rust",
115        class_types: hs(&["struct_item", "enum_item", "trait_item", "impl_item"]),
116        function_types: hs(&["function_item"]),
117        import_types: hs(&["use_declaration"]),
118        call_types: hs(&["call_expression"]),
119        name_field: "name",
120        body_field: "body",
121        call_function_field: "function",
122        function_boundary_types: hs(&["function_item", "impl_item"]),
123        function_label_parens: false,
124    }
125}
126
127pub fn java_config() -> LanguageConfig {
128    LanguageConfig {
129        name: "java",
130        ts_crate: "tree_sitter_java",
131        class_types: hs(&[
132            "class_declaration",
133            "interface_declaration",
134            "enum_declaration",
135            "annotation_type_declaration",
136        ]),
137        function_types: hs(&["method_declaration", "constructor_declaration"]),
138        import_types: hs(&["import_declaration"]),
139        call_types: hs(&["method_invocation"]),
140        name_field: "name",
141        body_field: "body",
142        call_function_field: "name",
143        function_boundary_types: hs(&[
144            "method_declaration",
145            "constructor_declaration",
146            "class_declaration",
147        ]),
148        function_label_parens: false,
149    }
150}
151
152pub fn c_config() -> LanguageConfig {
153    LanguageConfig {
154        name: "c",
155        ts_crate: "tree_sitter_c",
156        class_types: hs(&["struct_specifier", "enum_specifier", "union_specifier"]),
157        function_types: hs(&["function_definition"]),
158        import_types: hs(&["preproc_include"]),
159        call_types: hs(&["call_expression"]),
160        name_field: "declarator",
161        body_field: "body",
162        call_function_field: "function",
163        function_boundary_types: hs(&["function_definition"]),
164        function_label_parens: false,
165    }
166}
167
168pub fn cpp_config() -> LanguageConfig {
169    let mut cfg = c_config();
170    cfg.name = "cpp";
171    cfg.ts_crate = "tree_sitter_cpp";
172    cfg.class_types.insert("class_specifier");
173    cfg.class_types.insert("namespace_definition");
174    cfg.function_types.insert("function_definition");
175    cfg
176}
177
178pub fn ruby_config() -> LanguageConfig {
179    LanguageConfig {
180        name: "ruby",
181        ts_crate: "tree_sitter_ruby",
182        class_types: hs(&["class", "module"]),
183        function_types: hs(&["method", "singleton_method"]),
184        import_types: hs(&["call"]), // require/include are calls in Ruby
185        call_types: hs(&["call", "command"]),
186        name_field: "name",
187        body_field: "body",
188        call_function_field: "method",
189        function_boundary_types: hs(&["method", "singleton_method", "class"]),
190        function_label_parens: false,
191    }
192}
193
194pub fn csharp_config() -> LanguageConfig {
195    LanguageConfig {
196        name: "csharp",
197        ts_crate: "tree_sitter_c_sharp",
198        class_types: hs(&[
199            "class_declaration",
200            "interface_declaration",
201            "struct_declaration",
202            "enum_declaration",
203        ]),
204        function_types: hs(&["method_declaration", "constructor_declaration"]),
205        import_types: hs(&["using_directive"]),
206        call_types: hs(&["invocation_expression"]),
207        name_field: "name",
208        body_field: "body",
209        call_function_field: "function",
210        function_boundary_types: hs(&[
211            "method_declaration",
212            "constructor_declaration",
213            "class_declaration",
214        ]),
215        function_label_parens: false,
216    }
217}
218
219pub fn kotlin_config() -> LanguageConfig {
220    LanguageConfig {
221        name: "kotlin",
222        ts_crate: "tree_sitter_kotlin",
223        class_types: hs(&[
224            "class_declaration",
225            "object_declaration",
226            "interface_declaration",
227        ]),
228        function_types: hs(&["function_declaration"]),
229        import_types: hs(&["import_header"]),
230        call_types: hs(&["call_expression"]),
231        name_field: "name",
232        body_field: "body",
233        call_function_field: "function",
234        function_boundary_types: hs(&["function_declaration", "class_declaration"]),
235        function_label_parens: false,
236    }
237}
238
239pub fn scala_config() -> LanguageConfig {
240    LanguageConfig {
241        name: "scala",
242        ts_crate: "tree_sitter_scala",
243        class_types: hs(&["class_definition", "object_definition", "trait_definition"]),
244        function_types: hs(&["function_definition", "val_definition"]),
245        import_types: hs(&["import_declaration"]),
246        call_types: hs(&["call_expression"]),
247        name_field: "name",
248        body_field: "body",
249        call_function_field: "function",
250        function_boundary_types: hs(&["function_definition", "class_definition"]),
251        function_label_parens: false,
252    }
253}
254
255pub fn php_config() -> LanguageConfig {
256    LanguageConfig {
257        name: "php",
258        ts_crate: "tree_sitter_php",
259        class_types: hs(&[
260            "class_declaration",
261            "interface_declaration",
262            "trait_declaration",
263        ]),
264        function_types: hs(&["function_definition", "method_declaration"]),
265        import_types: hs(&["namespace_use_declaration"]),
266        call_types: hs(&["function_call_expression", "method_call_expression"]),
267        name_field: "name",
268        body_field: "body",
269        call_function_field: "function",
270        function_boundary_types: hs(&[
271            "function_definition",
272            "method_declaration",
273            "class_declaration",
274        ]),
275        function_label_parens: false,
276    }
277}
278
279pub fn swift_config() -> LanguageConfig {
280    LanguageConfig {
281        name: "swift",
282        ts_crate: "tree_sitter_swift",
283        class_types: hs(&[
284            "class_declaration",
285            "struct_declaration",
286            "protocol_declaration",
287            "enum_declaration",
288        ]),
289        function_types: hs(&["function_declaration", "init_declaration"]),
290        import_types: hs(&["import_declaration"]),
291        call_types: hs(&["call_expression"]),
292        name_field: "name",
293        body_field: "body",
294        call_function_field: "function",
295        function_boundary_types: hs(&["function_declaration", "class_declaration"]),
296        function_label_parens: false,
297    }
298}
299
300pub fn lua_config() -> LanguageConfig {
301    LanguageConfig {
302        name: "lua",
303        ts_crate: "tree_sitter_lua",
304        class_types: hs(&[]),
305        function_types: hs(&[
306            "function_declaration",
307            "local_function_declaration",
308            "function_definition",
309        ]),
310        import_types: hs(&[]), // require() is a call
311        call_types: hs(&["function_call"]),
312        name_field: "name",
313        body_field: "body",
314        call_function_field: "name",
315        function_boundary_types: hs(&["function_declaration", "local_function_declaration"]),
316        function_label_parens: false,
317    }
318}
319
320pub fn zig_config() -> LanguageConfig {
321    LanguageConfig {
322        name: "zig",
323        ts_crate: "tree_sitter_zig",
324        class_types: hs(&["container_declaration"]),
325        function_types: hs(&["fn_proto"]),
326        import_types: hs(&[]),
327        call_types: hs(&["call_expr"]),
328        name_field: "name",
329        body_field: "body",
330        call_function_field: "function",
331        function_boundary_types: hs(&["fn_proto"]),
332        function_label_parens: false,
333    }
334}
335
336pub fn powershell_config() -> LanguageConfig {
337    LanguageConfig {
338        name: "powershell",
339        ts_crate: "tree_sitter_powershell",
340        class_types: hs(&["class_statement"]),
341        function_types: hs(&["function_statement"]),
342        import_types: hs(&["using_statement"]),
343        call_types: hs(&["command_expression"]),
344        name_field: "name",
345        body_field: "body",
346        call_function_field: "name",
347        function_boundary_types: hs(&["function_statement"]),
348        function_label_parens: false,
349    }
350}
351
352pub fn elixir_config() -> LanguageConfig {
353    LanguageConfig {
354        name: "elixir",
355        ts_crate: "tree_sitter_elixir",
356        class_types: hs(&["call"]),    // defmodule is a call in Elixir
357        function_types: hs(&["call"]), // def/defp are calls too
358        import_types: hs(&["call"]),   // import/use/require
359        call_types: hs(&["call"]),
360        name_field: "target",
361        body_field: "body",
362        call_function_field: "target",
363        function_boundary_types: hs(&["call"]),
364        function_label_parens: false,
365    }
366}
367
368pub fn objc_config() -> LanguageConfig {
369    LanguageConfig {
370        name: "objc",
371        ts_crate: "tree_sitter_objc",
372        class_types: hs(&[
373            "class_interface",
374            "class_implementation",
375            "protocol_declaration",
376        ]),
377        function_types: hs(&["method_definition", "function_definition"]),
378        import_types: hs(&["preproc_import", "preproc_include"]),
379        call_types: hs(&["message_expression", "call_expression"]),
380        name_field: "name",
381        body_field: "body",
382        call_function_field: "selector",
383        function_boundary_types: hs(&["method_definition", "function_definition"]),
384        function_label_parens: false,
385    }
386}
387
388pub fn julia_config() -> LanguageConfig {
389    LanguageConfig {
390        name: "julia",
391        ts_crate: "tree_sitter_julia",
392        class_types: hs(&["struct_definition", "abstract_definition"]),
393        function_types: hs(&["function_definition", "short_function_definition"]),
394        import_types: hs(&["import_statement", "using_statement"]),
395        call_types: hs(&["call_expression"]),
396        name_field: "name",
397        body_field: "body",
398        call_function_field: "function",
399        function_boundary_types: hs(&["function_definition"]),
400        function_label_parens: false,
401    }
402}
403
404pub fn dart_config() -> LanguageConfig {
405    LanguageConfig {
406        name: "dart",
407        ts_crate: "tree_sitter_dart",
408        class_types: hs(&[
409            "class_definition",
410            "enum_declaration",
411            "mixin_declaration",
412            "extension_declaration",
413        ]),
414        function_types: hs(&[
415            "function_signature",
416            "method_signature",
417            "function_body",
418            "method_definition",
419        ]),
420        import_types: hs(&["import_or_export"]),
421        call_types: hs(&["method_invocation", "function_expression_invocation"]),
422        name_field: "name",
423        body_field: "body",
424        call_function_field: "function",
425        function_boundary_types: hs(&[
426            "method_definition",
427            "function_signature",
428            "class_definition",
429        ]),
430        function_label_parens: false,
431    }
432}
433
434/// Return the [`LanguageConfig`] for the given language name.
435pub fn config_for_language(lang: &str) -> Option<LanguageConfig> {
436    match lang {
437        "python" => Some(python_config()),
438        "javascript" => Some(javascript_config()),
439        "typescript" => Some(typescript_config()),
440        "go" => Some(go_config()),
441        "rust" => Some(rust_config()),
442        "java" => Some(java_config()),
443        "c" => Some(c_config()),
444        "cpp" => Some(cpp_config()),
445        "ruby" => Some(ruby_config()),
446        "csharp" => Some(csharp_config()),
447        "kotlin" => Some(kotlin_config()),
448        "scala" => Some(scala_config()),
449        "php" => Some(php_config()),
450        "swift" => Some(swift_config()),
451        "lua" => Some(lua_config()),
452        "zig" => Some(zig_config()),
453        "powershell" => Some(powershell_config()),
454        "elixir" => Some(elixir_config()),
455        "objc" => Some(objc_config()),
456        "julia" => Some(julia_config()),
457        "dart" => Some(dart_config()),
458        _ => None,
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn all_21_languages_have_configs() {
468        let languages = [
469            "python",
470            "javascript",
471            "typescript",
472            "go",
473            "rust",
474            "java",
475            "c",
476            "cpp",
477            "ruby",
478            "csharp",
479            "kotlin",
480            "scala",
481            "php",
482            "swift",
483            "lua",
484            "zig",
485            "powershell",
486            "elixir",
487            "objc",
488            "julia",
489            "dart",
490        ];
491        for lang in languages {
492            assert!(
493                config_for_language(lang).is_some(),
494                "missing config for {lang}"
495            );
496        }
497        assert_eq!(languages.len(), 21);
498    }
499
500    #[test]
501    fn python_config_has_expected_types() {
502        let cfg = python_config();
503        assert!(cfg.class_types.contains("class_definition"));
504        assert!(cfg.function_types.contains("function_definition"));
505        assert!(cfg.import_types.contains("import_statement"));
506        assert!(cfg.import_types.contains("import_from_statement"));
507    }
508
509    #[test]
510    fn typescript_extends_javascript() {
511        let ts = typescript_config();
512        assert!(ts.class_types.contains("abstract_class_declaration"));
513        assert!(ts.class_types.contains("class_declaration"));
514    }
515
516    #[test]
517    fn unknown_language_returns_none() {
518        assert!(config_for_language("brainfuck").is_none());
519    }
520}