code_analyze_core/languages/
cpp.rs1use tree_sitter::Node;
4
5pub const ELEMENT_QUERY: &str = r"
7(function_definition
8 declarator: (function_declarator
9 declarator: (identifier) @func_name)) @function
10(function_definition
11 declarator: (function_declarator
12 declarator: (qualified_identifier
13 name: (identifier) @method_name))) @function
14(class_specifier
15 name: (type_identifier) @class_name) @class
16(struct_specifier
17 name: (type_identifier) @class_name) @class
18(template_declaration
19 (function_definition
20 declarator: (function_declarator
21 declarator: (identifier) @func_name))) @function
22";
23
24pub const CALL_QUERY: &str = r"
26(call_expression
27 function: (identifier) @call)
28(call_expression
29 function: (field_expression field: (field_identifier) @call))
30";
31
32pub const REFERENCE_QUERY: &str = r"
34(type_identifier) @type_ref
35";
36
37pub const IMPORT_QUERY: &str = r"
39(preproc_include
40 path: (string_literal) @import_path)
41(preproc_include
42 path: (system_lib_string) @import_path)
43";
44
45pub fn extract_function_name(node: &Node, source: &str, _lang: &str) -> Option<String> {
48 node.child_by_field_name("declarator")
49 .and_then(|decl| extract_declarator_name(decl, source))
50}
51
52#[must_use]
54pub fn find_method_for_receiver(
55 node: &Node,
56 source: &str,
57 _depth: Option<usize>,
58) -> Option<String> {
59 if node.kind() != "function_definition" {
60 return None;
61 }
62
63 let mut parent = node.parent();
65 let mut in_class = false;
66 while let Some(p) = parent {
67 if p.kind() == "class_specifier" || p.kind() == "struct_specifier" {
68 in_class = true;
69 break;
70 }
71 parent = p.parent();
72 }
73
74 if !in_class {
75 return None;
76 }
77
78 if let Some(decl) = node.child_by_field_name("declarator") {
80 extract_declarator_name(decl, source)
81 } else {
82 None
83 }
84}
85
86#[must_use]
88pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
89 let mut inherits = Vec::new();
90
91 if node.kind() != "class_specifier" && node.kind() != "struct_specifier" {
92 return inherits;
93 }
94
95 for i in 0..node.named_child_count() {
97 if let Some(child) = node.named_child(u32::try_from(i).unwrap_or(u32::MAX))
98 && child.kind() == "base_class_clause"
99 {
100 for j in 0..child.named_child_count() {
102 if let Some(base) = child.named_child(u32::try_from(j).unwrap_or(u32::MAX))
103 && base.kind() == "type_identifier"
104 {
105 let text = &source[base.start_byte()..base.end_byte()];
106 inherits.push(text.to_string());
107 }
108 }
109 }
110 }
111
112 inherits
113}
114
115fn extract_declarator_name(node: Node, source: &str) -> Option<String> {
117 match node.kind() {
118 "identifier" | "field_identifier" => {
119 let start = node.start_byte();
120 let end = node.end_byte();
121 if end <= source.len() {
122 Some(source[start..end].to_string())
123 } else {
124 None
125 }
126 }
127 "qualified_identifier" => node.child_by_field_name("name").and_then(|n| {
128 let start = n.start_byte();
129 let end = n.end_byte();
130 if end <= source.len() {
131 Some(source[start..end].to_string())
132 } else {
133 None
134 }
135 }),
136 "function_declarator" => node
137 .child_by_field_name("declarator")
138 .and_then(|n| extract_declarator_name(n, source)),
139 "pointer_declarator" => node
140 .child_by_field_name("declarator")
141 .and_then(|n| extract_declarator_name(n, source)),
142 "reference_declarator" => node
143 .child_by_field_name("declarator")
144 .and_then(|n| extract_declarator_name(n, source)),
145 _ => None,
146 }
147}
148
149#[cfg(all(test, feature = "lang-cpp"))]
150mod tests {
151 use super::*;
152 use tree_sitter::Parser;
153
154 fn parse_cpp(source: &str) -> tree_sitter::Tree {
155 let mut parser = Parser::new();
156 parser
157 .set_language(&tree_sitter_cpp::LANGUAGE.into())
158 .expect("failed to set C++ language");
159 parser.parse(source, None).expect("failed to parse source")
160 }
161
162 #[test]
163 fn test_free_function() {
164 let source = "int add(int a, int b) { return a + b; }";
166 let tree = parse_cpp(source);
167 let root = tree.root_node();
168 let func_node = root.named_child(0).expect("expected function_definition");
169 let result = find_method_for_receiver(&func_node, source, None);
171 assert_eq!(result, None);
173 }
174
175 #[test]
176 fn test_class_with_method() {
177 let source = "class Foo { public: int getValue() { return 42; } };";
179 let tree = parse_cpp(source);
180 let root = tree.root_node();
181 let func_node = find_node_by_kind(root, "function_definition").expect("expected function");
183 let result = find_method_for_receiver(&func_node, source, None);
185 assert_eq!(result, Some("getValue".to_string()));
187 }
188
189 #[test]
190 fn test_struct() {
191 let source = "struct Point { int x; int y; };";
193 let tree = parse_cpp(source);
194 let root = tree.root_node();
195 let struct_node =
196 find_node_by_kind(root, "struct_specifier").expect("expected struct_specifier");
197 assert_eq!(struct_node.kind(), "struct_specifier");
199 let result = extract_inheritance(&struct_node, source);
201 assert!(
202 result.is_empty(),
203 "expected no inheritance, got: {result:?}"
204 );
205 }
206
207 #[test]
208 fn test_include_directive() {
209 use tree_sitter::StreamingIterator;
210 let source = "#include <stdio.h>\n#include \"myfile.h\"\n";
212 let tree = parse_cpp(source);
213 let lang: tree_sitter::Language = tree_sitter_cpp::LANGUAGE.into();
215 let query = tree_sitter::Query::new(&lang, super::IMPORT_QUERY)
216 .expect("IMPORT_QUERY must be valid");
217 let mut cursor = tree_sitter::QueryCursor::new();
218 let mut iter = cursor.captures(&query, tree.root_node(), source.as_bytes());
219 let mut captures: Vec<String> = Vec::new();
220 while let Some((m, _)) = iter.next() {
221 for c in m.captures {
222 let text = c
223 .node
224 .utf8_text(source.as_bytes())
225 .unwrap_or("")
226 .to_string();
227 captures.push(text);
228 }
229 }
230 assert!(
232 captures.iter().any(|s| s.contains("stdio.h")),
233 "expected stdio.h in captures: {captures:?}"
234 );
235 assert!(
236 captures.iter().any(|s| s.contains("myfile.h")),
237 "expected myfile.h in captures: {captures:?}"
238 );
239 }
240
241 #[test]
242 fn test_template_function() {
243 use tree_sitter::StreamingIterator;
244 let source = "template<typename T> T max(T a, T b) { return a > b ? a : b; }";
246 let tree = parse_cpp(source);
247 let lang: tree_sitter::Language = tree_sitter_cpp::LANGUAGE.into();
249 let query = tree_sitter::Query::new(&lang, super::ELEMENT_QUERY)
250 .expect("ELEMENT_QUERY must be valid");
251 let mut cursor = tree_sitter::QueryCursor::new();
252 let mut iter = cursor.captures(&query, tree.root_node(), source.as_bytes());
253 let mut func_names: Vec<String> = Vec::new();
254 while let Some((m, _)) = iter.next() {
255 for c in m.captures {
256 let name = query.capture_names()[c.index as usize];
257 if name == "func_name" {
258 if let Ok(text) = c.node.utf8_text(source.as_bytes()) {
259 func_names.push(text.to_string());
260 }
261 }
262 }
263 }
264 assert!(
266 func_names.iter().any(|s| s == "max"),
267 "expected 'max' in func_names: {func_names:?}"
268 );
269 }
270
271 #[test]
272 fn test_class_with_inheritance() {
273 let source = "class Derived : public Base { };";
275 let tree = parse_cpp(source);
276 let root = tree.root_node();
277 let class_node = find_node_by_kind(root, "class_specifier").expect("expected class");
278 let result = extract_inheritance(&class_node, source);
280 assert!(!result.is_empty(), "expected inheritance information");
282 assert!(
283 result.iter().any(|s| s.contains("Base")),
284 "expected 'Base' in inheritance: {:?}",
285 result
286 );
287 }
288
289 fn find_node_by_kind<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
291 if node.kind() == kind {
292 return Some(node);
293 }
294 for i in 0..node.child_count() {
295 if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
296 if let Some(found) = find_node_by_kind(child, kind) {
297 return Some(found);
298 }
299 }
300 }
301 None
302 }
303}