Skip to main content

code_analyze_core/languages/
cpp.rs

1// SPDX-FileCopyrightText: 2026 code-analyze-mcp contributors
2// SPDX-License-Identifier: Apache-2.0
3use tree_sitter::Node;
4
5/// Tree-sitter query for extracting C/C++ elements (functions, classes, and structures).
6pub const ELEMENT_QUERY: &str = r"
7(function_definition
8  declarator: (function_declarator
9    declarator: (identifier) @func_name)) @function
10(function_definition
11  declarator: (function_declarator
12    declarator: (qualified_identifier
13      name: (identifier) @method_name))) @function
14(class_specifier
15  name: (type_identifier) @class_name) @class
16(struct_specifier
17  name: (type_identifier) @class_name) @class
18(template_declaration
19  (function_definition
20    declarator: (function_declarator
21      declarator: (identifier) @func_name))) @function
22";
23
24/// Tree-sitter query for extracting function calls.
25pub const CALL_QUERY: &str = r"
26(call_expression
27  function: (identifier) @call)
28(call_expression
29  function: (field_expression field: (field_identifier) @call))
30";
31
32/// Tree-sitter query for extracting type references.
33pub const REFERENCE_QUERY: &str = r"
34(type_identifier) @type_ref
35";
36
37/// Tree-sitter query for extracting C/C++ preprocessor directives (#include).
38pub const IMPORT_QUERY: &str = r"
39(preproc_include
40  path: (string_literal) @import_path)
41(preproc_include
42  path: (system_lib_string) @import_path)
43";
44
45/// Extract the function name from a C/C++ `function_definition` node by
46/// walking the declarator chain: declarator -> function_declarator -> declarator -> identifier.
47pub fn extract_function_name(node: &Node, source: &str, _lang: &str) -> Option<String> {
48    node.child_by_field_name("declarator")
49        .and_then(|decl| extract_declarator_name(decl, source))
50}
51
52/// Find method name for a receiver type (class/struct context).
53#[must_use]
54pub fn find_method_for_receiver(
55    node: &Node,
56    source: &str,
57    _depth: Option<usize>,
58) -> Option<String> {
59    if node.kind() != "function_definition" {
60        return None;
61    }
62
63    // Walk up to find if we're in a class_specifier or struct_specifier
64    let mut parent = node.parent();
65    let mut in_class = false;
66    while let Some(p) = parent {
67        if p.kind() == "class_specifier" || p.kind() == "struct_specifier" {
68            in_class = true;
69            break;
70        }
71        parent = p.parent();
72    }
73
74    if !in_class {
75        return None;
76    }
77
78    // Extract the method name from function_declarator
79    if let Some(decl) = node.child_by_field_name("declarator") {
80        extract_declarator_name(decl, source)
81    } else {
82        None
83    }
84}
85
86/// Extract inheritance information from a class_specifier or struct_specifier node.
87#[must_use]
88pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
89    let mut inherits = Vec::new();
90
91    if node.kind() != "class_specifier" && node.kind() != "struct_specifier" {
92        return inherits;
93    }
94
95    // Look for base_class_clause child
96    for i in 0..node.named_child_count() {
97        if let Some(child) = node.named_child(u32::try_from(i).unwrap_or(u32::MAX))
98            && child.kind() == "base_class_clause"
99        {
100            // Walk base_class_clause for type_identifier nodes
101            for j in 0..child.named_child_count() {
102                if let Some(base) = child.named_child(u32::try_from(j).unwrap_or(u32::MAX))
103                    && base.kind() == "type_identifier"
104                {
105                    let text = &source[base.start_byte()..base.end_byte()];
106                    inherits.push(text.to_string());
107                }
108            }
109        }
110    }
111
112    inherits
113}
114
115/// Helper: extract name from a declarator node (handles identifiers and qualified identifiers).
116fn extract_declarator_name(node: Node, source: &str) -> Option<String> {
117    match node.kind() {
118        "identifier" | "field_identifier" => {
119            let start = node.start_byte();
120            let end = node.end_byte();
121            if end <= source.len() {
122                Some(source[start..end].to_string())
123            } else {
124                None
125            }
126        }
127        "qualified_identifier" => node.child_by_field_name("name").and_then(|n| {
128            let start = n.start_byte();
129            let end = n.end_byte();
130            if end <= source.len() {
131                Some(source[start..end].to_string())
132            } else {
133                None
134            }
135        }),
136        "function_declarator" => node
137            .child_by_field_name("declarator")
138            .and_then(|n| extract_declarator_name(n, source)),
139        "pointer_declarator" => node
140            .child_by_field_name("declarator")
141            .and_then(|n| extract_declarator_name(n, source)),
142        "reference_declarator" => node
143            .child_by_field_name("declarator")
144            .and_then(|n| extract_declarator_name(n, source)),
145        _ => None,
146    }
147}
148
149#[cfg(all(test, feature = "lang-cpp"))]
150mod tests {
151    use super::*;
152    use tree_sitter::Parser;
153
154    fn parse_cpp(source: &str) -> tree_sitter::Tree {
155        let mut parser = Parser::new();
156        parser
157            .set_language(&tree_sitter_cpp::LANGUAGE.into())
158            .expect("failed to set C++ language");
159        parser.parse(source, None).expect("failed to parse source")
160    }
161
162    #[test]
163    fn test_free_function() {
164        // Arrange: free function definition
165        let source = "int add(int a, int b) { return a + b; }";
166        let tree = parse_cpp(source);
167        let root = tree.root_node();
168        let func_node = root.named_child(0).expect("expected function_definition");
169        // Act
170        let result = find_method_for_receiver(&func_node, source, None);
171        // Assert: free function should not be a method
172        assert_eq!(result, None);
173    }
174
175    #[test]
176    fn test_class_with_method() {
177        // Arrange: class with method
178        let source = "class Foo { public: int getValue() { return 42; } };";
179        let tree = parse_cpp(source);
180        let root = tree.root_node();
181        // Find the function_definition inside the class
182        let func_node = find_node_by_kind(root, "function_definition").expect("expected function");
183        // Act
184        let result = find_method_for_receiver(&func_node, source, None);
185        // Assert: method inside class should be recognized
186        assert_eq!(result, Some("getValue".to_string()));
187    }
188
189    #[test]
190    fn test_struct() {
191        // Arrange: struct with no base class
192        let source = "struct Point { int x; int y; };";
193        let tree = parse_cpp(source);
194        let root = tree.root_node();
195        let struct_node =
196            find_node_by_kind(root, "struct_specifier").expect("expected struct_specifier");
197        // Assert: node kind is correct
198        assert_eq!(struct_node.kind(), "struct_specifier");
199        // Act + Assert: struct with no inheritance returns empty
200        let result = extract_inheritance(&struct_node, source);
201        assert!(
202            result.is_empty(),
203            "expected no inheritance, got: {result:?}"
204        );
205    }
206
207    #[test]
208    fn test_include_directive() {
209        use tree_sitter::StreamingIterator;
210        // Arrange
211        let source = "#include <stdio.h>\n#include \"myfile.h\"\n";
212        let tree = parse_cpp(source);
213        // Act: run IMPORT_QUERY
214        let lang: tree_sitter::Language = tree_sitter_cpp::LANGUAGE.into();
215        let query = tree_sitter::Query::new(&lang, super::IMPORT_QUERY)
216            .expect("IMPORT_QUERY must be valid");
217        let mut cursor = tree_sitter::QueryCursor::new();
218        let mut iter = cursor.captures(&query, tree.root_node(), source.as_bytes());
219        let mut captures: Vec<String> = Vec::new();
220        while let Some((m, _)) = iter.next() {
221            for c in m.captures {
222                let text = c
223                    .node
224                    .utf8_text(source.as_bytes())
225                    .unwrap_or("")
226                    .to_string();
227                captures.push(text);
228            }
229        }
230        // Assert: both includes captured
231        assert!(
232            captures.iter().any(|s| s.contains("stdio.h")),
233            "expected stdio.h in captures: {captures:?}"
234        );
235        assert!(
236            captures.iter().any(|s| s.contains("myfile.h")),
237            "expected myfile.h in captures: {captures:?}"
238        );
239    }
240
241    #[test]
242    fn test_template_function() {
243        use tree_sitter::StreamingIterator;
244        // Arrange: template function definition
245        let source = "template<typename T> T max(T a, T b) { return a > b ? a : b; }";
246        let tree = parse_cpp(source);
247        // Act: run ELEMENT_QUERY
248        let lang: tree_sitter::Language = tree_sitter_cpp::LANGUAGE.into();
249        let query = tree_sitter::Query::new(&lang, super::ELEMENT_QUERY)
250            .expect("ELEMENT_QUERY must be valid");
251        let mut cursor = tree_sitter::QueryCursor::new();
252        let mut iter = cursor.captures(&query, tree.root_node(), source.as_bytes());
253        let mut func_names: Vec<String> = Vec::new();
254        while let Some((m, _)) = iter.next() {
255            for c in m.captures {
256                let name = query.capture_names()[c.index as usize];
257                if name == "func_name" {
258                    if let Ok(text) = c.node.utf8_text(source.as_bytes()) {
259                        func_names.push(text.to_string());
260                    }
261                }
262            }
263        }
264        // Assert: "max" captured as func_name
265        assert!(
266            func_names.iter().any(|s| s == "max"),
267            "expected 'max' in func_names: {func_names:?}"
268        );
269    }
270
271    #[test]
272    fn test_class_with_inheritance() {
273        // Arrange: class with base class
274        let source = "class Derived : public Base { };";
275        let tree = parse_cpp(source);
276        let root = tree.root_node();
277        let class_node = find_node_by_kind(root, "class_specifier").expect("expected class");
278        // Act
279        let result = extract_inheritance(&class_node, source);
280        // Assert: should have "Base" as inheritance
281        assert!(!result.is_empty(), "expected inheritance information");
282        assert!(
283            result.iter().any(|s| s.contains("Base")),
284            "expected 'Base' in inheritance: {:?}",
285            result
286        );
287    }
288
289    /// Helper to find the first node of a given kind
290    fn find_node_by_kind<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
291        if node.kind() == kind {
292            return Some(node);
293        }
294        for i in 0..node.child_count() {
295            if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
296                if let Some(found) = find_node_by_kind(child, kind) {
297                    return Some(found);
298                }
299            }
300        }
301        None
302    }
303}