use std::collections::HashMap;
use syn::visit::Visit;
use crate::ast::RustAST;
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct Location {
pub name: String,
}
impl Location {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct Symbol {
pub name: String,
pub kind: SymbolKind,
pub definition: Location,
pub references: Vec<Location>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SymbolKind {
LocalVar,
Parameter,
Function,
Struct,
Enum,
Const,
TypeAlias,
Impl,
}
pub struct DefRefs;
impl DefRefs {
pub fn analyze(ast: &RustAST) -> SymbolTable {
let mut collector = SymbolCollector::new();
collector.visit_file(ast.file());
collector.table
}
pub fn find_definition(ast: &RustAST, name: &str) -> Option<Symbol> {
let table = Self::analyze(ast);
table.symbols.get(name).cloned()
}
pub fn find_references(ast: &RustAST, name: &str) -> Vec<Location> {
let table = Self::analyze(ast);
table
.symbols
.get(name)
.map(|s| s.references.clone())
.unwrap_or_default()
}
}
#[derive(Debug, Default)]
pub struct SymbolTable {
pub symbols: HashMap<String, Symbol>,
}
impl SymbolTable {
pub fn by_kind(&self, kind: SymbolKind) -> Vec<&Symbol> {
self.symbols.values().filter(|s| s.kind == kind).collect()
}
pub fn functions(&self) -> Vec<&Symbol> {
self.by_kind(SymbolKind::Function)
}
pub fn local_vars(&self) -> Vec<&Symbol> {
self.by_kind(SymbolKind::LocalVar)
}
}
struct SymbolCollector {
table: SymbolTable,
scopes: Vec<HashMap<String, Location>>,
}
impl SymbolCollector {
fn new() -> Self {
Self {
table: SymbolTable::default(),
scopes: vec![HashMap::new()], }
}
fn enter_scope(&mut self) {
self.scopes.push(HashMap::new());
}
fn exit_scope(&mut self) {
self.scopes.pop();
}
fn define_symbol(&mut self, name: &str, kind: SymbolKind) {
let loc = Location::new(name);
if matches!(kind, SymbolKind::LocalVar | SymbolKind::Parameter) {
if let Some(scope) = self.scopes.last_mut() {
scope.insert(name.to_string(), loc.clone());
}
}
self.table.symbols.insert(
name.to_string(),
Symbol {
name: name.to_string(),
kind,
definition: loc,
references: vec![],
},
);
}
fn add_reference(&mut self, name: &str) {
let loc = Location::new(name);
if let Some(symbol) = self.table.symbols.get_mut(name) {
symbol.references.push(loc);
}
}
fn is_defined(&self, name: &str) -> bool {
self.scopes.iter().rev().any(|s| s.contains_key(name))
|| self.table.symbols.contains_key(name)
}
fn define_from_pat(&mut self, pat: &syn::Pat, kind: SymbolKind) {
match pat {
syn::Pat::Ident(pat_ident) => {
self.define_symbol(&pat_ident.ident.to_string(), kind);
}
syn::Pat::Tuple(pat_tuple) => {
for elem in &pat_tuple.elems {
self.define_from_pat(elem, kind);
}
}
syn::Pat::TupleStruct(pat_tuple_struct) => {
for elem in &pat_tuple_struct.elems {
self.define_from_pat(elem, kind);
}
}
syn::Pat::Struct(pat_struct) => {
for field in &pat_struct.fields {
self.define_from_pat(&field.pat, kind);
}
}
syn::Pat::Reference(pat_ref) => {
self.define_from_pat(&pat_ref.pat, kind);
}
syn::Pat::Type(pat_type) => {
self.define_from_pat(&pat_type.pat, kind);
}
syn::Pat::Or(pat_or) => {
for case in &pat_or.cases {
self.define_from_pat(case, kind);
}
}
syn::Pat::Slice(pat_slice) => {
for elem in &pat_slice.elems {
self.define_from_pat(elem, kind);
}
}
_ => {}
}
}
}
impl<'ast> Visit<'ast> for SymbolCollector {
fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
self.define_symbol(&node.sig.ident.to_string(), SymbolKind::Function);
self.enter_scope();
for param in &node.sig.inputs {
if let syn::FnArg::Typed(pat_type) = param {
self.define_from_pat(&pat_type.pat, SymbolKind::Parameter);
}
}
syn::visit::visit_block(self, &node.block);
self.exit_scope();
}
fn visit_local(&mut self, node: &'ast syn::Local) {
if let Some(init) = &node.init {
self.visit_expr(&init.expr);
}
self.define_from_pat(&node.pat, SymbolKind::LocalVar);
}
fn visit_expr_path(&mut self, node: &'ast syn::ExprPath) {
if node.path.segments.len() == 1 {
let name = node.path.segments[0].ident.to_string();
if self.is_defined(&name) {
self.add_reference(&name);
}
}
syn::visit::visit_expr_path(self, node);
}
fn visit_item_struct(&mut self, node: &'ast syn::ItemStruct) {
self.define_symbol(&node.ident.to_string(), SymbolKind::Struct);
syn::visit::visit_item_struct(self, node);
}
fn visit_item_enum(&mut self, node: &'ast syn::ItemEnum) {
self.define_symbol(&node.ident.to_string(), SymbolKind::Enum);
syn::visit::visit_item_enum(self, node);
}
fn visit_item_const(&mut self, node: &'ast syn::ItemConst) {
self.define_symbol(&node.ident.to_string(), SymbolKind::Const);
syn::visit::visit_item_const(self, node);
}
fn visit_item_static(&mut self, node: &'ast syn::ItemStatic) {
self.define_symbol(&node.ident.to_string(), SymbolKind::Const);
syn::visit::visit_item_static(self, node);
}
fn visit_item_type(&mut self, node: &'ast syn::ItemType) {
self.define_symbol(&node.ident.to_string(), SymbolKind::TypeAlias);
syn::visit::visit_item_type(self, node);
}
fn visit_block(&mut self, node: &'ast syn::Block) {
self.enter_scope();
syn::visit::visit_block(self, node);
self.exit_scope();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_find_function_def() {
let ast = RustAST::parse(
r#"
fn hello() {}
fn world() {}
"#,
)
.unwrap();
let table = DefRefs::analyze(&ast);
assert!(table.symbols.contains_key("hello"));
assert!(table.symbols.contains_key("world"));
assert_eq!(table.functions().len(), 2);
}
#[test]
fn test_find_local_var() {
let ast = RustAST::parse(
r#"
fn main() {
let x = 1;
let y = 2;
}
"#,
)
.unwrap();
let table = DefRefs::analyze(&ast);
assert!(table.symbols.contains_key("x"));
assert!(table.symbols.contains_key("y"));
}
#[test]
fn test_find_references() {
let ast = RustAST::parse(
r#"
fn main() {
let x = 1;
let y = x + 1;
let z = x + y;
}
"#,
)
.unwrap();
let refs = DefRefs::find_references(&ast, "x");
assert_eq!(refs.len(), 2); }
#[test]
fn test_struct_definition() {
let ast = RustAST::parse(
r#"
struct Point {
x: i32,
y: i32,
}
"#,
)
.unwrap();
let table = DefRefs::analyze(&ast);
assert!(table.symbols.contains_key("Point"));
assert_eq!(table.symbols["Point"].kind, SymbolKind::Struct);
}
#[test]
fn test_symbol_table_by_kind() {
let ast = RustAST::parse(
r#"
struct Foo {}
enum Bar {}
fn baz() {
let x = 1;
}
"#,
)
.unwrap();
let table = DefRefs::analyze(&ast);
assert_eq!(table.by_kind(SymbolKind::Struct).len(), 1);
assert_eq!(table.by_kind(SymbolKind::Enum).len(), 1);
assert_eq!(table.by_kind(SymbolKind::Function).len(), 1);
}
}