use super::{AstNode, Position, LanguageConfig, NodeMetadata};
use tree_sitter::{Node, Tree, TreeCursor};
pub struct AstWalker {
config: Option<LanguageConfig>,
}
impl AstWalker {
pub fn new() -> Self {
Self { config: None }
}
pub fn with_language(mut self, lang: &str) -> Self {
self.config = Some(LanguageConfig::for_language(lang));
self
}
pub fn walk(&mut self, tree: &Tree, source: &str, nodes: &mut Vec<AstNode>) {
let mut cursor = tree.walk();
self.walk_cursor(&mut cursor, source, nodes, 0, None);
}
fn walk_cursor(
&self,
cursor: &mut TreeCursor,
source: &str,
nodes: &mut Vec<AstNode>,
depth: usize,
parent: Option<usize>,
) {
let node = cursor.node();
if self.should_collect_node(&node) {
let node_index = nodes.len();
let ast_node = AstNode {
kind: node.kind().to_string(),
name: self.extract_node_name(&node, source),
start: Position {
line: node.start_position().row,
column: node.start_position().column,
offset: node.start_byte(),
},
end: Position {
line: node.end_position().row,
column: node.end_position().column,
offset: node.end_byte(),
},
range: node.byte_range(),
depth,
parent,
children: Vec::new(),
metadata: self.extract_node_metadata(&node, source),
};
nodes.push(ast_node);
if let Some(parent_idx) = parent {
if let Some(parent_node) = nodes.get_mut(parent_idx) {
parent_node.children.push(node_index);
}
}
if cursor.goto_first_child() {
loop {
self.walk_cursor(cursor, source, nodes, depth + 1, Some(node_index));
if !cursor.goto_next_sibling() {
break;
}
}
cursor.goto_parent();
}
} else {
if cursor.goto_first_child() {
loop {
self.walk_cursor(cursor, source, nodes, depth, parent);
if !cursor.goto_next_sibling() {
break;
}
}
cursor.goto_parent();
}
}
}
fn should_collect_node(&self, node: &Node) -> bool {
if !node.is_named() {
return false;
}
if let Some(config) = &self.config {
config.scope_types.iter().any(|t| t == &node.kind().to_string()) ||
config.context_types.iter().any(|t| t == &node.kind().to_string())
} else {
matches!(
node.kind(),
"function" | "method" | "class" | "struct" | "interface" |
"function_declaration" | "function_definition" |
"method_definition" | "class_declaration" | "class_definition" |
"impl_item" | "function_item" | "struct_item" | "enum_item" |
"trait_item" | "mod_item"
)
}
}
fn extract_node_name(&self, node: &Node, source: &str) -> Option<String> {
let name_fields: Vec<&str> = if let Some(config) = &self.config {
config.name_fields.iter().map(|s| s.as_str()).collect()
} else {
vec!["name", "identifier"]
};
for field in name_fields {
if let Some(name_node) = node.child_by_field_name(&field) {
let name = &source[name_node.byte_range()];
return Some(name.to_string());
}
}
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
let child = cursor.node();
if child.kind() == "identifier" {
let name = &source[child.byte_range()];
return Some(name.to_string());
}
if !cursor.goto_next_sibling() {
break;
}
}
}
None
}
fn extract_node_metadata(&self, node: &Node, source: &str) -> NodeMetadata {
let mut metadata = NodeMetadata::default();
if let Some(config) = &self.config {
metadata.is_scope = config.scope_types.iter().any(|t| t == &node.kind().to_string());
metadata.is_definition = config.definition_types.iter().any(|t| t == &node.kind().to_string());
metadata.is_declaration = config.declaration_types.iter().any(|t| t == &node.kind().to_string());
} else {
metadata.is_scope = self.is_default_scope_node(node);
metadata.is_definition = self.is_default_definition_node(node);
metadata.is_declaration = self.is_default_declaration_node(node);
}
metadata.visibility = self.extract_visibility(node, source);
metadata.documentation = self.extract_documentation(node, source);
metadata
}
fn is_default_scope_node(&self, node: &Node) -> bool {
matches!(
node.kind(),
"function" | "method" | "class" | "struct" | "interface" |
"function_declaration" | "function_definition" |
"method_definition" | "class_declaration" | "class_definition" |
"impl_item" | "function_item" | "struct_item" | "enum_item" |
"trait_item" | "mod_item"
)
}
fn is_default_definition_node(&self, node: &Node) -> bool {
matches!(
node.kind(),
"function_definition" | "class_definition" | "struct_item" |
"enum_item" | "trait_item" | "function_item"
)
}
fn is_default_declaration_node(&self, node: &Node) -> bool {
matches!(
node.kind(),
"function_declaration" | "class_declaration" | "function_signature"
)
}
fn extract_visibility(&self, _node: &Node, _source: &str) -> Option<String> {
None
}
fn extract_documentation(&self, _node: &Node, _source: &str) -> Option<String> {
None
}
}
impl Default for AstWalker {
fn default() -> Self {
Self::new()
}
}