code_graph/
lang.rs

1use tree_sitter::{Language, Node};
2use uuid::Uuid;
3
4use crate::{CodeBlockType, CodeNode};
5
6pub trait SymbolQuery {
7    fn get_call(&self, code: &str, node: &Node) -> Option<CodeNode>;
8    fn get_lang(&self) -> Language;
9    fn get_definition(&self, code: &str, node: &Node) -> Option<CodeNode>;
10}
11pub struct RustQuery;
12pub struct CQuery;
13pub struct JavaQuery;
14pub struct JsQuery;
15
16impl SymbolQuery for JsQuery {
17    fn get_call(&self, code: &str, node: &Node) -> Option<CodeNode> {
18        let node_type = node.kind();
19
20        if node_type == "call_expression" {
21            let block_text = &code[node.byte_range()];
22            let fe = node.child_by_field_name("function");
23            if let Some(fe) = fe {
24                let fi = fe.child_by_field_name("property");
25                if let Some(fi) = fi {
26                    let label = &code[fi.byte_range()];
27                    return Some(CodeNode::new(
28                        format!("{}", Uuid::new_v4()).as_str(),
29                        label,
30                        block_text,
31                        fi.start_position().row + 1,
32                        CodeBlockType::CALL,
33                        0,
34                    ));
35                } else {
36                    let label = &code[fe.byte_range()];
37                    return Some(CodeNode::new(
38                        format!("{}", Uuid::new_v4()).as_str(),
39                        label,
40                        block_text,
41                        fe.start_position().row + 1,
42                        CodeBlockType::CALL,
43                        0,
44                    ));
45                }
46            }
47        }
48        None
49    }
50
51    fn get_lang(&self) -> Language {
52        tree_sitter_javascript::language()
53    }
54
55    fn get_definition(&self, code: &str, node: &Node) -> Option<CodeNode> {
56        let node_type = node.kind();
57        let definition_list = [
58            ("function_declaration", "formal_parameters"),
59            ("class_declaration", "class_body"),
60            ("method_definition", "formal_parameters"),
61        ];
62        for (root_type, end_type) in definition_list {
63            if node_type == root_type {
64                let mut output = String::new();
65                for child in node.children(&mut node.walk()) {
66                    if child.kind() == end_type {
67                        break;
68                    } else {
69                        let node_text = &code[child.byte_range()];
70                        output.push_str(node_text);
71                        output.push(' ');
72                    }
73                }
74                let block_type = match root_type {
75                    "function_declaration" => CodeBlockType::FUNCTION,
76                    "method_definition" => CodeBlockType::FUNCTION,
77                    "class_declaration" => CodeBlockType::CLASS,
78                    _ => CodeBlockType::NORMAL,
79                };
80                let block_text = &code[node.byte_range()];
81                return Some(CodeNode::new(
82                    format!("{}", Uuid::new_v4()).as_str(),
83                    output.as_str(),
84                    block_text,
85                    node.start_position().row + 1,
86                    block_type,
87                    0,
88                ));
89            }
90        }
91        if node_type == "lexical_declaration" {
92            if node.parent().is_some() && node.parent().unwrap().grammar_name() == "program" {
93                let mut output = String::new();
94                let kind_node = node.child_by_field_name("kind");
95                if let Some(kind_node) = kind_node {
96                    output.push_str(&code[kind_node.byte_range()]);
97                }
98                for child in node.children(&mut node.walk()) {
99                    if "variable_declarator" == child.kind() {
100                        let name = child.child_by_field_name("name");
101                        if let Some(name) = name {
102                            output.push_str(" ");
103                            output.push_str(&code[name.byte_range()]);
104                        }
105                    }
106                }
107                let block_type = CodeBlockType::CONST;
108                let block_text = &code[node.byte_range()];
109                return Some(CodeNode::new(
110                    format!("{}", Uuid::new_v4()).as_str(),
111                    output.as_str(),
112                    block_text,
113                    node.start_position().row + 1,
114                    block_type,
115                    0,
116                ));
117            }
118        }
119        None
120    }
121}
122
123impl SymbolQuery for CQuery {
124    fn get_call(&self, code: &str, node: &Node) -> Option<CodeNode> {
125        let node_type = node.kind();
126
127        if node_type == "call_expression" {
128            let block_text = &code[node.byte_range()];
129            let fe = node.child_by_field_name("function");
130            if let Some(fe) = fe {
131                let fi = fe.child_by_field_name("field");
132                if let Some(fi) = fi {
133                    let label = &code[fi.byte_range()];
134                    return Some(CodeNode::new(
135                        format!("{}", Uuid::new_v4()).as_str(),
136                        label,
137                        block_text,
138                        fi.start_position().row + 1,
139                        CodeBlockType::CALL,
140                        0,
141                    ));
142                } else {
143                    let label = &code[fe.byte_range()];
144                    return Some(CodeNode::new(
145                        format!("{}", Uuid::new_v4()).as_str(),
146                        label,
147                        block_text,
148                        fe.start_position().row + 1,
149                        CodeBlockType::CALL,
150                        0,
151                    ));
152                }
153            }
154        }
155        None
156    }
157
158    fn get_lang(&self) -> Language {
159        tree_sitter_c::language()
160    }
161
162    fn get_definition(&self, code: &str, node: &Node) -> Option<CodeNode> {
163        let node_type = node.kind();
164        let definition_list = [("function_definition", "compound_statement")];
165        for (root_type, end_type) in definition_list {
166            if node_type == root_type {
167                let mut output = String::new();
168                for child in node.children(&mut node.walk()) {
169                    if child.kind() == end_type {
170                        break;
171                    } else {
172                        let node_text = &code[child.byte_range()];
173                        output.push_str(node_text);
174                        output.push(' ');
175                    }
176                }
177                let block_type = match root_type {
178                    "function_definition" => CodeBlockType::FUNCTION,
179                    "struct_item" => CodeBlockType::STRUCT,
180                    _ => CodeBlockType::NORMAL,
181                };
182                let block_text = &code[node.byte_range()];
183                return Some(CodeNode::new(
184                    format!("{}", Uuid::new_v4()).as_str(),
185                    output.as_str().split("(").next().unwrap_or("bad symbol"),
186                    block_text,
187                    node.start_position().row + 1,
188                    block_type,
189                    0,
190                ));
191            }
192        }
193
194        None
195    }
196}
197
198impl SymbolQuery for JavaQuery {
199    fn get_call(&self, code: &str, node: &Node) -> Option<CodeNode> {
200        let node_type = node.kind();
201
202        if node_type == "method_invocation" {
203            let block_text = &code[node.byte_range()];
204            let fe = node.child_by_field_name("name");
205            if let Some(fe) = fe {
206                let label = &code[fe.byte_range()];
207                return Some(CodeNode::new(
208                    format!("{}", Uuid::new_v4()).as_str(),
209                    label,
210                    block_text,
211                    fe.start_position().row + 1,
212                    CodeBlockType::CALL,
213                    0,
214                ));
215            }
216        }
217        None
218    }
219
220    fn get_lang(&self) -> Language {
221        tree_sitter_java::language()
222    }
223
224    fn get_definition(&self, code: &str, node: &Node) -> Option<CodeNode> {
225        let node_type = node.kind();
226        let definition_list = [
227            ("class_declaration", "class_body"),
228            ("method_declaration", "formal_parameters"),
229            ("interface_declaration", "interface_body"),
230        ];
231        for (root_type, end_type) in definition_list {
232            if node_type == root_type {
233                let mut output = String::new();
234                for child in node.children(&mut node.walk()) {
235                    if child.kind() == end_type {
236                        break;
237                    } else {
238                        let node_text = &code[child.byte_range()];
239
240                        output.push_str(node_text);
241
242                        output.push(' ');
243                    }
244                }
245                let block_type = match root_type {
246                    "method_declaration" => CodeBlockType::FUNCTION,
247                    "class_declaration" => CodeBlockType::CLASS,
248                    "interface_declaration" => CodeBlockType::CLASS,
249                    _ => CodeBlockType::NORMAL,
250                };
251                let block_text = &code[node.byte_range()];
252                return Some(CodeNode::new(
253                    format!("{}", Uuid::new_v4()).as_str(),
254                    output.as_str(),
255                    block_text,
256                    node.start_position().row + 1,
257                    block_type,
258                    0,
259                ));
260            }
261        }
262
263        None
264    }
265}
266
267impl SymbolQuery for RustQuery {
268    fn get_lang(&self) -> Language {
269        tree_sitter_rust::language()
270    }
271
272    // call_expression 下 identifier 和 field_identifier
273    fn get_call(&self, code: &str, node: &Node) -> Option<CodeNode> {
274        let node_type = node.kind();
275
276        if node_type == "call_expression" {
277            let block_text = &code[node.byte_range()];
278            let fe = node.child_by_field_name("function");
279            if let Some(fe) = fe {
280                let fi = fe.child_by_field_name("field");
281                if let Some(fi) = fi {
282                    let label = &code[fi.byte_range()];
283                    return Some(CodeNode::new(
284                        format!("{}", Uuid::new_v4()).as_str(),
285                        label,
286                        block_text,
287                        fi.start_position().row + 1,
288                        CodeBlockType::CALL,
289                        0,
290                    ));
291                } else {
292                    let label = &code[fe.byte_range()];
293                    return Some(CodeNode::new(
294                        format!("{}", Uuid::new_v4()).as_str(),
295                        label,
296                        block_text,
297                        fe.start_position().row + 1,
298                        CodeBlockType::CALL,
299                        0,
300                    ));
301                }
302            }
303        }
304        None
305    }
306
307    fn get_definition(&self, code: &str, node: &Node) -> Option<CodeNode> {
308        let node_type = node.kind();
309        let definition_list = [
310            ("function_item", "parameters"),
311            ("impl_item", "declaration_list"),
312            ("struct_item", "field_declaration_list"),
313            ("trait_item", "declaration_list"),
314            ("function_signature_item", "parameters"),
315        ];
316        for (root_type, end_type) in definition_list {
317            if node_type == root_type {
318                let mut output = String::new();
319                for child in node.children(&mut node.walk()) {
320                    if child.kind() == end_type {
321                        break;
322                    } else {
323                        let node_text = &code[child.byte_range()];
324
325                        output.push_str(node_text);
326
327                        output.push(' ');
328                    }
329                }
330                let block_type = match root_type {
331                    "function_item" => CodeBlockType::FUNCTION,
332                    "struct_item" => CodeBlockType::STRUCT,
333                    "function_signature_item" => CodeBlockType::FUNCTION,
334                    "trait_item" => CodeBlockType::CLASS,
335                    "impl_item" => CodeBlockType::CLASS,
336                    _ => CodeBlockType::NORMAL,
337                };
338                let block_text = &code[node.byte_range()];
339                return Some(CodeNode::new(
340                    format!("{}", Uuid::new_v4()).as_str(),
341                    output.as_str(),
342                    block_text,
343                    node.start_position().row + 1,
344                    block_type,
345                    0,
346                ));
347            }
348        }
349
350        None
351    }
352}