use anyhow::Result;
use std::collections::HashMap;
use tree_sitter::{Parser, Node};
use std::path::Path;
#[derive(Debug, Clone)]
pub struct Symbol {
pub name: String,
pub kind: SymbolKind,
pub range: Range,
pub children: Vec<Symbol>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SymbolKind {
Function,
Struct,
Enum,
Impl,
Mod,
Const,
Static,
Trait,
Type,
Variable,
}
#[derive(Debug, Clone)]
pub struct Range {
pub start_line: usize,
pub start_col: usize,
pub end_line: usize,
pub end_col: usize,
}
pub struct SemanticAnalyzer {
parser: Parser,
symbol_table: HashMap<String, Vec<Symbol>>,
}
impl SemanticAnalyzer {
pub fn new() -> Result<Self> {
let mut parser = Parser::new();
let language = tree_sitter_rust::LANGUAGE.into();
parser.set_language(&language)?;
Ok(Self {
parser,
symbol_table: HashMap::new(),
})
}
pub fn analyze_file(&mut self, file_path: &Path, source: &str) -> Result<Vec<Symbol>> {
let tree = self.parser.parse(source, None)
.ok_or_else(|| anyhow::anyhow!("Failed to parse file"))?;
let root = tree.root_node();
let symbols = self.extract_symbols(root, source)?;
self.symbol_table.insert(
file_path.to_string_lossy().to_string(),
symbols.clone()
);
Ok(symbols)
}
fn extract_symbols(&self, node: Node, source: &str) -> Result<Vec<Symbol>> {
let mut symbols = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Some(symbol) = self.node_to_symbol(child, source)? {
symbols.push(symbol);
}
}
Ok(symbols)
}
fn node_to_symbol(&self, node: Node, source: &str) -> Result<Option<Symbol>> {
let kind_str = node.kind();
let kind = match kind_str {
"function_item" => SymbolKind::Function,
"struct_item" => SymbolKind::Struct,
"enum_item" => SymbolKind::Enum,
"impl_item" => SymbolKind::Impl,
"mod_item" => SymbolKind::Mod,
"const_item" => SymbolKind::Const,
"static_item" => SymbolKind::Static,
"trait_item" => SymbolKind::Trait,
"type_item" => SymbolKind::Type,
_ => return Ok(None),
};
let name = self.extract_name(node, source)?;
let range = Range {
start_line: node.start_position().row + 1,
start_col: node.start_position().column,
end_line: node.end_position().row + 1,
end_col: node.end_position().column,
};
let children = self.extract_symbols(node, source)?;
Ok(Some(Symbol {
name,
kind,
range,
children,
}))
}
fn extract_name(&self, node: Node, source: &str) -> Result<String> {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" {
let start = child.start_byte();
let end = child.end_byte();
return Ok(source[start..end].to_string());
}
}
Ok("(anonymous)".to_string())
}
pub fn find_symbol_at_position(
&self,
file_path: &Path,
line: usize,
column: usize
) -> Option<&Symbol> {
let symbols = self.symbol_table.get(&file_path.to_string_lossy().to_string())?;
self.find_symbol_in_tree(symbols, line, column)
}
fn find_symbol_in_tree<'a>(
&self,
symbols: &'a [Symbol],
line: usize,
column: usize
) -> Option<&'a Symbol> {
for symbol in symbols {
if self.contains_position(&symbol.range, line, column) {
if let Some(child) = self.find_symbol_in_tree(&symbol.children, line, column) {
return Some(child);
}
return Some(symbol);
}
}
None
}
fn contains_position(&self, range: &Range, line: usize, column: usize) -> bool {
if line < range.start_line || line > range.end_line {
return false;
}
if line == range.start_line && column < range.start_col {
return false;
}
if line == range.end_line && column > range.end_col {
return false;
}
true
}
pub fn get_symbols(&self, file_path: &Path) -> Option<&Vec<Symbol>> {
self.symbol_table.get(&file_path.to_string_lossy().to_string())
}
pub fn detect_dx_patterns(&mut self, source: &str) -> Result<Vec<DxPattern>> {
let tree = self.parser.parse(source, None)
.ok_or_else(|| anyhow::anyhow!("Failed to parse source"))?;
let mut patterns = Vec::new();
let root = tree.root_node();
self.find_dx_elements(root, source, &mut patterns)?;
Ok(patterns)
}
fn find_dx_elements(
&self,
node: Node,
source: &str,
patterns: &mut Vec<DxPattern>
) -> Result<()> {
if node.kind() == "jsx_element" || node.kind() == "jsx_self_closing_element" {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "jsx_opening_element" || child.kind() == "identifier" {
let start = child.start_byte();
let end = child.end_byte();
let text = &source[start..end];
if text.starts_with("dx") || text.contains("<dx") {
let component_name = text
.trim_start_matches('<')
.split(|c: char| c.is_whitespace() || c == '>')
.next()
.unwrap_or("")
.to_string();
if component_name.starts_with("dx") {
patterns.push(DxPattern {
component_name,
line: node.start_position().row + 1,
col: node.start_position().column,
});
}
}
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.find_dx_elements(child, source, patterns)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct DxPattern {
pub component_name: String,
pub line: usize,
pub col: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_analyzer_creation() {
let analyzer = SemanticAnalyzer::new();
assert!(analyzer.is_ok());
}
#[test]
fn test_rust_parsing() {
let mut analyzer = SemanticAnalyzer::new().unwrap();
let source = r#"
fn main() {
println!("Hello");
}
struct MyStruct {
field: i32,
}
"#;
let path = Path::new("test.rs");
let symbols = analyzer.analyze_file(path, source).unwrap();
assert!(!symbols.is_empty());
assert!(symbols.iter().any(|s| s.kind == SymbolKind::Function));
assert!(symbols.iter().any(|s| s.kind == SymbolKind::Struct));
}
#[test]
fn test_nested_symbols() {
let mut analyzer = SemanticAnalyzer::new().unwrap();
let source = r#"
mod my_mod {
struct Inner {
x: i32
}
impl Inner {
fn new() -> Self { Self { x: 0 } }
}
}
"#;
let path = Path::new("nested.rs");
let symbols = analyzer.analyze_file(path, source).unwrap();
let mod_symbol = symbols.iter().find(|s| s.kind == SymbolKind::Mod).unwrap();
assert_eq!(mod_symbol.name, "my_mod");
assert!(!mod_symbol.children.is_empty());
let struct_symbol = mod_symbol.children.iter().find(|s| s.kind == SymbolKind::Struct).unwrap();
assert_eq!(struct_symbol.name, "Inner");
}
#[test]
fn test_find_symbol_at_position() {
let mut analyzer = SemanticAnalyzer::new().unwrap();
let source = r#"
fn target_function() {
// code
}
"#;
let path = Path::new("lookup.rs");
analyzer.analyze_file(path, source).unwrap();
let symbol = analyzer.find_symbol_at_position(path, 2, 15);
assert!(symbol.is_some());
assert_eq!(symbol.unwrap().name, "target_function");
let symbol = analyzer.find_symbol_at_position(path, 10, 0);
assert!(symbol.is_none());
}
}