use std::path::Path;
use tree_sitter::Node;
use crate::ast::parser::ParserPool;
use crate::types::{InheritanceNode, Language};
use crate::TldrResult;
pub fn extract_classes(
source: &str,
file_path: &Path,
parser_pool: &ParserPool,
) -> TldrResult<Vec<InheritanceNode>> {
let tree = parser_pool.parse(source, Language::Cpp)?;
let mut classes = Vec::new();
let root = tree.root_node();
visit_node(&root, source, file_path, &mut classes);
Ok(classes)
}
pub fn extract_classes_c(
source: &str,
file_path: &Path,
parser_pool: &ParserPool,
) -> TldrResult<Vec<InheritanceNode>> {
let tree = parser_pool.parse(source, Language::C)?;
let mut classes = Vec::new();
let root = tree.root_node();
visit_node_with_lang(&root, source, file_path, &mut classes, Language::C);
Ok(classes)
}
fn visit_node(node: &Node, source: &str, file_path: &Path, classes: &mut Vec<InheritanceNode>) {
visit_node_with_lang(node, source, file_path, classes, Language::Cpp);
}
fn visit_node_with_lang(
node: &Node,
source: &str,
file_path: &Path,
classes: &mut Vec<InheritanceNode>,
lang: Language,
) {
match node.kind() {
"class_specifier" => {
if let Some(class) = extract_class_specifier(node, source, file_path, lang, false) {
classes.push(class);
}
}
"struct_specifier" => {
if let Some(class) = extract_class_specifier(node, source, file_path, lang, true) {
classes.push(class);
}
}
"function_definition" | "declaration" => {
if let Some(class) =
extract_macro_prefixed_class(node, source, file_path, lang)
{
classes.push(class);
return;
}
}
_ => {}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
visit_node_with_lang(&child, source, file_path, classes, lang);
}
}
fn extract_macro_prefixed_class(
node: &Node,
source: &str,
file_path: &Path,
lang: Language,
) -> Option<InheritanceNode> {
let type_node = node.child_by_field_name("type")?;
if type_node.kind() != "class_specifier" && type_node.kind() != "struct_specifier" {
return None;
}
let declarator = node.child_by_field_name("declarator")?;
let class_name: String = match declarator.kind() {
"identifier" => declarator
.utf8_text(source.as_bytes())
.ok()?
.trim()
.to_string(),
_ => return None,
};
if class_name.is_empty() {
return None;
}
let line = node.start_position().row as u32 + 1;
let mut class_node =
InheritanceNode::new(class_name, file_path.to_path_buf(), line, lang);
class_node.bases = extract_macro_error_base_clause(node, source);
Some(class_node)
}
fn extract_macro_error_base_clause(node: &Node, source: &str) -> Vec<String> {
let mut bases = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() != "ERROR" {
continue;
}
let mut sub_cursor = child.walk();
for sub in child.children(&mut sub_cursor) {
let kind = sub.kind();
if kind == "identifier" || kind == "type_identifier" {
if let Ok(text) = sub.utf8_text(source.as_bytes()) {
let text = text.trim();
if !text.is_empty()
&& !matches!(
text,
"public" | "protected" | "private" | "virtual"
)
{
bases.push(text.to_string());
}
}
} else if kind == "qualified_identifier" || kind == "template_type" {
if let Some(name) = extract_base_name(&sub, source) {
bases.push(name);
}
}
}
}
bases
}
fn extract_class_specifier(
node: &Node,
source: &str,
file_path: &Path,
lang: Language,
is_struct: bool,
) -> Option<InheritanceNode> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source.as_bytes()).ok()?.to_string();
if name.is_empty() {
return None;
}
let line = node.start_position().row as u32 + 1;
let mut class_node = InheritanceNode::new(name, file_path.to_path_buf(), line, lang);
class_node.bases = extract_base_class_clause(node, source);
let _ = is_struct;
Some(class_node)
}
fn extract_base_class_clause(node: &Node, source: &str) -> Vec<String> {
let mut bases = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "base_class_clause" {
let mut sub_cursor = child.walk();
for sub in child.children(&mut sub_cursor) {
if let Some(name) = extract_base_name(&sub, source) {
if !name.is_empty() {
bases.push(name);
}
}
}
}
}
bases
}
fn extract_base_name(node: &Node, source: &str) -> Option<String> {
match node.kind() {
"type_identifier" => node
.utf8_text(source.as_bytes())
.ok()
.map(|s| s.trim().to_string()),
"qualified_identifier" => {
let mut cursor = node.walk();
let mut last_simple: Option<String> = None;
for child in node.children(&mut cursor) {
if child.kind() == "type_identifier" {
if let Ok(text) = child.utf8_text(source.as_bytes()) {
last_simple = Some(text.trim().to_string());
}
} else if child.kind() == "qualified_identifier"
|| child.kind() == "template_type"
{
if let Some(nested) = extract_base_name(&child, source) {
last_simple = Some(nested);
}
}
}
last_simple
}
"template_type" => {
if let Some(name) = node.child_by_field_name("name") {
return name
.utf8_text(source.as_bytes())
.ok()
.map(|s| s.trim().to_string());
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_identifier" {
return child
.utf8_text(source.as_bytes())
.ok()
.map(|s| s.trim().to_string());
}
}
None
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_and_extract(source: &str) -> Vec<InheritanceNode> {
let pool = ParserPool::new();
extract_classes(source, Path::new("test.cpp"), &pool).unwrap()
}
#[test]
fn test_simple_class() {
let source = "class Foo { public: int x; };";
let nodes = parse_and_extract(source);
assert_eq!(nodes.len(), 1);
assert_eq!(nodes[0].name, "Foo");
assert!(nodes[0].bases.is_empty());
}
#[test]
fn test_single_inheritance() {
let source = "class Base {}; class Derived : public Base {};";
let nodes = parse_and_extract(source);
let derived = nodes.iter().find(|n| n.name == "Derived").unwrap();
assert_eq!(derived.bases, vec!["Base"]);
}
#[test]
fn test_multiple_inheritance() {
let source =
"class A {}; class B {}; class C : public A, public B {};";
let nodes = parse_and_extract(source);
let c = nodes.iter().find(|n| n.name == "C").unwrap();
assert_eq!(c.bases, vec!["A", "B"]);
}
#[test]
fn test_virtual_inheritance() {
let source = "class Base {}; class Derived : virtual public Base {};";
let nodes = parse_and_extract(source);
let derived = nodes.iter().find(|n| n.name == "Derived").unwrap();
assert_eq!(derived.bases, vec!["Base"]);
}
#[test]
fn test_namespace_class() {
let source = "namespace foo { class Base {}; class D : public Base {}; }";
let nodes = parse_and_extract(source);
let d = nodes.iter().find(|n| n.name == "D").unwrap();
assert_eq!(d.bases, vec!["Base"]);
}
#[test]
fn test_template_base() {
let source = "template<typename T> class Vec {}; class IntVec : public Vec<int> {};";
let nodes = parse_and_extract(source);
let iv = nodes.iter().find(|n| n.name == "IntVec").unwrap();
assert_eq!(iv.bases, vec!["Vec"]);
}
#[test]
fn test_macro_prefixed_class_inheritance() {
let source =
"class MACRO XMLText : public XMLNode {\npublic:\n int x;\n};";
let nodes = parse_and_extract(source);
let xt = nodes.iter().find(|n| n.name == "XMLText");
assert!(
xt.is_some(),
"Expected to recover XMLText class from macro-prefixed declaration; got {:?}",
nodes.iter().map(|n| &n.name).collect::<Vec<_>>()
);
assert_eq!(xt.unwrap().bases, vec!["XMLNode"]);
}
}