use crate::languages::LanguageParser;
use crate::node::{CodeNode, NodeKind, Visibility};
use tree_sitter::{Language, Node, Tree};
pub struct RustParser;
impl LanguageParser for RustParser {
fn language(&self) -> Language {
tree_sitter_rust::language()
}
fn extensions(&self) -> &[&str] {
&["rs"]
}
fn extract_nodes(&self, tree: &Tree, source: &str, file_path: &str) -> Vec<CodeNode> {
let mut nodes = Vec::new();
let root = tree.root_node();
extract_from_node(&root, source, file_path, &mut nodes, None);
nodes
}
}
fn extract_from_node(
node: &Node,
source: &str,
file_path: &str,
nodes: &mut Vec<CodeNode>,
context: Option<&str>,
) {
let kind = node.kind();
match kind {
"function_item" => {
if let Some(code_node) = extract_function(node, source, file_path, context) {
nodes.push(code_node);
}
}
"struct_item" => {
if let Some(code_node) = extract_struct(node, source, file_path) {
nodes.push(code_node);
}
}
"enum_item" => {
if let Some(code_node) = extract_enum(node, source, file_path) {
nodes.push(code_node);
}
}
"trait_item" => {
if let Some(code_node) = extract_trait(node, source, file_path) {
let trait_name = code_node.name.clone();
nodes.push(code_node);
if let Some(body) = find_child_by_kind(node, "declaration_list") {
for i in 0..body.child_count() {
if let Some(child) = body.child(i) {
extract_from_node(&child, source, file_path, nodes, Some(&trait_name));
}
}
}
return;
}
}
"impl_item" => {
let impl_target = get_impl_target(node, source);
if let Some(body) = find_child_by_kind(node, "declaration_list") {
for i in 0..body.child_count() {
if let Some(child) = body.child(i) {
extract_from_node(&child, source, file_path, nodes, impl_target.as_deref());
}
}
}
return;
}
"mod_item" => {
if let Some(code_node) = extract_module(node, source, file_path) {
nodes.push(code_node);
}
}
"use_declaration" => {
if let Some(code_node) = extract_use(node, source, file_path) {
nodes.push(code_node);
}
}
"const_item" | "static_item" => {
if let Some(code_node) = extract_const(node, source, file_path) {
nodes.push(code_node);
}
}
"type_item" => {
if let Some(code_node) = extract_type_alias(node, source, file_path) {
nodes.push(code_node);
}
}
_ => {}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
extract_from_node(&child, source, file_path, nodes, context);
}
}
}
fn extract_function(
node: &Node,
source: &str,
file_path: &str,
context: Option<&str>,
) -> Option<CodeNode> {
let name_node = node.child_by_field_name("name")?;
let name = get_text(&name_node, source);
let kind = if context.is_some() {
NodeKind::Method
} else {
NodeKind::Function
};
let qualified_name = match context {
Some(ctx) => format!("{}.{}", ctx, name),
None => name.clone(),
};
let visibility = detect_visibility(node, source);
let is_async = has_modifier(node, "async");
let signature = build_function_signature(node, source, &name);
let references = extract_call_references(node, source);
Some(
CodeNode::new(&name, &qualified_name, kind, file_path)
.with_lines(
node.start_position().row as u32 + 1,
node.end_position().row as u32 + 1,
)
.with_bytes(node.start_byte() as u32, node.end_byte() as u32)
.with_column(name_node.start_position().column as u32)
.with_signature(signature)
.with_visibility(visibility)
.with_references(references)
.with_async_if(is_async),
)
}
fn extract_struct(node: &Node, source: &str, file_path: &str) -> Option<CodeNode> {
let name_node = node.child_by_field_name("name")?;
let name = get_text(&name_node, source);
let visibility = detect_visibility(node, source);
Some(
CodeNode::new(&name, &name, NodeKind::Struct, file_path)
.with_lines(
node.start_position().row as u32 + 1,
node.end_position().row as u32 + 1,
)
.with_bytes(node.start_byte() as u32, node.end_byte() as u32)
.with_column(name_node.start_position().column as u32)
.with_visibility(visibility),
)
}
fn extract_enum(node: &Node, source: &str, file_path: &str) -> Option<CodeNode> {
let name_node = node.child_by_field_name("name")?;
let name = get_text(&name_node, source);
let visibility = detect_visibility(node, source);
Some(
CodeNode::new(&name, &name, NodeKind::Enum, file_path)
.with_lines(
node.start_position().row as u32 + 1,
node.end_position().row as u32 + 1,
)
.with_bytes(node.start_byte() as u32, node.end_byte() as u32)
.with_column(name_node.start_position().column as u32)
.with_visibility(visibility),
)
}
fn extract_trait(node: &Node, source: &str, file_path: &str) -> Option<CodeNode> {
let name_node = node.child_by_field_name("name")?;
let name = get_text(&name_node, source);
let visibility = detect_visibility(node, source);
Some(
CodeNode::new(&name, &name, NodeKind::Interface, file_path)
.with_lines(
node.start_position().row as u32 + 1,
node.end_position().row as u32 + 1,
)
.with_bytes(node.start_byte() as u32, node.end_byte() as u32)
.with_column(name_node.start_position().column as u32)
.with_visibility(visibility),
)
}
fn extract_module(node: &Node, source: &str, file_path: &str) -> Option<CodeNode> {
let name_node = node.child_by_field_name("name")?;
let name = get_text(&name_node, source);
let visibility = detect_visibility(node, source);
Some(
CodeNode::new(&name, &name, NodeKind::Module, file_path)
.with_lines(
node.start_position().row as u32 + 1,
node.end_position().row as u32 + 1,
)
.with_bytes(node.start_byte() as u32, node.end_byte() as u32)
.with_column(name_node.start_position().column as u32)
.with_visibility(visibility),
)
}
fn extract_use(node: &Node, source: &str, file_path: &str) -> Option<CodeNode> {
if let Some(arg) = node.child_by_field_name("argument") {
let path = get_text(&arg, source);
return Some(
CodeNode::new(&path, &path, NodeKind::Import, file_path)
.with_lines(
node.start_position().row as u32 + 1,
node.end_position().row as u32 + 1,
)
.with_bytes(node.start_byte() as u32, node.end_byte() as u32),
);
}
None
}
fn extract_const(node: &Node, source: &str, file_path: &str) -> Option<CodeNode> {
let name_node = node.child_by_field_name("name")?;
let name = get_text(&name_node, source);
let visibility = detect_visibility(node, source);
Some(
CodeNode::new(&name, &name, NodeKind::Constant, file_path)
.with_lines(
node.start_position().row as u32 + 1,
node.end_position().row as u32 + 1,
)
.with_bytes(node.start_byte() as u32, node.end_byte() as u32)
.with_column(name_node.start_position().column as u32)
.with_visibility(visibility),
)
}
fn extract_type_alias(node: &Node, source: &str, file_path: &str) -> Option<CodeNode> {
let name_node = node.child_by_field_name("name")?;
let name = get_text(&name_node, source);
let visibility = detect_visibility(node, source);
Some(
CodeNode::new(&name, &name, NodeKind::TypeAlias, file_path)
.with_lines(
node.start_position().row as u32 + 1,
node.end_position().row as u32 + 1,
)
.with_bytes(node.start_byte() as u32, node.end_byte() as u32)
.with_column(name_node.start_position().column as u32)
.with_visibility(visibility),
)
}
fn get_text(node: &Node, source: &str) -> String {
source[node.byte_range()].to_string()
}
fn find_child_by_kind<'a>(node: &'a Node, kind: &str) -> Option<Node<'a>> {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == kind {
return Some(child);
}
}
}
None
}
fn get_impl_target(node: &Node, source: &str) -> Option<String> {
if let Some(type_node) = node.child_by_field_name("type") {
return Some(get_text(&type_node, source));
}
None
}
fn detect_visibility(node: &Node, source: &str) -> Visibility {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "visibility_modifier" {
let text = get_text(&child, source);
if text == "pub" {
return Visibility::Public;
} else if text.contains("crate") || text.contains("super") {
return Visibility::Internal;
}
}
}
}
Visibility::Private
}
fn has_modifier(node: &Node, modifier: &str) -> bool {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == modifier {
return true;
}
}
}
false
}
fn build_function_signature(node: &Node, source: &str, name: &str) -> String {
let params = node
.child_by_field_name("parameters")
.map(|n| get_text(&n, source))
.unwrap_or_else(|| "()".to_string());
let return_type = node
.child_by_field_name("return_type")
.map(|n| get_text(&n, source))
.unwrap_or_default();
if return_type.is_empty() {
format!("fn {}{}", name, params)
} else {
format!("fn {}{} {}", name, params, return_type)
}
}
fn extract_call_references(node: &Node, source: &str) -> Vec<String> {
let mut refs = Vec::new();
collect_calls(node, source, &mut refs);
refs.sort();
refs.dedup();
refs
}
fn collect_calls(node: &Node, source: &str, refs: &mut Vec<String>) {
if node.kind() == "call_expression" {
if let Some(func_node) = node.child_by_field_name("function") {
let call_name = get_text(&func_node, source);
refs.push(call_name);
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
collect_calls(&child, source, refs);
}
}
}
trait CodeNodeExt {
fn with_async_if(self, cond: bool) -> Self;
}
impl CodeNodeExt for CodeNode {
fn with_async_if(self, cond: bool) -> Self {
if cond {
self.as_async()
} else {
self
}
}
}