use crate::models::{Class, Function};
use crate::parsers::{ImportInfo, ParseResult};
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::path::Path;
use tree_sitter::{Node, Parser, Query, QueryCursor};
pub fn parse(path: &Path) -> Result<ParseResult> {
let source = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read file: {}", path.display()))?;
parse_source(&source, path)
}
pub fn parse_source(source: &str, path: &Path) -> Result<ParseResult> {
let mut parser = Parser::new();
let language = tree_sitter_cpp::LANGUAGE;
parser
.set_language(&language.into())
.context("Failed to set C++ language")?;
let tree = parser
.parse(source, None)
.context("Failed to parse C++ source")?;
let root = tree.root_node();
let source_bytes = source.as_bytes();
let mut result = ParseResult::default();
extract_functions(&root, source_bytes, path, &mut result)?;
extract_classes(&root, source_bytes, path, &mut result)?;
extract_structs(&root, source_bytes, path, &mut result)?;
extract_includes(&root, source_bytes, &mut result)?;
extract_calls(&root, source_bytes, path, &mut result)?;
Ok(result)
}
fn extract_functions(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
let query_str = r#"
(function_definition
type: (_) @return_type
declarator: (function_declarator
declarator: (_) @func_name
parameters: (parameter_list) @params
)
) @func
"#;
let language = tree_sitter_cpp::LANGUAGE;
let query = Query::new(&language.into(), query_str).context("Failed to create function query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source);
while let Some(m) = matches.next() {
let mut func_node = None;
let mut name = String::new();
let mut params_node = None;
let mut return_type_node = None;
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"func" => func_node = Some(capture.node),
"func_name" => {
let name_text = capture.node.utf8_text(source).unwrap_or("");
name = name_text.trim_start_matches('*').to_string();
if name.contains("::") {
continue;
}
}
"params" => params_node = Some(capture.node),
"return_type" => return_type_node = Some(capture.node),
_ => {}
}
}
if name.is_empty() {
continue;
}
if let Some(node) = func_node {
let parameters = extract_parameters(params_node, source);
let return_type = return_type_node
.map(|n| n.utf8_text(source).unwrap_or("").to_string());
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}:{}", path.display(), name, line_start);
result.functions.push(Function {
name: name.clone(),
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async: false,
complexity: Some(calculate_complexity(&node, source)),
});
}
}
Ok(())
}
fn extract_classes(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
let query_str = r#"
(class_specifier
name: (type_identifier) @class_name
body: (field_declaration_list) @body
) @class
"#;
let language = tree_sitter_cpp::LANGUAGE;
let query = Query::new(&language.into(), query_str).context("Failed to create class query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source);
while let Some(m) = matches.next() {
let mut class_node = None;
let mut name = String::new();
let mut body_node = None;
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"class" => class_node = Some(capture.node),
"class_name" => {
name = capture.node.utf8_text(source).unwrap_or("").to_string();
}
"body" => body_node = Some(capture.node),
_ => {}
}
}
if let Some(node) = class_node {
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}", path.display(), name);
let methods = if let Some(body) = body_node {
extract_class_methods(&body, source, path, &name)?
} else {
vec![]
};
for method in &methods {
result.functions.push(method.clone());
}
result.classes.push(Class {
name: name.clone(),
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
bases: vec![], methods: methods.iter().map(|m| m.name.clone()).collect(),
});
}
}
Ok(())
}
fn extract_class_methods(
body: &Node,
source: &[u8],
path: &Path,
class_name: &str,
) -> Result<Vec<Function>> {
let mut methods = vec![];
let query_str = r#"
(function_definition
type: (_) @return_type
declarator: (function_declarator
declarator: (_) @method_name
parameters: (parameter_list) @params
)
) @method
"#;
let language = tree_sitter_cpp::LANGUAGE;
let query = Query::new(&language.into(), query_str).context("Failed to create method query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *body, source);
while let Some(m) = matches.next() {
let mut method_node = None;
let mut name = String::new();
let mut params_node = None;
let mut return_type_node = None;
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"method" => method_node = Some(capture.node),
"method_name" => {
name = capture.node.utf8_text(source).unwrap_or("").to_string();
}
"params" => params_node = Some(capture.node),
"return_type" => return_type_node = Some(capture.node),
_ => {}
}
}
if let Some(node) = method_node {
let parameters = extract_parameters(params_node, source);
let return_type = return_type_node
.map(|n| n.utf8_text(source).unwrap_or("").to_string());
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}::{}:{}", path.display(), class_name, name, line_start);
methods.push(Function {
name: name.clone(),
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async: false,
complexity: Some(calculate_complexity(&node, source)),
});
}
}
Ok(methods)
}
fn extract_structs(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
let query_str = r#"
(struct_specifier
name: (type_identifier) @struct_name
body: (field_declaration_list)? @body
) @struct
"#;
let language = tree_sitter_cpp::LANGUAGE;
let query = Query::new(&language.into(), query_str).context("Failed to create struct query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source);
while let Some(m) = matches.next() {
let mut struct_node = None;
let mut name = String::new();
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"struct" => struct_node = Some(capture.node),
"struct_name" => {
name = capture.node.utf8_text(source).unwrap_or("").to_string();
}
_ => {}
}
}
if let Some(node) = struct_node {
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}", path.display(), name);
result.classes.push(Class {
name: name.clone(),
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
bases: vec![],
methods: vec![],
});
}
}
Ok(())
}
fn extract_includes(root: &Node, source: &[u8], result: &mut ParseResult) -> Result<()> {
let query_str = r#"
(preproc_include
path: [
(string_literal) @path
(system_lib_string) @system_path
]
)
"#;
let language = tree_sitter_cpp::LANGUAGE;
let query = Query::new(&language.into(), query_str).context("Failed to create include query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source);
while let Some(m) = matches.next() {
for capture in m.captures.iter() {
let path_text = capture.node.utf8_text(source).unwrap_or("");
let import_path = path_text
.trim_matches('"')
.trim_matches('<')
.trim_matches('>')
.to_string();
result.imports.push(ImportInfo::runtime(import_path));
}
}
Ok(())
}
fn extract_calls(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
let query_str = r#"
(call_expression
function: [
(identifier) @func_name
(field_expression
field: (field_identifier) @method_name
)
(qualified_identifier) @qualified_name
]
) @call
"#;
let language = tree_sitter_cpp::LANGUAGE;
let query = Query::new(&language.into(), query_str).context("Failed to create call query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source);
while let Some(m) = matches.next() {
let mut call_node = None;
let mut callee_name = String::new();
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"call" => call_node = Some(capture.node),
"func_name" | "method_name" | "qualified_name" => {
callee_name = capture.node.utf8_text(source).unwrap_or("").to_string();
}
_ => {}
}
}
if let Some(node) = call_node {
let caller = find_enclosing_function(&node, source, path);
let callee_line = node.start_position().row as u32 + 1;
result.calls.push((
caller,
format!("{}::{}:{}", path.display(), callee_name, callee_line),
));
}
}
Ok(())
}
fn find_enclosing_function(node: &Node, source: &[u8], path: &Path) -> String {
let mut current = node.parent();
while let Some(parent) = current {
if parent.kind() == "function_definition" {
if let Some(declarator) = parent.child_by_field_name("declarator") {
if let Some(name_node) = declarator.child_by_field_name("declarator") {
let name = name_node.utf8_text(source).unwrap_or("unknown");
let line = parent.start_position().row as u32 + 1;
return format!("{}::{}:{}", path.display(), name, line);
}
}
}
current = parent.parent();
}
format!("{}::<global>", path.display())
}
fn extract_parameters(params_node: Option<Node>, source: &[u8]) -> Vec<String> {
let Some(params) = params_node else {
return vec![];
};
let mut parameters = vec![];
let mut cursor = params.walk();
for child in params.children(&mut cursor) {
match child.kind() {
"parameter_declaration" | "optional_parameter_declaration" => {
if let Some(declarator) = child.child_by_field_name("declarator") {
let name = declarator.utf8_text(source).unwrap_or("");
let name = name.trim_start_matches('*').trim_start_matches('&');
if !name.is_empty() {
parameters.push(name.to_string());
}
}
}
_ => {}
}
}
parameters
}
fn calculate_complexity(node: &Node, source: &[u8]) -> u32 {
let mut complexity = 1u32;
let query_str = r#"
(if_statement) @if
(for_statement) @for
(while_statement) @while
(do_statement) @do
(switch_statement) @switch
(case_statement) @case
(conditional_expression) @ternary
(catch_clause) @catch
("&&") @and
("||") @or
"#;
let language = tree_sitter_cpp::LANGUAGE;
if let Ok(query) = Query::new(&language.into(), query_str) {
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, *node, source);
complexity += matches.count() as u32;
}
complexity
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_parse_simple_function() {
let source = r#"
int add(int a, int b) {
return a + b;
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.functions.len(), 1);
assert_eq!(result.functions[0].name, "add");
assert_eq!(result.functions[0].parameters, vec!["a", "b"]);
}
#[test]
fn test_parse_class() {
let source = r#"
class Calculator {
public:
int add(int a, int b) {
return a + b;
}
int subtract(int a, int b) {
return a - b;
}
};
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.classes.len(), 1);
assert_eq!(result.classes[0].name, "Calculator");
}
#[test]
fn test_parse_includes() {
let source = r#"
#include <iostream>
#include <vector>
#include "myheader.h"
int main() {
return 0;
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).unwrap();
assert!(result.imports.iter().any(|i| i.path == "iostream"));
assert!(result.imports.iter().any(|i| i.path == "vector"));
assert!(result.imports.iter().any(|i| i.path == "myheader.h"));
}
#[test]
fn test_complexity() {
let source = r#"
int complex(int x) {
if (x > 0) {
for (int i = 0; i < x; i++) {
if (i % 2 == 0) {
x++;
}
}
} else if (x < 0) {
while (x < 0) {
x++;
}
}
return x;
}
"#;
let path = PathBuf::from("test.cpp");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.functions.len(), 1);
assert!(result.functions[0].complexity.unwrap_or(0) >= 5); }
}