Skip to main content

hematite/memory/
repo_map.rs

1use anyhow::Result;
2use ignore::WalkBuilder;
3use petgraph::graph::DiGraph;
4use std::collections::{HashMap, HashSet};
5use std::fmt::Write as _;
6use std::fs;
7use std::path::Path;
8use std::sync::OnceLock;
9use tree_sitter::{Language, Parser, Query, QueryCursor};
10
11// ── Cached query bundles ──────────────────────────────────────────────────────
12// Query::new compiles tree-sitter patterns against the grammar — non-trivial.
13// Cache once per process; reuse across every generate() call.
14
15fn rust_def_bundle() -> Option<&'static (Language, Query)> {
16    static Q: OnceLock<Option<(Language, Query)>> = OnceLock::new();
17    Q.get_or_init(|| {
18        let lang: Language = tree_sitter_rust::LANGUAGE.into();
19        let src = r#"
20            (function_item name: (identifier) @name)
21            (struct_item name: (type_identifier) @name)
22            (impl_item type: (type_identifier) @name)
23            (trait_item name: (type_identifier) @name)
24            (enum_item name: (type_identifier) @name)
25        "#;
26        Query::new(&lang, src).ok().map(|q| (lang, q))
27    })
28    .as_ref()
29}
30
31fn rust_ref_bundle() -> Option<&'static (Language, Query)> {
32    static Q: OnceLock<Option<(Language, Query)>> = OnceLock::new();
33    Q.get_or_init(|| {
34        let lang: Language = tree_sitter_rust::LANGUAGE.into();
35        let src = r#"(identifier) @ref (type_identifier) @ref (field_identifier) @ref"#;
36        Query::new(&lang, src).ok().map(|q| (lang, q))
37    })
38    .as_ref()
39}
40
41fn python_def_bundle() -> Option<&'static (Language, Query)> {
42    static Q: OnceLock<Option<(Language, Query)>> = OnceLock::new();
43    Q.get_or_init(|| {
44        let lang: Language = tree_sitter_python::LANGUAGE.into();
45        let src = r#"
46            (class_definition name: (identifier) @name)
47            (function_definition name: (identifier) @name)
48        "#;
49        Query::new(&lang, src).ok().map(|q| (lang, q))
50    })
51    .as_ref()
52}
53
54fn python_ref_bundle() -> Option<&'static (Language, Query)> {
55    static Q: OnceLock<Option<(Language, Query)>> = OnceLock::new();
56    Q.get_or_init(|| {
57        let lang: Language = tree_sitter_python::LANGUAGE.into();
58        Query::new(&lang, "(identifier) @ref")
59            .ok()
60            .map(|q| (lang, q))
61    })
62    .as_ref()
63}
64
65fn ts_def_bundle() -> Option<&'static (Language, Query)> {
66    static Q: OnceLock<Option<(Language, Query)>> = OnceLock::new();
67    Q.get_or_init(|| {
68        let lang: Language = tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into();
69        let src = r#"
70            (interface_declaration name: (type_identifier) @name)
71            (class_declaration name: (type_identifier) @name)
72            (function_declaration name: (identifier) @name)
73        "#;
74        Query::new(&lang, src).ok().map(|q| (lang, q))
75    })
76    .as_ref()
77}
78
79fn ts_ref_bundle() -> Option<&'static (Language, Query)> {
80    static Q: OnceLock<Option<(Language, Query)>> = OnceLock::new();
81    Q.get_or_init(|| {
82        let lang: Language = tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into();
83        let src = r#"(identifier) @ref (type_identifier) @ref"#;
84        Query::new(&lang, src).ok().map(|q| (lang, q))
85    })
86    .as_ref()
87}
88
89fn js_def_bundle() -> Option<&'static (Language, Query)> {
90    static Q: OnceLock<Option<(Language, Query)>> = OnceLock::new();
91    Q.get_or_init(|| {
92        let lang: Language = tree_sitter_javascript::LANGUAGE.into();
93        let src = r#"
94            (class_declaration name: (identifier) @name)
95            (function_declaration name: (identifier) @name)
96        "#;
97        Query::new(&lang, src).ok().map(|q| (lang, q))
98    })
99    .as_ref()
100}
101
102fn js_ref_bundle() -> Option<&'static (Language, Query)> {
103    static Q: OnceLock<Option<(Language, Query)>> = OnceLock::new();
104    Q.get_or_init(|| {
105        let lang: Language = tree_sitter_javascript::LANGUAGE.into();
106        Query::new(&lang, "(identifier) @ref")
107            .ok()
108            .map(|q| (lang, q))
109    })
110    .as_ref()
111}
112
113// ── RepoMapGenerator ──────────────────────────────────────────────────────────
114
115pub struct RepoMapGenerator {
116    root: std::path::PathBuf,
117    /// Hot files with normalized heat weights in [0.0, 1.0].
118    hot_files: Vec<(String, f64)>,
119    max_symbols: usize,
120}
121
122impl RepoMapGenerator {
123    pub fn new(root: impl AsRef<Path>) -> Self {
124        Self {
125            root: root.as_ref().to_path_buf(),
126            hot_files: Vec::new(),
127            max_symbols: 1500,
128        }
129    }
130
131    /// Bias PageRank toward files the user is actively editing.
132    /// Accepts paths with normalized heat weights [0.0, 1.0].
133    /// Hottest file gets 100x boost; others scale proportionally.
134    pub fn with_hot_files(mut self, files: &[(String, f64)]) -> Self {
135        self.hot_files = files.to_vec();
136        self
137    }
138
139    pub fn generate(&self) -> Result<String> {
140        // Map: symbol_name → set of files that define it
141        let mut defines: HashMap<String, HashSet<String>> = HashMap::new();
142        // Map: symbol_name → list of files that reference it
143        let mut references: HashMap<String, Vec<String>> = HashMap::new();
144        // Map: file → list of definition names for display
145        let mut definitions_display: HashMap<String, Vec<String>> = HashMap::new();
146        // All file paths that produced at least one def or ref tag.
147        let mut all_files: HashSet<String> = HashSet::new();
148
149        let walker = WalkBuilder::new(&self.root)
150            .hidden(true)
151            .ignore(true)
152            .git_ignore(true)
153            .add_custom_ignore_filename(".hematiteignore")
154            .filter_entry(|entry| {
155                if let Some(name) = entry.file_name().to_str() {
156                    if name == ".git"
157                        || name == "target"
158                        || name == "node_modules"
159                        || name.ends_with(".min.js")
160                    {
161                        return false;
162                    }
163                }
164                true
165            })
166            .build();
167
168        // One parser reused across all files; set_language() switches between grammars.
169        let mut parser = Parser::new();
170
171        for result in walker {
172            let entry = match result {
173                Ok(e) => e,
174                Err(_) => continue,
175            };
176            let path = entry.path();
177            if !path.is_file() {
178                continue;
179            }
180
181            let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
182            let (def_bundle, ref_bundle) = match ext {
183                "rs" => (rust_def_bundle(), rust_ref_bundle()),
184                "py" => (python_def_bundle(), python_ref_bundle()),
185                "ts" | "tsx" => (ts_def_bundle(), ts_ref_bundle()),
186                "js" | "jsx" => (js_def_bundle(), js_ref_bundle()),
187                _ => continue,
188            };
189
190            let Ok(source_code) = fs::read_to_string(path) else {
191                continue;
192            };
193
194            let rel_path = path
195                .strip_prefix(&self.root)
196                .unwrap_or(path)
197                .to_string_lossy()
198                .replace('\\', "/");
199
200            // Parse each file once and run both def and ref queries on the same tree.
201            match (def_bundle, ref_bundle) {
202                (Some((lang, def_q)), Some((_, ref_q))) => {
203                    if parser.set_language(lang).is_ok() {
204                        if let Some(tree) = parser.parse(&source_code, None) {
205                            // Def pass
206                            let mut cursor = QueryCursor::new();
207                            for m in cursor.matches(def_q, tree.root_node(), source_code.as_bytes())
208                            {
209                                for capture in m.captures {
210                                    if let Ok(text) = capture.node.utf8_text(source_code.as_bytes())
211                                    {
212                                        let name = text.to_string();
213                                        all_files.insert(rel_path.clone());
214                                        defines
215                                            .entry(name.clone())
216                                            .or_default()
217                                            .insert(rel_path.clone());
218                                        definitions_display
219                                            .entry(rel_path.clone())
220                                            .or_default()
221                                            .push(name);
222                                    }
223                                }
224                            }
225
226                            // Ref pass — seen_refs borrows source_code bytes to avoid String clones
227                            let mut cursor = QueryCursor::new();
228                            let mut seen_refs: HashSet<&str> = HashSet::new();
229                            for m in cursor.matches(ref_q, tree.root_node(), source_code.as_bytes())
230                            {
231                                for capture in m.captures {
232                                    if let Ok(text) = capture.node.utf8_text(source_code.as_bytes())
233                                    {
234                                        if seen_refs.insert(text) {
235                                            all_files.insert(rel_path.clone());
236                                            references
237                                                .entry(text.to_string())
238                                                .or_default()
239                                                .push(rel_path.clone());
240                                        }
241                                    }
242                                }
243                            }
244                        }
245                    }
246                }
247                (Some((lang, def_q)), None) => {
248                    if parser.set_language(lang).is_ok() {
249                        if let Some(tree) = parser.parse(&source_code, None) {
250                            let mut cursor = QueryCursor::new();
251                            for m in cursor.matches(def_q, tree.root_node(), source_code.as_bytes())
252                            {
253                                for capture in m.captures {
254                                    if let Ok(text) = capture.node.utf8_text(source_code.as_bytes())
255                                    {
256                                        let name = text.to_string();
257                                        all_files.insert(rel_path.clone());
258                                        defines
259                                            .entry(name.clone())
260                                            .or_default()
261                                            .insert(rel_path.clone());
262                                        definitions_display
263                                            .entry(rel_path.clone())
264                                            .or_default()
265                                            .push(name);
266                                    }
267                                }
268                            }
269                        }
270                    }
271                }
272                (None, Some((lang, ref_q))) => {
273                    if parser.set_language(lang).is_ok() {
274                        if let Some(tree) = parser.parse(&source_code, None) {
275                            let mut cursor = QueryCursor::new();
276                            let mut seen_refs: HashSet<&str> = HashSet::new();
277                            for m in cursor.matches(ref_q, tree.root_node(), source_code.as_bytes())
278                            {
279                                for capture in m.captures {
280                                    if let Ok(text) = capture.node.utf8_text(source_code.as_bytes())
281                                    {
282                                        if seen_refs.insert(text) {
283                                            all_files.insert(rel_path.clone());
284                                            references
285                                                .entry(text.to_string())
286                                                .or_default()
287                                                .push(rel_path.clone());
288                                        }
289                                    }
290                                }
291                            }
292                        }
293                    }
294                }
295                (None, None) => {}
296            }
297        }
298
299        // Deduplicate definition display lists
300        for defs in definitions_display.values_mut() {
301            defs.sort_unstable();
302            defs.dedup();
303        }
304
305        // If there are no references at all (e.g. tiny repo), treat defs as refs
306        if references.is_empty() {
307            for (name, files) in &defines {
308                references.insert(name.clone(), files.iter().cloned().collect());
309            }
310        }
311
312        // ── Build PageRank graph ──────────────────────────────────────────────
313        let defined_names: HashSet<&String> = defines.keys().collect();
314        let referenced_names: HashSet<&String> = references.keys().collect();
315        let shared_idents: HashSet<&&String> =
316            defined_names.intersection(&referenced_names).collect();
317
318        let mut graph = DiGraph::<String, f64>::new();
319        let mut node_map: HashMap<String, petgraph::graph::NodeIndex> =
320            HashMap::with_capacity(all_files.len());
321        for file in &all_files {
322            let idx = graph.add_node(file.clone());
323            node_map.insert(file.clone(), idx);
324        }
325
326        // Build edges: referencer → definer
327        for ident in &shared_idents {
328            let ident: &String = ident;
329            let definers = match defines.get(ident) {
330                Some(d) => d,
331                None => continue,
332            };
333            let referencers = match references.get(ident) {
334                Some(r) => r,
335                None => continue,
336            };
337
338            let mut mul: f64 = 1.0;
339            let is_snake = ident.contains('_') && ident.chars().any(|c| c.is_alphabetic());
340            let is_camel =
341                ident.chars().any(|c| c.is_uppercase()) && ident.chars().any(|c| c.is_lowercase());
342            if (is_snake || is_camel) && ident.len() >= 8 {
343                mul *= 10.0;
344            }
345            if ident.starts_with('_') {
346                mul *= 0.1;
347            }
348            if definers.len() > 5 {
349                mul *= 0.1;
350            }
351
352            for referencer in referencers {
353                let Some(&src) = node_map.get(referencer) else {
354                    continue;
355                };
356                for definer in definers {
357                    let Some(&dst) = node_map.get(definer) else {
358                        continue;
359                    };
360                    graph.add_edge(src, dst, mul);
361                }
362            }
363        }
364
365        // ── PageRank ──────────────────────────────────────────────────────────
366        let node_count = graph.node_count();
367        if node_count == 0 {
368            return Ok(
369                "=== Repository Map (Structural Overview) ===\n(no parseable source files found)\n"
370                    .to_string(),
371            );
372        }
373
374        let damping = 0.85_f64;
375        let iterations = 30_usize;
376        let base_score = 1.0 / node_count as f64;
377        let base_decay = (1.0 - damping) * base_score;
378
379        // Precompute total outgoing edge-weight per node once — constant across iterations.
380        // Avoids O(E × avg_degree) recomputation inside the 30-iteration loop.
381        let out_weights: Vec<f64> = graph
382            .node_indices()
383            .map(|idx| graph.edges(idx).map(|e| *e.weight()).sum::<f64>().max(1.0))
384            .collect();
385
386        // Personalization boosts for hot files, stored as a dense Vec.
387        let base_boost = 100.0 / node_count.max(1) as f64;
388        let mut pers_boosts: Vec<f64> = vec![0.0; node_count];
389        for (file, weight) in &self.hot_files {
390            if let Some(&idx) = node_map.get(file.as_str()) {
391                pers_boosts[idx.index()] = base_boost * weight;
392            }
393        }
394
395        // Vec-based PageRank — replaces 30 HashMap allocations with 2 pre-allocated Vecs.
396        let mut scores: Vec<f64> = vec![base_score; node_count];
397        let mut new_scores: Vec<f64> = vec![0.0; node_count];
398
399        for _ in 0..iterations {
400            new_scores.iter_mut().for_each(|s| *s = base_decay);
401
402            for edge in graph.edge_indices() {
403                let (src, dst) = graph.edge_endpoints(edge).unwrap();
404                let weight = graph[edge];
405                let contrib = damping * scores[src.index()] * (weight / out_weights[src.index()]);
406                new_scores[dst.index()] += contrib;
407            }
408
409            for (i, &pers) in pers_boosts.iter().enumerate() {
410                if pers > 0.0 {
411                    new_scores[i] += pers * base_score;
412                }
413            }
414
415            std::mem::swap(&mut scores, &mut new_scores);
416        }
417
418        // ── Render ranked output ──────────────────────────────────────────────
419        let mut ranked_files: Vec<(String, f64)> = graph
420            .node_indices()
421            .map(|idx| (graph[idx].clone(), scores[idx.index()]))
422            .collect();
423        ranked_files
424            .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
425
426        let mut output = String::with_capacity(self.max_symbols * 40 + 64);
427        output.push_str("=== Repository Map (Structural Overview) ===\n");
428        let mut total_symbols = 0;
429
430        for (rel_path, _score) in &ranked_files {
431            if total_symbols >= self.max_symbols {
432                output.push_str("... (Repository Map Truncated — showing most important files)\n");
433                break;
434            }
435
436            if let Some(defs) = definitions_display.get(rel_path) {
437                let _ = writeln!(output, "{}:", rel_path);
438                for def in defs {
439                    let _ = writeln!(output, "  - {}", def);
440                    total_symbols += 1;
441                    if total_symbols >= self.max_symbols {
442                        break;
443                    }
444                }
445            }
446        }
447
448        Ok(output)
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use std::fs;
456    use tempfile::tempdir;
457
458    #[test]
459    fn test_repo_map_generation() {
460        let dir = tempdir().unwrap();
461        let file_path = dir.path().join("main.rs");
462
463        let mock_code = r#"
464        struct MyDatabase {
465            id: String,
466        }
467
468        impl MyDatabase {
469            fn save(&self) {}
470        }
471
472        fn launch_system() {}
473        "#;
474
475        fs::write(&file_path, mock_code).unwrap();
476
477        let gen = RepoMapGenerator::new(dir.path());
478        let map = gen.generate().unwrap();
479
480        assert!(map.contains("main.rs:"));
481        assert!(map.contains("MyDatabase"));
482        assert!(map.contains("launch_system"));
483    }
484
485    #[test]
486    fn test_pagerank_orders_central_files_first() {
487        let dir = tempdir().unwrap();
488
489        // "core.rs" defines a struct used everywhere
490        fs::write(
491            dir.path().join("core.rs"),
492            "pub struct Engine {\n    pub id: u32,\n}\n\npub fn init_engine() -> Engine { Engine { id: 0 } }\n",
493        )
494        .unwrap();
495
496        // "user.rs" references Engine
497        fs::write(
498            dir.path().join("user.rs"),
499            "use crate::core::Engine;\n\nfn use_engine(e: Engine) {\n    let _ = e;\n}\n",
500        )
501        .unwrap();
502
503        // "admin.rs" also references Engine
504        fs::write(
505            dir.path().join("admin.rs"),
506            "use crate::core::Engine;\n\nfn admin_engine(e: Engine) {\n    let _ = e;\n}\n",
507        )
508        .unwrap();
509
510        // "leaf.rs" defines something nobody uses
511        fs::write(
512            dir.path().join("leaf.rs"),
513            "fn unused_leaf_function() {}\n\nstruct OrphanStruct {}\n",
514        )
515        .unwrap();
516
517        let gen = RepoMapGenerator::new(dir.path());
518        let map = gen.generate().unwrap();
519
520        // core.rs should appear before leaf.rs because it's referenced by 2 files
521        let core_pos = map.find("core.rs:").unwrap_or(usize::MAX);
522        let leaf_pos = map.find("leaf.rs:").unwrap_or(usize::MAX);
523        assert!(
524            core_pos < leaf_pos,
525            "core.rs (referenced by 2 files) should rank before leaf.rs (referenced by 0). Map:\n{}",
526            map
527        );
528    }
529}