Skip to main content

hematite/memory/
repo_map.rs

1use anyhow::Result;
2use ignore::WalkBuilder;
3use petgraph::graph::DiGraph;
4use std::collections::{HashMap, HashSet};
5use std::fs;
6use std::path::Path;
7use tree_sitter::{Language, Parser, Query, QueryCursor};
8
9// ── Tag types ─────────────────────────────────────────────────────────────────
10
11struct Tag {
12    rel_path: String,
13}
14
15// ── Tree-sitter query factories ───────────────────────────────────────────────
16
17fn get_rust_def_query() -> Result<(Language, Query)> {
18    let language = tree_sitter_rust::LANGUAGE.into();
19    let query_src = r#"
20        (function_item name: (identifier) @name)
21        (struct_item name: (type_identifier) @name)
22        (impl_item type: (type_identifier) @name)
23        (trait_item name: (type_identifier) @name)
24        (enum_item name: (type_identifier) @name)
25    "#;
26    let query = Query::new(&language, query_src)?;
27    Ok((language, query))
28}
29
30fn get_rust_ref_query() -> Result<(Language, Query)> {
31    let language = tree_sitter_rust::LANGUAGE.into();
32    let query_src = r#"
33        (identifier) @ref
34        (type_identifier) @ref
35        (field_identifier) @ref
36    "#;
37    let query = Query::new(&language, query_src)?;
38    Ok((language, query))
39}
40
41fn get_python_def_query() -> Result<(Language, Query)> {
42    let language = tree_sitter_python::LANGUAGE.into();
43    let query_src = r#"
44        (class_definition name: (identifier) @name)
45        (function_definition name: (identifier) @name)
46    "#;
47    let query = Query::new(&language, query_src)?;
48    Ok((language, query))
49}
50
51fn get_python_ref_query() -> Result<(Language, Query)> {
52    let language = tree_sitter_python::LANGUAGE.into();
53    let query_src = "(identifier) @ref";
54    let query = Query::new(&language, query_src)?;
55    Ok((language, query))
56}
57
58fn get_ts_def_query() -> Result<(Language, Query)> {
59    let language = tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into();
60    let query_src = r#"
61        (interface_declaration name: (type_identifier) @name)
62        (class_declaration name: (type_identifier) @name)
63        (function_declaration name: (identifier) @name)
64    "#;
65    let query = Query::new(&language, query_src)?;
66    Ok((language, query))
67}
68
69fn get_ts_ref_query() -> Result<(Language, Query)> {
70    let language = tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into();
71    let query_src = r#"
72        (identifier) @ref
73        (type_identifier) @ref
74    "#;
75    let query = Query::new(&language, query_src)?;
76    Ok((language, query))
77}
78
79fn get_js_def_query() -> Result<(Language, Query)> {
80    let language = tree_sitter_javascript::LANGUAGE.into();
81    let query_src = r#"
82        (class_declaration name: (identifier) @name)
83        (function_declaration name: (identifier) @name)
84    "#;
85    let query = Query::new(&language, query_src)?;
86    Ok((language, query))
87}
88
89fn get_js_ref_query() -> Result<(Language, Query)> {
90    let language = tree_sitter_javascript::LANGUAGE.into();
91    let query_src = "(identifier) @ref";
92    let query = Query::new(&language, query_src)?;
93    Ok((language, query))
94}
95
96// ── RepoMapGenerator ──────────────────────────────────────────────────────────
97
98pub struct RepoMapGenerator {
99    root: std::path::PathBuf,
100    /// Hot files with normalized heat weights in [0.0, 1.0].
101    hot_files: Vec<(String, f64)>,
102    max_symbols: usize,
103}
104
105impl RepoMapGenerator {
106    pub fn new(root: impl AsRef<Path>) -> Self {
107        Self {
108            root: root.as_ref().to_path_buf(),
109            hot_files: Vec::new(),
110            max_symbols: 1500,
111        }
112    }
113
114    /// Bias PageRank toward files the user is actively editing.
115    /// Accepts paths with normalized heat weights [0.0, 1.0].
116    /// Hottest file gets 100x boost; others scale proportionally.
117    pub fn with_hot_files(mut self, files: &[(String, f64)]) -> Self {
118        self.hot_files = files.to_vec();
119        self
120    }
121
122    pub fn generate(&self) -> Result<String> {
123        // ── Pass 1: Collect defs + refs from every source file ────────────
124        let mut all_tags: Vec<Tag> = Vec::new();
125        // Map: symbol_name → set of files that define it
126        let mut defines: HashMap<String, HashSet<String>> = HashMap::new();
127        // Map: symbol_name → list of files that reference it
128        let mut references: HashMap<String, Vec<String>> = HashMap::new();
129        // Map: (file, symbol_name) → list of definition tag names for display
130        let mut definitions_display: HashMap<String, Vec<String>> = HashMap::new();
131
132        let walker = WalkBuilder::new(&self.root)
133            .hidden(true)
134            .ignore(true)
135            .git_ignore(true)
136            .add_custom_ignore_filename(".hematiteignore")
137            .filter_entry(|entry| {
138                if let Some(name) = entry.file_name().to_str() {
139                    if name == ".git"
140                        || name == "target"
141                        || name == "node_modules"
142                        || name.ends_with(".min.js")
143                    {
144                        return false;
145                    }
146                }
147                true
148            })
149            .build();
150
151        let rust_def = get_rust_def_query().ok();
152        let rust_ref = get_rust_ref_query().ok();
153        let python_def = get_python_def_query().ok();
154        let python_ref = get_python_ref_query().ok();
155        let ts_def = get_ts_def_query().ok();
156        let ts_ref = get_ts_ref_query().ok();
157        let js_def = get_js_def_query().ok();
158        let js_ref = get_js_ref_query().ok();
159
160        for result in walker {
161            let entry = match result {
162                Ok(e) => e,
163                Err(_) => continue,
164            };
165            let path = entry.path();
166            if !path.is_file() {
167                continue;
168            }
169
170            let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
171            let (def_bundle, ref_bundle) = match ext {
172                "rs" => (rust_def.as_ref(), rust_ref.as_ref()),
173                "py" => (python_def.as_ref(), python_ref.as_ref()),
174                "ts" | "tsx" => (ts_def.as_ref(), ts_ref.as_ref()),
175                "js" | "jsx" => (js_def.as_ref(), js_ref.as_ref()),
176                _ => continue,
177            };
178
179            let Ok(source_code) = fs::read_to_string(path) else {
180                continue;
181            };
182
183            let rel_path = path
184                .strip_prefix(&self.root)
185                .unwrap_or(path)
186                .to_string_lossy()
187                .replace('\\', "/");
188
189            // Extract definitions
190            if let Some((lang, query)) = def_bundle {
191                let mut parser = Parser::new();
192                if parser.set_language(lang).is_ok() {
193                    if let Some(tree) = parser.parse(&source_code, None) {
194                        let mut cursor = QueryCursor::new();
195                        let matches =
196                            cursor.matches(query, tree.root_node(), source_code.as_bytes());
197                        for m in matches {
198                            for capture in m.captures {
199                                if let Ok(text) = capture.node.utf8_text(source_code.as_bytes()) {
200                                    let name = text.to_string();
201                                    all_tags.push(Tag {
202                                        rel_path: rel_path.clone(),
203                                    });
204                                    defines
205                                        .entry(name.clone())
206                                        .or_default()
207                                        .insert(rel_path.clone());
208                                    definitions_display
209                                        .entry(rel_path.clone())
210                                        .or_default()
211                                        .push(name);
212                                }
213                            }
214                        }
215                    }
216                }
217            }
218
219            // Extract references
220            if let Some((lang, query)) = ref_bundle {
221                let mut parser = Parser::new();
222                if parser.set_language(lang).is_ok() {
223                    if let Some(tree) = parser.parse(&source_code, None) {
224                        let mut cursor = QueryCursor::new();
225                        let matches =
226                            cursor.matches(query, tree.root_node(), source_code.as_bytes());
227                        let mut seen_refs: HashSet<String> = HashSet::new();
228                        for m in matches {
229                            for capture in m.captures {
230                                if let Ok(text) = capture.node.utf8_text(source_code.as_bytes()) {
231                                    let name = text.to_string();
232                                    // Only count each unique identifier once per file
233                                    if seen_refs.insert(name.clone()) {
234                                        all_tags.push(Tag {
235                                            rel_path: rel_path.clone(),
236                                        });
237                                        references.entry(name).or_default().push(rel_path.clone());
238                                    }
239                                }
240                            }
241                        }
242                    }
243                }
244            }
245        }
246
247        // Deduplicate definition display lists
248        for defs in definitions_display.values_mut() {
249            defs.sort();
250            defs.dedup();
251        }
252
253        // If there are no references at all (e.g. tiny repo), treat defs as refs
254        if references.is_empty() {
255            for (name, files) in &defines {
256                references.insert(name.clone(), files.iter().cloned().collect());
257            }
258        }
259
260        // ── Pass 2: Build the PageRank graph ──────────────────────────────
261        let defined_names: HashSet<&String> = defines.keys().collect();
262        let referenced_names: HashSet<&String> = references.keys().collect();
263        let shared_idents: HashSet<&&String> =
264            defined_names.intersection(&referenced_names).collect();
265
266        // Collect all file paths that appear as nodes
267        let mut all_files: HashSet<String> = HashSet::new();
268        for tag in &all_tags {
269            all_files.insert(tag.rel_path.clone());
270        }
271
272        // Node index map
273        let mut graph = DiGraph::<String, f64>::new();
274        let mut node_map: HashMap<String, petgraph::graph::NodeIndex> = HashMap::new();
275        for file in &all_files {
276            let idx = graph.add_node(file.clone());
277            node_map.insert(file.clone(), idx);
278        }
279
280        // Build edges: referencer → definer
281        for ident in &shared_idents {
282            let ident: &String = ident;
283            let definers = match defines.get(ident) {
284                Some(d) => d,
285                None => continue,
286            };
287            let referencers = match references.get(ident) {
288                Some(r) => r,
289                None => continue,
290            };
291
292            // Weight multiplier based on identifier quality
293            let mut mul: f64 = 1.0;
294            let is_snake = ident.contains('_') && ident.chars().any(|c| c.is_alphabetic());
295            let is_camel =
296                ident.chars().any(|c| c.is_uppercase()) && ident.chars().any(|c| c.is_lowercase());
297            if (is_snake || is_camel) && ident.len() >= 8 {
298                mul *= 10.0;
299            }
300            if ident.starts_with('_') {
301                mul *= 0.1;
302            }
303            // Overly generic names defined in 5+ files get downweighted
304            if definers.len() > 5 {
305                mul *= 0.1;
306            }
307
308            for referencer in referencers {
309                let Some(&src) = node_map.get(referencer) else {
310                    continue;
311                };
312                for definer in definers {
313                    let Some(&dst) = node_map.get(definer) else {
314                        continue;
315                    };
316                    // Accumulate weight on the edge
317                    graph.add_edge(src, dst, mul);
318                }
319            }
320        }
321
322        // ── Pass 3: PageRank ──────────────────────────────────────────────
323        let node_count = graph.node_count();
324        if node_count == 0 {
325            return Ok(
326                "=== Repository Map (Structural Overview) ===\n(no parseable source files found)\n"
327                    .to_string(),
328            );
329        }
330
331        let damping = 0.85;
332        let iterations = 30;
333        let base_score = 1.0 / node_count as f64;
334
335        // Personalization: boost hot files proportional to heat weight.
336        // Hottest file (weight 1.0) gets 100x boost; others scale down linearly.
337        let mut personalization: HashMap<petgraph::graph::NodeIndex, f64> = HashMap::new();
338        let base_boost = 100.0 / node_count.max(1) as f64;
339        for (file, weight) in &self.hot_files {
340            if let Some(&idx) = node_map.get(file.as_str()) {
341                personalization.insert(idx, base_boost * weight);
342            }
343        }
344
345        // Initialize scores
346        let mut scores: HashMap<petgraph::graph::NodeIndex, f64> = HashMap::new();
347        for idx in graph.node_indices() {
348            scores.insert(idx, base_score);
349        }
350
351        // Iterate PageRank
352        for _ in 0..iterations {
353            let mut new_scores: HashMap<petgraph::graph::NodeIndex, f64> = HashMap::new();
354            for idx in graph.node_indices() {
355                new_scores.insert(idx, (1.0 - damping) * base_score);
356            }
357
358            for edge in graph.edge_indices() {
359                let (src, dst) = graph.edge_endpoints(edge).unwrap();
360                let weight = graph[edge];
361                // Total outgoing weight from src
362                let out_weight: f64 = graph.edges(src).map(|e| *e.weight()).sum::<f64>().max(1.0);
363                let contrib = damping * scores[&src] * (weight / out_weight);
364                *new_scores.entry(dst).or_default() += contrib;
365            }
366
367            // Apply personalization
368            for (&idx, &pers) in &personalization {
369                *new_scores.entry(idx).or_default() += pers * base_score;
370            }
371
372            scores = new_scores;
373        }
374
375        // ── Pass 4: Render ranked output ──────────────────────────────────
376        let mut ranked_files: Vec<(String, f64)> = scores
377            .iter()
378            .map(|(&idx, &score)| (graph[idx].clone(), score))
379            .collect();
380        ranked_files.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
381
382        let mut output = String::new();
383        output.push_str("=== Repository Map (Structural Overview) ===\n");
384        let mut total_symbols = 0;
385
386        for (rel_path, _score) in &ranked_files {
387            if total_symbols >= self.max_symbols {
388                output.push_str("... (Repository Map Truncated — showing most important files)\n");
389                break;
390            }
391
392            if let Some(defs) = definitions_display.get(rel_path) {
393                output.push_str(&format!("{}:\n", rel_path));
394                for def in defs {
395                    output.push_str(&format!("  - {}\n", def));
396                    total_symbols += 1;
397                    if total_symbols >= self.max_symbols {
398                        break;
399                    }
400                }
401            }
402        }
403
404        Ok(output)
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use std::fs;
412    use tempfile::tempdir;
413
414    #[test]
415    fn test_repo_map_generation() {
416        let dir = tempdir().unwrap();
417        let file_path = dir.path().join("main.rs");
418
419        let mock_code = r#"
420        struct MyDatabase {
421            id: String,
422        }
423
424        impl MyDatabase {
425            fn save(&self) {}
426        }
427
428        fn launch_system() {}
429        "#;
430
431        fs::write(&file_path, mock_code).unwrap();
432
433        let gen = RepoMapGenerator::new(dir.path());
434        let map = gen.generate().unwrap();
435
436        assert!(map.contains("main.rs:"));
437        assert!(map.contains("MyDatabase"));
438        assert!(map.contains("launch_system"));
439    }
440
441    #[test]
442    fn test_pagerank_orders_central_files_first() {
443        let dir = tempdir().unwrap();
444
445        // "core.rs" defines a struct used everywhere
446        fs::write(
447            dir.path().join("core.rs"),
448            "pub struct Engine {\n    pub id: u32,\n}\n\npub fn init_engine() -> Engine { Engine { id: 0 } }\n",
449        )
450        .unwrap();
451
452        // "user.rs" references Engine
453        fs::write(
454            dir.path().join("user.rs"),
455            "use crate::core::Engine;\n\nfn use_engine(e: Engine) {\n    let _ = e;\n}\n",
456        )
457        .unwrap();
458
459        // "admin.rs" also references Engine
460        fs::write(
461            dir.path().join("admin.rs"),
462            "use crate::core::Engine;\n\nfn admin_engine(e: Engine) {\n    let _ = e;\n}\n",
463        )
464        .unwrap();
465
466        // "leaf.rs" defines something nobody uses
467        fs::write(
468            dir.path().join("leaf.rs"),
469            "fn unused_leaf_function() {}\n\nstruct OrphanStruct {}\n",
470        )
471        .unwrap();
472
473        let gen = RepoMapGenerator::new(dir.path());
474        let map = gen.generate().unwrap();
475
476        // core.rs should appear before leaf.rs because it's referenced by 2 files
477        let core_pos = map.find("core.rs:").unwrap_or(usize::MAX);
478        let leaf_pos = map.find("leaf.rs:").unwrap_or(usize::MAX);
479        assert!(
480            core_pos < leaf_pos,
481            "core.rs (referenced by 2 files) should rank before leaf.rs (referenced by 0). Map:\n{}",
482            map
483        );
484    }
485}