1use crate::code::node_source;
2use crate::code::File;
3use crate::lang::Language;
4use tree_sitter::Node;
5
6const CLASS_QUERY: &str = r#"
7(type_declaration
8 (_
9 name: (type_identifier) @name)) @definition.class
10"#;
11
12const FUNCTION_DECLARATION_QUERY: &str = r#"
13[
14 (method_declaration
15 name: (field_identifier) @name
16 parameters: (_) @parameters)
17 (function_declaration
18 name: (identifier) @name
19 parameters: (_) @parameters)
20] @definition.function
21"#;
22
23const FIELD_QUERY: &str = r#"
24(struct_type
25 (field_declaration_list
26 (field_declaration
27 name: (field_identifier) @name))) @field
28"#;
29
30pub struct Go {
31 pub class_query: tree_sitter::Query,
32 pub function_declaration_query: tree_sitter::Query,
33 pub field_query: tree_sitter::Query,
34}
35
36impl Go {
37 pub const BINARY: &'static str = "binary_expression";
38 pub const BLOCK: &'static str = "block";
39 pub const BREAK: &'static str = "break_statement";
40 pub const CALL_EXPRESSION: &'static str = "call_expression";
41 pub const COMMENT: &'static str = "comment";
42 pub const COMMUNICATION_CASE: &'static str = "communication_case";
43 pub const CONTINUE: &'static str = "continue_statement";
44 pub const EXPRESSION_CASE: &'static str = "expression_case";
45 pub const EXPRESSION_SWITCH: &'static str = "expression_switch_statement";
46 pub const FALLTHROUGH: &'static str = "fallthrough_statement";
47 pub const FIELD_DECLARATION: &'static str = "field_declaration";
48 pub const FOR: &'static str = "for_statement";
49 pub const FUNCTION_DECLARATION: &'static str = "function_declaration";
50 pub const GOTO: &'static str = "goto_statement";
51 pub const IDENTIFIER: &'static str = "identifier";
52 pub const IF: &'static str = "if_statement";
53 pub const INTERPRETED_STRING: &'static str = "interpreted_string_literal";
54 pub const LABEL_NAME: &'static str = "label_name";
55 pub const LAMBDA: &'static str = "func_literal";
56 pub const METHOD_DECLARATION: &'static str = "method_declaration";
57 pub const RAW_STRING: &'static str = "raw_string_literal";
58 pub const RETURN: &'static str = "return_statement";
59 pub const SELECT: &'static str = "select_statement";
60 pub const SELECTOR_EXPRESSION: &'static str = "selector_expression";
61 pub const SELF: &'static str = "this";
62 pub const SOURCE_FILE: &'static str = "source_file";
63 pub const TYPE_CASE: &'static str = "type_case";
64 pub const TYPE_SWITCH: &'static str = "type_switch_statement";
65
66 pub const AND: &'static str = "&&";
67 pub const OR: &'static str = "||";
68}
69
70impl Default for Go {
71 fn default() -> Self {
72 let language = tree_sitter_go::language();
73
74 Self {
75 class_query: tree_sitter::Query::new(&language, CLASS_QUERY).unwrap(),
76 field_query: tree_sitter::Query::new(&language, FIELD_QUERY).unwrap(),
77 function_declaration_query: tree_sitter::Query::new(
78 &language,
79 FUNCTION_DECLARATION_QUERY,
80 )
81 .unwrap(),
82 }
83 }
84}
85
86impl Language for Go {
87 fn name(&self) -> &str {
88 "go"
89 }
90
91 fn self_keyword(&self) -> Option<&str> {
92 None
93 }
94
95 fn class_query(&self) -> &tree_sitter::Query {
96 &self.class_query
97 }
98
99 fn function_declaration_query(&self) -> &tree_sitter::Query {
100 &self.function_declaration_query
101 }
102
103 fn field_query(&self) -> &tree_sitter::Query {
104 &self.field_query
105 }
106
107 fn if_nodes(&self) -> Vec<&str> {
108 vec![Self::IF]
109 }
110
111 fn block_nodes(&self) -> Vec<&str> {
112 vec![Self::BLOCK]
113 }
114
115 fn invisible_container_nodes(&self) -> Vec<&str> {
116 vec![Self::SOURCE_FILE]
117 }
118
119 fn switch_nodes(&self) -> Vec<&str> {
120 vec![Self::EXPRESSION_SWITCH, Self::TYPE_SWITCH, Self::SELECT]
121 }
122
123 fn case_nodes(&self) -> Vec<&str> {
124 vec![
125 Self::EXPRESSION_CASE,
126 Self::TYPE_CASE,
127 Self::COMMUNICATION_CASE,
128 ]
129 }
130
131 fn loop_nodes(&self) -> Vec<&str> {
132 vec![Self::FOR]
133 }
134
135 fn jump_nodes(&self) -> Vec<&str> {
136 vec![Self::BREAK, Self::CONTINUE, Self::GOTO, Self::FALLTHROUGH]
137 }
138
139 fn return_nodes(&self) -> Vec<&str> {
140 vec![Self::RETURN]
141 }
142
143 fn binary_nodes(&self) -> Vec<&str> {
144 vec![Self::BINARY]
145 }
146
147 fn boolean_operator_nodes(&self) -> Vec<&str> {
148 vec![Self::AND, Self::OR]
149 }
150
151 fn field_nodes(&self) -> Vec<&str> {
152 vec![Self::FIELD_DECLARATION]
153 }
154
155 fn call_nodes(&self) -> Vec<&str> {
156 vec![Self::CALL_EXPRESSION]
157 }
158
159 fn function_nodes(&self) -> Vec<&str> {
160 vec![Self::FUNCTION_DECLARATION, Self::METHOD_DECLARATION]
161 }
162
163 fn closure_nodes(&self) -> Vec<&str> {
164 vec![Self::LAMBDA]
165 }
166
167 fn comment_nodes(&self) -> Vec<&str> {
168 vec![Self::COMMENT]
169 }
170
171 fn string_nodes(&self) -> Vec<&str> {
172 vec![Self::RAW_STRING, Self::INTERPRETED_STRING]
173 }
174
175 fn is_jump_label(&self, node: &Node) -> bool {
176 node.kind() == Self::LABEL_NAME
177 }
178
179 fn has_labeled_jumps(&self) -> bool {
180 true
181 }
182
183 fn call_identifiers(&self, source_file: &File, node: &Node) -> (Option<String>, String) {
184 let function_node = node.child_by_field_name("function");
185 match function_node.as_ref().map(|n| n.kind()) {
186 Some(Self::IDENTIFIER) => {
187 (None, get_node_source_or_default(function_node, source_file))
188 }
189 Some(Self::SELECTOR_EXPRESSION) => {
190 let (receiver, object) =
191 self.field_identifiers(source_file, &function_node.unwrap());
192
193 (Some(receiver), object)
194 }
195 _ => (Some("<UNKNOWN>".to_string()), "<UNKNOWN>".to_string()),
196 }
197 }
198
199 fn field_identifiers(&self, source_file: &File, node: &Node) -> (String, String) {
200 let object_node = node.child_by_field_name("operand");
201 let property_node = node.child_by_field_name("field");
202
203 match (&object_node, &property_node) {
204 (Some(obj), Some(prop)) if obj.kind() == Self::SELECTOR_EXPRESSION => {
205 let object_source =
206 get_node_source_or_default(obj.child_by_field_name("field"), source_file);
207 let property_source = get_node_source_or_default(Some(*prop), source_file);
208 (object_source, property_source)
209 }
210 (Some(obj), Some(prop)) => (
211 get_node_source_or_default(Some(*obj), source_file),
212 get_node_source_or_default(Some(*prop), source_file),
213 ),
214 _ => ("<UNKNOWN>".to_string(), "<UNKNOWN>".to_string()),
215 }
216 }
217
218 fn tree_sitter_language(&self) -> tree_sitter::Language {
219 tree_sitter_go::language()
220 }
221}
222
223fn get_node_source_or_default(node: Option<Node>, source_file: &File) -> String {
224 node.as_ref()
225 .map(|n| node_source(n, source_file))
226 .unwrap_or("<UNKNOWN>".to_string())
227}
228
229#[cfg(test)]
230mod test {
231 use super::*;
232 use std::collections::HashSet;
233 use tree_sitter::Tree;
234
235 #[test]
236 fn mutually_exclusive() {
237 let lang = Go::default();
238 let mut kinds: Vec<&str> = vec![];
239
240 kinds.extend(lang.if_nodes());
241 kinds.extend(lang.switch_nodes());
242 kinds.extend(lang.case_nodes());
243 kinds.extend(lang.loop_nodes());
244 kinds.extend(lang.except_nodes());
245 kinds.extend(lang.try_expression_nodes());
246 kinds.extend(lang.jump_nodes());
247 kinds.extend(lang.return_nodes());
248 kinds.extend(lang.binary_nodes());
249 kinds.extend(lang.field_nodes());
250 kinds.extend(lang.call_nodes());
251 kinds.extend(lang.function_nodes());
252 kinds.extend(lang.closure_nodes());
253 kinds.extend(lang.comment_nodes());
254 kinds.extend(lang.string_nodes());
255 kinds.extend(lang.boolean_operator_nodes());
256 kinds.extend(lang.block_nodes());
257
258 let unique: HashSet<_> = kinds.iter().cloned().collect();
259 assert_eq!(unique.len(), kinds.len());
260 }
261
262 #[test]
263 fn field_identifier_read() {
264 let source_file = File::from_string("go", "object.foo\n");
265 let tree = source_file.parse();
266 let root_node = tree.root_node();
267 let expression = root_node.named_child(0).unwrap();
268 let field = expression.named_child(0).unwrap();
269 let language = Go::default();
270
271 assert_eq!(
272 language.field_identifiers(&source_file, &field),
273 ("object".to_string(), "foo".to_string())
274 );
275 }
276
277 #[test]
278 fn field_identifier_write() {
279 let source_file = File::from_string("go", "object.foo = 1");
280 let tree = source_file.parse();
281 let root_node = tree.root_node();
282 let expression = root_node.named_child(0).unwrap();
283 let assignment = expression.named_child(0).unwrap();
284 let field = assignment.named_child(0).unwrap();
285 let language = Go::default();
286
287 assert_eq!(
288 language.field_identifiers(&source_file, &field),
289 ("object".to_string(), "foo".to_string())
290 );
291 }
292
293 #[test]
294 fn call_identifier() {
295 let source_file = File::from_string("go", "foo()");
296 let tree = source_file.parse();
297 let call = call_node(&tree);
298 let language = Go::default();
299
300 assert_eq!(
301 language.call_identifiers(&source_file, &call),
302 (None, "foo".to_string())
303 );
304 }
305
306 #[test]
307 fn call_member() {
308 let source_file = File::from_string("go", "foo.bar()");
309 let tree = source_file.parse();
310 let call = call_node(&tree);
311 let language = Go::default();
312
313 assert_eq!(
314 language.call_identifiers(&source_file, &call),
315 (Some("foo".into()), "bar".into())
316 );
317 }
318
319 #[test]
320 fn call_with_custom_context() {
321 let source_file = File::from_string("go", "foo.call(context)\n");
322 let tree = source_file.parse();
323 let call = call_node(&tree);
324 let language = Go::default();
325
326 assert_eq!(
327 language.call_identifiers(&source_file, &call),
328 (Some("foo".to_string()), "call".to_string())
329 );
330 }
331
332 #[test]
333 fn method_call_on_nested_object() {
334 let source_file = File::from_string("go", "obj.nestedObj.foo()\n");
335 let tree = source_file.parse();
336 let call = call_node(&tree);
337 let language = Go::default();
338
339 assert_eq!(
340 language.call_identifiers(&source_file, &call),
341 (Some("nestedObj".to_string()), "foo".to_string())
342 );
343 }
344
345 #[test]
346 fn nested_field_access() {
347 let source_file = File::from_string("go", "obj.nestedObj.oneMoreObj\n");
348 let tree = source_file.parse();
349 let root_node = tree.root_node();
350 let expression = root_node.named_child(0).unwrap();
351 let field = expression.named_child(0).unwrap();
352 let language = Go::default();
353
354 assert_eq!(
355 language.field_identifiers(&source_file, &field),
356 ("nestedObj".to_string(), "oneMoreObj".to_string())
357 );
358 }
359
360 #[test]
361 fn call_function_property() {
362 let source_file = File::from_string("go", "foo.bar()\n");
363 let tree = source_file.parse();
364 let call = call_node(&tree);
365 let language = Go::default();
366
367 assert_eq!(
368 language.call_identifiers(&source_file, &call),
369 (Some("foo".to_string()), "bar".to_string())
370 );
371 }
372
373 fn call_node(tree: &Tree) -> Node {
374 let root_node = tree.root_node();
375 let expression = root_node.named_child(0).unwrap();
376 expression.named_child(0).unwrap()
377 }
378}