Skip to main content

graphify_extract/treesitter/
treesitter_config.rs

1//! Tree-sitter language configurations.
2//!
3//! Each supported language has a `TsConfig` describing which tree-sitter node
4//! kinds correspond to classes, functions, imports, and calls.
5
6use std::collections::HashSet;
7
8use tree_sitter::Language;
9
10/// Describes which tree-sitter node kinds correspond to classes, functions,
11/// imports and calls for a given language.
12pub struct TsConfig {
13    pub class_types: HashSet<&'static str>,
14    pub function_types: HashSet<&'static str>,
15    pub import_types: HashSet<&'static str>,
16    pub call_types: HashSet<&'static str>,
17    pub name_field: &'static str,
18    pub class_name_field: Option<&'static str>,
19    pub body_field: &'static str,
20    pub call_function_field: &'static str,
21}
22
23/// Resolve a language identifier to its tree-sitter `Language` and `TsConfig`.
24/// Returns `None` for unsupported languages.
25pub fn resolve_language(lang: &str) -> Option<(Language, TsConfig)> {
26    match lang {
27        "python" => Some((tree_sitter_python::LANGUAGE.into(), python_config())),
28        "javascript" => Some((tree_sitter_javascript::LANGUAGE.into(), js_config())),
29        "typescript" => Some((
30            tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
31            js_config(),
32        )),
33        "tsx" => Some((tree_sitter_typescript::LANGUAGE_TSX.into(), js_config())),
34        "rust" => Some((tree_sitter_rust::LANGUAGE.into(), rust_config())),
35        "go" => Some((tree_sitter_go::LANGUAGE.into(), go_config())),
36        "java" => Some((tree_sitter_java::LANGUAGE.into(), java_config())),
37        "c" => Some((tree_sitter_c::LANGUAGE.into(), c_config())),
38        "cpp" => Some((tree_sitter_cpp::LANGUAGE.into(), cpp_config())),
39        "ruby" => Some((tree_sitter_ruby::LANGUAGE.into(), ruby_config())),
40        "csharp" => Some((tree_sitter_c_sharp::LANGUAGE.into(), csharp_config())),
41        "dart" => Some((tree_sitter_dart::LANGUAGE.into(), dart_config())),
42        _ => None,
43    }
44}
45
46fn python_config() -> TsConfig {
47    TsConfig {
48        class_types: ["class_definition"].into_iter().collect(),
49        function_types: ["function_definition"].into_iter().collect(),
50        import_types: ["import_statement", "import_from_statement"]
51            .into_iter()
52            .collect(),
53        call_types: ["call"].into_iter().collect(),
54        name_field: "name",
55        class_name_field: None,
56        body_field: "body",
57        call_function_field: "function",
58    }
59}
60
61fn js_config() -> TsConfig {
62    TsConfig {
63        class_types: ["class_declaration", "class"].into_iter().collect(),
64        function_types: [
65            "function_declaration",
66            "method_definition",
67            "arrow_function",
68            "generator_function_declaration",
69            "generator_function",
70            "async_function_declaration",
71        ]
72        .into_iter()
73        .collect(),
74        import_types: ["import_statement"].into_iter().collect(),
75        call_types: ["call_expression"].into_iter().collect(),
76        name_field: "name",
77        class_name_field: None,
78        body_field: "body",
79        call_function_field: "function",
80    }
81}
82
83fn rust_config() -> TsConfig {
84    TsConfig {
85        class_types: ["struct_item", "enum_item", "trait_item", "impl_item"]
86            .into_iter()
87            .collect(),
88        function_types: ["function_item"].into_iter().collect(),
89        import_types: ["use_declaration"].into_iter().collect(),
90        call_types: ["call_expression"].into_iter().collect(),
91        name_field: "name",
92        class_name_field: None,
93        body_field: "body",
94        call_function_field: "function",
95    }
96}
97
98fn go_config() -> TsConfig {
99    TsConfig {
100        class_types: ["type_declaration"].into_iter().collect(),
101        function_types: ["function_declaration", "method_declaration"]
102            .into_iter()
103            .collect(),
104        import_types: ["import_declaration"].into_iter().collect(),
105        call_types: ["call_expression"].into_iter().collect(),
106        name_field: "name",
107        class_name_field: None,
108        body_field: "body",
109        call_function_field: "function",
110    }
111}
112
113fn java_config() -> TsConfig {
114    TsConfig {
115        class_types: [
116            "class_declaration",
117            "interface_declaration",
118            "enum_declaration",
119        ]
120        .into_iter()
121        .collect(),
122        function_types: ["method_declaration", "constructor_declaration"]
123            .into_iter()
124            .collect(),
125        import_types: ["import_declaration"].into_iter().collect(),
126        call_types: ["method_invocation"].into_iter().collect(),
127        name_field: "name",
128        class_name_field: None,
129        body_field: "body",
130        call_function_field: "name",
131    }
132}
133
134fn c_config() -> TsConfig {
135    TsConfig {
136        class_types: ["struct_specifier", "enum_specifier", "type_definition"]
137            .into_iter()
138            .collect(),
139        function_types: ["function_definition"].into_iter().collect(),
140        import_types: ["preproc_include"].into_iter().collect(),
141        call_types: ["call_expression"].into_iter().collect(),
142        name_field: "declarator",
143        class_name_field: Some("name"),
144        body_field: "body",
145        call_function_field: "function",
146    }
147}
148
149fn cpp_config() -> TsConfig {
150    TsConfig {
151        class_types: [
152            "class_specifier",
153            "struct_specifier",
154            "enum_specifier",
155            "namespace_definition",
156        ]
157        .into_iter()
158        .collect(),
159        function_types: ["function_definition"].into_iter().collect(),
160        import_types: ["preproc_include"].into_iter().collect(),
161        call_types: ["call_expression"].into_iter().collect(),
162        name_field: "declarator",
163        class_name_field: Some("name"),
164        body_field: "body",
165        call_function_field: "function",
166    }
167}
168
169fn ruby_config() -> TsConfig {
170    TsConfig {
171        class_types: ["class", "module"].into_iter().collect(),
172        function_types: ["method", "singleton_method"].into_iter().collect(),
173        import_types: ["call"].into_iter().collect(),
174        call_types: ["call"].into_iter().collect(),
175        name_field: "name",
176        class_name_field: None,
177        body_field: "body",
178        call_function_field: "method",
179    }
180}
181
182fn csharp_config() -> TsConfig {
183    TsConfig {
184        class_types: [
185            "class_declaration",
186            "interface_declaration",
187            "struct_declaration",
188            "enum_declaration",
189        ]
190        .into_iter()
191        .collect(),
192        function_types: ["method_declaration", "constructor_declaration"]
193            .into_iter()
194            .collect(),
195        import_types: ["using_directive"].into_iter().collect(),
196        call_types: ["invocation_expression"].into_iter().collect(),
197        name_field: "name",
198        class_name_field: None,
199        body_field: "body",
200        call_function_field: "function",
201    }
202}
203
204fn dart_config() -> TsConfig {
205    TsConfig {
206        class_types: [
207            "class_definition",
208            "enum_declaration",
209            "mixin_declaration",
210            "extension_declaration",
211        ]
212        .into_iter()
213        .collect(),
214        function_types: [
215            "function_signature",
216            "method_signature",
217            "function_body",
218            "function_declaration",
219            "method_definition",
220        ]
221        .into_iter()
222        .collect(),
223        import_types: ["import_or_export", "part_directive", "part_of_directive"]
224            .into_iter()
225            .collect(),
226        call_types: ["method_invocation", "function_expression_invocation"]
227            .into_iter()
228            .collect(),
229        name_field: "name",
230        class_name_field: None,
231        body_field: "body",
232        call_function_field: "function",
233    }
234}