use super::{extract_doc_comment, extract_signature, field_text, node_text};
use crate::types::*;
pub fn extract(tree: &tree_sitter::Tree, source: &[u8]) -> (Vec<Symbol>, Vec<Import>) {
let root = tree.root_node();
let mut symbols = Vec::new();
let mut imports = Vec::new();
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
extract_node(&child, source, &mut symbols, &mut imports, None);
}
(symbols, imports)
}
fn find_child_by_kind<'a>(
node: &tree_sitter::Node<'a>,
kind: &str,
) -> Option<tree_sitter::Node<'a>> {
let mut cursor = node.walk();
let result: Vec<_> = node
.children(&mut cursor)
.filter(|c| c.kind() == kind)
.collect();
result.into_iter().next()
}
fn find_child_by_kinds<'a>(
node: &tree_sitter::Node<'a>,
kinds: &[&str],
) -> Option<tree_sitter::Node<'a>> {
let mut cursor = node.walk();
let result: Vec<_> = node
.children(&mut cursor)
.filter(|c| kinds.contains(&c.kind()))
.collect();
result.into_iter().next()
}
fn extract_node(
node: &tree_sitter::Node,
source: &[u8],
symbols: &mut Vec<Symbol>,
imports: &mut Vec<Import>,
parent_name: Option<&str>,
) {
match node.kind() {
"function_definition" => {
if let Some(sym) = extract_function(node, source, parent_name) {
symbols.push(sym);
}
}
"declaration" => {
extract_declaration(node, source, symbols, parent_name);
}
"class_specifier" => {
if let Some(sym) = extract_class(node, source) {
symbols.push(sym);
}
}
"struct_specifier" => {
if let Some(sym) = extract_struct(node, source) {
symbols.push(sym);
}
}
"enum_specifier" => {
if let Some(sym) = extract_enum(node, source) {
symbols.push(sym);
}
}
"namespace_definition" => {
if let Some(sym) = extract_namespace(node, source) {
symbols.push(sym);
}
}
"template_declaration" => {
extract_template(node, source, symbols, imports, parent_name);
}
"type_definition" => {
if let Some(sym) = extract_typedef(node, source) {
symbols.push(sym);
}
}
"preproc_include" => {
let path = field_text(node, "path", source).unwrap_or_else(|| node_text(node, source));
imports.push(Import {
path: path.to_string(),
alias: None,
span: Span::from_node(node),
});
}
"preproc_def" => {
if let Some(sym) = extract_define(node, source) {
symbols.push(sym);
}
}
_ => {}
}
}
fn extract_function(
node: &tree_sitter::Node,
source: &[u8],
parent_name: Option<&str>,
) -> Option<Symbol> {
let declarator = node.child_by_field_name("declarator")?;
let name = find_declarator_name(&declarator, source)?;
let kind = if parent_name.is_some() {
SymbolKind::Method
} else {
SymbolKind::Function
};
Some(Symbol {
name: name.to_string(),
kind,
span: Span::from_node(node),
signature: extract_signature(node, "body", source),
doc_comment: extract_doc_comment(node, source),
parent: parent_name.map(|s| s.to_string()),
children: Vec::new(),
})
}
fn find_declarator_name<'a>(node: &tree_sitter::Node, source: &'a [u8]) -> Option<&'a str> {
match node.kind() {
"identifier" | "field_identifier" | "destructor_name" => Some(node_text(node, source)),
"qualified_identifier" => {
let text = node_text(node, source);
Some(text)
}
"function_declarator"
| "pointer_declarator"
| "parenthesized_declarator"
| "reference_declarator" => {
if let Some(inner) = node.child_by_field_name("declarator") {
find_declarator_name(&inner, source)
} else {
node.child_by_field_name("name")
.map(|n| node_text(&n, source))
}
}
"operator_name" => Some(node_text(node, source)),
_ => node
.child_by_field_name("declarator")
.and_then(|d| find_declarator_name(&d, source))
.or_else(|| {
node.child_by_field_name("name")
.map(|n| node_text(&n, source))
}),
}
}
fn extract_class(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
let name = field_text(node, "name", source)?;
if name.is_empty() {
return None;
}
let mut children = Vec::new();
extract_class_body(node, source, name, &mut children);
Some(Symbol {
name: name.to_string(),
kind: SymbolKind::Class,
span: Span::from_node(node),
signature: extract_signature(node, "body", source),
doc_comment: extract_doc_comment(node, source),
parent: None,
children,
})
}
fn extract_struct(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
let name = field_text(node, "name", source)?;
if name.is_empty() {
return None;
}
Some(Symbol {
name: name.to_string(),
kind: SymbolKind::Struct,
span: Span::from_node(node),
signature: node_text(node, source).to_string(),
doc_comment: extract_doc_comment(node, source),
parent: None,
children: Vec::new(),
})
}
fn extract_enum(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
let name = field_text(node, "name", source)?;
if name.is_empty() {
return None;
}
Some(Symbol {
name: name.to_string(),
kind: SymbolKind::Enum,
span: Span::from_node(node),
signature: node_text(node, source).to_string(),
doc_comment: extract_doc_comment(node, source),
parent: None,
children: Vec::new(),
})
}
fn extract_namespace(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
let name = field_text(node, "name", source).unwrap_or("(anonymous)");
let mut children = Vec::new();
let body = node
.child_by_field_name("body")
.or_else(|| find_child_by_kind(node, "declaration_list"));
if let Some(body) = body {
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
let mut inner_imports = Vec::new();
extract_node(&child, source, &mut children, &mut inner_imports, None);
}
}
Some(Symbol {
name: name.to_string(),
kind: SymbolKind::Module,
span: Span::from_node(node),
signature: format!("namespace {}", name),
doc_comment: extract_doc_comment(node, source),
parent: None,
children,
})
}
fn extract_template(
node: &tree_sitter::Node,
source: &[u8],
symbols: &mut Vec<Symbol>,
imports: &mut Vec<Import>,
parent_name: Option<&str>,
) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"template_parameter_list" => {} _ => {
extract_node(&child, source, symbols, imports, parent_name);
}
}
}
}
fn extract_class_body(
node: &tree_sitter::Node,
source: &[u8],
class_name: &str,
children: &mut Vec<Symbol>,
) {
let body = match node.child_by_field_name("body") {
Some(b) => b,
None => return,
};
let mut cursor = body.walk();
for member in body.children(&mut cursor) {
match member.kind() {
"function_definition" => {
if let Some(sym) = extract_function(&member, source, Some(class_name)) {
children.push(sym);
}
}
"declaration" | "field_declaration" => {
let declarator = member.child_by_field_name("declarator").or_else(|| {
find_child_by_kinds(
&member,
&["function_declarator", "identifier", "field_identifier"],
)
});
if let Some(declarator) = declarator {
if let Some(name) = find_declarator_name(&declarator, source) {
let is_func = declarator.kind() == "function_declarator";
children.push(Symbol {
name: name.to_string(),
kind: if is_func {
SymbolKind::Method
} else {
SymbolKind::Const
},
span: Span::from_node(&member),
signature: node_text(&member, source)
.trim_end_matches(';')
.trim()
.to_string(),
doc_comment: extract_doc_comment(&member, source),
parent: Some(class_name.to_string()),
children: Vec::new(),
});
}
}
}
"access_specifier" => {} "template_declaration" => {
let mut inner_imports = Vec::new();
extract_template(
&member,
source,
children,
&mut inner_imports,
Some(class_name),
);
}
_ => {}
}
}
}
fn extract_declaration(
node: &tree_sitter::Node,
source: &[u8],
symbols: &mut Vec<Symbol>,
parent_name: Option<&str>,
) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"class_specifier" => {
if let Some(sym) = extract_class(&child, source) {
symbols.push(sym);
}
}
"struct_specifier" => {
if let Some(sym) = extract_struct(&child, source) {
symbols.push(sym);
}
}
"enum_specifier" => {
if let Some(sym) = extract_enum(&child, source) {
symbols.push(sym);
}
}
_ => {}
}
}
if let Some(declarator) = node.child_by_field_name("declarator") {
if let Some(name) = find_declarator_name(&declarator, source) {
let is_func = declarator.kind() == "function_declarator";
if is_func {
let kind = if parent_name.is_some() {
SymbolKind::Method
} else {
SymbolKind::Function
};
symbols.push(Symbol {
name: name.to_string(),
kind,
span: Span::from_node(node),
signature: node_text(node, source)
.trim_end_matches(';')
.trim()
.to_string(),
doc_comment: extract_doc_comment(node, source),
parent: parent_name.map(|s| s.to_string()),
children: Vec::new(),
});
}
}
}
}
fn extract_typedef(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
let declarator = node.child_by_field_name("declarator")?;
let name = find_declarator_name(&declarator, source)?;
Some(Symbol {
name: name.to_string(),
kind: SymbolKind::TypeAlias,
span: Span::from_node(node),
signature: node_text(node, source).to_string(),
doc_comment: extract_doc_comment(node, source),
parent: None,
children: Vec::new(),
})
}
fn extract_define(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
let name = field_text(node, "name", source)?;
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "preproc_params" {
return None; }
}
Some(Symbol {
name: name.to_string(),
kind: SymbolKind::Const,
span: Span::from_node(node),
signature: node_text(node, source).to_string(),
doc_comment: extract_doc_comment(node, source),
parent: None,
children: Vec::new(),
})
}
#[cfg(test)]
#[cfg(feature = "cpp")]
mod tests {
use super::*;
fn parse_cpp(source: &str) -> (Vec<Symbol>, Vec<Import>) {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_cpp::LANGUAGE.into())
.expect("failed to set cpp language");
let tree = parser.parse(source, None).expect("failed to parse");
extract(&tree, source.as_bytes())
}
#[test]
fn test_cpp_extraction() {
let src = r#"
#include <iostream>
#include "utils.h"
#define VERSION 42
namespace math {
class Calculator {
public:
int add(int a, int b) {
return a + b;
}
virtual int subtract(int a, int b);
};
struct Point {
double x;
double y;
};
enum class Color { Red, Green, Blue };
int multiply(int a, int b) {
return a * b;
}
} // namespace math
template<typename T>
T identity(T val) {
return val;
}
"#;
let (symbols, imports) = parse_cpp(src);
assert_eq!(imports.len(), 2);
assert!(imports[0].path.contains("iostream"));
assert!(imports[1].path.contains("utils.h"));
let consts: Vec<_> = symbols
.iter()
.filter(|s| s.kind == SymbolKind::Const)
.collect();
assert_eq!(consts.len(), 1);
assert_eq!(consts[0].name, "VERSION");
let namespaces: Vec<_> = symbols
.iter()
.filter(|s| s.kind == SymbolKind::Module)
.collect();
assert_eq!(namespaces.len(), 1);
assert_eq!(namespaces[0].name, "math");
let classes: Vec<_> = namespaces[0]
.children
.iter()
.filter(|s| s.kind == SymbolKind::Class)
.collect();
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].name, "Calculator");
let methods: Vec<_> = classes[0]
.children
.iter()
.filter(|s| s.kind == SymbolKind::Method)
.collect();
assert!(methods.len() >= 2);
let method_names: Vec<_> = methods.iter().map(|m| m.name.as_str()).collect();
assert!(method_names.contains(&"add"));
assert!(method_names.contains(&"subtract"));
assert_eq!(methods[0].parent.as_deref(), Some("Calculator"));
let structs: Vec<_> = namespaces[0]
.children
.iter()
.filter(|s| s.kind == SymbolKind::Struct)
.collect();
assert_eq!(structs.len(), 1);
assert_eq!(structs[0].name, "Point");
let enums: Vec<_> = namespaces[0]
.children
.iter()
.filter(|s| s.kind == SymbolKind::Enum)
.collect();
assert_eq!(enums.len(), 1);
assert_eq!(enums[0].name, "Color");
let ns_funcs: Vec<_> = namespaces[0]
.children
.iter()
.filter(|s| s.kind == SymbolKind::Function)
.collect();
assert!(ns_funcs.len() >= 1);
assert!(ns_funcs.iter().any(|f| f.name == "multiply"));
let top_funcs: Vec<_> = symbols
.iter()
.filter(|s| s.kind == SymbolKind::Function)
.collect();
assert!(top_funcs.iter().any(|f| f.name == "identity"));
}
}