Skip to main content

code_analyze_core/languages/
cpp.rs

1// SPDX-FileCopyrightText: 2026 code-analyze-mcp contributors
2// SPDX-License-Identifier: Apache-2.0
3use tree_sitter::Node;
4
5/// Tree-sitter query for extracting C/C++ elements (functions, classes, and structures).
6pub const ELEMENT_QUERY: &str = r"
7(function_definition
8  declarator: (function_declarator
9    declarator: (identifier) @func_name)) @function
10(function_definition
11  declarator: (function_declarator
12    declarator: (qualified_identifier
13      name: (identifier) @method_name))) @function
14(class_specifier
15  name: (type_identifier) @class_name) @class
16(struct_specifier
17  name: (type_identifier) @class_name) @class
18(template_declaration
19  (function_definition
20    declarator: (function_declarator
21      declarator: (identifier) @func_name))) @function
22";
23
24/// Tree-sitter query for extracting function calls.
25pub const CALL_QUERY: &str = r"
26(call_expression
27  function: (identifier) @call)
28(call_expression
29  function: (field_expression field: (field_identifier) @call))
30";
31
32/// Tree-sitter query for extracting type references.
33pub const REFERENCE_QUERY: &str = r"
34(type_identifier) @type_ref
35";
36
37/// Tree-sitter query for extracting C/C++ preprocessor directives (#include).
38pub const IMPORT_QUERY: &str = r"
39(preproc_include
40  path: (string_literal) @import_path)
41(preproc_include
42  path: (system_lib_string) @import_path)
43";
44
45/// Tree-sitter query for extracting definition and use sites.
46pub const DEFUSE_QUERY: &str = r"
47(init_declarator declarator: (identifier) @write.decl)
48(assignment_expression left: (identifier) @write.assign)
49(update_expression argument: (identifier) @writeread.update)
50(identifier) @read.usage
51";
52
53/// Extract the function name from a C/C++ `function_definition` node by
54/// walking the declarator chain: declarator -> function_declarator -> declarator -> identifier.
55pub fn extract_function_name(node: &Node, source: &str, _lang: &str) -> Option<String> {
56    node.child_by_field_name("declarator")
57        .and_then(|decl| extract_declarator_name(decl, source))
58}
59
60/// Find method name for a receiver type (class/struct context).
61#[must_use]
62pub fn find_method_for_receiver(
63    node: &Node,
64    source: &str,
65    _depth: Option<usize>,
66) -> Option<String> {
67    if node.kind() != "function_definition" {
68        return None;
69    }
70
71    // Walk up to find if we're in a class_specifier or struct_specifier
72    let mut parent = node.parent();
73    let mut in_class = false;
74    while let Some(p) = parent {
75        if p.kind() == "class_specifier" || p.kind() == "struct_specifier" {
76            in_class = true;
77            break;
78        }
79        parent = p.parent();
80    }
81
82    if !in_class {
83        return None;
84    }
85
86    // Extract the method name from function_declarator
87    if let Some(decl) = node.child_by_field_name("declarator") {
88        extract_declarator_name(decl, source)
89    } else {
90        None
91    }
92}
93
94/// Extract inheritance information from a class_specifier or struct_specifier node.
95#[must_use]
96pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
97    let mut inherits = Vec::new();
98
99    if node.kind() != "class_specifier" && node.kind() != "struct_specifier" {
100        return inherits;
101    }
102
103    // Look for base_class_clause child
104    for i in 0..node.named_child_count() {
105        if let Some(child) = node.named_child(u32::try_from(i).unwrap_or(u32::MAX))
106            && child.kind() == "base_class_clause"
107        {
108            // Walk base_class_clause for type_identifier nodes
109            for j in 0..child.named_child_count() {
110                if let Some(base) = child.named_child(u32::try_from(j).unwrap_or(u32::MAX))
111                    && base.kind() == "type_identifier"
112                {
113                    let text = &source[base.start_byte()..base.end_byte()];
114                    inherits.push(text.to_string());
115                }
116            }
117        }
118    }
119
120    inherits
121}
122
123/// Helper: extract name from a declarator node (handles identifiers and qualified identifiers).
124fn extract_declarator_name(node: Node, source: &str) -> Option<String> {
125    match node.kind() {
126        "identifier" | "field_identifier" => {
127            let start = node.start_byte();
128            let end = node.end_byte();
129            if end <= source.len() {
130                Some(source[start..end].to_string())
131            } else {
132                None
133            }
134        }
135        "qualified_identifier" => node.child_by_field_name("name").and_then(|n| {
136            let start = n.start_byte();
137            let end = n.end_byte();
138            if end <= source.len() {
139                Some(source[start..end].to_string())
140            } else {
141                None
142            }
143        }),
144        "function_declarator" => node
145            .child_by_field_name("declarator")
146            .and_then(|n| extract_declarator_name(n, source)),
147        "pointer_declarator" => node
148            .child_by_field_name("declarator")
149            .and_then(|n| extract_declarator_name(n, source)),
150        "reference_declarator" => node
151            .child_by_field_name("declarator")
152            .and_then(|n| extract_declarator_name(n, source)),
153        _ => None,
154    }
155}
156
157#[cfg(all(test, feature = "lang-cpp"))]
158mod tests {
159    use super::*;
160    use crate::DefUseKind;
161    use crate::parser::SemanticExtractor;
162    use tree_sitter::Parser;
163
164    fn parse_cpp(source: &str) -> tree_sitter::Tree {
165        let mut parser = Parser::new();
166        parser
167            .set_language(&tree_sitter_cpp::LANGUAGE.into())
168            .expect("failed to set C++ language");
169        parser.parse(source, None).expect("failed to parse source")
170    }
171
172    #[test]
173    fn test_free_function() {
174        // Arrange: free function definition
175        let source = "int add(int a, int b) { return a + b; }";
176        let tree = parse_cpp(source);
177        let root = tree.root_node();
178        let func_node = root.named_child(0).expect("expected function_definition");
179        // Act
180        let result = find_method_for_receiver(&func_node, source, None);
181        // Assert: free function should not be a method
182        assert_eq!(result, None);
183    }
184
185    #[test]
186    fn test_class_with_method() {
187        // Arrange: class with method
188        let source = "class Foo { public: int getValue() { return 42; } };";
189        let tree = parse_cpp(source);
190        let root = tree.root_node();
191        // Find the function_definition inside the class
192        let func_node = find_node_by_kind(root, "function_definition").expect("expected function");
193        // Act
194        let result = find_method_for_receiver(&func_node, source, None);
195        // Assert: method inside class should be recognized
196        assert_eq!(result, Some("getValue".to_string()));
197    }
198
199    #[test]
200    fn test_struct() {
201        // Arrange: struct with no base class
202        let source = "struct Point { int x; int y; };";
203        let tree = parse_cpp(source);
204        let root = tree.root_node();
205        let struct_node =
206            find_node_by_kind(root, "struct_specifier").expect("expected struct_specifier");
207        // Assert: node kind is correct
208        assert_eq!(struct_node.kind(), "struct_specifier");
209        // Act + Assert: struct with no inheritance returns empty
210        let result = extract_inheritance(&struct_node, source);
211        assert!(
212            result.is_empty(),
213            "expected no inheritance, got: {result:?}"
214        );
215    }
216
217    #[test]
218    fn test_include_directive() {
219        use tree_sitter::StreamingIterator;
220        // Arrange
221        let source = "#include <stdio.h>\n#include \"myfile.h\"\n";
222        let tree = parse_cpp(source);
223        // Act: run IMPORT_QUERY
224        let lang: tree_sitter::Language = tree_sitter_cpp::LANGUAGE.into();
225        let query = tree_sitter::Query::new(&lang, super::IMPORT_QUERY)
226            .expect("IMPORT_QUERY must be valid");
227        let mut cursor = tree_sitter::QueryCursor::new();
228        let mut iter = cursor.captures(&query, tree.root_node(), source.as_bytes());
229        let mut captures: Vec<String> = Vec::new();
230        while let Some((m, _)) = iter.next() {
231            for c in m.captures {
232                let text = c
233                    .node
234                    .utf8_text(source.as_bytes())
235                    .unwrap_or("")
236                    .to_string();
237                captures.push(text);
238            }
239        }
240        // Assert: both includes captured
241        assert!(
242            captures.iter().any(|s| s.contains("stdio.h")),
243            "expected stdio.h in captures: {captures:?}"
244        );
245        assert!(
246            captures.iter().any(|s| s.contains("myfile.h")),
247            "expected myfile.h in captures: {captures:?}"
248        );
249    }
250
251    #[test]
252    fn test_template_function() {
253        use tree_sitter::StreamingIterator;
254        // Arrange: template function definition
255        let source = "template<typename T> T max(T a, T b) { return a > b ? a : b; }";
256        let tree = parse_cpp(source);
257        // Act: run ELEMENT_QUERY
258        let lang: tree_sitter::Language = tree_sitter_cpp::LANGUAGE.into();
259        let query = tree_sitter::Query::new(&lang, super::ELEMENT_QUERY)
260            .expect("ELEMENT_QUERY must be valid");
261        let mut cursor = tree_sitter::QueryCursor::new();
262        let mut iter = cursor.captures(&query, tree.root_node(), source.as_bytes());
263        let mut func_names: Vec<String> = Vec::new();
264        while let Some((m, _)) = iter.next() {
265            for c in m.captures {
266                let name = query.capture_names()[c.index as usize];
267                if name == "func_name" {
268                    if let Ok(text) = c.node.utf8_text(source.as_bytes()) {
269                        func_names.push(text.to_string());
270                    }
271                }
272            }
273        }
274        // Assert: "max" captured as func_name
275        assert!(
276            func_names.iter().any(|s| s == "max"),
277            "expected 'max' in func_names: {func_names:?}"
278        );
279    }
280
281    #[test]
282    fn test_class_with_inheritance() {
283        // Arrange: class with base class
284        let source = "class Derived : public Base { };";
285        let tree = parse_cpp(source);
286        let root = tree.root_node();
287        let class_node = find_node_by_kind(root, "class_specifier").expect("expected class");
288        // Act
289        let result = extract_inheritance(&class_node, source);
290        // Assert: should have "Base" as inheritance
291        assert!(!result.is_empty(), "expected inheritance information");
292        assert!(
293            result.iter().any(|s| s.contains("Base")),
294            "expected 'Base' in inheritance: {:?}",
295            result
296        );
297    }
298
299    /// Helper to find the first node of a given kind
300    fn find_node_by_kind<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
301        if node.kind() == kind {
302            return Some(node);
303        }
304        for i in 0..node.child_count() {
305            if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
306                if let Some(found) = find_node_by_kind(child, kind) {
307                    return Some(found);
308                }
309            }
310        }
311        None
312    }
313
314    #[test]
315    fn test_defuse_query_write_site() {
316        // Arrange
317        let src = "void f() { int a = 7; }\n";
318        let sites = SemanticExtractor::extract_def_use_for_file(src, "cpp", "a", "test.cpp", None);
319        assert!(!sites.is_empty(), "defuse sites should not be empty");
320        let has_write = sites.iter().any(|s| matches!(s.kind, DefUseKind::Write));
321        assert!(has_write, "should contain a Write DefUseSite");
322    }
323}