code_analyze_core/languages/
cpp.rs1use tree_sitter::Node;
4
5pub const ELEMENT_QUERY: &str = r"
7(function_definition
8 declarator: (function_declarator
9 declarator: (identifier) @func_name)) @function
10(function_definition
11 declarator: (function_declarator
12 declarator: (qualified_identifier
13 name: (identifier) @method_name))) @function
14(class_specifier
15 name: (type_identifier) @class_name) @class
16(struct_specifier
17 name: (type_identifier) @class_name) @class
18(template_declaration
19 (function_definition
20 declarator: (function_declarator
21 declarator: (identifier) @func_name))) @function
22";
23
24pub const CALL_QUERY: &str = r"
26(call_expression
27 function: (identifier) @call)
28(call_expression
29 function: (field_expression field: (field_identifier) @call))
30";
31
32pub const REFERENCE_QUERY: &str = r"
34(type_identifier) @type_ref
35";
36
37pub const IMPORT_QUERY: &str = r"
39(preproc_include
40 path: (string_literal) @import_path)
41(preproc_include
42 path: (system_lib_string) @import_path)
43";
44
45pub const DEFUSE_QUERY: &str = r"
47(init_declarator declarator: (identifier) @write.decl)
48(assignment_expression left: (identifier) @write.assign)
49(update_expression argument: (identifier) @writeread.update)
50(identifier) @read.usage
51";
52
53pub fn extract_function_name(node: &Node, source: &str, _lang: &str) -> Option<String> {
56 node.child_by_field_name("declarator")
57 .and_then(|decl| extract_declarator_name(decl, source))
58}
59
60#[must_use]
62pub fn find_method_for_receiver(
63 node: &Node,
64 source: &str,
65 _depth: Option<usize>,
66) -> Option<String> {
67 if node.kind() != "function_definition" {
68 return None;
69 }
70
71 let mut parent = node.parent();
73 let mut in_class = false;
74 while let Some(p) = parent {
75 if p.kind() == "class_specifier" || p.kind() == "struct_specifier" {
76 in_class = true;
77 break;
78 }
79 parent = p.parent();
80 }
81
82 if !in_class {
83 return None;
84 }
85
86 if let Some(decl) = node.child_by_field_name("declarator") {
88 extract_declarator_name(decl, source)
89 } else {
90 None
91 }
92}
93
94#[must_use]
96pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
97 let mut inherits = Vec::new();
98
99 if node.kind() != "class_specifier" && node.kind() != "struct_specifier" {
100 return inherits;
101 }
102
103 for i in 0..node.named_child_count() {
105 if let Some(child) = node.named_child(u32::try_from(i).unwrap_or(u32::MAX))
106 && child.kind() == "base_class_clause"
107 {
108 for j in 0..child.named_child_count() {
110 if let Some(base) = child.named_child(u32::try_from(j).unwrap_or(u32::MAX))
111 && base.kind() == "type_identifier"
112 {
113 let text = &source[base.start_byte()..base.end_byte()];
114 inherits.push(text.to_string());
115 }
116 }
117 }
118 }
119
120 inherits
121}
122
123fn extract_declarator_name(node: Node, source: &str) -> Option<String> {
125 match node.kind() {
126 "identifier" | "field_identifier" => {
127 let start = node.start_byte();
128 let end = node.end_byte();
129 if end <= source.len() {
130 Some(source[start..end].to_string())
131 } else {
132 None
133 }
134 }
135 "qualified_identifier" => node.child_by_field_name("name").and_then(|n| {
136 let start = n.start_byte();
137 let end = n.end_byte();
138 if end <= source.len() {
139 Some(source[start..end].to_string())
140 } else {
141 None
142 }
143 }),
144 "function_declarator" => node
145 .child_by_field_name("declarator")
146 .and_then(|n| extract_declarator_name(n, source)),
147 "pointer_declarator" => node
148 .child_by_field_name("declarator")
149 .and_then(|n| extract_declarator_name(n, source)),
150 "reference_declarator" => node
151 .child_by_field_name("declarator")
152 .and_then(|n| extract_declarator_name(n, source)),
153 _ => None,
154 }
155}
156
157#[cfg(all(test, feature = "lang-cpp"))]
158mod tests {
159 use super::*;
160 use crate::DefUseKind;
161 use crate::parser::SemanticExtractor;
162 use tree_sitter::Parser;
163
164 fn parse_cpp(source: &str) -> tree_sitter::Tree {
165 let mut parser = Parser::new();
166 parser
167 .set_language(&tree_sitter_cpp::LANGUAGE.into())
168 .expect("failed to set C++ language");
169 parser.parse(source, None).expect("failed to parse source")
170 }
171
172 #[test]
173 fn test_free_function() {
174 let source = "int add(int a, int b) { return a + b; }";
176 let tree = parse_cpp(source);
177 let root = tree.root_node();
178 let func_node = root.named_child(0).expect("expected function_definition");
179 let result = find_method_for_receiver(&func_node, source, None);
181 assert_eq!(result, None);
183 }
184
185 #[test]
186 fn test_class_with_method() {
187 let source = "class Foo { public: int getValue() { return 42; } };";
189 let tree = parse_cpp(source);
190 let root = tree.root_node();
191 let func_node = find_node_by_kind(root, "function_definition").expect("expected function");
193 let result = find_method_for_receiver(&func_node, source, None);
195 assert_eq!(result, Some("getValue".to_string()));
197 }
198
199 #[test]
200 fn test_struct() {
201 let source = "struct Point { int x; int y; };";
203 let tree = parse_cpp(source);
204 let root = tree.root_node();
205 let struct_node =
206 find_node_by_kind(root, "struct_specifier").expect("expected struct_specifier");
207 assert_eq!(struct_node.kind(), "struct_specifier");
209 let result = extract_inheritance(&struct_node, source);
211 assert!(
212 result.is_empty(),
213 "expected no inheritance, got: {result:?}"
214 );
215 }
216
217 #[test]
218 fn test_include_directive() {
219 use tree_sitter::StreamingIterator;
220 let source = "#include <stdio.h>\n#include \"myfile.h\"\n";
222 let tree = parse_cpp(source);
223 let lang: tree_sitter::Language = tree_sitter_cpp::LANGUAGE.into();
225 let query = tree_sitter::Query::new(&lang, super::IMPORT_QUERY)
226 .expect("IMPORT_QUERY must be valid");
227 let mut cursor = tree_sitter::QueryCursor::new();
228 let mut iter = cursor.captures(&query, tree.root_node(), source.as_bytes());
229 let mut captures: Vec<String> = Vec::new();
230 while let Some((m, _)) = iter.next() {
231 for c in m.captures {
232 let text = c
233 .node
234 .utf8_text(source.as_bytes())
235 .unwrap_or("")
236 .to_string();
237 captures.push(text);
238 }
239 }
240 assert!(
242 captures.iter().any(|s| s.contains("stdio.h")),
243 "expected stdio.h in captures: {captures:?}"
244 );
245 assert!(
246 captures.iter().any(|s| s.contains("myfile.h")),
247 "expected myfile.h in captures: {captures:?}"
248 );
249 }
250
251 #[test]
252 fn test_template_function() {
253 use tree_sitter::StreamingIterator;
254 let source = "template<typename T> T max(T a, T b) { return a > b ? a : b; }";
256 let tree = parse_cpp(source);
257 let lang: tree_sitter::Language = tree_sitter_cpp::LANGUAGE.into();
259 let query = tree_sitter::Query::new(&lang, super::ELEMENT_QUERY)
260 .expect("ELEMENT_QUERY must be valid");
261 let mut cursor = tree_sitter::QueryCursor::new();
262 let mut iter = cursor.captures(&query, tree.root_node(), source.as_bytes());
263 let mut func_names: Vec<String> = Vec::new();
264 while let Some((m, _)) = iter.next() {
265 for c in m.captures {
266 let name = query.capture_names()[c.index as usize];
267 if name == "func_name" {
268 if let Ok(text) = c.node.utf8_text(source.as_bytes()) {
269 func_names.push(text.to_string());
270 }
271 }
272 }
273 }
274 assert!(
276 func_names.iter().any(|s| s == "max"),
277 "expected 'max' in func_names: {func_names:?}"
278 );
279 }
280
281 #[test]
282 fn test_class_with_inheritance() {
283 let source = "class Derived : public Base { };";
285 let tree = parse_cpp(source);
286 let root = tree.root_node();
287 let class_node = find_node_by_kind(root, "class_specifier").expect("expected class");
288 let result = extract_inheritance(&class_node, source);
290 assert!(!result.is_empty(), "expected inheritance information");
292 assert!(
293 result.iter().any(|s| s.contains("Base")),
294 "expected 'Base' in inheritance: {:?}",
295 result
296 );
297 }
298
299 fn find_node_by_kind<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
301 if node.kind() == kind {
302 return Some(node);
303 }
304 for i in 0..node.child_count() {
305 if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
306 if let Some(found) = find_node_by_kind(child, kind) {
307 return Some(found);
308 }
309 }
310 }
311 None
312 }
313
314 #[test]
315 fn test_defuse_query_write_site() {
316 let src = "void f() { int a = 7; }\n";
318 let sites = SemanticExtractor::extract_def_use_for_file(src, "cpp", "a", "test.cpp", None);
319 assert!(!sites.is_empty(), "defuse sites should not be empty");
320 let has_write = sites.iter().any(|s| matches!(s.kind, DefUseKind::Write));
321 assert!(has_write, "should contain a Write DefUseSite");
322 }
323}