1use crate::code::File;
2use crate::code::{child_source, node_source};
3use crate::lang::Language;
4use tree_sitter::Node;
5
6const CLASS_QUERY: &str = r#"
7[
8 (struct_item
9 name: (type_identifier) @name)
10
11 (enum_item
12 name: (type_identifier) @name)
13
14 (union_item
15 name: (type_identifier) @name)
16] @definition.class
17"#;
18
19const IMPLEMENTATION_QUERY: &str = r#"
20[
21 (impl_item
22 type: (generic_type
23 type: (type_identifier) @name))
24
25 (impl_item
26 type: (type_identifier) @name)
27] @reference.implementation
28"#;
29
30const FUNCTION_DECLARATION_QUERY: &str = r#"
31(function_item
32 name: (identifier) @name
33 parameters: (_) @parameters) @definition.function
34"#;
35
36const FIELD_QUERY: &str = r#"
37[
38 (field_declaration
39 name: (field_identifier) @name) @field
40 (field_expression
41 value: (self) @receiver
42 field: (_) @name) @field
43]
44"#;
45
46pub struct Rust {
47 pub class_query: tree_sitter::Query,
48 pub function_declaration_query: tree_sitter::Query,
49 pub field_query: tree_sitter::Query,
50 pub implementation_query: tree_sitter::Query,
51}
52
53impl Rust {
54 pub const SELF: &'static str = "self";
55
56 pub const IDENTIFIER: &'static str = "identifier";
57 pub const BINARY: &'static str = "binary_expression";
58 pub const BREAK: &'static str = "break_expression";
59 pub const CALL: &'static str = "call_expression";
60 pub const CLOSURE: &'static str = "closure_expression";
61 pub const CONTINUE: &'static str = "continue_expression";
62 pub const ELSE: &'static str = "else_clause";
63 pub const FOR: &'static str = "for_expression";
64 pub const FUNCTION: &'static str = "function_item";
65 pub const IF: &'static str = "if_expression";
66 pub const LOOP: &'static str = "loop_expression";
67 pub const MATCH: &'static str = "match_expression";
68 pub const WHILE: &'static str = "while_expression";
69 pub const SOURCE_FILE: &'static str = "source_file";
70 pub const LINE_COMMENT: &'static str = "line_comment";
71 pub const BLOCK_COMMENT: &'static str = "block_comment";
72 pub const STRING: &'static str = "string_literal";
73 pub const RAW_STRING: &'static str = "raw_string_literal";
74 pub const RETURN: &'static str = "return_expression";
75 pub const FIELD_EXPRESSION: &'static str = "field_expression";
76 pub const SCOPED_EXPRESSION: &'static str = "scoped_identifier";
77 pub const SELF_PARAMETER: &'static str = "self_parameter";
78
79 pub const AND: &'static str = "&&";
80 pub const OR: &'static str = "||";
81}
82
83impl Default for Rust {
84 fn default() -> Self {
85 let language = tree_sitter_rust::language();
86
87 Self {
88 class_query: tree_sitter::Query::new(&language, CLASS_QUERY).unwrap(),
89 function_declaration_query: tree_sitter::Query::new(
90 &language,
91 FUNCTION_DECLARATION_QUERY,
92 )
93 .unwrap(),
94 field_query: tree_sitter::Query::new(&language, FIELD_QUERY).unwrap(),
95 implementation_query: tree_sitter::Query::new(&language, IMPLEMENTATION_QUERY).unwrap(),
96 }
97 }
98}
99
100impl Language for Rust {
101 fn name(&self) -> &str {
102 "rust"
103 }
104
105 fn self_keyword(&self) -> Option<&str> {
106 Some(Self::SELF)
107 }
108
109 fn class_query(&self) -> &tree_sitter::Query {
110 &self.class_query
111 }
112
113 fn function_declaration_query(&self) -> &tree_sitter::Query {
114 &self.function_declaration_query
115 }
116
117 fn field_query(&self) -> &tree_sitter::Query {
118 &self.field_query
119 }
120
121 fn implementation_query(&self) -> Option<&tree_sitter::Query> {
122 Some(&self.implementation_query)
123 }
124
125 fn constructor_names(&self) -> Vec<&str> {
126 vec!["new", "default"]
127 }
128
129 fn destructor_names(&self) -> Vec<&str> {
130 vec!["drop"]
131 }
132
133 fn is_instance_method(&self, _file: &File, node: &Node) -> bool {
134 let parameters = node.child_by_field_name("parameters").unwrap();
135
136 if let Some(first_parameter) = parameters.named_child(0) {
137 first_parameter.kind() == Self::SELF_PARAMETER
138 } else {
139 false
140 }
141 }
142
143 fn if_nodes(&self) -> Vec<&str> {
144 vec![Self::IF]
145 }
146
147 fn else_nodes(&self) -> Vec<&str> {
148 vec![Self::ELSE]
149 }
150
151 fn conditional_assignment_nodes(&self) -> Vec<&str> {
152 vec![]
153 }
154
155 fn invisible_container_nodes(&self) -> Vec<&str> {
156 vec![Self::SOURCE_FILE]
157 }
158
159 fn switch_nodes(&self) -> Vec<&str> {
160 vec![Self::MATCH]
161 }
162
163 fn case_nodes(&self) -> Vec<&str> {
164 vec!["match_arm"]
165 }
166
167 fn ternary_nodes(&self) -> Vec<&str> {
168 vec![]
169 }
170
171 fn loop_nodes(&self) -> Vec<&str> {
172 vec![Self::FOR, Self::WHILE, Self::LOOP]
173 }
174
175 fn except_nodes(&self) -> Vec<&str> {
176 vec![]
177 }
178
179 fn try_expression_nodes(&self) -> Vec<&str> {
180 vec!["try_expression"]
181 }
182
183 fn jump_nodes(&self) -> Vec<&str> {
184 vec![Self::BREAK, Self::CONTINUE]
185 }
186
187 fn return_nodes(&self) -> Vec<&str> {
188 vec![Self::RETURN]
189 }
190
191 fn binary_nodes(&self) -> Vec<&str> {
192 vec![Self::BINARY]
193 }
194
195 fn field_nodes(&self) -> Vec<&str> {
196 vec![Self::FIELD_EXPRESSION]
197 }
198
199 fn call_nodes(&self) -> Vec<&str> {
200 vec![Self::CALL]
201 }
202
203 fn function_nodes(&self) -> Vec<&str> {
204 vec![Self::FUNCTION]
205 }
206
207 fn closure_nodes(&self) -> Vec<&str> {
208 vec![Self::CLOSURE]
209 }
210
211 fn comment_nodes(&self) -> Vec<&str> {
212 vec![Self::LINE_COMMENT, Self::BLOCK_COMMENT]
213 }
214
215 fn string_nodes(&self) -> Vec<&str> {
216 vec![Self::STRING, Self::RAW_STRING]
217 }
218
219 fn boolean_operator_nodes(&self) -> Vec<&str> {
220 vec![Self::AND, Self::OR]
221 }
222
223 fn iterator_method_identifiers(&self) -> Vec<&str> {
224 vec![
225 "filter",
226 "map",
227 "any",
228 "all",
229 "find",
230 "position",
231 "fold",
232 "scan",
233 "for_each",
234 "filter_map",
235 "flat_map",
236 "inspect",
237 "partition",
238 "max_by",
239 "min_by",
240 "take_while",
241 "skip_while",
242 "try_fold",
243 "try_for_each",
244 ]
245 }
246
247 fn call_identifiers(&self, source_file: &File, node: &Node) -> (Option<String>, String) {
248 let function_node = node.child_by_field_name("function").unwrap();
249 let function_kind = function_node.kind();
250
251 match function_kind {
252 Self::IDENTIFIER => (
253 Some("".to_string()),
254 node_source(&function_node, source_file),
255 ),
256 Self::FIELD_EXPRESSION => {
257 let (receiver, object) = self.field_identifiers(source_file, &function_node);
258
259 (Some(receiver), object)
260 }
261 Self::SCOPED_EXPRESSION => {
262 let receiver =
263 if let Some(receiver_node) = function_node.child_by_field_name("path") {
264 node_source(&receiver_node, source_file)
265 } else {
266 Self::SELF.to_string()
267 };
268
269 (
270 Some(receiver),
271 child_source(&function_node, "name", source_file),
272 )
273 }
274 _ => (Some("<UNKNOWN>".to_string()), "<UNKNOWN>".to_string()),
275 }
276 }
277
278 fn field_identifiers(&self, source_file: &File, node: &Node) -> (String, String) {
279 (
280 child_source(node, "value", source_file),
281 child_source(node, "field", source_file),
282 )
283 }
284
285 fn tree_sitter_language(&self) -> tree_sitter::Language {
286 tree_sitter_rust::language()
287 }
288}
289
290#[cfg(test)]
291mod test {
292 use super::*;
293 use std::collections::HashSet;
294 use tree_sitter::Tree;
295
296 #[test]
297 fn mutually_exclusive() {
298 let lang = Rust::default();
299 let mut kinds: Vec<&str> = vec![];
300
301 kinds.extend(lang.if_nodes());
302 kinds.extend(lang.else_nodes());
303 kinds.extend(lang.conditional_assignment_nodes());
304 kinds.extend(lang.switch_nodes());
305 kinds.extend(lang.case_nodes());
306 kinds.extend(lang.ternary_nodes());
307 kinds.extend(lang.loop_nodes());
308 kinds.extend(lang.except_nodes());
309 kinds.extend(lang.try_expression_nodes());
310 kinds.extend(lang.jump_nodes());
311 kinds.extend(lang.return_nodes());
312 kinds.extend(lang.binary_nodes());
313 kinds.extend(lang.field_nodes());
314 kinds.extend(lang.call_nodes());
315 kinds.extend(lang.function_nodes());
316 kinds.extend(lang.closure_nodes());
317 kinds.extend(lang.comment_nodes());
318 kinds.extend(lang.string_nodes());
319 kinds.extend(lang.boolean_operator_nodes());
320
321 let unique: HashSet<_> = kinds.iter().cloned().collect();
322 assert_eq!(unique.len(), kinds.len());
323 }
324
325 #[test]
326 fn field_identifier_read() {
327 let source_file = File::from_string("rust", "self.foo;");
328 let tree = source_file.parse();
329 let root_node = tree.root_node();
330 let expression = root_node.named_child(0).unwrap();
331 let field = expression.named_child(0).unwrap();
332 let language = Rust::default();
333
334 assert_eq!(
335 language.field_identifiers(&source_file, &field),
336 ("self".to_string(), "foo".to_string())
337 );
338 }
339
340 #[test]
341 fn field_identifier_write() {
342 let source_file = File::from_string("rust", "self.foo = 1;");
343 let tree = source_file.parse();
344 let root_node = tree.root_node();
345 let expression = root_node.named_child(0).unwrap();
346 let assignment = expression.named_child(0).unwrap();
347 let field = assignment.named_child(0).unwrap();
348 let language = Rust::default();
349
350 assert_eq!(
351 language.field_identifiers(&source_file, &field),
352 ("self".to_string(), "foo".to_string())
353 );
354 }
355
356 #[test]
357 fn field_identifier_collaborator() {
358 let source_file = File::from_string("rust", "other.foo;");
359 let tree = source_file.parse();
360 let root_node = tree.root_node();
361 let expression = root_node.named_child(0).unwrap();
362 let field = expression.named_child(0).unwrap();
363 let language = Rust::default();
364
365 assert_eq!(
366 language.field_identifiers(&source_file, &field),
367 ("other".to_string(), "foo".to_string())
368 );
369 }
370
371 #[test]
372 fn call_identifier() {
373 let source_file = File::from_string("rust", "foo();");
374 let tree = source_file.parse();
375 let call = call_node(&tree);
376 let language = Rust::default();
377
378 assert_eq!(
379 language.call_identifiers(&source_file, &call),
380 (Some("".to_string()), "foo".to_string())
381 );
382 }
383
384 #[test]
385 fn call_self() {
386 let source_file = File::from_string("rust", "self.foo();");
387 let tree = source_file.parse();
388 let call = call_node(&tree);
389 let language = Rust::default();
390
391 assert_eq!(
392 language.call_identifiers(&source_file, &call),
393 (Some("self".to_string()), "foo".to_string())
394 );
395 }
396
397 #[test]
398 fn call_field() {
399 let source_file = File::from_string("rust", "foo.bar();");
400 let tree = source_file.parse();
401 let call = call_node(&tree);
402 let language = Rust::default();
403
404 assert_eq!(
405 language.call_identifiers(&source_file, &call),
406 (Some("foo".to_string()), "bar".to_string())
407 );
408 }
409
410 #[test]
411 fn call_scoped() {
412 let source_file = File::from_string("rust", "Foo::bar();");
413 let tree = source_file.parse();
414 let call = call_node(&tree);
415 let language = Rust::default();
416
417 assert_eq!(
418 language.call_identifiers(&source_file, &call),
419 (Some("Foo".to_string()), "bar".to_string())
420 );
421 }
422
423 fn call_node(tree: &Tree) -> Node {
424 let root_node = tree.root_node();
425 let expression = root_node.named_child(0).unwrap();
426 expression.named_child(0).unwrap()
427 }
428}