1use std::collections::HashMap;
6
7use super::ast::*;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct PureSymbol {
12 pub name: String,
14 pub kind: PureSymbolKind,
16 pub ref_count: usize,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum PureSymbolKind {
23 Function,
25 Struct,
27 Enum,
29 Trait,
31 TypeAlias,
33 Const,
35 Static,
37 LocalVar,
39 Parameter,
41 Module,
43 Impl,
45}
46
47#[derive(Debug, Clone, Default)]
49pub struct PureSymbolTable {
50 pub symbols: HashMap<String, PureSymbol>,
52}
53
54impl PureSymbolTable {
55 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn by_kind(&self, kind: PureSymbolKind) -> Vec<&PureSymbol> {
62 self.symbols.values().filter(|s| s.kind == kind).collect()
63 }
64
65 pub fn functions(&self) -> Vec<&PureSymbol> {
67 self.by_kind(PureSymbolKind::Function)
68 }
69
70 pub fn structs(&self) -> Vec<&PureSymbol> {
72 self.by_kind(PureSymbolKind::Struct)
73 }
74
75 pub fn get(&self, name: &str) -> Option<&PureSymbol> {
77 self.symbols.get(name)
78 }
79}
80
81pub struct PureDefRefs;
83
84impl PureDefRefs {
85 pub fn analyze(file: &PureFile) -> PureSymbolTable {
87 let mut collector = SymbolCollector::new();
88 collector.visit_file(file);
89 collector.table
90 }
91
92 pub fn find_definition(file: &PureFile, name: &str) -> Option<PureSymbol> {
94 let table = Self::analyze(file);
95 table.symbols.get(name).cloned()
96 }
97
98 pub fn count_references(file: &PureFile, name: &str) -> usize {
100 let table = Self::analyze(file);
101 table.symbols.get(name).map(|s| s.ref_count).unwrap_or(0)
102 }
103
104 pub fn all_definitions(file: &PureFile) -> Vec<String> {
106 let table = Self::analyze(file);
107 table.symbols.keys().cloned().collect()
108 }
109}
110
111struct SymbolCollector {
113 table: PureSymbolTable,
114 refs: HashMap<String, usize>,
116}
117
118impl SymbolCollector {
119 fn new() -> Self {
120 Self {
121 table: PureSymbolTable::new(),
122 refs: HashMap::new(),
123 }
124 }
125
126 fn define(&mut self, name: &str, kind: PureSymbolKind) {
127 self.table.symbols.insert(
128 name.to_string(),
129 PureSymbol {
130 name: name.to_string(),
131 kind,
132 ref_count: 0,
133 },
134 );
135 }
136
137 fn add_ref(&mut self, name: &str) {
138 *self.refs.entry(name.to_string()).or_insert(0) += 1;
139 }
140
141 fn finalize_refs(&mut self) {
142 for (name, count) in &self.refs {
143 if let Some(symbol) = self.table.symbols.get_mut(name) {
144 symbol.ref_count = *count;
145 }
146 }
147 }
148
149 fn visit_file(&mut self, file: &PureFile) {
150 for item in &file.items {
151 self.visit_item(item);
152 }
153 self.finalize_refs();
154 }
155
156 fn visit_item(&mut self, item: &PureItem) {
157 match item {
158 PureItem::Fn(f) => self.visit_fn(f),
159 PureItem::Struct(s) => self.visit_struct(s),
160 PureItem::Enum(e) => self.visit_enum(e),
161 PureItem::Impl(i) => self.visit_impl(i),
162 PureItem::Trait(t) => self.visit_trait(t),
163 PureItem::Const(c) => self.visit_const(c),
164 PureItem::Static(s) => self.visit_static(s),
165 PureItem::Type(t) => self.visit_type_alias(t),
166 PureItem::Mod(m) => self.visit_mod(m),
167 PureItem::Use(_) | PureItem::Macro(_) | PureItem::Other(_) => {}
168 }
169 }
170
171 fn visit_fn(&mut self, f: &PureFn) {
172 self.define(&f.name, PureSymbolKind::Function);
173
174 for param in &f.params {
176 if let PureParam::Typed { name, .. } = param {
177 self.define(name, PureSymbolKind::Parameter);
178 }
179 }
180
181 self.visit_block(&f.body);
183 }
184
185 fn visit_struct(&mut self, s: &PureStruct) {
186 self.define(&s.name, PureSymbolKind::Struct);
187 }
188
189 fn visit_enum(&mut self, e: &PureEnum) {
190 self.define(&e.name, PureSymbolKind::Enum);
191 }
192
193 fn visit_impl(&mut self, i: &PureImpl) {
194 self.add_ref(&i.self_ty);
196
197 for item in &i.items {
198 if let PureImplItem::Fn(f) = item {
199 self.visit_fn(f);
200 }
201 }
202 }
203
204 fn visit_trait(&mut self, t: &PureTrait) {
205 self.define(&t.name, PureSymbolKind::Trait);
206 }
207
208 fn visit_const(&mut self, c: &PureConst) {
209 self.define(&c.name, PureSymbolKind::Const);
210 if let Some(v) = &c.value {
211 self.visit_expr(v);
212 }
213 }
214
215 fn visit_static(&mut self, s: &PureStatic) {
216 self.define(&s.name, PureSymbolKind::Static);
217 self.visit_expr(&s.value);
218 }
219
220 fn visit_type_alias(&mut self, t: &PureTypeAlias) {
221 self.define(&t.name, PureSymbolKind::TypeAlias);
222 }
223
224 fn visit_mod(&mut self, m: &PureMod) {
225 self.define(&m.name, PureSymbolKind::Module);
226 for item in &m.items {
227 self.visit_item(item);
228 }
229 }
230
231 fn visit_block(&mut self, block: &PureBlock) {
232 for stmt in &block.stmts {
233 self.visit_stmt(stmt);
234 }
235 }
236
237 fn visit_stmt(&mut self, stmt: &PureStmt) {
238 match stmt {
239 PureStmt::Local { pattern, init, .. } => {
240 if let Some(expr) = init {
242 self.visit_expr(expr);
243 }
244 self.define_from_pattern(pattern);
246 }
247 PureStmt::Expr(expr) | PureStmt::Semi(expr) => {
248 self.visit_expr(expr);
249 }
250 PureStmt::Item(item) => {
251 self.visit_item(item);
252 }
253 }
254 }
255
256 fn define_from_pattern(&mut self, pattern: &PurePattern) {
257 match pattern {
258 PurePattern::Ident { name, .. } => {
259 self.define(name, PureSymbolKind::LocalVar);
260 }
261 PurePattern::Tuple(pats) => {
262 for pat in pats {
263 self.define_from_pattern(pat);
264 }
265 }
266 PurePattern::Struct { fields, .. } => {
267 for (_, pat) in fields {
268 self.define_from_pattern(pat);
269 }
270 }
271 PurePattern::Ref { pattern, .. } => {
272 self.define_from_pattern(pattern);
273 }
274 PurePattern::Or(pats) => {
275 if let Some(first) = pats.first() {
277 self.define_from_pattern(first);
278 }
279 }
280 PurePattern::Slice(pats) => {
281 for pat in pats {
282 self.define_from_pattern(pat);
283 }
284 }
285 PurePattern::Wild
286 | PurePattern::Lit(_)
287 | PurePattern::Path(_)
288 | PurePattern::Range { .. }
289 | PurePattern::Rest
290 | PurePattern::Other(_) => {}
291 }
292 }
293
294 fn visit_expr(&mut self, expr: &PureExpr) {
295 match expr {
296 PureExpr::Path(path) if !path.contains("::") => {
297 self.add_ref(path);
299 }
300 PureExpr::Binary { left, right, .. } => {
301 self.visit_expr(left);
302 self.visit_expr(right);
303 }
304 PureExpr::Unary { expr, .. } => {
305 self.visit_expr(expr);
306 }
307 PureExpr::Call { func, args } => {
308 self.visit_expr(func);
309 for arg in args {
310 self.visit_expr(arg);
311 }
312 }
313 PureExpr::MethodCall { receiver, args, .. } => {
314 self.visit_expr(receiver);
315 for arg in args {
316 self.visit_expr(arg);
317 }
318 }
319 PureExpr::Field { expr, .. } => {
320 self.visit_expr(expr);
321 }
322 PureExpr::Index { expr, index } => {
323 self.visit_expr(expr);
324 self.visit_expr(index);
325 }
326 PureExpr::Block { block, .. } => {
327 self.visit_block(block);
328 }
329 PureExpr::If {
330 cond,
331 then_branch,
332 else_branch,
333 } => {
334 self.visit_expr(cond);
335 self.visit_block(then_branch);
336 if let Some(else_expr) = else_branch {
337 self.visit_expr(else_expr);
338 }
339 }
340 PureExpr::Match { expr, arms } => {
341 self.visit_expr(expr);
342 for arm in arms {
343 self.define_from_pattern(&arm.pattern);
344 if let Some(guard) = &arm.guard {
345 self.visit_expr(guard);
346 }
347 self.visit_expr(&arm.body);
348 }
349 }
350 PureExpr::Loop { body: block, .. } | PureExpr::While { body: block, .. } => {
351 self.visit_block(block);
352 }
353 PureExpr::For {
354 pat, expr, body, ..
355 } => {
356 self.visit_expr(expr);
357 self.define_from_pattern(pat);
358 self.visit_block(body);
359 }
360 PureExpr::Return(Some(expr))
361 | PureExpr::Break {
362 expr: Some(expr), ..
363 } => {
364 self.visit_expr(expr);
365 }
366 PureExpr::Closure { params, body, .. } => {
367 for param in params {
368 self.define_from_pattern(¶m.pattern);
369 }
370 self.visit_expr(body);
371 }
372 PureExpr::Struct { fields, .. } => {
373 for (_, expr) in fields {
374 self.visit_expr(expr);
375 }
376 }
377 PureExpr::Tuple(exprs) | PureExpr::Array(exprs) => {
378 for expr in exprs {
379 self.visit_expr(expr);
380 }
381 }
382 PureExpr::Ref { expr, .. } => {
383 self.visit_expr(expr);
384 }
385 PureExpr::Await(expr) | PureExpr::Try(expr) => {
386 self.visit_expr(expr);
387 }
388 _ => {}
389 }
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_analyze_functions() {
399 let file = PureFile::from_source(
400 r#"
401 fn foo() {}
402 fn bar() {}
403 fn baz() {}
404 "#,
405 )
406 .unwrap();
407
408 let table = PureDefRefs::analyze(&file);
409 assert_eq!(table.functions().len(), 3);
410 assert!(table.get("foo").is_some());
411 assert!(table.get("bar").is_some());
412 assert!(table.get("baz").is_some());
413 }
414
415 #[test]
416 fn test_analyze_structs() {
417 let file = PureFile::from_source(
418 r#"
419 struct Point { x: i32, y: i32 }
420 struct Line { start: Point, end: Point }
421 "#,
422 )
423 .unwrap();
424
425 let table = PureDefRefs::analyze(&file);
426 assert_eq!(table.structs().len(), 2);
427 }
428
429 #[test]
430 fn test_count_references() {
431 let file = PureFile::from_source(
432 r#"
433 fn main() {
434 let x = 1;
435 let y = x + 1;
436 let z = x + y;
437 }
438 "#,
439 )
440 .unwrap();
441
442 let x_refs = PureDefRefs::count_references(&file, "x");
443 assert_eq!(x_refs, 2); let y_refs = PureDefRefs::count_references(&file, "y");
446 assert_eq!(y_refs, 1); }
448
449 #[test]
450 fn test_find_definition() {
451 let file = PureFile::from_source("fn my_function() {}").unwrap();
452
453 let symbol = PureDefRefs::find_definition(&file, "my_function");
454 assert!(symbol.is_some());
455 assert_eq!(symbol.unwrap().kind, PureSymbolKind::Function);
456 }
457
458 #[test]
459 fn test_all_definitions() {
460 let file = PureFile::from_source(
461 r#"
462 struct Foo {}
463 enum Bar {}
464 fn baz() {}
465 const QUX: i32 = 1;
466 "#,
467 )
468 .unwrap();
469
470 let defs = PureDefRefs::all_definitions(&file);
471 assert!(defs.contains(&"Foo".to_string()));
472 assert!(defs.contains(&"Bar".to_string()));
473 assert!(defs.contains(&"baz".to_string()));
474 assert!(defs.contains(&"QUX".to_string()));
475 }
476
477 #[test]
478 fn test_parallel_analysis() {
479 use std::sync::Arc;
480 use std::thread;
481
482 let file = PureFile::from_source(
483 r#"
484 fn alpha() {}
485 fn beta() {}
486 fn gamma() {}
487 "#,
488 )
489 .unwrap();
490
491 let shared = Arc::new(file);
492
493 let handles: Vec<_> = (0..4)
494 .map(|_| {
495 let f = Arc::clone(&shared);
496 thread::spawn(move || PureDefRefs::analyze(&f).functions().len())
497 })
498 .collect();
499
500 for handle in handles {
501 assert_eq!(handle.join().unwrap(), 3);
502 }
503 }
504}