use tree_sitter::Node;
pub const ELEMENT_QUERY: &str = r"
(function_definition
declarator: (function_declarator
declarator: (identifier) @func_name)) @function
(function_definition
declarator: (function_declarator
declarator: (qualified_identifier
name: (identifier) @method_name))) @function
(class_specifier
name: (type_identifier) @class_name) @class
(struct_specifier
name: (type_identifier) @class_name) @class
(template_declaration
(function_definition
declarator: (function_declarator
declarator: (identifier) @func_name))) @function
";
pub const CALL_QUERY: &str = r"
(call_expression
function: (identifier) @call)
(call_expression
function: (field_expression field: (field_identifier) @call))
";
pub const REFERENCE_QUERY: &str = r"
(type_identifier) @type_ref
";
pub const IMPORT_QUERY: &str = r"
(preproc_include
path: (string_literal) @import_path)
(preproc_include
path: (system_lib_string) @import_path)
";
pub fn extract_function_name(node: &Node, source: &str, _lang: &str) -> Option<String> {
node.child_by_field_name("declarator")
.and_then(|decl| extract_declarator_name(decl, source))
}
#[must_use]
pub fn find_method_for_receiver(
node: &Node,
source: &str,
_depth: Option<usize>,
) -> Option<String> {
if node.kind() != "function_definition" {
return None;
}
let mut parent = node.parent();
let mut in_class = false;
while let Some(p) = parent {
if p.kind() == "class_specifier" || p.kind() == "struct_specifier" {
in_class = true;
break;
}
parent = p.parent();
}
if !in_class {
return None;
}
if let Some(decl) = node.child_by_field_name("declarator") {
extract_declarator_name(decl, source)
} else {
None
}
}
#[must_use]
pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
let mut inherits = Vec::new();
if node.kind() != "class_specifier" && node.kind() != "struct_specifier" {
return inherits;
}
for i in 0..node.named_child_count() {
if let Some(child) = node.named_child(u32::try_from(i).unwrap_or(u32::MAX))
&& child.kind() == "base_class_clause"
{
for j in 0..child.named_child_count() {
if let Some(base) = child.named_child(u32::try_from(j).unwrap_or(u32::MAX))
&& base.kind() == "type_identifier"
{
let text = &source[base.start_byte()..base.end_byte()];
inherits.push(text.to_string());
}
}
}
}
inherits
}
fn extract_declarator_name(node: Node, source: &str) -> Option<String> {
match node.kind() {
"identifier" | "field_identifier" => {
let start = node.start_byte();
let end = node.end_byte();
if end <= source.len() {
Some(source[start..end].to_string())
} else {
None
}
}
"qualified_identifier" => node.child_by_field_name("name").and_then(|n| {
let start = n.start_byte();
let end = n.end_byte();
if end <= source.len() {
Some(source[start..end].to_string())
} else {
None
}
}),
"function_declarator" => node
.child_by_field_name("declarator")
.and_then(|n| extract_declarator_name(n, source)),
"pointer_declarator" => node
.child_by_field_name("declarator")
.and_then(|n| extract_declarator_name(n, source)),
"reference_declarator" => node
.child_by_field_name("declarator")
.and_then(|n| extract_declarator_name(n, source)),
_ => None,
}
}
#[cfg(all(test, feature = "lang-cpp"))]
mod tests {
use super::*;
use tree_sitter::Parser;
fn parse_cpp(source: &str) -> tree_sitter::Tree {
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_cpp::LANGUAGE.into())
.expect("failed to set C++ language");
parser.parse(source, None).expect("failed to parse source")
}
#[test]
fn test_free_function() {
let source = "int add(int a, int b) { return a + b; }";
let tree = parse_cpp(source);
let root = tree.root_node();
let func_node = root.named_child(0).expect("expected function_definition");
let result = find_method_for_receiver(&func_node, source, None);
assert_eq!(result, None);
}
#[test]
fn test_class_with_method() {
let source = "class Foo { public: int getValue() { return 42; } };";
let tree = parse_cpp(source);
let root = tree.root_node();
let func_node = find_node_by_kind(root, "function_definition").expect("expected function");
let result = find_method_for_receiver(&func_node, source, None);
assert_eq!(result, Some("getValue".to_string()));
}
#[test]
fn test_struct() {
let source = "struct Point { int x; int y; };";
let tree = parse_cpp(source);
let root = tree.root_node();
let struct_node =
find_node_by_kind(root, "struct_specifier").expect("expected struct_specifier");
assert_eq!(struct_node.kind(), "struct_specifier");
let result = extract_inheritance(&struct_node, source);
assert!(
result.is_empty(),
"expected no inheritance, got: {result:?}"
);
}
#[test]
fn test_include_directive() {
use tree_sitter::StreamingIterator;
let source = "#include <stdio.h>\n#include \"myfile.h\"\n";
let tree = parse_cpp(source);
let lang: tree_sitter::Language = tree_sitter_cpp::LANGUAGE.into();
let query = tree_sitter::Query::new(&lang, super::IMPORT_QUERY)
.expect("IMPORT_QUERY must be valid");
let mut cursor = tree_sitter::QueryCursor::new();
let mut iter = cursor.captures(&query, tree.root_node(), source.as_bytes());
let mut captures: Vec<String> = Vec::new();
while let Some((m, _)) = iter.next() {
for c in m.captures {
let text = c
.node
.utf8_text(source.as_bytes())
.unwrap_or("")
.to_string();
captures.push(text);
}
}
assert!(
captures.iter().any(|s| s.contains("stdio.h")),
"expected stdio.h in captures: {captures:?}"
);
assert!(
captures.iter().any(|s| s.contains("myfile.h")),
"expected myfile.h in captures: {captures:?}"
);
}
#[test]
fn test_template_function() {
use tree_sitter::StreamingIterator;
let source = "template<typename T> T max(T a, T b) { return a > b ? a : b; }";
let tree = parse_cpp(source);
let lang: tree_sitter::Language = tree_sitter_cpp::LANGUAGE.into();
let query = tree_sitter::Query::new(&lang, super::ELEMENT_QUERY)
.expect("ELEMENT_QUERY must be valid");
let mut cursor = tree_sitter::QueryCursor::new();
let mut iter = cursor.captures(&query, tree.root_node(), source.as_bytes());
let mut func_names: Vec<String> = Vec::new();
while let Some((m, _)) = iter.next() {
for c in m.captures {
let name = query.capture_names()[c.index as usize];
if name == "func_name" {
if let Ok(text) = c.node.utf8_text(source.as_bytes()) {
func_names.push(text.to_string());
}
}
}
}
assert!(
func_names.iter().any(|s| s == "max"),
"expected 'max' in func_names: {func_names:?}"
);
}
#[test]
fn test_class_with_inheritance() {
let source = "class Derived : public Base { };";
let tree = parse_cpp(source);
let root = tree.root_node();
let class_node = find_node_by_kind(root, "class_specifier").expect("expected class");
let result = extract_inheritance(&class_node, source);
assert!(!result.is_empty(), "expected inheritance information");
assert!(
result.iter().any(|s| s.contains("Base")),
"expected 'Base' in inheritance: {:?}",
result
);
}
fn find_node_by_kind<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
if node.kind() == kind {
return Some(node);
}
for i in 0..node.child_count() {
if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
if let Some(found) = find_node_by_kind(child, kind) {
return Some(found);
}
}
}
None
}
}