use crate::models::{Class, Function};
use crate::parsers::{ImportInfo, ParseResult};
use anyhow::{Context, Result};
use std::cell::RefCell;
use std::path::Path;
use std::sync::OnceLock;
use tree_sitter::{Node, Parser, Query, QueryCursor, StreamingIterator};
thread_local! {
static CPP_PARSER: RefCell<Parser> = RefCell::new({
let mut p = Parser::new();
p.set_language(&tree_sitter_cpp::LANGUAGE.into()).expect("C++ language");
p
});
}
const CPP_FUNC_QUERY_STR: &str = r#"
(function_definition
type: (_) @return_type
declarator: (function_declarator
declarator: (_) @func_name
parameters: (parameter_list) @params
)
) @func
"#;
const CPP_CLASS_QUERY_STR: &str = r#"
(class_specifier
name: (type_identifier) @class_name
body: (field_declaration_list) @body
) @class
"#;
const CPP_METHOD_QUERY_STR: &str = r#"
(function_definition
type: (_) @return_type
declarator: (function_declarator
declarator: (_) @method_name
parameters: (parameter_list) @params
)
) @method
"#;
const CPP_STRUCT_QUERY_STR: &str = r#"
(struct_specifier
name: (type_identifier) @struct_name
body: (field_declaration_list)? @body
) @struct
"#;
const CPP_INCLUDE_QUERY_STR: &str = r#"
(preproc_include
path: [
(string_literal) @path
(system_lib_string) @system_path
]
)
"#;
const CPP_CALL_QUERY_STR: &str = r#"
(call_expression
function: [
(identifier) @func_name
(field_expression
field: (field_identifier) @method_name
)
(qualified_identifier) @qualified_name
]
) @call
"#;
static CPP_FUNC_QUERY: OnceLock<Query> = OnceLock::new();
static CPP_CLASS_QUERY: OnceLock<Query> = OnceLock::new();
static CPP_METHOD_QUERY: OnceLock<Query> = OnceLock::new();
static CPP_STRUCT_QUERY: OnceLock<Query> = OnceLock::new();
static CPP_INCLUDE_QUERY: OnceLock<Query> = OnceLock::new();
static CPP_CALL_QUERY: OnceLock<Query> = OnceLock::new();
#[allow(dead_code)]
pub fn parse(path: &Path) -> Result<ParseResult> {
let source = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read file: {}", path.display()))?;
parse_source(&source, path)
}
pub fn parse_source(source: &str, path: &Path) -> Result<ParseResult> {
parse_source_with_tree(source, path).map(|(r, _)| r)
}
pub fn parse_source_with_tree(
source: &str,
path: &Path,
) -> Result<(ParseResult, tree_sitter::Tree)> {
let tree = CPP_PARSER
.with(|cell| cell.borrow_mut().parse(source, None))
.context("Failed to parse C++ source")?;
let root = tree.root_node();
let extractor = CppExtractor::new(source.as_bytes(), path);
let result = extractor.run(root)?;
Ok((result, tree))
}
struct CppExtractor<'a> {
source: &'a [u8],
path: &'a Path,
result: ParseResult,
}
impl<'a> CppExtractor<'a> {
fn new(source: &'a [u8], path: &'a Path) -> Self {
Self {
source,
path,
result: ParseResult::default(),
}
}
fn run(mut self, root: Node<'_>) -> Result<ParseResult> {
self.extract_functions(&root)?;
self.extract_classes(&root)?;
self.extract_structs(&root)?;
self.extract_includes(&root)?;
self.extract_calls(&root)?;
Ok(self.result)
}
}
fn has_storage_class(func_node: &Node, source: &[u8], specifier: &str) -> bool {
for child in func_node.children(&mut func_node.walk()) {
if child.kind() == "storage_class_specifier" {
if let Ok(text) = child.utf8_text(source) {
if text == specifier {
return true;
}
}
}
}
false
}
fn is_inside_class_body(node: &Node) -> bool {
let mut current = node.parent();
while let Some(parent) = current {
if parent.kind() == "field_declaration_list" {
return true;
}
current = parent.parent();
}
false
}
impl<'a> CppExtractor<'a> {
fn extract_functions(&mut self, root: &Node) -> Result<()> {
let query = CPP_FUNC_QUERY.get_or_init(|| {
Query::new(&tree_sitter_cpp::LANGUAGE.into(), CPP_FUNC_QUERY_STR)
.expect("valid C++ function query")
});
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, *root, self.source);
while let Some(m) = matches.next() {
let mut func_node = None;
let mut name = String::new();
let mut params_node = None;
let mut return_type_node = None;
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"func" => func_node = Some(capture.node),
"func_name" => {
let name_text = capture.node.utf8_text(self.source).unwrap_or("");
name = name_text.trim_start_matches('*').to_string();
if name.contains("::") {
continue;
}
}
"params" => params_node = Some(capture.node),
"return_type" => return_type_node = Some(capture.node),
_ => {}
}
}
if name.is_empty() {
continue;
}
if let Some(node) = func_node {
if is_inside_class_body(&node) {
continue;
}
let parameters = extract_parameters(params_node, self.source);
let return_type =
return_type_node.map(|n| n.utf8_text(self.source).unwrap_or("").to_string());
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}:{}", self.path.display(), name, line_start);
let annotations = if has_storage_class(&node, self.source, "extern") {
vec!["exported".to_string()]
} else {
vec![]
};
self.result.functions.push(Function {
name: name.clone(),
qualified_name,
file_path: self.path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async: false,
complexity: Some(calculate_complexity(&node, self.source)),
max_nesting: None,
doc_comment: None,
annotations,
});
}
}
Ok(())
}
}
impl<'a> CppExtractor<'a> {
fn extract_classes(&mut self, root: &Node) -> Result<()> {
let query = CPP_CLASS_QUERY.get_or_init(|| {
Query::new(&tree_sitter_cpp::LANGUAGE.into(), CPP_CLASS_QUERY_STR)
.expect("valid C++ class query")
});
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, *root, self.source);
while let Some(m) = matches.next() {
let mut class_node = None;
let mut name = String::new();
let mut body_node = None;
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"class" => class_node = Some(capture.node),
"class_name" => {
name = capture
.node
.utf8_text(self.source)
.unwrap_or("")
.to_string();
}
"body" => body_node = Some(capture.node),
_ => {}
}
}
if let Some(node) = class_node {
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}", self.path.display(), name);
let bases = extract_base_classes(&node, self.source);
let methods = if let Some(body) = body_node {
self.extract_class_methods(&body, &name, "private")?
} else {
vec![]
};
for method in &methods {
self.result.functions.push(method.clone());
}
self.result.classes.push(Class {
name: name.clone(),
qualified_name,
file_path: self.path.to_path_buf(),
line_start,
line_end,
bases,
methods: methods.iter().map(|m| m.name.clone()).collect(),
field_count: 0,
doc_comment: None,
annotations: vec![],
});
}
}
Ok(())
}
}
fn extract_base_classes(class_node: &Node, source: &[u8]) -> Vec<String> {
let mut bases = vec![];
for child in class_node.children(&mut class_node.walk()) {
if child.kind() == "base_class_clause" {
for base_child in child.children(&mut child.walk()) {
if base_child.kind() == "type_identifier"
|| base_child.kind() == "qualified_identifier"
{
if let Ok(text) = base_child.utf8_text(source) {
bases.push(text.to_string());
}
}
}
}
}
bases
}
impl<'a> CppExtractor<'a> {
fn build_access_map(
&self,
body: &Node,
default_access: &str,
) -> std::collections::HashMap<usize, String> {
let mut access_map = std::collections::HashMap::new();
let mut current_access = default_access.to_string();
for child in body.children(&mut body.walk()) {
if child.kind() == "access_specifier" {
if let Ok(text) = child.utf8_text(self.source) {
current_access = text.trim_end_matches(':').trim().to_string();
}
} else if child.kind() == "function_definition" || child.kind() == "declaration" {
access_map.insert(child.start_byte(), current_access.clone());
}
}
access_map
}
}
impl<'a> CppExtractor<'a> {
fn extract_class_methods(
&self,
body: &Node,
class_name: &str,
default_access: &str,
) -> Result<Vec<Function>> {
let mut methods = vec![];
let access_map = self.build_access_map(body, default_access);
let query = CPP_METHOD_QUERY.get_or_init(|| {
Query::new(&tree_sitter_cpp::LANGUAGE.into(), CPP_METHOD_QUERY_STR)
.expect("valid C++ method query")
});
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, *body, self.source);
while let Some(m) = matches.next() {
let mut method_node = None;
let mut name = String::new();
let mut params_node = None;
let mut return_type_node = None;
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"method" => method_node = Some(capture.node),
"method_name" => {
name = capture
.node
.utf8_text(self.source)
.unwrap_or("")
.to_string();
}
"params" => params_node = Some(capture.node),
"return_type" => return_type_node = Some(capture.node),
_ => {}
}
}
if let Some(node) = method_node {
let parameters = extract_parameters(params_node, self.source);
let return_type =
return_type_node.map(|n| n.utf8_text(self.source).unwrap_or("").to_string());
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!(
"{}::{}::{}:{}",
self.path.display(),
class_name,
name,
line_start
);
let access = access_map
.get(&node.start_byte())
.map(|s| s.as_str())
.unwrap_or(default_access);
let annotations = if access == "public" {
vec!["exported".to_string()]
} else {
vec![]
};
methods.push(Function {
name: name.clone(),
qualified_name,
file_path: self.path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async: false,
complexity: Some(calculate_complexity(&node, self.source)),
max_nesting: None,
doc_comment: None,
annotations,
});
}
}
Ok(methods)
}
}
impl<'a> CppExtractor<'a> {
fn extract_structs(&mut self, root: &Node) -> Result<()> {
let query = CPP_STRUCT_QUERY.get_or_init(|| {
Query::new(&tree_sitter_cpp::LANGUAGE.into(), CPP_STRUCT_QUERY_STR)
.expect("valid C++ struct query")
});
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, *root, self.source);
while let Some(m) = matches.next() {
let mut struct_node = None;
let mut name = String::new();
let mut body_node = None;
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"struct" => struct_node = Some(capture.node),
"struct_name" => {
name = capture
.node
.utf8_text(self.source)
.unwrap_or("")
.to_string();
}
"body" => body_node = Some(capture.node),
_ => {}
}
}
if let Some(node) = struct_node {
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}", self.path.display(), name);
let bases = extract_base_classes(&node, self.source);
let methods = if let Some(body) = body_node {
self.extract_class_methods(&body, &name, "public")?
} else {
vec![]
};
for method in &methods {
self.result.functions.push(method.clone());
}
self.result.classes.push(Class {
name: name.clone(),
qualified_name,
file_path: self.path.to_path_buf(),
line_start,
line_end,
bases,
methods: methods.iter().map(|m| m.name.clone()).collect(),
field_count: 0,
doc_comment: None,
annotations: vec![],
});
}
}
Ok(())
}
}
impl<'a> CppExtractor<'a> {
fn extract_includes(&mut self, root: &Node) -> Result<()> {
let query = CPP_INCLUDE_QUERY.get_or_init(|| {
Query::new(&tree_sitter_cpp::LANGUAGE.into(), CPP_INCLUDE_QUERY_STR)
.expect("valid C++ include query")
});
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, *root, self.source);
while let Some(m) = matches.next() {
for capture in m.captures.iter() {
let path_text = capture.node.utf8_text(self.source).unwrap_or("");
let import_path = path_text
.trim_matches('"')
.trim_matches('<')
.trim_matches('>')
.to_string();
self.result.imports.push(ImportInfo::runtime(import_path));
}
}
Ok(())
}
}
impl<'a> CppExtractor<'a> {
fn extract_calls(&mut self, root: &Node) -> Result<()> {
let query = CPP_CALL_QUERY.get_or_init(|| {
Query::new(&tree_sitter_cpp::LANGUAGE.into(), CPP_CALL_QUERY_STR)
.expect("valid C++ call query")
});
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, *root, self.source);
while let Some(m) = matches.next() {
let mut call_node = None;
let mut callee_name = String::new();
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"call" => call_node = Some(capture.node),
"func_name" | "method_name" | "qualified_name" => {
callee_name = capture
.node
.utf8_text(self.source)
.unwrap_or("")
.to_string();
}
_ => {}
}
}
if let Some(node) = call_node {
let caller = self.find_enclosing_function(&node);
let _callee_line = node.start_position().row as u32 + 1;
self.result.calls.push((
caller,
callee_name.clone(), ));
}
}
Ok(())
}
}
impl<'a> CppExtractor<'a> {
fn find_enclosing_function(&self, node: &Node) -> String {
let mut current = node.parent();
while let Some(parent) = current {
if parent.kind() == "function_definition" {
if let Some(declarator) = parent.child_by_field_name("declarator") {
if let Some(name_node) = declarator.child_by_field_name("declarator") {
let name = name_node.utf8_text(self.source).unwrap_or("unknown");
let line = parent.start_position().row as u32 + 1;
return format!("{}::{}:{}", self.path.display(), name, line);
}
}
}
current = parent.parent();
}
format!("{}::<global>", self.path.display())
}
}
fn extract_parameters(params_node: Option<Node>, source: &[u8]) -> Vec<String> {
let Some(params) = params_node else {
return vec![];
};
let mut parameters = vec![];
let mut cursor = params.walk();
for child in params.children(&mut cursor) {
match child.kind() {
"parameter_declaration" | "optional_parameter_declaration" => {
if let Some(declarator) = child.child_by_field_name("declarator") {
let name = declarator.utf8_text(source).unwrap_or("");
let name = name.trim_start_matches('*').trim_start_matches('&');
if !name.is_empty() {
parameters.push(name.to_string());
}
}
}
_ => {}
}
}
parameters
}
fn calculate_complexity(node: &Node, _source: &[u8]) -> u32 {
let mut complexity = 1u32;
fn count_branches(node: &Node, complexity: &mut u32) {
match node.kind() {
"if_statement" | "for_statement" | "while_statement" | "do_statement" => {
*complexity += 1;
}
"case_statement" | "default_statement" | "case_label" | "default_label" => {
*complexity += 1;
}
"conditional_expression" | "catch_clause" => {
*complexity += 1;
}
"binary_expression" => {
for child in node.children(&mut node.walk()) {
if child.kind() == "&&" || child.kind() == "||" {
*complexity += 1;
}
}
}
_ => {}
}
for child in node.children(&mut node.walk()) {
count_branches(&child, complexity);
}
}
count_branches(node, &mut complexity);
complexity
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_parse_simple_function() {
let source = r#"
int add(int a, int b) {
return a + b;
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).expect("should parse C++ source");
assert_eq!(result.functions.len(), 1);
assert_eq!(result.functions[0].name, "add");
assert_eq!(result.functions[0].parameters, vec!["a", "b"]);
}
#[test]
fn test_parse_class() {
let source = r#"
class Calculator {
public:
int add(int a, int b) {
return a + b;
}
int subtract(int a, int b) {
return a - b;
}
};
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).expect("should parse C++ source");
assert_eq!(result.classes.len(), 1);
assert_eq!(result.classes[0].name, "Calculator");
}
#[test]
fn test_parse_includes() {
let source = r#"
#include <iostream>
#include <vector>
#include "myheader.h"
int main() {
return 0;
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).expect("should parse C++ source");
assert!(result.imports.iter().any(|i| i.path == "iostream"));
assert!(result.imports.iter().any(|i| i.path == "vector"));
assert!(result.imports.iter().any(|i| i.path == "myheader.h"));
}
#[test]
fn test_complexity_switch_counts_cases_not_switch() {
let source = r#"
int classify(int x) {
switch (x) {
case 1: return 1;
case 2: return 2;
default: return 0;
}
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).expect("should parse C++ source");
let c = result.functions[0].complexity.unwrap_or(0);
assert!(
c >= 3,
"expected switch branches to increase complexity, got {c}"
);
}
#[test]
fn test_complexity() {
let source = r#"
int complex(int x) {
if (x > 0) {
for (int i = 0; i < x; i++) {
if (i % 2 == 0) {
x++;
}
}
} else if (x < 0) {
while (x < 0) {
x++;
}
}
return x;
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).expect("should parse C++ source");
assert_eq!(result.functions.len(), 1);
assert!(result.functions[0].complexity.unwrap_or(0) >= 5); }
#[test]
fn test_public_methods_exported() {
let source = r#"
class MyClass {
public:
int public_method(int x) {
return x;
}
private:
int private_method(int x) {
return x;
}
protected:
int protected_method(int x) {
return x;
}
};
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).expect("should parse C++ source");
let public_fn = result
.functions
.iter()
.find(|f| f.name == "public_method")
.expect("should find public_method");
assert!(
public_fn.annotations.contains(&"exported".to_string()),
"public method should be exported"
);
let private_fn = result
.functions
.iter()
.find(|f| f.name == "private_method")
.expect("should find private_method");
assert!(
private_fn.annotations.is_empty(),
"private method should not be exported"
);
let protected_fn = result
.functions
.iter()
.find(|f| f.name == "protected_method")
.expect("should find protected_method");
assert!(
protected_fn.annotations.is_empty(),
"protected method should not be exported"
);
}
#[test]
fn test_class_default_private() {
let source = r#"
class Foo {
int implicit_private(int x) {
return x;
}
};
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).expect("should parse C++ source");
let func = result
.functions
.iter()
.find(|f| f.name == "implicit_private")
.expect("should find implicit_private");
assert!(
func.annotations.is_empty(),
"class methods without access specifier should default to private (not exported)"
);
}
#[test]
fn test_struct_methods_default_public() {
let source = r#"
struct Bar {
int implicit_public(int x) {
return x;
}
private:
int explicit_private(int x) {
return x;
}
};
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).expect("should parse C++ source");
let public_fn = result
.functions
.iter()
.find(|f| f.name == "implicit_public")
.expect("should find implicit_public");
assert!(
public_fn.annotations.contains(&"exported".to_string()),
"struct methods without access specifier should default to public (exported)"
);
let private_fn = result
.functions
.iter()
.find(|f| f.name == "explicit_private")
.expect("should find explicit_private");
assert!(
private_fn.annotations.is_empty(),
"struct method after private: should not be exported"
);
}
#[test]
fn test_base_class_extraction() {
let source = r#"
class Base {};
class Derived : public Base {
public:
int method() {
return 0;
}
};
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).expect("should parse C++ source");
let derived = result
.classes
.iter()
.find(|c| c.name == "Derived")
.expect("should find Derived class");
assert_eq!(derived.bases, vec!["Base"]);
}
#[test]
fn test_extern_free_function_exported() {
let source = r#"
extern int api_func(int x) {
return x;
}
int internal_func(int x) {
return x;
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).expect("should parse C++ source");
let api = result
.functions
.iter()
.find(|f| f.name == "api_func")
.expect("should find api_func");
assert!(
api.annotations.contains(&"exported".to_string()),
"extern free function should be exported"
);
let internal = result
.functions
.iter()
.find(|f| f.name == "internal_func")
.expect("should find internal_func");
assert!(
internal.annotations.is_empty(),
"plain free function should not be exported"
);
}
}