use tree_sitter::{Node, Tree};
use crate::graph::ast_node::{is_structural_kind, kinds, AstNode};
pub fn extract_ast_nodes(tree: &Tree, source: &[u8]) -> Vec<AstNode> {
let extractor = AstExtractor::new(source);
extractor.extract(tree)
}
pub fn normalize_node_kind<'a>(kind: &'a str, _language: &str) -> &'a str {
match kind {
"if_expression" | "if_statement" => kinds::IF,
"match_expression" | "match_statement" => kinds::MATCH,
"while_expression" | "while_statement" => kinds::WHILE,
"for_expression" | "for_statement" => kinds::FOR,
"loop_expression" => kinds::LOOP,
"return_expression" | "return_statement" => kinds::RETURN,
"break_expression" | "break_statement" => kinds::BREAK,
"continue_expression" | "continue_statement" => kinds::CONTINUE,
"function_item" | "function_definition" => kinds::FUNCTION,
"method_definition" => kinds::FUNCTION,
"struct_item" | "struct_definition" => kinds::STRUCT,
"enum_item" | "enum_definition" => kinds::ENUM,
"trait_item" | "trait_definition" => kinds::TRAIT,
"impl_item" => kinds::IMPL,
"mod_item" => kinds::MODULE,
"class_definition" => kinds::CLASS,
"interface_definition" => kinds::INTERFACE,
"block" | "block_expression" | "statement_block" => kinds::BLOCK,
"let_statement" => kinds::LET,
"expression_statement" => "Expression", "assignment_expression" => kinds::ASSIGN,
"call_expression" => kinds::CALL,
"attribute_item" | "decorated_definition" => kinds::ATTRIBUTE,
"const_item" => kinds::CONST,
"static_item" => kinds::STATIC,
_ => kind,
}
}
pub fn language_from_path(path: &str) -> Option<&'static str> {
let ext = path.rsplit('.').next()?;
match ext {
"rs" => Some("rust"),
"py" => Some("python"),
"c" | "h" => Some("c"),
"cpp" | "cc" | "cxx" | "hpp" | "hh" | "hxx" => Some("cpp"),
"java" => Some("java"),
"js" | "mjs" | "cjs" => Some("javascript"),
"ts" => Some("typescript"),
"tsx" => Some("tsx"),
_ => None,
}
}
struct AstExtractor<'a> {
source: &'a [u8],
nodes: Vec<AstNode>,
parent_stack: Vec<Option<usize>>,
}
impl<'a> AstExtractor<'a> {
fn new(source: &'a [u8]) -> Self {
Self {
source,
nodes: Vec::new(),
parent_stack: Vec::new(),
}
}
fn extract(mut self, tree: &Tree) -> Vec<AstNode> {
let root = tree.root_node();
self.traverse(&root);
self.nodes
}
fn traverse(&mut self, node: &Node) {
let kind = node.kind();
if is_structural_kind(kind) {
let byte_start = node.start_byte();
let byte_end = node.end_byte();
if byte_end > self.source.len() || byte_start > byte_end {
return;
}
let ast_node = AstNode {
id: None, parent_id: None, kind: kind.to_string(),
byte_start,
byte_end,
};
let node_index = self.nodes.len();
self.nodes.push(ast_node);
if let Some(last_index) = self.parent_stack.last() {
if let &Some(parent_idx) = last_index {
self.nodes[node_index].parent_id = Some(-(parent_idx as i64) - 1);
}
}
self.parent_stack.push(Some(node_index));
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
self.traverse(&cursor.node());
if !cursor.goto_next_sibling() {
break;
}
}
}
self.parent_stack.pop();
} else {
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
self.traverse(&cursor.node());
if !cursor.goto_next_sibling() {
break;
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tree_sitter::Parser;
#[test]
fn test_extract_simple_function() {
let source = b"fn main() { }";
let mut parser = Parser::new();
parser.set_language(&tree_sitter_rust::language()).unwrap();
let tree = parser.parse(source, None).unwrap();
let nodes = extract_ast_nodes(&tree, source);
assert!(!nodes.is_empty());
let fn_node = nodes.iter().find(|n| n.kind == "function_item");
assert!(fn_node.is_some());
}
#[test]
fn test_extract_if_expression() {
let source = b"fn test() { if x { y } else { z } }";
let mut parser = Parser::new();
parser.set_language(&tree_sitter_rust::language()).unwrap();
let tree = parser.parse(source, None).unwrap();
let nodes = extract_ast_nodes(&tree, source);
let if_nodes: Vec<_> = nodes.iter().filter(|n| n.kind == "if_expression").collect();
assert!(!if_nodes.is_empty());
}
#[test]
fn test_extract_ignores_identifiers() {
let source = b"fn main() { let x = 42; }";
let mut parser = Parser::new();
parser.set_language(&tree_sitter_rust::language()).unwrap();
let tree = parser.parse(source, None).unwrap();
let nodes = extract_ast_nodes(&tree, source);
assert!(!nodes.iter().any(|n| n.kind == "identifier"));
assert!(nodes.iter().any(|n| n.kind == "function_item"));
assert!(nodes.iter().any(|n| n.kind == "let_declaration"));
}
#[test]
fn test_parent_child_relationships() {
let source = b"fn main() { if x { y } }";
let mut parser = Parser::new();
parser.set_language(&tree_sitter_rust::language()).unwrap();
let tree = parser.parse(source, None).unwrap();
let nodes = extract_ast_nodes(&tree, source);
let if_node = nodes.iter().find(|n| n.kind == "if_expression");
assert!(if_node.is_some());
let if_node = if_node.unwrap();
assert!(if_node.parent_id.is_some());
}
#[test]
fn test_normalize_if_expression() {
assert_eq!(normalize_node_kind("if_expression", "rust"), "If");
assert_eq!(normalize_node_kind("if_statement", "python"), "If");
}
#[test]
fn test_normalize_function() {
assert_eq!(normalize_node_kind("function_item", "rust"), "Function");
assert_eq!(
normalize_node_kind("function_definition", "python"),
"Function"
);
}
#[test]
fn test_normalize_unknown() {
assert_eq!(normalize_node_kind("unknown_kind", "rust"), "unknown_kind");
}
#[test]
fn test_language_from_path() {
assert_eq!(language_from_path("src/main.rs"), Some("rust"));
assert_eq!(language_from_path("script.py"), Some("python"));
assert_eq!(language_from_path("header.h"), Some("c"));
assert_eq!(language_from_path("file.cpp"), Some("cpp"));
assert_eq!(language_from_path("Main.java"), Some("java"));
assert_eq!(language_from_path("app.js"), Some("javascript"));
assert_eq!(language_from_path("app.ts"), Some("typescript"));
assert_eq!(language_from_path("unknown.xyz"), None);
}
}