Skip to main content

hematite/memory/
repo_map.rs

1use anyhow::Result;
2use ignore::WalkBuilder;
3use petgraph::graph::DiGraph;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::Path;
7use tree_sitter::{Language, Parser, Query, QueryCursor};
8
9// ── Tag types ─────────────────────────────────────────────────────────────────
10
11struct Tag {
12    rel_path: String,
13}
14
15// ── Tree-sitter query factories ───────────────────────────────────────────────
16
17fn get_rust_def_query() -> Result<(Language, Query)> {
18    let language = tree_sitter_rust::LANGUAGE.into();
19    let query_src = r#"
20        (function_item name: (identifier) @name)
21        (struct_item name: (type_identifier) @name)
22        (impl_item type: (type_identifier) @name)
23        (trait_item name: (type_identifier) @name)
24        (enum_item name: (type_identifier) @name)
25    "#;
26    let query = Query::new(&language, query_src)?;
27    Ok((language, query))
28}
29
30fn get_rust_ref_query() -> Result<(Language, Query)> {
31    let language = tree_sitter_rust::LANGUAGE.into();
32    let query_src = r#"
33        (identifier) @ref
34        (type_identifier) @ref
35        (field_identifier) @ref
36    "#;
37    let query = Query::new(&language, query_src)?;
38    Ok((language, query))
39}
40
41fn get_python_def_query() -> Result<(Language, Query)> {
42    let language = tree_sitter_python::LANGUAGE.into();
43    let query_src = r#"
44        (class_definition name: (identifier) @name)
45        (function_definition name: (identifier) @name)
46    "#;
47    let query = Query::new(&language, query_src)?;
48    Ok((language, query))
49}
50
51fn get_python_ref_query() -> Result<(Language, Query)> {
52    let language = tree_sitter_python::LANGUAGE.into();
53    let query_src = "(identifier) @ref";
54    let query = Query::new(&language, query_src)?;
55    Ok((language, query))
56}
57
58fn get_ts_def_query() -> Result<(Language, Query)> {
59    let language = tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into();
60    let query_src = r#"
61        (interface_declaration name: (type_identifier) @name)
62        (class_declaration name: (type_identifier) @name)
63        (function_declaration name: (identifier) @name)
64    "#;
65    let query = Query::new(&language, query_src)?;
66    Ok((language, query))
67}
68
69fn get_ts_ref_query() -> Result<(Language, Query)> {
70    let language = tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into();
71    let query_src = r#"
72        (identifier) @ref
73        (type_identifier) @ref
74    "#;
75    let query = Query::new(&language, query_src)?;
76    Ok((language, query))
77}
78
79fn get_js_def_query() -> Result<(Language, Query)> {
80    let language = tree_sitter_javascript::LANGUAGE.into();
81    let query_src = r#"
82        (class_declaration name: (identifier) @name)
83        (function_declaration name: (identifier) @name)
84    "#;
85    let query = Query::new(&language, query_src)?;
86    Ok((language, query))
87}
88
89fn get_js_ref_query() -> Result<(Language, Query)> {
90    let language = tree_sitter_javascript::LANGUAGE.into();
91    let query_src = "(identifier) @ref";
92    let query = Query::new(&language, query_src)?;
93    Ok((language, query))
94}
95
96// ── RepoMapGenerator ──────────────────────────────────────────────────────────
97
98pub struct RepoMapGenerator {
99    root: std::path::PathBuf,
100    /// Hot files with normalized heat weights in [0.0, 1.0].
101    hot_files: Vec<(String, f64)>,
102    max_symbols: usize,
103}
104
105impl RepoMapGenerator {
106    pub fn new(root: impl AsRef<Path>) -> Self {
107        Self {
108            root: root.as_ref().to_path_buf(),
109            hot_files: Vec::new(),
110            max_symbols: 1500,
111        }
112    }
113
114    /// Bias PageRank toward files the user is actively editing.
115    /// Accepts paths with normalized heat weights [0.0, 1.0].
116    /// Hottest file gets 100x boost; others scale proportionally.
117    pub fn with_hot_files(mut self, files: &[(String, f64)]) -> Self {
118        self.hot_files = files.to_vec();
119        self
120    }
121
122    pub fn generate(&self) -> Result<String> {
123        // ── Pass 1: Collect defs + refs from every source file ────────────
124        let mut all_tags: Vec<Tag> = Vec::new();
125        // Map: symbol_name → set of files that define it
126        let mut defines: HashMap<String, HashSet<String>> = HashMap::new();
127        // Map: symbol_name → list of files that reference it
128        let mut references: HashMap<String, Vec<String>> = HashMap::new();
129        // Map: (file, symbol_name) → list of definition tag names for display
130        let mut definitions_display: HashMap<String, Vec<String>> = HashMap::new();
131
132        let walker = WalkBuilder::new(&self.root)
133            .hidden(true)
134            .ignore(true)
135            .git_ignore(true)
136            .add_custom_ignore_filename(".hematiteignore")
137            .filter_entry(|entry| {
138                if let Some(name) = entry.file_name().to_str() {
139                    if name == ".git"
140                        || name == "target"
141                        || name == "node_modules"
142                        || name.ends_with(".min.js")
143                    {
144                        return false;
145                    }
146                }
147                true
148            })
149            .build();
150
151        let rust_def = get_rust_def_query().ok();
152        let rust_ref = get_rust_ref_query().ok();
153        let python_def = get_python_def_query().ok();
154        let python_ref = get_python_ref_query().ok();
155        let ts_def = get_ts_def_query().ok();
156        let ts_ref = get_ts_ref_query().ok();
157        let js_def = get_js_def_query().ok();
158        let js_ref = get_js_ref_query().ok();
159
160        for result in walker {
161            let entry = match result {
162                Ok(e) => e,
163                Err(_) => continue,
164            };
165            let path = entry.path();
166            if !path.is_file() {
167                continue;
168            }
169
170            let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
171            let (def_bundle, ref_bundle) = match ext {
172                "rs" => (rust_def.as_ref(), rust_ref.as_ref()),
173                "py" => (python_def.as_ref(), python_ref.as_ref()),
174                "ts" | "tsx" => (ts_def.as_ref(), ts_ref.as_ref()),
175                "js" | "jsx" => (js_def.as_ref(), js_ref.as_ref()),
176                _ => continue,
177            };
178
179            let Ok(source_code) = fs::read_to_string(path) else {
180                continue;
181            };
182
183            let rel_path = path
184                .strip_prefix(&self.root)
185                .unwrap_or(path)
186                .to_string_lossy()
187                .replace('\\', "/");
188
189            // Extract definitions
190            if let Some((lang, query)) = def_bundle {
191                let mut parser = Parser::new();
192                if parser.set_language(lang).is_ok() {
193                    if let Some(tree) = parser.parse(&source_code, None) {
194                        let mut cursor = QueryCursor::new();
195                        let matches =
196                            cursor.matches(query, tree.root_node(), source_code.as_bytes());
197                        for m in matches {
198                            for capture in m.captures {
199                                if let Ok(text) = capture.node.utf8_text(source_code.as_bytes()) {
200                                    let name = text.to_string();
201                                    all_tags.push(Tag {
202                                        rel_path: rel_path.clone(),
203                                    });
204                                    defines
205                                        .entry(name.clone())
206                                        .or_default()
207                                        .insert(rel_path.clone());
208                                    definitions_display
209                                        .entry(rel_path.clone())
210                                        .or_default()
211                                        .push(name);
212                                }
213                            }
214                        }
215                    }
216                }
217            }
218
219            // Extract references
220            if let Some((lang, query)) = ref_bundle {
221                let mut parser = Parser::new();
222                if parser.set_language(lang).is_ok() {
223                    if let Some(tree) = parser.parse(&source_code, None) {
224                        let mut cursor = QueryCursor::new();
225                        let matches =
226                            cursor.matches(query, tree.root_node(), source_code.as_bytes());
227                        let mut seen_refs: HashSet<String> = HashSet::new();
228                        for m in matches {
229                            for capture in m.captures {
230                                if let Ok(text) = capture.node.utf8_text(source_code.as_bytes()) {
231                                    let name = text.to_string();
232                                    // Only count each unique identifier once per file
233                                    if seen_refs.insert(name.clone()) {
234                                        all_tags.push(Tag {
235                                            rel_path: rel_path.clone(),
236                                        });
237                                        references
238                                            .entry(name)
239                                            .or_default()
240                                            .push(rel_path.clone());
241                                    }
242                                }
243                            }
244                        }
245                    }
246                }
247            }
248        }
249
250        // Deduplicate definition display lists
251        for defs in definitions_display.values_mut() {
252            defs.sort();
253            defs.dedup();
254        }
255
256        // If there are no references at all (e.g. tiny repo), treat defs as refs
257        if references.is_empty() {
258            for (name, files) in &defines {
259                references.insert(name.clone(), files.iter().cloned().collect());
260            }
261        }
262
263        // ── Pass 2: Build the PageRank graph ──────────────────────────────
264        let defined_names: HashSet<&String> = defines.keys().collect();
265        let referenced_names: HashSet<&String> = references.keys().collect();
266        let shared_idents: HashSet<&&String> = defined_names.intersection(&referenced_names).collect();
267
268        // Collect all file paths that appear as nodes
269        let mut all_files: HashSet<String> = HashSet::new();
270        for tag in &all_tags {
271            all_files.insert(tag.rel_path.clone());
272        }
273
274        // Node index map
275        let mut graph = DiGraph::<String, f64>::new();
276        let mut node_map: HashMap<String, petgraph::graph::NodeIndex> = HashMap::new();
277        for file in &all_files {
278            let idx = graph.add_node(file.clone());
279            node_map.insert(file.clone(), idx);
280        }
281
282        // Build edges: referencer → definer
283        for ident in &shared_idents {
284            let ident: &String = ident;
285            let definers = match defines.get(ident) {
286                Some(d) => d,
287                None => continue,
288            };
289            let referencers = match references.get(ident) {
290                Some(r) => r,
291                None => continue,
292            };
293
294            // Weight multiplier based on identifier quality
295            let mut mul: f64 = 1.0;
296            let is_snake = ident.contains('_') && ident.chars().any(|c| c.is_alphabetic());
297            let is_camel = ident.chars().any(|c| c.is_uppercase())
298                && ident.chars().any(|c| c.is_lowercase());
299            if (is_snake || is_camel) && ident.len() >= 8 {
300                mul *= 10.0;
301            }
302            if ident.starts_with('_') {
303                mul *= 0.1;
304            }
305            // Overly generic names defined in 5+ files get downweighted
306            if definers.len() > 5 {
307                mul *= 0.1;
308            }
309
310            for referencer in referencers {
311                let Some(&src) = node_map.get(referencer) else {
312                    continue;
313                };
314                for definer in definers {
315                    let Some(&dst) = node_map.get(definer) else {
316                        continue;
317                    };
318                    // Accumulate weight on the edge
319                    graph.add_edge(src, dst, mul);
320                }
321            }
322        }
323
324        // ── Pass 3: PageRank ──────────────────────────────────────────────
325        let node_count = graph.node_count();
326        if node_count == 0 {
327            return Ok("=== Repository Map (Structural Overview) ===\n(no parseable source files found)\n".to_string());
328        }
329
330        let damping = 0.85;
331        let iterations = 30;
332        let base_score = 1.0 / node_count as f64;
333
334        // Personalization: boost hot files proportional to heat weight.
335        // Hottest file (weight 1.0) gets 100x boost; others scale down linearly.
336        let mut personalization: HashMap<petgraph::graph::NodeIndex, f64> = HashMap::new();
337        let base_boost = 100.0 / node_count.max(1) as f64;
338        for (file, weight) in &self.hot_files {
339            if let Some(&idx) = node_map.get(file.as_str()) {
340                personalization.insert(idx, base_boost * weight);
341            }
342        }
343
344        // Initialize scores
345        let mut scores: HashMap<petgraph::graph::NodeIndex, f64> = HashMap::new();
346        for idx in graph.node_indices() {
347            scores.insert(idx, base_score);
348        }
349
350        // Iterate PageRank
351        for _ in 0..iterations {
352            let mut new_scores: HashMap<petgraph::graph::NodeIndex, f64> = HashMap::new();
353            for idx in graph.node_indices() {
354                new_scores.insert(idx, (1.0 - damping) * base_score);
355            }
356
357            for edge in graph.edge_indices() {
358                let (src, dst) = graph.edge_endpoints(edge).unwrap();
359                let weight = graph[edge];
360                // Total outgoing weight from src
361                let out_weight: f64 = graph
362                    .edges(src)
363                    .map(|e| *e.weight())
364                    .sum::<f64>()
365                    .max(1.0);
366                let contrib = damping * scores[&src] * (weight / out_weight);
367                *new_scores.entry(dst).or_default() += contrib;
368            }
369
370            // Apply personalization
371            for (&idx, &pers) in &personalization {
372                *new_scores.entry(idx).or_default() += pers * base_score;
373            }
374
375            scores = new_scores;
376        }
377
378        // ── Pass 4: Render ranked output ──────────────────────────────────
379        let mut ranked_files: Vec<(String, f64)> = scores
380            .iter()
381            .map(|(&idx, &score)| (graph[idx].clone(), score))
382            .collect();
383        ranked_files.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
384
385        let mut output = String::new();
386        output.push_str("=== Repository Map (Structural Overview) ===\n");
387        let mut total_symbols = 0;
388
389        for (rel_path, _score) in &ranked_files {
390            if total_symbols >= self.max_symbols {
391                output.push_str("... (Repository Map Truncated — showing most important files)\n");
392                break;
393            }
394
395            if let Some(defs) = definitions_display.get(rel_path) {
396                output.push_str(&format!("{}:\n", rel_path));
397                for def in defs {
398                    output.push_str(&format!("  - {}\n", def));
399                    total_symbols += 1;
400                    if total_symbols >= self.max_symbols {
401                        break;
402                    }
403                }
404            }
405        }
406
407        Ok(output)
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use std::fs;
415    use tempfile::tempdir;
416
417    #[test]
418    fn test_repo_map_generation() {
419        let dir = tempdir().unwrap();
420        let file_path = dir.path().join("main.rs");
421
422        let mock_code = r#"
423        struct MyDatabase {
424            id: String,
425        }
426
427        impl MyDatabase {
428            fn save(&self) {}
429        }
430
431        fn launch_system() {}
432        "#;
433
434        fs::write(&file_path, mock_code).unwrap();
435
436        let gen = RepoMapGenerator::new(dir.path());
437        let map = gen.generate().unwrap();
438
439        assert!(map.contains("main.rs:"));
440        assert!(map.contains("MyDatabase"));
441        assert!(map.contains("launch_system"));
442    }
443
444    #[test]
445    fn test_pagerank_orders_central_files_first() {
446        let dir = tempdir().unwrap();
447
448        // "core.rs" defines a struct used everywhere
449        fs::write(
450            dir.path().join("core.rs"),
451            "pub struct Engine {\n    pub id: u32,\n}\n\npub fn init_engine() -> Engine { Engine { id: 0 } }\n",
452        )
453        .unwrap();
454
455        // "user.rs" references Engine
456        fs::write(
457            dir.path().join("user.rs"),
458            "use crate::core::Engine;\n\nfn use_engine(e: Engine) {\n    let _ = e;\n}\n",
459        )
460        .unwrap();
461
462        // "admin.rs" also references Engine
463        fs::write(
464            dir.path().join("admin.rs"),
465            "use crate::core::Engine;\n\nfn admin_engine(e: Engine) {\n    let _ = e;\n}\n",
466        )
467        .unwrap();
468
469        // "leaf.rs" defines something nobody uses
470        fs::write(
471            dir.path().join("leaf.rs"),
472            "fn unused_leaf_function() {}\n\nstruct OrphanStruct {}\n",
473        )
474        .unwrap();
475
476        let gen = RepoMapGenerator::new(dir.path());
477        let map = gen.generate().unwrap();
478
479        // core.rs should appear before leaf.rs because it's referenced by 2 files
480        let core_pos = map.find("core.rs:").unwrap_or(usize::MAX);
481        let leaf_pos = map.find("leaf.rs:").unwrap_or(usize::MAX);
482        assert!(
483            core_pos < leaf_pos,
484            "core.rs (referenced by 2 files) should rank before leaf.rs (referenced by 0). Map:\n{}",
485            map
486        );
487    }
488}