ryo-source 0.1.0

High-speed Rust AST manipulation engine
Documentation
//! Definition and Reference analysis.
//!
//! Query operation to find where symbols are defined and referenced.

use std::collections::HashMap;
use syn::visit::Visit;

use crate::ast::RustAST;

/// A location in source code (simplified - just tracks the name).
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct Location {
    /// Symbol name.
    pub name: String,
}

impl Location {
    /// Construct a `Location` from a symbol name.
    pub fn new(name: &str) -> Self {
        Self {
            name: name.to_string(),
        }
    }
}

/// Information about a symbol (variable, function, type, etc.).
#[derive(Debug, Clone)]
pub struct Symbol {
    /// Symbol name.
    pub name: String,
    /// Kind of symbol.
    pub kind: SymbolKind,
    /// Definition location.
    pub definition: Location,
    /// All reference locations.
    pub references: Vec<Location>,
}

/// Kind of symbol.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SymbolKind {
    /// Local variable (let binding).
    LocalVar,
    /// Function parameter.
    Parameter,
    /// Function definition.
    Function,
    /// Struct definition.
    Struct,
    /// Enum definition.
    Enum,
    /// Const or static.
    Const,
    /// Type alias.
    TypeAlias,
    /// Impl block (for a type).
    Impl,
}

/// Definition and Reference finder.
pub struct DefRefs;

impl DefRefs {
    /// Find all symbols and their references in the AST.
    pub fn analyze(ast: &RustAST) -> SymbolTable {
        let mut collector = SymbolCollector::new();
        collector.visit_file(ast.file());
        collector.table
    }

    /// Find definition of a symbol at a given location.
    pub fn find_definition(ast: &RustAST, name: &str) -> Option<Symbol> {
        let table = Self::analyze(ast);
        table.symbols.get(name).cloned()
    }

    /// Find all references to a symbol.
    pub fn find_references(ast: &RustAST, name: &str) -> Vec<Location> {
        let table = Self::analyze(ast);
        table
            .symbols
            .get(name)
            .map(|s| s.references.clone())
            .unwrap_or_default()
    }
}

/// Table of all symbols in a file.
#[derive(Debug, Default)]
pub struct SymbolTable {
    /// All symbols indexed by name.
    pub symbols: HashMap<String, Symbol>,
}

impl SymbolTable {
    /// Get all symbols of a specific kind.
    pub fn by_kind(&self, kind: SymbolKind) -> Vec<&Symbol> {
        self.symbols.values().filter(|s| s.kind == kind).collect()
    }

    /// Get all function definitions.
    pub fn functions(&self) -> Vec<&Symbol> {
        self.by_kind(SymbolKind::Function)
    }

    /// Get all local variables.
    pub fn local_vars(&self) -> Vec<&Symbol> {
        self.by_kind(SymbolKind::LocalVar)
    }
}

/// Visitor that collects symbol definitions and references.
struct SymbolCollector {
    table: SymbolTable,
    /// Stack of scopes (for tracking local variables).
    scopes: Vec<HashMap<String, Location>>,
}

impl SymbolCollector {
    fn new() -> Self {
        Self {
            table: SymbolTable::default(),
            scopes: vec![HashMap::new()], // Global scope
        }
    }

    fn enter_scope(&mut self) {
        self.scopes.push(HashMap::new());
    }

    fn exit_scope(&mut self) {
        self.scopes.pop();
    }

    fn define_symbol(&mut self, name: &str, kind: SymbolKind) {
        let loc = Location::new(name);

        // Add to current scope for locals
        if matches!(kind, SymbolKind::LocalVar | SymbolKind::Parameter) {
            if let Some(scope) = self.scopes.last_mut() {
                scope.insert(name.to_string(), loc.clone());
            }
        }

        // Add to symbol table
        self.table.symbols.insert(
            name.to_string(),
            Symbol {
                name: name.to_string(),
                kind,
                definition: loc,
                references: vec![],
            },
        );
    }

    fn add_reference(&mut self, name: &str) {
        let loc = Location::new(name);
        if let Some(symbol) = self.table.symbols.get_mut(name) {
            symbol.references.push(loc);
        }
    }

    fn is_defined(&self, name: &str) -> bool {
        self.scopes.iter().rev().any(|s| s.contains_key(name))
            || self.table.symbols.contains_key(name)
    }

    /// Visit a pattern and define any bound variables.
    fn define_from_pat(&mut self, pat: &syn::Pat, kind: SymbolKind) {
        match pat {
            syn::Pat::Ident(pat_ident) => {
                self.define_symbol(&pat_ident.ident.to_string(), kind);
            }
            syn::Pat::Tuple(pat_tuple) => {
                for elem in &pat_tuple.elems {
                    self.define_from_pat(elem, kind);
                }
            }
            syn::Pat::TupleStruct(pat_tuple_struct) => {
                for elem in &pat_tuple_struct.elems {
                    self.define_from_pat(elem, kind);
                }
            }
            syn::Pat::Struct(pat_struct) => {
                for field in &pat_struct.fields {
                    self.define_from_pat(&field.pat, kind);
                }
            }
            syn::Pat::Reference(pat_ref) => {
                self.define_from_pat(&pat_ref.pat, kind);
            }
            syn::Pat::Type(pat_type) => {
                self.define_from_pat(&pat_type.pat, kind);
            }
            syn::Pat::Or(pat_or) => {
                for case in &pat_or.cases {
                    self.define_from_pat(case, kind);
                }
            }
            syn::Pat::Slice(pat_slice) => {
                for elem in &pat_slice.elems {
                    self.define_from_pat(elem, kind);
                }
            }
            _ => {}
        }
    }
}

impl<'ast> Visit<'ast> for SymbolCollector {
    fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
        // Define the function
        self.define_symbol(&node.sig.ident.to_string(), SymbolKind::Function);

        // Enter function scope
        self.enter_scope();

        // Define parameters
        for param in &node.sig.inputs {
            if let syn::FnArg::Typed(pat_type) = param {
                self.define_from_pat(&pat_type.pat, SymbolKind::Parameter);
            }
        }

        // Visit function body
        syn::visit::visit_block(self, &node.block);

        self.exit_scope();
    }

    fn visit_local(&mut self, node: &'ast syn::Local) {
        // Visit the init expression first (before defining the variable)
        if let Some(init) = &node.init {
            self.visit_expr(&init.expr);
        }

        // Define local variable(s) from pattern
        self.define_from_pat(&node.pat, SymbolKind::LocalVar);
    }

    fn visit_expr_path(&mut self, node: &'ast syn::ExprPath) {
        // This might be a reference to a variable
        if node.path.segments.len() == 1 {
            let name = node.path.segments[0].ident.to_string();
            if self.is_defined(&name) {
                self.add_reference(&name);
            }
        }
        syn::visit::visit_expr_path(self, node);
    }

    fn visit_item_struct(&mut self, node: &'ast syn::ItemStruct) {
        self.define_symbol(&node.ident.to_string(), SymbolKind::Struct);
        syn::visit::visit_item_struct(self, node);
    }

    fn visit_item_enum(&mut self, node: &'ast syn::ItemEnum) {
        self.define_symbol(&node.ident.to_string(), SymbolKind::Enum);
        syn::visit::visit_item_enum(self, node);
    }

    fn visit_item_const(&mut self, node: &'ast syn::ItemConst) {
        self.define_symbol(&node.ident.to_string(), SymbolKind::Const);
        syn::visit::visit_item_const(self, node);
    }

    fn visit_item_static(&mut self, node: &'ast syn::ItemStatic) {
        self.define_symbol(&node.ident.to_string(), SymbolKind::Const);
        syn::visit::visit_item_static(self, node);
    }

    fn visit_item_type(&mut self, node: &'ast syn::ItemType) {
        self.define_symbol(&node.ident.to_string(), SymbolKind::TypeAlias);
        syn::visit::visit_item_type(self, node);
    }

    fn visit_block(&mut self, node: &'ast syn::Block) {
        self.enter_scope();
        syn::visit::visit_block(self, node);
        self.exit_scope();
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_find_function_def() {
        let ast = RustAST::parse(
            r#"
            fn hello() {}
            fn world() {}
            "#,
        )
        .unwrap();

        let table = DefRefs::analyze(&ast);
        assert!(table.symbols.contains_key("hello"));
        assert!(table.symbols.contains_key("world"));
        assert_eq!(table.functions().len(), 2);
    }

    #[test]
    fn test_find_local_var() {
        let ast = RustAST::parse(
            r#"
            fn main() {
                let x = 1;
                let y = 2;
            }
            "#,
        )
        .unwrap();

        let table = DefRefs::analyze(&ast);
        assert!(table.symbols.contains_key("x"));
        assert!(table.symbols.contains_key("y"));
    }

    #[test]
    fn test_find_references() {
        let ast = RustAST::parse(
            r#"
            fn main() {
                let x = 1;
                let y = x + 1;
                let z = x + y;
            }
            "#,
        )
        .unwrap();

        let refs = DefRefs::find_references(&ast, "x");
        assert_eq!(refs.len(), 2); // x is used twice
    }

    #[test]
    fn test_struct_definition() {
        let ast = RustAST::parse(
            r#"
            struct Point {
                x: i32,
                y: i32,
            }
            "#,
        )
        .unwrap();

        let table = DefRefs::analyze(&ast);
        assert!(table.symbols.contains_key("Point"));
        assert_eq!(table.symbols["Point"].kind, SymbolKind::Struct);
    }

    #[test]
    fn test_symbol_table_by_kind() {
        let ast = RustAST::parse(
            r#"
            struct Foo {}
            enum Bar {}
            fn baz() {
                let x = 1;
            }
            "#,
        )
        .unwrap();

        let table = DefRefs::analyze(&ast);
        assert_eq!(table.by_kind(SymbolKind::Struct).len(), 1);
        assert_eq!(table.by_kind(SymbolKind::Enum).len(), 1);
        assert_eq!(table.by_kind(SymbolKind::Function).len(), 1);
    }
}