Skip to main content

graphify_extract/
lang_config.rs

1//! Per-language configuration for tree-sitter–based extraction.
2//!
3//! Each [`LanguageConfig`] describes the AST node types that correspond to
4//! classes, functions, imports, and calls for a given language. The actual
5//! tree-sitter grammar crates are not yet wired in; this module is
6//! **configuration only**.
7
8use std::collections::HashSet;
9
10/// Metadata describing how to extract entities from a language's AST.
11#[derive(Debug, Clone)]
12pub struct LanguageConfig {
13    /// Human-readable language name.
14    pub name: &'static str,
15    /// Crate that provides the tree-sitter grammar (e.g. `"tree_sitter_python"`).
16    pub ts_crate: &'static str,
17    /// AST node types that represent class-like definitions.
18    pub class_types: HashSet<&'static str>,
19    /// AST node types that represent function/method definitions.
20    pub function_types: HashSet<&'static str>,
21    /// AST node types that represent import statements.
22    pub import_types: HashSet<&'static str>,
23    /// AST node types that represent function calls.
24    pub call_types: HashSet<&'static str>,
25    /// Field name for the entity name within the AST node.
26    pub name_field: &'static str,
27    /// Field name for the body of a class/function.
28    pub body_field: &'static str,
29    /// Field name for the function being called in a call expression.
30    pub call_function_field: &'static str,
31    /// AST node types that delimit function boundaries (for scope analysis).
32    pub function_boundary_types: HashSet<&'static str>,
33    /// Whether the language uses parenthesised labels (e.g. Go `func (r *Recv) Name()`).
34    pub function_label_parens: bool,
35}
36
37// ---------------------------------------------------------------------------
38// Helper
39// ---------------------------------------------------------------------------
40
41fn hs(items: &[&'static str]) -> HashSet<&'static str> {
42    items.iter().copied().collect()
43}
44
45// ---------------------------------------------------------------------------
46// Language configs (matching the Python extract.py)
47// ---------------------------------------------------------------------------
48
49pub fn python_config() -> LanguageConfig {
50    LanguageConfig {
51        name: "python",
52        ts_crate: "tree_sitter_python",
53        class_types: hs(&["class_definition"]),
54        function_types: hs(&["function_definition"]),
55        import_types: hs(&["import_statement", "import_from_statement"]),
56        call_types: hs(&["call"]),
57        name_field: "name",
58        body_field: "body",
59        call_function_field: "function",
60        function_boundary_types: hs(&["function_definition", "class_definition"]),
61        function_label_parens: false,
62    }
63}
64
65pub fn javascript_config() -> LanguageConfig {
66    LanguageConfig {
67        name: "javascript",
68        ts_crate: "tree_sitter_javascript",
69        class_types: hs(&["class_declaration", "class"]),
70        function_types: hs(&[
71            "function_declaration",
72            "method_definition",
73            "arrow_function",
74            "function",
75        ]),
76        import_types: hs(&["import_statement"]),
77        call_types: hs(&["call_expression"]),
78        name_field: "name",
79        body_field: "body",
80        call_function_field: "function",
81        function_boundary_types: hs(&[
82            "function_declaration",
83            "method_definition",
84            "arrow_function",
85            "class_declaration",
86        ]),
87        function_label_parens: false,
88    }
89}
90
91pub fn typescript_config() -> LanguageConfig {
92    let mut cfg = javascript_config();
93    cfg.name = "typescript";
94    cfg.ts_crate = "tree_sitter_typescript";
95    cfg.class_types.insert("abstract_class_declaration");
96    cfg.function_types.insert("function_signature");
97    cfg.import_types.insert("import_statement");
98    cfg
99}
100
101pub fn go_config() -> LanguageConfig {
102    LanguageConfig {
103        name: "go",
104        ts_crate: "tree_sitter_go",
105        class_types: hs(&["type_declaration", "type_spec"]),
106        function_types: hs(&["function_declaration", "method_declaration"]),
107        import_types: hs(&["import_declaration", "import_spec"]),
108        call_types: hs(&["call_expression"]),
109        name_field: "name",
110        body_field: "body",
111        call_function_field: "function",
112        function_boundary_types: hs(&["function_declaration", "method_declaration"]),
113        function_label_parens: true,
114    }
115}
116
117pub fn rust_config() -> LanguageConfig {
118    LanguageConfig {
119        name: "rust",
120        ts_crate: "tree_sitter_rust",
121        class_types: hs(&["struct_item", "enum_item", "trait_item", "impl_item"]),
122        function_types: hs(&["function_item"]),
123        import_types: hs(&["use_declaration"]),
124        call_types: hs(&["call_expression"]),
125        name_field: "name",
126        body_field: "body",
127        call_function_field: "function",
128        function_boundary_types: hs(&["function_item", "impl_item"]),
129        function_label_parens: false,
130    }
131}
132
133pub fn java_config() -> LanguageConfig {
134    LanguageConfig {
135        name: "java",
136        ts_crate: "tree_sitter_java",
137        class_types: hs(&[
138            "class_declaration",
139            "interface_declaration",
140            "enum_declaration",
141            "annotation_type_declaration",
142        ]),
143        function_types: hs(&["method_declaration", "constructor_declaration"]),
144        import_types: hs(&["import_declaration"]),
145        call_types: hs(&["method_invocation"]),
146        name_field: "name",
147        body_field: "body",
148        call_function_field: "name",
149        function_boundary_types: hs(&[
150            "method_declaration",
151            "constructor_declaration",
152            "class_declaration",
153        ]),
154        function_label_parens: false,
155    }
156}
157
158pub fn c_config() -> LanguageConfig {
159    LanguageConfig {
160        name: "c",
161        ts_crate: "tree_sitter_c",
162        class_types: hs(&["struct_specifier", "enum_specifier", "union_specifier"]),
163        function_types: hs(&["function_definition"]),
164        import_types: hs(&["preproc_include"]),
165        call_types: hs(&["call_expression"]),
166        name_field: "declarator",
167        body_field: "body",
168        call_function_field: "function",
169        function_boundary_types: hs(&["function_definition"]),
170        function_label_parens: false,
171    }
172}
173
174pub fn cpp_config() -> LanguageConfig {
175    let mut cfg = c_config();
176    cfg.name = "cpp";
177    cfg.ts_crate = "tree_sitter_cpp";
178    cfg.class_types.insert("class_specifier");
179    cfg.class_types.insert("namespace_definition");
180    cfg.function_types.insert("function_definition");
181    cfg
182}
183
184pub fn ruby_config() -> LanguageConfig {
185    LanguageConfig {
186        name: "ruby",
187        ts_crate: "tree_sitter_ruby",
188        class_types: hs(&["class", "module"]),
189        function_types: hs(&["method", "singleton_method"]),
190        import_types: hs(&["call"]), // require/include are calls in Ruby
191        call_types: hs(&["call", "command"]),
192        name_field: "name",
193        body_field: "body",
194        call_function_field: "method",
195        function_boundary_types: hs(&["method", "singleton_method", "class"]),
196        function_label_parens: false,
197    }
198}
199
200pub fn csharp_config() -> LanguageConfig {
201    LanguageConfig {
202        name: "csharp",
203        ts_crate: "tree_sitter_c_sharp",
204        class_types: hs(&[
205            "class_declaration",
206            "interface_declaration",
207            "struct_declaration",
208            "enum_declaration",
209        ]),
210        function_types: hs(&["method_declaration", "constructor_declaration"]),
211        import_types: hs(&["using_directive"]),
212        call_types: hs(&["invocation_expression"]),
213        name_field: "name",
214        body_field: "body",
215        call_function_field: "function",
216        function_boundary_types: hs(&[
217            "method_declaration",
218            "constructor_declaration",
219            "class_declaration",
220        ]),
221        function_label_parens: false,
222    }
223}
224
225pub fn kotlin_config() -> LanguageConfig {
226    LanguageConfig {
227        name: "kotlin",
228        ts_crate: "tree_sitter_kotlin",
229        class_types: hs(&[
230            "class_declaration",
231            "object_declaration",
232            "interface_declaration",
233        ]),
234        function_types: hs(&["function_declaration"]),
235        import_types: hs(&["import_header"]),
236        call_types: hs(&["call_expression"]),
237        name_field: "name",
238        body_field: "body",
239        call_function_field: "function",
240        function_boundary_types: hs(&["function_declaration", "class_declaration"]),
241        function_label_parens: false,
242    }
243}
244
245pub fn scala_config() -> LanguageConfig {
246    LanguageConfig {
247        name: "scala",
248        ts_crate: "tree_sitter_scala",
249        class_types: hs(&["class_definition", "object_definition", "trait_definition"]),
250        function_types: hs(&["function_definition", "val_definition"]),
251        import_types: hs(&["import_declaration"]),
252        call_types: hs(&["call_expression"]),
253        name_field: "name",
254        body_field: "body",
255        call_function_field: "function",
256        function_boundary_types: hs(&["function_definition", "class_definition"]),
257        function_label_parens: false,
258    }
259}
260
261pub fn php_config() -> LanguageConfig {
262    LanguageConfig {
263        name: "php",
264        ts_crate: "tree_sitter_php",
265        class_types: hs(&[
266            "class_declaration",
267            "interface_declaration",
268            "trait_declaration",
269        ]),
270        function_types: hs(&["function_definition", "method_declaration"]),
271        import_types: hs(&["namespace_use_declaration"]),
272        call_types: hs(&["function_call_expression", "method_call_expression"]),
273        name_field: "name",
274        body_field: "body",
275        call_function_field: "function",
276        function_boundary_types: hs(&[
277            "function_definition",
278            "method_declaration",
279            "class_declaration",
280        ]),
281        function_label_parens: false,
282    }
283}
284
285pub fn swift_config() -> LanguageConfig {
286    LanguageConfig {
287        name: "swift",
288        ts_crate: "tree_sitter_swift",
289        class_types: hs(&[
290            "class_declaration",
291            "struct_declaration",
292            "protocol_declaration",
293            "enum_declaration",
294        ]),
295        function_types: hs(&["function_declaration", "init_declaration"]),
296        import_types: hs(&["import_declaration"]),
297        call_types: hs(&["call_expression"]),
298        name_field: "name",
299        body_field: "body",
300        call_function_field: "function",
301        function_boundary_types: hs(&["function_declaration", "class_declaration"]),
302        function_label_parens: false,
303    }
304}
305
306pub fn lua_config() -> LanguageConfig {
307    LanguageConfig {
308        name: "lua",
309        ts_crate: "tree_sitter_lua",
310        class_types: hs(&[]),
311        function_types: hs(&[
312            "function_declaration",
313            "local_function_declaration",
314            "function_definition",
315        ]),
316        import_types: hs(&[]), // require() is a call
317        call_types: hs(&["function_call"]),
318        name_field: "name",
319        body_field: "body",
320        call_function_field: "name",
321        function_boundary_types: hs(&["function_declaration", "local_function_declaration"]),
322        function_label_parens: false,
323    }
324}
325
326pub fn zig_config() -> LanguageConfig {
327    LanguageConfig {
328        name: "zig",
329        ts_crate: "tree_sitter_zig",
330        class_types: hs(&["container_declaration"]),
331        function_types: hs(&["fn_proto"]),
332        import_types: hs(&[]),
333        call_types: hs(&["call_expr"]),
334        name_field: "name",
335        body_field: "body",
336        call_function_field: "function",
337        function_boundary_types: hs(&["fn_proto"]),
338        function_label_parens: false,
339    }
340}
341
342pub fn powershell_config() -> LanguageConfig {
343    LanguageConfig {
344        name: "powershell",
345        ts_crate: "tree_sitter_powershell",
346        class_types: hs(&["class_statement"]),
347        function_types: hs(&["function_statement"]),
348        import_types: hs(&["using_statement"]),
349        call_types: hs(&["command_expression"]),
350        name_field: "name",
351        body_field: "body",
352        call_function_field: "name",
353        function_boundary_types: hs(&["function_statement"]),
354        function_label_parens: false,
355    }
356}
357
358pub fn elixir_config() -> LanguageConfig {
359    LanguageConfig {
360        name: "elixir",
361        ts_crate: "tree_sitter_elixir",
362        class_types: hs(&["call"]),    // defmodule is a call in Elixir
363        function_types: hs(&["call"]), // def/defp are calls too
364        import_types: hs(&["call"]),   // import/use/require
365        call_types: hs(&["call"]),
366        name_field: "target",
367        body_field: "body",
368        call_function_field: "target",
369        function_boundary_types: hs(&["call"]),
370        function_label_parens: false,
371    }
372}
373
374pub fn objc_config() -> LanguageConfig {
375    LanguageConfig {
376        name: "objc",
377        ts_crate: "tree_sitter_objc",
378        class_types: hs(&[
379            "class_interface",
380            "class_implementation",
381            "protocol_declaration",
382        ]),
383        function_types: hs(&["method_definition", "function_definition"]),
384        import_types: hs(&["preproc_import", "preproc_include"]),
385        call_types: hs(&["message_expression", "call_expression"]),
386        name_field: "name",
387        body_field: "body",
388        call_function_field: "selector",
389        function_boundary_types: hs(&["method_definition", "function_definition"]),
390        function_label_parens: false,
391    }
392}
393
394pub fn julia_config() -> LanguageConfig {
395    LanguageConfig {
396        name: "julia",
397        ts_crate: "tree_sitter_julia",
398        class_types: hs(&["struct_definition", "abstract_definition"]),
399        function_types: hs(&["function_definition", "short_function_definition"]),
400        import_types: hs(&["import_statement", "using_statement"]),
401        call_types: hs(&["call_expression"]),
402        name_field: "name",
403        body_field: "body",
404        call_function_field: "function",
405        function_boundary_types: hs(&["function_definition"]),
406        function_label_parens: false,
407    }
408}
409
410// ---------------------------------------------------------------------------
411// Lookup
412// ---------------------------------------------------------------------------
413
414/// Return the [`LanguageConfig`] for the given language name.
415pub fn config_for_language(lang: &str) -> Option<LanguageConfig> {
416    match lang {
417        "python" => Some(python_config()),
418        "javascript" => Some(javascript_config()),
419        "typescript" => Some(typescript_config()),
420        "go" => Some(go_config()),
421        "rust" => Some(rust_config()),
422        "java" => Some(java_config()),
423        "c" => Some(c_config()),
424        "cpp" => Some(cpp_config()),
425        "ruby" => Some(ruby_config()),
426        "csharp" => Some(csharp_config()),
427        "kotlin" => Some(kotlin_config()),
428        "scala" => Some(scala_config()),
429        "php" => Some(php_config()),
430        "swift" => Some(swift_config()),
431        "lua" => Some(lua_config()),
432        "zig" => Some(zig_config()),
433        "powershell" => Some(powershell_config()),
434        "elixir" => Some(elixir_config()),
435        "objc" => Some(objc_config()),
436        "julia" => Some(julia_config()),
437        _ => None,
438    }
439}
440
441// ---------------------------------------------------------------------------
442// Tests
443// ---------------------------------------------------------------------------
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn all_20_languages_have_configs() {
451        let languages = [
452            "python",
453            "javascript",
454            "typescript",
455            "go",
456            "rust",
457            "java",
458            "c",
459            "cpp",
460            "ruby",
461            "csharp",
462            "kotlin",
463            "scala",
464            "php",
465            "swift",
466            "lua",
467            "zig",
468            "powershell",
469            "elixir",
470            "objc",
471            "julia",
472        ];
473        for lang in languages {
474            assert!(
475                config_for_language(lang).is_some(),
476                "missing config for {lang}"
477            );
478        }
479        assert_eq!(languages.len(), 20);
480    }
481
482    #[test]
483    fn python_config_has_expected_types() {
484        let cfg = python_config();
485        assert!(cfg.class_types.contains("class_definition"));
486        assert!(cfg.function_types.contains("function_definition"));
487        assert!(cfg.import_types.contains("import_statement"));
488        assert!(cfg.import_types.contains("import_from_statement"));
489    }
490
491    #[test]
492    fn typescript_extends_javascript() {
493        let ts = typescript_config();
494        assert!(ts.class_types.contains("abstract_class_declaration"));
495        // Inherited from JS
496        assert!(ts.class_types.contains("class_declaration"));
497    }
498
499    #[test]
500    fn unknown_language_returns_none() {
501        assert!(config_for_language("brainfuck").is_none());
502    }
503}