//! C++ parser using tree-sitter
//!
//! Extracts classes, structs, functions, methods, imports, and call relationships from C++ source code.
use crate::models::{Class, Function};
use crate::parsers::ParseResult;
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::path::Path;
use streaming_iterator::StreamingIterator;
use tree_sitter::{Node, Parser, Query, QueryCursor};
/// Parse a C++ file and extract all code entities
pub fn parse(path: &Path) -> Result<ParseResult> {
let source = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read file: {}", path.display()))?;
parse_source(&source, path)
}
/// Parse C++ source code directly (useful for testing)
pub fn parse_source(source: &str, path: &Path) -> Result<ParseResult> {
let mut parser = Parser::new();
let language = tree_sitter_cpp::LANGUAGE;
parser
.set_language(&language.into())
.context("Failed to set C++ language")?;
let tree = parser
.parse(source, None)
.context("Failed to parse C++ source")?;
let root = tree.root_node();
let source_bytes = source.as_bytes();
let mut result = ParseResult::default();
extract_functions(&root, source_bytes, path, &mut result)?;
extract_classes(&root, source_bytes, path, &mut result)?;
extract_includes(&root, source_bytes, &mut result)?;
extract_calls(&root, source_bytes, path, &mut result)?;
Ok(result)
}
/// Extract function definitions from the AST
fn extract_functions(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
extract_functions_recursive(root, source, path, result, None);
Ok(())
}
/// Recursively extract functions (handles namespace scope)
fn extract_functions_recursive(
node: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
namespace: Option<&str>,
) {
for child in node.children(&mut node.walk()) {
match child.kind() {
"function_definition" => {
if let Some(func) = parse_function_node(&child, source, path, namespace) {
result.functions.push(func);
}
}
"namespace_definition" => {
let ns_name = child
.child_by_field_name("name")
.and_then(|n| n.utf8_text(source).ok())
.unwrap_or("");
let full_ns = if let Some(parent_ns) = namespace {
format!("{}::{}", parent_ns, ns_name)
} else {
ns_name.to_string()
};
if let Some(body) = child.child_by_field_name("body") {
extract_functions_recursive(&body, source, path, result, Some(&full_ns));
}
}
"template_declaration" => {
// Handle templated functions
for grandchild in child.children(&mut child.walk()) {
if grandchild.kind() == "function_definition" {
if let Some(func) = parse_function_node(&grandchild, source, path, namespace) {
result.functions.push(func);
}
}
}
}
_ => {
extract_functions_recursive(&child, source, path, result, namespace);
}
}
}
}
/// Parse a function definition node
fn parse_function_node(node: &Node, source: &[u8], path: &Path, namespace: Option<&str>) -> Option<Function> {
let declarator = node.child_by_field_name("declarator")?;
let (name, class_scope) = extract_function_name(&declarator, source)?;
// Skip if this is a method definition (handled in class extraction)
if class_scope.is_some() {
return parse_method_definition(node, source, path, &name, class_scope.as_deref());
}
let params_node = find_parameters(&declarator);
let parameters = extract_parameters(params_node, source);
let return_type = node
.child_by_field_name("type")
.and_then(|n| n.utf8_text(source).ok())
.map(|s| s.to_string());
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let full_name = if let Some(ns) = namespace {
format!("{}::{}", ns, name)
} else {
name.clone()
};
let qualified_name = format!("{}::{}:{}", path.display(), full_name, line_start);
Some(Function {
name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async: false,
complexity: Some(calculate_complexity(node, source)),
})
}
/// Parse a method definition outside class body
fn parse_method_definition(
node: &Node,
source: &[u8],
path: &Path,
name: &str,
class_scope: Option<&str>,
) -> Option<Function> {
let declarator = node.child_by_field_name("declarator")?;
let params_node = find_parameters(&declarator);
let parameters = extract_parameters(params_node, source);
let return_type = node
.child_by_field_name("type")
.and_then(|n| n.utf8_text(source).ok())
.map(|s| s.to_string());
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let full_name = if let Some(class) = class_scope {
format!("{}::{}", class, name)
} else {
name.to_string()
};
let qualified_name = format!("{}::{}:{}", path.display(), full_name, line_start);
Some(Function {
name: name.to_string(),
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async: false,
complexity: Some(calculate_complexity(node, source)),
})
}
/// Extract function name and optional class scope from declarator
fn extract_function_name(declarator: &Node, source: &[u8]) -> Option<(String, Option<String>)> {
match declarator.kind() {
"function_declarator" => {
let inner = declarator.child_by_field_name("declarator")?;
extract_function_name(&inner, source)
}
"qualified_identifier" => {
// Class::method
let scope_text = declarator.utf8_text(source).ok()?;
let parts: Vec<&str> = scope_text.rsplitn(2, "::").collect();
if parts.len() == 2 {
Some((parts[0].to_string(), Some(parts[1].to_string())))
} else {
Some((scope_text.to_string(), None))
}
}
"identifier" => {
let name = declarator.utf8_text(source).ok()?.to_string();
Some((name, None))
}
"pointer_declarator" | "reference_declarator" => {
for child in declarator.children(&mut declarator.walk()) {
if let Some(result) = extract_function_name(&child, source) {
return Some(result);
}
}
None
}
"destructor_name" => {
let name = declarator.utf8_text(source).ok()?.to_string();
Some((name, None))
}
"template_function" => {
if let Some(name_node) = declarator.child_by_field_name("name") {
let name = name_node.utf8_text(source).ok()?.to_string();
Some((name, None))
} else {
None
}
}
"operator_name" => {
let name = declarator.utf8_text(source).ok()?.to_string();
Some((name, None))
}
_ => None,
}
}
/// Find parameters node in a declarator
fn find_parameters(declarator: &Node) -> Option<Node> {
if declarator.kind() == "function_declarator" {
return declarator.child_by_field_name("parameters");
}
for child in declarator.children(&mut declarator.walk()) {
if let Some(params) = find_parameters(&child) {
return Some(params);
}
}
None
}
/// Extract parameter names from a parameter list
fn extract_parameters(params_node: Option<Node>, source: &[u8]) -> Vec<String> {
let Some(node) = params_node else {
return vec![];
};
let mut params = Vec::new();
for child in node.children(&mut node.walk()) {
match child.kind() {
"parameter_declaration" => {
if let Some(name) = find_parameter_name(&child, source) {
params.push(name);
}
}
"optional_parameter_declaration" => {
if let Some(name) = find_parameter_name(&child, source) {
params.push(name);
}
}
"variadic_parameter_declaration" => {
params.push("...".to_string());
}
_ => {}
}
}
params
}
/// Find the parameter name from a parameter declaration
fn find_parameter_name(param_node: &Node, source: &[u8]) -> Option<String> {
if let Some(decl) = param_node.child_by_field_name("declarator") {
return extract_declarator_name(&decl, source);
}
// Fallback: look for identifier
for child in param_node.children(&mut param_node.walk()) {
if child.kind() == "identifier" {
return child.utf8_text(source).ok().map(|s| s.to_string());
}
}
None
}
/// Extract name from a declarator node
fn extract_declarator_name(node: &Node, source: &[u8]) -> Option<String> {
match node.kind() {
"identifier" => node.utf8_text(source).ok().map(|s| s.to_string()),
"pointer_declarator" | "reference_declarator" | "array_declarator" => {
for child in node.children(&mut node.walk()) {
if let Some(name) = extract_declarator_name(&child, source) {
return Some(name);
}
}
None
}
_ => None,
}
}
/// Extract class and struct definitions from the AST
fn extract_classes(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
extract_classes_recursive(root, source, path, result, None);
Ok(())
}
/// Recursively extract classes
fn extract_classes_recursive(
node: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
namespace: Option<&str>,
) {
for child in node.children(&mut node.walk()) {
match child.kind() {
"class_specifier" | "struct_specifier" => {
if let Some(class) = parse_class_node(&child, source, path, namespace, child.kind() == "struct_specifier") {
let class_name = class.name.clone();
extract_class_members(&child, source, path, result, &class_name);
result.classes.push(class);
}
}
"namespace_definition" => {
let ns_name = child
.child_by_field_name("name")
.and_then(|n| n.utf8_text(source).ok())
.unwrap_or("");
let full_ns = if let Some(parent_ns) = namespace {
format!("{}::{}", parent_ns, ns_name)
} else {
ns_name.to_string()
};
if let Some(body) = child.child_by_field_name("body") {
extract_classes_recursive(&body, source, path, result, Some(&full_ns));
}
}
"template_declaration" => {
// Handle templated classes
for grandchild in child.children(&mut child.walk()) {
if grandchild.kind() == "class_specifier" || grandchild.kind() == "struct_specifier" {
if let Some(class) = parse_class_node(&grandchild, source, path, namespace, grandchild.kind() == "struct_specifier") {
let class_name = class.name.clone();
extract_class_members(&grandchild, source, path, result, &class_name);
result.classes.push(class);
}
}
}
}
_ => {
extract_classes_recursive(&child, source, path, result, namespace);
}
}
}
}
/// Parse a class/struct specifier node
fn parse_class_node(
node: &Node,
source: &[u8],
path: &Path,
namespace: Option<&str>,
is_struct: bool,
) -> Option<Class> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source).ok()?.to_string();
let full_name = if let Some(ns) = namespace {
format!("{}::{}", ns, name)
} else {
name.clone()
};
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let kind = if is_struct { "struct" } else { "class" };
let qualified_name = format!("{}::{}::{}:{}", path.display(), kind, full_name, line_start);
let bases = extract_base_classes(node, source);
let methods = extract_method_names(node, source);
Some(Class {
name: full_name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
methods,
bases,
})
}
/// Extract base class names
fn extract_base_classes(class_node: &Node, source: &[u8]) -> Vec<String> {
let mut bases = Vec::new();
for child in class_node.children(&mut class_node.walk()) {
if child.kind() == "base_class_clause" {
for specifier in child.children(&mut child.walk()) {
if specifier.kind() == "base_specifier" {
if let Some(type_node) = specifier.child_by_field_name("type") {
if let Ok(text) = type_node.utf8_text(source) {
bases.push(text.to_string());
}
}
}
}
}
}
bases
}
/// Extract method names from class body
fn extract_method_names(class_node: &Node, source: &[u8]) -> Vec<String> {
let mut methods = Vec::new();
if let Some(body) = class_node.child_by_field_name("body") {
for child in body.children(&mut body.walk()) {
if child.kind() == "function_definition" || child.kind() == "declaration" {
if let Some(declarator) = child.child_by_field_name("declarator") {
if let Some((name, _)) = extract_function_name(&declarator, source) {
methods.push(name);
}
}
} else if child.kind() == "access_specifier" {
// Skip
} else if child.kind() == "field_declaration" {
// Check if it's a method declaration
for grandchild in child.children(&mut child.walk()) {
if grandchild.kind() == "function_declarator" {
if let Some((name, _)) = extract_function_name(&grandchild, source) {
methods.push(name);
}
}
}
}
}
}
methods
}
/// Extract method definitions from class body
fn extract_class_members(
class_node: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
class_name: &str,
) {
if let Some(body) = class_node.child_by_field_name("body") {
for child in body.children(&mut body.walk()) {
if child.kind() == "function_definition" {
if let Some(func) = parse_class_method(&child, source, path, class_name) {
result.functions.push(func);
}
}
}
}
}
/// Parse a method defined inside class body
fn parse_class_method(node: &Node, source: &[u8], path: &Path, class_name: &str) -> Option<Function> {
let declarator = node.child_by_field_name("declarator")?;
let (name, _) = extract_function_name(&declarator, source)?;
let params_node = find_parameters(&declarator);
let parameters = extract_parameters(params_node, source);
let return_type = node
.child_by_field_name("type")
.and_then(|n| n.utf8_text(source).ok())
.map(|s| s.to_string());
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}::{}:{}", path.display(), class_name, name, line_start);
Some(Function {
name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async: false,
complexity: Some(calculate_complexity(node, source)),
})
}
/// Extract #include statements from the AST
fn extract_includes(root: &Node, source: &[u8], result: &mut ParseResult) -> Result<()> {
let query_str = r#"
(preproc_include
path: (_) @include_path
)
"#;
let language = tree_sitter_cpp::LANGUAGE;
let query = Query::new(&language.into(), query_str).context("Failed to create include query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source);
while let Some(m) = matches.next() {
for capture in m.captures.iter() {
if let Ok(text) = capture.node.utf8_text(source) {
let import = text
.trim_start_matches(|c| c == '"' || c == '<')
.trim_end_matches(|c| c == '"' || c == '>')
.to_string();
if !import.is_empty() {
result.imports.push(import);
}
}
}
}
Ok(())
}
/// Extract function calls from the AST
fn extract_calls(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
let mut scope_map: HashMap<(u32, u32), String> = HashMap::new();
for func in &result.functions {
scope_map.insert(
(func.line_start, func.line_end),
func.qualified_name.clone(),
);
}
extract_calls_recursive(root, source, path, &scope_map, result);
Ok(())
}
/// Recursively extract calls from the AST
fn extract_calls_recursive(
node: &Node,
source: &[u8],
path: &Path,
scope_map: &HashMap<(u32, u32), String>,
result: &mut ParseResult,
) {
if node.kind() == "call_expression" {
let call_line = node.start_position().row as u32 + 1;
let caller = find_containing_scope(call_line, scope_map);
if let Some(func_node) = node.child_by_field_name("function") {
let callee = extract_call_target(&func_node, source);
if let (Some(caller), Some(callee)) = (caller, callee) {
result.calls.push((caller, callee));
}
}
}
// Handle new expressions
if node.kind() == "new_expression" {
let call_line = node.start_position().row as u32 + 1;
let caller = find_containing_scope(call_line, scope_map);
if let Some(type_node) = node.child_by_field_name("type") {
if let Ok(callee) = type_node.utf8_text(source) {
if let Some(caller) = caller {
result.calls.push((caller, format!("new {}", callee)));
}
}
}
}
for child in node.children(&mut node.walk()) {
extract_calls_recursive(&child, source, path, scope_map, result);
}
}
/// Find which scope contains a given line
fn find_containing_scope(line: u32, scope_map: &HashMap<(u32, u32), String>) -> Option<String> {
let mut best_match: Option<(&(u32, u32), &String)> = None;
for (range, name) in scope_map {
if line >= range.0 && line <= range.1 {
match best_match {
None => best_match = Some((range, name)),
Some((best_range, _)) => {
if (range.1 - range.0) < (best_range.1 - best_range.0) {
best_match = Some((range, name));
}
}
}
}
}
best_match.map(|(_, name)| name.clone())
}
/// Extract the target of a function call
fn extract_call_target(node: &Node, source: &[u8]) -> Option<String> {
match node.kind() {
"identifier" => node.utf8_text(source).ok().map(|s| s.to_string()),
"qualified_identifier" | "field_expression" => {
node.utf8_text(source).ok().map(|s| s.to_string())
}
"template_function" => {
node.child_by_field_name("name")
.and_then(|n| n.utf8_text(source).ok())
.map(|s| s.to_string())
}
_ => node.utf8_text(source).ok().map(|s| s.to_string()),
}
}
/// Calculate cyclomatic complexity of a function
fn calculate_complexity(node: &Node, _source: &[u8]) -> u32 {
let mut complexity = 1;
fn count_branches(node: &Node, complexity: &mut u32) {
match node.kind() {
"if_statement" | "while_statement" | "for_statement" | "do_statement" | "for_range_loop" => {
*complexity += 1;
}
"case_statement" | "default_statement" => {
*complexity += 1;
}
"catch_clause" => {
*complexity += 1;
}
"conditional_expression" => {
*complexity += 1;
}
"binary_expression" => {
for child in node.children(&mut node.walk()) {
if child.kind() == "&&" || child.kind() == "||" {
*complexity += 1;
}
}
}
"lambda_expression" => {
*complexity += 1;
}
_ => {}
}
for child in node.children(&mut node.walk()) {
count_branches(&child, complexity);
}
}
count_branches(node, &mut complexity);
complexity
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_parse_simple_function() {
let source = r#"
int add(int a, int b) {
return a + b;
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.functions.len(), 1);
let func = &result.functions[0];
assert_eq!(func.name, "add");
}
#[test]
fn test_parse_class() {
let source = r#"
class MyClass {
public:
void doSomething() {
// implementation
}
};
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.classes.len(), 1);
let class = &result.classes[0];
assert_eq!(class.name, "MyClass");
}
#[test]
fn test_parse_class_with_inheritance() {
let source = r#"
class Derived : public Base, public Interface {
public:
void method() {}
};
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.classes.len(), 1);
let class = &result.classes[0];
assert!(class.bases.contains(&"Base".to_string()));
}
#[test]
fn test_parse_namespace() {
let source = r#"
namespace myns {
void helper() {}
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.functions.len(), 1);
let func = &result.functions[0];
assert!(func.qualified_name.contains("myns"));
}
#[test]
fn test_parse_includes() {
let source = r#"
#include <iostream>
#include <vector>
#include "myheader.h"
int main() {
return 0;
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).unwrap();
assert!(result.imports.contains(&"iostream".to_string()));
assert!(result.imports.contains(&"vector".to_string()));
assert!(result.imports.contains(&"myheader.h".to_string()));
}
#[test]
fn test_parse_method_definition() {
let source = r#"
class MyClass {
public:
void method();
};
void MyClass::method() {
// implementation
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).unwrap();
// Should have the out-of-class method definition
assert!(result.functions.iter().any(|f| f.name == "method"));
}
#[test]
fn test_parse_template_function() {
let source = r#"
template<typename T>
T max(T a, T b) {
return a > b ? a : b;
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.functions.len(), 1);
let func = &result.functions[0];
assert_eq!(func.name, "max");
}
}