infiniloom_engine/
ranking.rs

1//! Symbol importance ranking
2
3#[cfg(test)]
4use crate::types::RepoFile;
5use crate::types::{Repository, Symbol, SymbolKind};
6use std::collections::HashMap;
7
8/// Count references to symbols across all files
9/// This is a second pass that updates Symbol.references based on how many
10/// times each symbol name appears in other files' calls lists
11pub fn count_symbol_references(repo: &mut Repository) {
12    // Build a map of symbol names to their (file_index, symbol_index) locations
13    let mut symbol_locations: HashMap<String, Vec<(usize, usize)>> = HashMap::new();
14
15    for (file_idx, file) in repo.files.iter().enumerate() {
16        for (sym_idx, symbol) in file.symbols.iter().enumerate() {
17            symbol_locations
18                .entry(symbol.name.clone())
19                .or_default()
20                .push((file_idx, sym_idx));
21        }
22    }
23
24    // Count references: for each symbol's calls, increment the referenced symbol's count
25    let mut reference_counts: HashMap<String, u32> = HashMap::new();
26
27    for file in &repo.files {
28        for symbol in &file.symbols {
29            for called_name in &symbol.calls {
30                *reference_counts.entry(called_name.clone()).or_insert(0) += 1;
31            }
32        }
33    }
34
35    // Also count imports as references
36    for file in &repo.files {
37        for symbol in &file.symbols {
38            if symbol.kind == SymbolKind::Import {
39                // The import name might reference another symbol
40                *reference_counts.entry(symbol.name.clone()).or_insert(0) += 1;
41            }
42        }
43    }
44
45    // Update symbol references
46    for file in &mut repo.files {
47        for symbol in &mut file.symbols {
48            if let Some(&count) = reference_counts.get(&symbol.name) {
49                symbol.references = count;
50            }
51        }
52    }
53}
54
55/// Symbol ranker using multiple heuristics
56pub struct SymbolRanker {
57    /// Weight for reference count
58    reference_weight: f32,
59    /// Weight for symbol type
60    type_weight: f32,
61    /// Weight for file importance
62    file_weight: f32,
63    /// Weight for line count (larger = more important)
64    size_weight: f32,
65}
66
67impl Default for SymbolRanker {
68    fn default() -> Self {
69        Self { reference_weight: 0.4, type_weight: 0.25, file_weight: 0.2, size_weight: 0.15 }
70    }
71}
72
73impl SymbolRanker {
74    /// Create a new ranker with default weights
75    pub fn new() -> Self {
76        Self::default()
77    }
78
79    /// Set custom weights
80    pub fn with_weights(mut self, reference: f32, type_: f32, file: f32, size: f32) -> Self {
81        self.reference_weight = reference;
82        self.type_weight = type_;
83        self.file_weight = file;
84        self.size_weight = size;
85        self
86    }
87
88    /// Rank all symbols in a repository
89    pub fn rank(&self, repo: &mut Repository) {
90        // First pass: collect statistics
91        let stats = self.collect_stats(repo);
92
93        // Second pass: compute importance scores
94        for file in &mut repo.files {
95            let file_importance = file.importance;
96
97            for symbol in &mut file.symbols {
98                let score = self.compute_score(symbol, file_importance, &stats);
99                symbol.importance = score;
100            }
101
102            // Update file importance based on symbol importance
103            if !file.symbols.is_empty() {
104                let avg_symbol_importance: f32 =
105                    file.symbols.iter().map(|s| s.importance).sum::<f32>()
106                        / file.symbols.len() as f32;
107                file.importance = (file.importance + avg_symbol_importance) / 2.0;
108            }
109        }
110    }
111
112    fn collect_stats(&self, repo: &Repository) -> RankingStats {
113        let mut stats = RankingStats::default();
114
115        for file in &repo.files {
116            for symbol in &file.symbols {
117                stats.total_symbols += 1;
118                stats.max_references = stats.max_references.max(symbol.references);
119                stats.max_lines = stats.max_lines.max(symbol.line_count());
120
121                *stats.type_counts.entry(symbol.kind).or_insert(0) += 1;
122            }
123        }
124
125        stats
126    }
127
128    fn compute_score(&self, symbol: &Symbol, file_importance: f32, stats: &RankingStats) -> f32 {
129        // Reference score (normalized)
130        let ref_score = if stats.max_references > 0 {
131            symbol.references as f32 / stats.max_references as f32
132        } else {
133            0.0
134        };
135
136        // Type score (based on symbol kind)
137        let type_score = type_importance(symbol.kind);
138
139        // Size score (normalized)
140        let size_score = if stats.max_lines > 0 {
141            (symbol.line_count() as f32 / stats.max_lines as f32).min(1.0)
142        } else {
143            0.0
144        };
145
146        // Combine scores
147        let score = self.reference_weight * ref_score
148            + self.type_weight * type_score
149            + self.file_weight * file_importance
150            + self.size_weight * size_score;
151
152        // Clamp to [0, 1]
153        score.clamp(0.0, 1.0)
154    }
155}
156
157/// Statistics for ranking normalization
158#[derive(Default)]
159struct RankingStats {
160    total_symbols: usize,
161    max_references: u32,
162    max_lines: u32,
163    type_counts: HashMap<SymbolKind, usize>,
164}
165
166/// Get base importance for a symbol kind
167fn type_importance(kind: SymbolKind) -> f32 {
168    match kind {
169        // Entry points and main interfaces are most important
170        SymbolKind::Class | SymbolKind::Interface | SymbolKind::Trait => 1.0,
171        // Public API functions
172        SymbolKind::Function | SymbolKind::Method => 0.8,
173        // Types and structures
174        SymbolKind::Struct | SymbolKind::Enum | SymbolKind::TypeAlias => 0.7,
175        // Constants and exports
176        SymbolKind::Constant | SymbolKind::Export => 0.6,
177        // Modules
178        SymbolKind::Module => 0.5,
179        // Less important
180        SymbolKind::Variable | SymbolKind::Import | SymbolKind::Macro => 0.3,
181    }
182}
183
184/// Rank files by importance using heuristics
185/// Priority: Entry points > Core implementation > Libraries > Config > Tests
186pub fn rank_files(repo: &mut Repository) {
187    // Critical entry point patterns (highest priority)
188    let critical_entry_patterns = [
189        "__main__.py",
190        "main.rs",
191        "main.go",
192        "main.c",
193        "main.cpp",
194        "main.ts",
195        "main.js",
196        "index.ts",
197        "index.js",
198        "index.tsx",
199        "index.jsx",
200        "app.ts",
201        "app.js",
202        "app.py",
203        "app.go",
204        "app.rb",
205        "server.ts",
206        "server.js",
207        "server.py",
208        "server.go",
209        "cli.rs",
210        "cli.ts",
211        "cli.js",
212        "cli.py",
213        "lib.rs",
214        "mod.rs",
215    ];
216
217    // Important implementation directories
218    let core_dirs =
219        ["/src/", "/lib/", "/core/", "/pkg/", "/internal/", "/app/", "/cmd/", "/bin/", "/crates/"];
220
221    // Entry point file prefixes (less specific)
222    let entry_prefixes = [
223        "main.",
224        "index.",
225        "app.",
226        "server.",
227        "cli.",
228        "mod.",
229        "lib.",
230        "init.",
231        "__init__.",
232        "entry.",
233        "bootstrap.",
234    ];
235
236    // Documentation (medium-low importance but still useful)
237    let doc_patterns = ["readme.", "changelog.", "contributing.", "license.", "authors."];
238
239    // Config patterns (low importance - metadata not code)
240    let config_patterns = [
241        "config.",
242        "settings.",
243        ".config",
244        "package.json",
245        "cargo.toml",
246        "pyproject.toml",
247        "setup.py",
248        "setup.cfg",
249        "tsconfig.",
250        "webpack.",
251        ".eslint",
252        ".prettier",
253        "jest.config",
254        "vite.config",
255        ".env",
256        "makefile",
257        "dockerfile",
258        "docker-compose",
259        ".github/",
260        ".gitlab",
261    ];
262
263    // Test patterns (lowest priority for code understanding)
264    // Note: Use both with and without leading slash to match root dirs
265    let test_patterns = [
266        "test_",
267        "_test.",
268        ".test.",
269        ".fixture.",
270        "_fixture.",
271        "spec.",
272        "_spec.",
273        "/tests/",
274        "tests/",
275        "/test/",
276        "test/",
277        "/__tests__/",
278        "__tests__/",
279        "/testing/",
280        "testing/",
281        "/fixtures/",
282        "fixtures/",
283        "/__fixtures__/",
284        "__fixtures__/",
285        "/mocks/",
286        "mocks/",
287        "/__mocks__/",
288        "__mocks__/",
289        "mock_",
290        "_mock.",
291        "/e2e/",
292        "e2e/",
293        "/integration/",
294        "integration/",
295        "/unit/",
296        "unit/",
297        "/examples/",
298        "examples/",
299        "/example/",
300        "example/",
301        "/benchmark/",
302        "benchmark/",
303        "/cypress/",
304        "cypress/",
305        "/playwright/",
306        "playwright/",
307    ];
308
309    // Vendor/generated patterns (exclude or very low priority)
310    let vendor_patterns = [
311        "/vendor/",
312        "vendor/",
313        "/node_modules/",
314        "node_modules/",
315        "/dist/",
316        "dist/",
317        "/build/",
318        "build/",
319        "/target/",
320        "target/",
321        "/__pycache__/",
322        "__pycache__/",
323        "/.next/",
324        ".next/",
325        "/coverage/",
326        "coverage/",
327        "/.cache/",
328        ".cache/",
329        "/generated/",
330        "generated/",
331        "/.generated/",
332        ".generated/",
333        "/gen/",
334        "gen/",
335        ".min.js",
336        ".min.css",
337        ".bundle.",
338        "/benchmarks/",
339        "benchmarks/",
340    ];
341
342    for file in &mut repo.files {
343        let filename = file.filename().to_lowercase();
344        let path = file.relative_path.to_lowercase();
345
346        let mut importance: f32;
347
348        // Check vendor/generated first (exclude from ranking)
349        if vendor_patterns.iter().any(|p| path.contains(p)) {
350            importance = 0.05;
351        }
352        // Check test patterns (low priority)
353        else if test_patterns.iter().any(|p| path.contains(p)) {
354            importance = 0.15;
355        }
356        // Check config patterns (low priority)
357        else if config_patterns
358            .iter()
359            .any(|p| filename.contains(p) || path.contains(p))
360        {
361            importance = 0.25;
362        }
363        // Check doc patterns
364        else if doc_patterns.iter().any(|p| filename.starts_with(p)) {
365            importance = 0.35;
366        }
367        // Check critical entry points (highest priority)
368        else if critical_entry_patterns.iter().any(|p| filename == *p) {
369            importance = 1.0;
370        }
371        // Check entry point prefixes
372        else if entry_prefixes.iter().any(|p| filename.starts_with(p)) {
373            importance = 0.9;
374        }
375        // Check core directories
376        else if core_dirs.iter().any(|p| path.contains(p)) {
377            importance = 0.75;
378        }
379        // Default for other source files
380        else {
381            importance = 0.5;
382        }
383
384        // Only apply boosts if not in test/vendor directories
385        let is_test_or_vendor = vendor_patterns.iter().any(|p| path.contains(p))
386            || test_patterns.iter().any(|p| path.contains(p));
387
388        if !is_test_or_vendor {
389            // Boost based on symbol count (more symbols = more important code)
390            let symbol_boost = (file.symbols.len() as f32 / 50.0).min(0.15);
391            importance = (importance + symbol_boost).min(1.0);
392
393            // Slight boost for files with common implementation names
394            if filename.contains("handler")
395                || filename.contains("service")
396                || filename.contains("controller")
397                || filename.contains("model")
398                || filename.contains("util")
399                || filename.contains("helper")
400                || filename.contains("router")
401                || filename.contains("middleware")
402            {
403                importance = (importance + 0.1).min(1.0);
404            }
405        }
406
407        file.importance = importance;
408    }
409}
410
411/// Sort repository files by importance (highest first)
412pub fn sort_files_by_importance(repo: &mut Repository) {
413    repo.files.sort_by(|a, b| {
414        b.importance
415            .partial_cmp(&a.importance)
416            .unwrap_or(std::cmp::Ordering::Equal)
417    });
418}
419
420#[cfg(test)]
421#[allow(clippy::str_to_string)]
422mod tests {
423    use super::*;
424    use crate::types::TokenCounts;
425
426    #[test]
427    fn test_type_importance() {
428        assert!(type_importance(SymbolKind::Class) > type_importance(SymbolKind::Variable));
429        assert!(type_importance(SymbolKind::Function) > type_importance(SymbolKind::Import));
430    }
431
432    #[test]
433    fn test_ranker() {
434        let mut repo = Repository::new("test", "/tmp/test");
435        repo.files.push(RepoFile {
436            path: "/tmp/test/main.py".into(),
437            relative_path: "main.py".to_string(),
438            language: Some("python".to_string()),
439            size_bytes: 100,
440            token_count: TokenCounts::default(),
441            symbols: vec![
442                Symbol {
443                    name: "main".to_string(),
444                    kind: SymbolKind::Function,
445                    references: 10,
446                    start_line: 1,
447                    end_line: 20,
448                    ..Symbol::new("main", SymbolKind::Function)
449                },
450                Symbol {
451                    name: "helper".to_string(),
452                    kind: SymbolKind::Function,
453                    references: 2,
454                    start_line: 25,
455                    end_line: 30,
456                    ..Symbol::new("helper", SymbolKind::Function)
457                },
458            ],
459            importance: 0.5,
460            content: None,
461        });
462
463        let ranker = SymbolRanker::new();
464        ranker.rank(&mut repo);
465
466        // Main should have higher importance due to more references
467        let main_importance = repo.files[0].symbols[0].importance;
468        let helper_importance = repo.files[0].symbols[1].importance;
469        assert!(main_importance > helper_importance);
470    }
471
472    #[test]
473    fn test_count_symbol_references() {
474        let mut repo = Repository::new("test", "/tmp/test");
475
476        // File 1: defines helper() and calls process()
477        repo.files.push(RepoFile {
478            path: "/tmp/test/utils.py".into(),
479            relative_path: "utils.py".to_string(),
480            language: Some("python".to_string()),
481            size_bytes: 100,
482            token_count: TokenCounts::default(),
483            symbols: vec![Symbol {
484                name: "helper".to_string(),
485                kind: SymbolKind::Function,
486                references: 0, // Not yet counted
487                start_line: 1,
488                end_line: 10,
489                calls: vec!["process".to_string()],
490                ..Symbol::new("helper", SymbolKind::Function)
491            }],
492            importance: 0.5,
493            content: None,
494        });
495
496        // File 2: defines main() and process(), calls helper() twice
497        repo.files.push(RepoFile {
498            path: "/tmp/test/main.py".into(),
499            relative_path: "main.py".to_string(),
500            language: Some("python".to_string()),
501            size_bytes: 100,
502            token_count: TokenCounts::default(),
503            symbols: vec![
504                Symbol {
505                    name: "main".to_string(),
506                    kind: SymbolKind::Function,
507                    references: 0,
508                    start_line: 1,
509                    end_line: 10,
510                    calls: vec!["helper".to_string(), "helper".to_string()],
511                    ..Symbol::new("main", SymbolKind::Function)
512                },
513                Symbol {
514                    name: "process".to_string(),
515                    kind: SymbolKind::Function,
516                    references: 0,
517                    start_line: 15,
518                    end_line: 25,
519                    calls: vec![],
520                    ..Symbol::new("process", SymbolKind::Function)
521                },
522            ],
523            importance: 0.5,
524            content: None,
525        });
526
527        // Count references
528        count_symbol_references(&mut repo);
529
530        // helper is called twice (by main)
531        assert_eq!(repo.files[0].symbols[0].references, 2, "helper should have 2 references");
532
533        // process is called once (by helper)
534        assert_eq!(repo.files[1].symbols[1].references, 1, "process should have 1 reference");
535
536        // main is not called by anyone
537        assert_eq!(repo.files[1].symbols[0].references, 0, "main should have 0 references");
538    }
539
540    #[test]
541    fn test_fixture_files_low_importance() {
542        let mut repo = Repository::new("test", "/tmp/test");
543
544        // Regular source file
545        repo.files.push(RepoFile {
546            path: "/tmp/test/src/api.go".into(),
547            relative_path: "src/api.go".to_string(),
548            language: Some("go".to_string()),
549            size_bytes: 100,
550            token_count: TokenCounts::default(),
551            symbols: vec![],
552            importance: 0.5,
553            content: None,
554        });
555
556        // Fixture file (should get low importance)
557        repo.files.push(RepoFile {
558            path: "/tmp/test/pkg/tools/ReadFile.fixture.go".into(),
559            relative_path: "pkg/tools/ReadFile.fixture.go".to_string(),
560            language: Some("go".to_string()),
561            size_bytes: 100,
562            token_count: TokenCounts::default(),
563            symbols: vec![],
564            importance: 0.5,
565            content: None,
566        });
567
568        // Test file (should get low importance)
569        repo.files.push(RepoFile {
570            path: "/tmp/test/src/api_test.go".into(),
571            relative_path: "src/api_test.go".to_string(),
572            language: Some("go".to_string()),
573            size_bytes: 100,
574            token_count: TokenCounts::default(),
575            symbols: vec![],
576            importance: 0.5,
577            content: None,
578        });
579
580        rank_files(&mut repo);
581
582        let api_importance = repo.files[0].importance;
583        let fixture_importance = repo.files[1].importance;
584        let test_importance = repo.files[2].importance;
585
586        // Source files should have higher importance than fixture/test files
587        assert!(
588            api_importance > fixture_importance,
589            "api.go ({}) should have higher importance than ReadFile.fixture.go ({})",
590            api_importance,
591            fixture_importance
592        );
593        assert!(
594            api_importance > test_importance,
595            "api.go ({}) should have higher importance than api_test.go ({})",
596            api_importance,
597            test_importance
598        );
599
600        // Fixture and test files should have low importance (0.15 = test pattern match)
601        assert!(
602            fixture_importance <= 0.20,
603            "fixture file importance ({}) should be <= 0.20",
604            fixture_importance
605        );
606        assert!(
607            test_importance <= 0.20,
608            "test file importance ({}) should be <= 0.20",
609            test_importance
610        );
611    }
612
613    #[test]
614    fn test_dist_files_low_importance() {
615        let mut repo = Repository::new("test", "/tmp/test");
616
617        // Regular source file
618        repo.files.push(RepoFile {
619            path: "/tmp/test/src/index.ts".into(),
620            relative_path: "src/index.ts".to_string(),
621            language: Some("typescript".to_string()),
622            size_bytes: 100,
623            token_count: TokenCounts::default(),
624            symbols: vec![],
625            importance: 0.5,
626            content: None,
627        });
628
629        // dist file (should get very low importance)
630        repo.files.push(RepoFile {
631            path: "/tmp/test/dist/index.js".into(),
632            relative_path: "dist/index.js".to_string(),
633            language: Some("javascript".to_string()),
634            size_bytes: 100,
635            token_count: TokenCounts::default(),
636            symbols: vec![],
637            importance: 0.5,
638            content: None,
639        });
640
641        // node_modules file (should get very low importance)
642        repo.files.push(RepoFile {
643            path: "/tmp/test/node_modules/pkg/index.js".into(),
644            relative_path: "node_modules/pkg/index.js".to_string(),
645            language: Some("javascript".to_string()),
646            size_bytes: 100,
647            token_count: TokenCounts::default(),
648            symbols: vec![],
649            importance: 0.5,
650            content: None,
651        });
652
653        rank_files(&mut repo);
654
655        let src_importance = repo.files[0].importance;
656        let dist_importance = repo.files[1].importance;
657        let node_modules_importance = repo.files[2].importance;
658
659        // Source files should have higher importance than vendor/dist files
660        assert!(
661            src_importance > dist_importance,
662            "src/index.ts ({}) should have higher importance than dist/index.js ({})",
663            src_importance,
664            dist_importance
665        );
666        assert!(
667            src_importance > node_modules_importance,
668            "src/index.ts ({}) should have higher importance than node_modules file ({})",
669            src_importance,
670            node_modules_importance
671        );
672
673        // dist and node_modules should have very low importance (0.05)
674        assert!(
675            dist_importance <= 0.10,
676            "dist file importance ({}) should be <= 0.10",
677            dist_importance
678        );
679        assert!(
680            node_modules_importance <= 0.10,
681            "node_modules file importance ({}) should be <= 0.10",
682            node_modules_importance
683        );
684    }
685}