use std::collections::HashMap;
use super::ast::*;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PureSymbol {
pub name: String,
pub kind: PureSymbolKind,
pub ref_count: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PureSymbolKind {
Function,
Struct,
Enum,
Trait,
TypeAlias,
Const,
Static,
LocalVar,
Parameter,
Module,
Impl,
}
#[derive(Debug, Clone, Default)]
pub struct PureSymbolTable {
pub symbols: HashMap<String, PureSymbol>,
}
impl PureSymbolTable {
pub fn new() -> Self {
Self::default()
}
pub fn by_kind(&self, kind: PureSymbolKind) -> Vec<&PureSymbol> {
self.symbols.values().filter(|s| s.kind == kind).collect()
}
pub fn functions(&self) -> Vec<&PureSymbol> {
self.by_kind(PureSymbolKind::Function)
}
pub fn structs(&self) -> Vec<&PureSymbol> {
self.by_kind(PureSymbolKind::Struct)
}
pub fn get(&self, name: &str) -> Option<&PureSymbol> {
self.symbols.get(name)
}
}
pub struct PureDefRefs;
impl PureDefRefs {
pub fn analyze(file: &PureFile) -> PureSymbolTable {
let mut collector = SymbolCollector::new();
collector.visit_file(file);
collector.table
}
pub fn find_definition(file: &PureFile, name: &str) -> Option<PureSymbol> {
let table = Self::analyze(file);
table.symbols.get(name).cloned()
}
pub fn count_references(file: &PureFile, name: &str) -> usize {
let table = Self::analyze(file);
table.symbols.get(name).map(|s| s.ref_count).unwrap_or(0)
}
pub fn all_definitions(file: &PureFile) -> Vec<String> {
let table = Self::analyze(file);
table.symbols.keys().cloned().collect()
}
}
struct SymbolCollector {
table: PureSymbolTable,
refs: HashMap<String, usize>,
}
impl SymbolCollector {
fn new() -> Self {
Self {
table: PureSymbolTable::new(),
refs: HashMap::new(),
}
}
fn define(&mut self, name: &str, kind: PureSymbolKind) {
self.table.symbols.insert(
name.to_string(),
PureSymbol {
name: name.to_string(),
kind,
ref_count: 0,
},
);
}
fn add_ref(&mut self, name: &str) {
*self.refs.entry(name.to_string()).or_insert(0) += 1;
}
fn finalize_refs(&mut self) {
for (name, count) in &self.refs {
if let Some(symbol) = self.table.symbols.get_mut(name) {
symbol.ref_count = *count;
}
}
}
fn visit_file(&mut self, file: &PureFile) {
for item in &file.items {
self.visit_item(item);
}
self.finalize_refs();
}
fn visit_item(&mut self, item: &PureItem) {
match item {
PureItem::Fn(f) => self.visit_fn(f),
PureItem::Struct(s) => self.visit_struct(s),
PureItem::Enum(e) => self.visit_enum(e),
PureItem::Impl(i) => self.visit_impl(i),
PureItem::Trait(t) => self.visit_trait(t),
PureItem::Const(c) => self.visit_const(c),
PureItem::Static(s) => self.visit_static(s),
PureItem::Type(t) => self.visit_type_alias(t),
PureItem::Mod(m) => self.visit_mod(m),
PureItem::Use(_) | PureItem::Macro(_) | PureItem::Other(_) => {}
}
}
fn visit_fn(&mut self, f: &PureFn) {
self.define(&f.name, PureSymbolKind::Function);
for param in &f.params {
if let PureParam::Typed { name, .. } = param {
self.define(name, PureSymbolKind::Parameter);
}
}
self.visit_block(&f.body);
}
fn visit_struct(&mut self, s: &PureStruct) {
self.define(&s.name, PureSymbolKind::Struct);
}
fn visit_enum(&mut self, e: &PureEnum) {
self.define(&e.name, PureSymbolKind::Enum);
}
fn visit_impl(&mut self, i: &PureImpl) {
self.add_ref(&i.self_ty);
for item in &i.items {
if let PureImplItem::Fn(f) = item {
self.visit_fn(f);
}
}
}
fn visit_trait(&mut self, t: &PureTrait) {
self.define(&t.name, PureSymbolKind::Trait);
}
fn visit_const(&mut self, c: &PureConst) {
self.define(&c.name, PureSymbolKind::Const);
if let Some(v) = &c.value {
self.visit_expr(v);
}
}
fn visit_static(&mut self, s: &PureStatic) {
self.define(&s.name, PureSymbolKind::Static);
self.visit_expr(&s.value);
}
fn visit_type_alias(&mut self, t: &PureTypeAlias) {
self.define(&t.name, PureSymbolKind::TypeAlias);
}
fn visit_mod(&mut self, m: &PureMod) {
self.define(&m.name, PureSymbolKind::Module);
for item in &m.items {
self.visit_item(item);
}
}
fn visit_block(&mut self, block: &PureBlock) {
for stmt in &block.stmts {
self.visit_stmt(stmt);
}
}
fn visit_stmt(&mut self, stmt: &PureStmt) {
match stmt {
PureStmt::Local { pattern, init, .. } => {
if let Some(expr) = init {
self.visit_expr(expr);
}
self.define_from_pattern(pattern);
}
PureStmt::Expr(expr) | PureStmt::Semi(expr) => {
self.visit_expr(expr);
}
PureStmt::Item(item) => {
self.visit_item(item);
}
}
}
fn define_from_pattern(&mut self, pattern: &PurePattern) {
match pattern {
PurePattern::Ident { name, .. } => {
self.define(name, PureSymbolKind::LocalVar);
}
PurePattern::Tuple(pats) => {
for pat in pats {
self.define_from_pattern(pat);
}
}
PurePattern::Struct { fields, .. } => {
for (_, pat) in fields {
self.define_from_pattern(pat);
}
}
PurePattern::Ref { pattern, .. } => {
self.define_from_pattern(pattern);
}
PurePattern::Or(pats) => {
if let Some(first) = pats.first() {
self.define_from_pattern(first);
}
}
PurePattern::Slice(pats) => {
for pat in pats {
self.define_from_pattern(pat);
}
}
PurePattern::Wild
| PurePattern::Lit(_)
| PurePattern::Path(_)
| PurePattern::Range { .. }
| PurePattern::Rest
| PurePattern::Other(_) => {}
}
}
fn visit_expr(&mut self, expr: &PureExpr) {
match expr {
PureExpr::Path(path) if !path.contains("::") => {
self.add_ref(path);
}
PureExpr::Binary { left, right, .. } => {
self.visit_expr(left);
self.visit_expr(right);
}
PureExpr::Unary { expr, .. } => {
self.visit_expr(expr);
}
PureExpr::Call { func, args } => {
self.visit_expr(func);
for arg in args {
self.visit_expr(arg);
}
}
PureExpr::MethodCall { receiver, args, .. } => {
self.visit_expr(receiver);
for arg in args {
self.visit_expr(arg);
}
}
PureExpr::Field { expr, .. } => {
self.visit_expr(expr);
}
PureExpr::Index { expr, index } => {
self.visit_expr(expr);
self.visit_expr(index);
}
PureExpr::Block { block, .. } => {
self.visit_block(block);
}
PureExpr::If {
cond,
then_branch,
else_branch,
} => {
self.visit_expr(cond);
self.visit_block(then_branch);
if let Some(else_expr) = else_branch {
self.visit_expr(else_expr);
}
}
PureExpr::Match { expr, arms } => {
self.visit_expr(expr);
for arm in arms {
self.define_from_pattern(&arm.pattern);
if let Some(guard) = &arm.guard {
self.visit_expr(guard);
}
self.visit_expr(&arm.body);
}
}
PureExpr::Loop { body: block, .. } | PureExpr::While { body: block, .. } => {
self.visit_block(block);
}
PureExpr::For {
pat, expr, body, ..
} => {
self.visit_expr(expr);
self.define_from_pattern(pat);
self.visit_block(body);
}
PureExpr::Return(Some(expr))
| PureExpr::Break {
expr: Some(expr), ..
} => {
self.visit_expr(expr);
}
PureExpr::Closure { params, body, .. } => {
for param in params {
self.define_from_pattern(¶m.pattern);
}
self.visit_expr(body);
}
PureExpr::Struct { fields, .. } => {
for (_, expr) in fields {
self.visit_expr(expr);
}
}
PureExpr::Tuple(exprs) | PureExpr::Array(exprs) => {
for expr in exprs {
self.visit_expr(expr);
}
}
PureExpr::Ref { expr, .. } => {
self.visit_expr(expr);
}
PureExpr::Await(expr) | PureExpr::Try(expr) => {
self.visit_expr(expr);
}
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_analyze_functions() {
let file = PureFile::from_source(
r#"
fn foo() {}
fn bar() {}
fn baz() {}
"#,
)
.unwrap();
let table = PureDefRefs::analyze(&file);
assert_eq!(table.functions().len(), 3);
assert!(table.get("foo").is_some());
assert!(table.get("bar").is_some());
assert!(table.get("baz").is_some());
}
#[test]
fn test_analyze_structs() {
let file = PureFile::from_source(
r#"
struct Point { x: i32, y: i32 }
struct Line { start: Point, end: Point }
"#,
)
.unwrap();
let table = PureDefRefs::analyze(&file);
assert_eq!(table.structs().len(), 2);
}
#[test]
fn test_count_references() {
let file = PureFile::from_source(
r#"
fn main() {
let x = 1;
let y = x + 1;
let z = x + y;
}
"#,
)
.unwrap();
let x_refs = PureDefRefs::count_references(&file, "x");
assert_eq!(x_refs, 2);
let y_refs = PureDefRefs::count_references(&file, "y");
assert_eq!(y_refs, 1); }
#[test]
fn test_find_definition() {
let file = PureFile::from_source("fn my_function() {}").unwrap();
let symbol = PureDefRefs::find_definition(&file, "my_function");
assert!(symbol.is_some());
assert_eq!(symbol.unwrap().kind, PureSymbolKind::Function);
}
#[test]
fn test_all_definitions() {
let file = PureFile::from_source(
r#"
struct Foo {}
enum Bar {}
fn baz() {}
const QUX: i32 = 1;
"#,
)
.unwrap();
let defs = PureDefRefs::all_definitions(&file);
assert!(defs.contains(&"Foo".to_string()));
assert!(defs.contains(&"Bar".to_string()));
assert!(defs.contains(&"baz".to_string()));
assert!(defs.contains(&"QUX".to_string()));
}
#[test]
fn test_parallel_analysis() {
use std::sync::Arc;
use std::thread;
let file = PureFile::from_source(
r#"
fn alpha() {}
fn beta() {}
fn gamma() {}
"#,
)
.unwrap();
let shared = Arc::new(file);
let handles: Vec<_> = (0..4)
.map(|_| {
let f = Arc::clone(&shared);
thread::spawn(move || PureDefRefs::analyze(&f).functions().len())
})
.collect();
for handle in handles {
assert_eq!(handle.join().unwrap(), 3);
}
}
}