llmcc_core/
graph_builder.rs

1use std::collections::HashSet;
2use std::marker::PhantomData;
3
4pub use crate::block::{BasicBlock, BlockId, BlockKind, BlockRelation};
5use crate::block::{
6    BlockCall, BlockClass, BlockConst, BlockEnum, BlockField, BlockFunc, BlockImpl, BlockRoot,
7    BlockStmt,
8};
9use crate::block_rel::BlockRelationMap;
10use crate::context::{CompileCtxt, CompileUnit};
11use crate::ir::HirNode;
12use crate::lang_def::LanguageTrait;
13use crate::symbol::{SymId, Symbol};
14use crate::visit::HirVisitor;
15
16#[derive(Debug, Clone)]
17pub struct UnitGraph {
18    /// Compile unit this graph belongs to
19    unit_index: usize,
20    /// Root block ID of this unit
21    root: BlockId,
22    /// Edges of this graph unit
23    edges: BlockRelationMap,
24}
25
26impl UnitGraph {
27    pub fn new(unit_index: usize, root: BlockId, edges: BlockRelationMap) -> Self {
28        Self {
29            unit_index,
30            root,
31            edges,
32        }
33    }
34
35    pub fn unit_index(&self) -> usize {
36        self.unit_index
37    }
38
39    pub fn root(&self) -> BlockId {
40        self.root
41    }
42
43    pub fn edges(&self) -> &BlockRelationMap {
44        &self.edges
45    }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
49pub struct GraphNode {
50    pub unit_index: usize,
51    pub block_id: BlockId,
52}
53
54/// ProjectGraph represents a complete compilation project with all units and their inter-dependencies.
55///
56/// # Overview
57/// ProjectGraph maintains a collection of per-unit compilation graphs (UnitGraph) and facilitates
58/// cross-unit dependency resolution. It provides efficient multi-dimensional indexing for block
59/// lookups by name, kind, unit, and ID, enabling quick context retrieval for LLM consumption.
60///
61/// # Architecture
62/// The graph consists of:
63/// - **UnitGraphs**: One per compilation unit (file), containing blocks and intra-unit relations
64/// - **Block Indexes**: Multi-dimensional indexes via BlockIndexMaps for O(1) to O(log n) lookups
65/// - **Cross-unit Links**: Dependencies tracked between blocks across different units
66///
67/// # Primary Use Cases
68/// 1. **Symbol Resolution**: Find blocks by name across the entire project
69/// 2. **Context Gathering**: Collect all related blocks for code analysis
70/// 3. **LLM Serialization**: Export graph as text or JSON for LLM model consumption
71/// 4. **Dependency Analysis**: Traverse dependency graphs to understand block relationships
72///
73#[derive(Debug)]
74pub struct ProjectGraph<'tcx> {
75    /// Reference to the compilation context containing all symbols, HIR nodes, and blocks
76    pub cc: &'tcx CompileCtxt<'tcx>,
77    /// Per-unit graphs containing blocks and intra-unit relations
78    units: Vec<UnitGraph>,
79}
80
81impl<'tcx> ProjectGraph<'tcx> {
82    pub fn new(cc: &'tcx CompileCtxt<'tcx>) -> Self {
83        Self {
84            cc,
85            units: Vec::new(),
86        }
87    }
88
89    pub fn add_child(&mut self, graph: UnitGraph) {
90        self.units.push(graph);
91    }
92
93    pub fn link_units(&mut self) {
94        if self.units.is_empty() {
95            return;
96        }
97
98        let mut unresolved = self.cc.unresolve_symbols.borrow_mut();
99
100        unresolved.retain(|symbol_ref| {
101            let target = *symbol_ref;
102            let Some(target_block) = target.block_id() else {
103                return false;
104            };
105
106            let dependents: Vec<SymId> = target.depended.borrow().clone();
107            for dependent_id in dependents {
108                let Some(source_symbol) = self.cc.opt_get_symbol(dependent_id) else {
109                    continue;
110                };
111                let Some(from_block) = source_symbol.block_id() else {
112                    continue;
113                };
114                self.add_cross_edge(
115                    source_symbol.unit_index().unwrap(),
116                    target.unit_index().unwrap(),
117                    from_block,
118                    target_block,
119                );
120            }
121
122            false
123        });
124    }
125
126    pub fn units(&self) -> &[UnitGraph] {
127        &self.units
128    }
129
130    pub fn block_by_name(&self, name: &str) -> Option<GraphNode> {
131        let block_indexes = self.cc.block_indexes.borrow();
132        let matches = block_indexes.find_by_name(name);
133
134        matches.first().map(|(unit_index, _, block_id)| GraphNode {
135            unit_index: *unit_index,
136            block_id: *block_id,
137        })
138    }
139
140    pub fn blocks_by_name(&self, name: &str) -> Vec<GraphNode> {
141        let block_indexes = self.cc.block_indexes.borrow();
142        let matches = block_indexes.find_by_name(name);
143
144        matches
145            .into_iter()
146            .map(|(unit_index, _, block_id)| GraphNode {
147                unit_index,
148                block_id,
149            })
150            .collect()
151    }
152
153    pub fn block_by_name_in(&self, unit_index: usize, name: &str) -> Option<GraphNode> {
154        let block_indexes = self.cc.block_indexes.borrow();
155        let matches = block_indexes.find_by_name(name);
156
157        matches
158            .iter()
159            .find(|(u, _, _)| *u == unit_index)
160            .map(|(_, _, block_id)| GraphNode {
161                unit_index,
162                block_id: *block_id,
163            })
164    }
165
166    pub fn blocks_by_kind(&self, block_kind: BlockKind) -> Vec<GraphNode> {
167        let block_indexes = self.cc.block_indexes.borrow();
168        let matches = block_indexes.find_by_kind(block_kind);
169
170        matches
171            .into_iter()
172            .map(|(unit_index, _, block_id)| GraphNode {
173                unit_index,
174                block_id,
175            })
176            .collect()
177    }
178
179    pub fn blocks_by_kind_in(&self, block_kind: BlockKind, unit_index: usize) -> Vec<GraphNode> {
180        let block_indexes = self.cc.block_indexes.borrow();
181        let block_ids = block_indexes.find_by_kind_and_unit(block_kind, unit_index);
182
183        block_ids
184            .into_iter()
185            .map(|block_id| GraphNode {
186                unit_index,
187                block_id,
188            })
189            .collect()
190    }
191
192    pub fn blocks_in(&self, unit_index: usize) -> Vec<GraphNode> {
193        let block_indexes = self.cc.block_indexes.borrow();
194        let matches = block_indexes.find_by_unit(unit_index);
195
196        matches
197            .into_iter()
198            .map(|(_, _, block_id)| GraphNode {
199                unit_index,
200                block_id,
201            })
202            .collect()
203    }
204
205    pub fn block_info(&self, block_id: BlockId) -> Option<(usize, Option<String>, BlockKind)> {
206        let block_indexes = self.cc.block_indexes.borrow();
207        block_indexes.get_block_info(block_id)
208    }
209
210    pub fn find_related_blocks(
211        &self,
212        node: GraphNode,
213        relations: Vec<BlockRelation>,
214    ) -> Vec<GraphNode> {
215        if node.unit_index >= self.units.len() {
216            return Vec::new();
217        }
218
219        let unit = &self.units[node.unit_index];
220        let mut result = Vec::new();
221
222        for relation in relations {
223            match relation {
224                BlockRelation::DependsOn => {
225                    // Get all blocks that this block depends on
226                    let dependencies = unit
227                        .edges
228                        .get_related(node.block_id, BlockRelation::DependsOn);
229                    for dep_block_id in dependencies {
230                        result.push(GraphNode {
231                            unit_index: node.unit_index,
232                            block_id: dep_block_id,
233                        });
234                    }
235                }
236                BlockRelation::DependedBy => {
237                    // Get all blocks that depend on this block
238                    let dependents = unit
239                        .edges
240                        .find_reverse_relations(node.block_id, BlockRelation::DependsOn);
241                    for dep_block_id in dependents {
242                        result.push(GraphNode {
243                            unit_index: node.unit_index,
244                            block_id: dep_block_id,
245                        });
246                    }
247                }
248                BlockRelation::Unknown => {
249                    // Skip unknown relations
250                }
251            }
252        }
253
254        result
255    }
256
257    pub fn find_dpends_blocks_recursive(&self, node: GraphNode) -> HashSet<GraphNode> {
258        let mut visited = HashSet::new();
259        let mut stack = vec![node];
260        let relations = vec![BlockRelation::DependsOn];
261
262        while let Some(current) = stack.pop() {
263            if visited.contains(&current) {
264                continue;
265            }
266            visited.insert(current);
267
268            for related in self.find_related_blocks(current, relations.clone()) {
269                if !visited.contains(&related) {
270                    stack.push(related);
271                }
272            }
273        }
274
275        visited.remove(&node);
276        visited
277    }
278
279    pub fn traverse_bfs<F>(&self, start: GraphNode, mut callback: F)
280    where
281        F: FnMut(GraphNode),
282    {
283        let mut visited = HashSet::new();
284        let mut queue = vec![start];
285        let relations = vec![BlockRelation::DependsOn, BlockRelation::DependedBy];
286
287        while !queue.is_empty() {
288            let current = queue.remove(0);
289            if visited.contains(&current) {
290                continue;
291            }
292            visited.insert(current);
293            callback(current);
294
295            for related in self.find_related_blocks(current, relations.clone()) {
296                if !visited.contains(&related) {
297                    queue.push(related);
298                }
299            }
300        }
301    }
302
303    pub fn traverse_dfs<F>(&self, start: GraphNode, mut callback: F)
304    where
305        F: FnMut(GraphNode),
306    {
307        let mut visited = HashSet::new();
308        self.traverse_dfs_impl(start, &mut visited, &mut callback);
309    }
310
311    fn traverse_dfs_impl<F>(
312        &self,
313        node: GraphNode,
314        visited: &mut HashSet<GraphNode>,
315        callback: &mut F,
316    ) where
317        F: FnMut(GraphNode),
318    {
319        if visited.contains(&node) {
320            return;
321        }
322        visited.insert(node);
323        callback(node);
324
325        let relations = vec![BlockRelation::DependsOn, BlockRelation::DependedBy];
326        for related in self.find_related_blocks(node, relations) {
327            if !visited.contains(&related) {
328                self.traverse_dfs_impl(related, visited, callback);
329            }
330        }
331    }
332
333    pub fn get_block_depends(&self, node: GraphNode) -> HashSet<GraphNode> {
334        if node.unit_index >= self.units.len() {
335            return HashSet::new();
336        }
337
338        let unit = &self.units[node.unit_index];
339        let mut result = HashSet::new();
340        let mut visited = HashSet::new();
341        let mut stack = vec![node.block_id];
342
343        while let Some(current_block) = stack.pop() {
344            if visited.contains(&current_block) {
345                continue;
346            }
347            visited.insert(current_block);
348
349            let dependencies = unit
350                .edges
351                .get_related(current_block, BlockRelation::DependsOn);
352            for dep_block_id in dependencies {
353                if dep_block_id != node.block_id {
354                    result.insert(GraphNode {
355                        unit_index: node.unit_index,
356                        block_id: dep_block_id,
357                    });
358                    stack.push(dep_block_id);
359                }
360            }
361        }
362
363        result
364    }
365
366    pub fn get_block_depended(&self, node: GraphNode) -> HashSet<GraphNode> {
367        if node.unit_index >= self.units.len() {
368            return HashSet::new();
369        }
370
371        let unit = &self.units[node.unit_index];
372        let mut result = HashSet::new();
373        let mut visited = HashSet::new();
374        let mut stack = vec![node.block_id];
375
376        while let Some(current_block) = stack.pop() {
377            if visited.contains(&current_block) {
378                continue;
379            }
380            visited.insert(current_block);
381
382            let dependencies = unit
383                .edges
384                .get_related(current_block, BlockRelation::DependedBy);
385            for dep_block_id in dependencies {
386                if dep_block_id != node.block_id {
387                    result.insert(GraphNode {
388                        unit_index: node.unit_index,
389                        block_id: dep_block_id,
390                    });
391                    stack.push(dep_block_id);
392                }
393            }
394        }
395
396        result
397    }
398
399    fn add_cross_edge(
400        &self,
401        from_idx: usize,
402        to_idx: usize,
403        from_block: BlockId,
404        to_block: BlockId,
405    ) {
406        if from_idx == to_idx {
407            let unit = &self.units[from_idx];
408            if !unit
409                .edges
410                .has_relation(from_block, BlockRelation::DependsOn, to_block)
411            {
412                unit.edges.add_relation(from_block, to_block);
413            }
414            return;
415        }
416
417        let from_unit = &self.units[from_idx];
418        from_unit
419            .edges
420            .add_relation_if_not_exists(from_block, BlockRelation::DependsOn, to_block);
421
422        let to_unit = &self.units[to_idx];
423        to_unit
424            .edges
425            .add_relation_if_not_exists(to_block, BlockRelation::DependedBy, from_block);
426    }
427}
428
429#[derive(Debug)]
430struct GraphBuilder<'tcx, Language> {
431    unit: CompileUnit<'tcx>,
432    root: Option<BlockId>,
433    children_stack: Vec<Vec<BlockId>>,
434    _marker: PhantomData<Language>,
435}
436
437impl<'tcx, Language: LanguageTrait> GraphBuilder<'tcx, Language> {
438    fn new(unit: CompileUnit<'tcx>) -> Self {
439        Self {
440            unit,
441            root: None,
442            children_stack: Vec::new(),
443            _marker: PhantomData,
444        }
445    }
446
447    fn next_id(&self) -> BlockId {
448        self.unit.reserve_block_id()
449    }
450
451    fn create_block(
452        &self,
453        id: BlockId,
454        node: HirNode<'tcx>,
455        kind: BlockKind,
456        parent: Option<BlockId>,
457        children: Vec<BlockId>,
458    ) -> BasicBlock<'tcx> {
459        let arena = &self.unit.cc.block_arena;
460        match kind {
461            BlockKind::Root => {
462                // Extract file_name from HirFile node if available
463                let file_name = node.as_file().map(|file| file.file_path.clone());
464                let block = BlockRoot::from_hir(id, node, parent, children, file_name);
465                BasicBlock::Root(arena.alloc(block))
466            }
467            BlockKind::Func => {
468                let block = BlockFunc::from_hir(id, node, parent, children);
469                BasicBlock::Func(arena.alloc(block))
470            }
471            BlockKind::Class => {
472                let block = BlockClass::from_hir(id, node, parent, children);
473                BasicBlock::Class(arena.alloc(block))
474            }
475            BlockKind::Stmt => {
476                let stmt = BlockStmt::from_hir(id, node, parent, children);
477                BasicBlock::Stmt(arena.alloc(stmt))
478            }
479            BlockKind::Call => {
480                let stmt = BlockCall::from_hir(id, node, parent, children);
481                BasicBlock::Call(arena.alloc(stmt))
482            }
483            BlockKind::Enum => {
484                let enum_ty = BlockEnum::from_hir(id, node, parent, children);
485                BasicBlock::Enum(arena.alloc(enum_ty))
486            }
487            BlockKind::Const => {
488                let stmt = BlockConst::from_hir(id, node, parent, children);
489                BasicBlock::Const(arena.alloc(stmt))
490            }
491            BlockKind::Impl => {
492                let block = BlockImpl::from_hir(id, node, parent, children);
493                BasicBlock::Impl(arena.alloc(block))
494            }
495            BlockKind::Field => {
496                let block = BlockField::from_hir(id, node, parent, children);
497                BasicBlock::Field(arena.alloc(block))
498            }
499            _ => {
500                panic!("unknown block kind: {}", kind)
501            }
502        }
503    }
504
505    fn build_edges(&self, node: HirNode<'tcx>) -> BlockRelationMap {
506        let edges = BlockRelationMap::default();
507        let mut visited = HashSet::new();
508        let mut unresolved = HashSet::new();
509        self.collect_edges(node, &edges, &mut visited, &mut unresolved);
510        edges
511    }
512
513    fn collect_edges(
514        &self,
515        node: HirNode<'tcx>,
516        edges: &BlockRelationMap,
517        visited: &mut HashSet<SymId>,
518        unresolved: &mut HashSet<SymId>,
519    ) {
520        // Try to process symbol dependencies for this node
521        if let Some(scope) = self.unit.opt_get_scope(node.hir_id()) {
522            if let Some(symbol) = scope.symbol() {
523                self.process_symbol(symbol, edges, visited, unresolved);
524            }
525        }
526
527        // Recurse into children
528        for &child_id in node.children() {
529            let child = self.unit.hir_node(child_id);
530            self.collect_edges(child, edges, visited, unresolved);
531        }
532    }
533
534    fn process_symbol(
535        &self,
536        symbol: &'tcx Symbol,
537        edges: &BlockRelationMap,
538        visited: &mut HashSet<SymId>,
539        unresolved: &mut HashSet<SymId>,
540    ) {
541        let symbol_id = symbol.id;
542
543        // Avoid processing the same symbol twice
544        if !visited.insert(symbol_id) {
545            return;
546        }
547
548        let Some(from_block) = symbol.block_id() else {
549            return;
550        };
551
552        for &dep_id in symbol.depends.borrow().iter() {
553            self.link_dependency(dep_id, from_block, edges, unresolved);
554        }
555    }
556    fn link_dependency(
557        &self,
558        dep_id: SymId,
559        from_block: BlockId,
560        edges: &BlockRelationMap,
561        unresolved: &mut HashSet<SymId>,
562    ) {
563        // If target symbol exists and has a block, add the dependency edge
564        if let Some(target_symbol) = self.unit.opt_get_symbol(dep_id) {
565            if let Some(to_block) = target_symbol.block_id() {
566                if !edges.has_relation(from_block, BlockRelation::DependsOn, to_block) {
567                    edges.add_relation(from_block, to_block);
568                }
569                let target_unit = target_symbol.unit_index();
570                if target_unit.is_some()
571                    && target_unit != Some(self.unit.index)
572                    && unresolved.insert(dep_id)
573                {
574                    self.unit.add_unresolved_symbol(target_symbol);
575                }
576                return;
577            }
578
579            // Target symbol exists but block not yet known
580            if unresolved.insert(dep_id) {
581                self.unit.add_unresolved_symbol(target_symbol);
582            }
583            return;
584        }
585
586        // Target symbol not found at all
587        unresolved.insert(dep_id);
588    }
589
590    fn build_block(&mut self, node: HirNode<'tcx>, parent: BlockId, recursive: bool) {
591        let id = self.next_id();
592        let block_kind = Language::block_kind(node.kind_id());
593        assert_ne!(block_kind, BlockKind::Undefined);
594
595        if self.root.is_none() {
596            self.root = Some(id);
597        }
598
599        let children = if recursive {
600            self.children_stack.push(Vec::new());
601            self.visit_children(node, id);
602
603            self.children_stack.pop().unwrap()
604        } else {
605            Vec::new()
606        };
607
608        let block = self.create_block(id, node, block_kind, Some(parent), children);
609        if let Some(scope) = self.unit.opt_get_scope(node.hir_id()) {
610            if let Some(symbol) = scope.symbol() {
611                // Only set the block ID if it hasn't been set before
612                // This prevents impl blocks from overwriting struct block IDs
613                if symbol.block_id().is_none() {
614                    symbol.set_block_id(Some(id));
615                }
616            }
617        }
618        self.unit.insert_block(id, block, parent);
619
620        if let Some(children) = self.children_stack.last_mut() {
621            children.push(id);
622        }
623    }
624}
625
626impl<'tcx, Language: LanguageTrait> HirVisitor<'tcx> for GraphBuilder<'tcx, Language> {
627    fn unit(&self) -> CompileUnit<'tcx> {
628        self.unit
629    }
630
631    fn visit_file(&mut self, node: HirNode<'tcx>, parent: BlockId) {
632        self.children_stack.push(Vec::new());
633        self.build_block(node, parent, true);
634    }
635
636    fn visit_internal(&mut self, node: HirNode<'tcx>, parent: BlockId) {
637        if Language::block_kind(node.kind_id()) != BlockKind::Undefined {
638            self.build_block(node, parent, false);
639        } else {
640            self.visit_children(node, parent);
641        }
642    }
643
644    fn visit_scope(&mut self, node: HirNode<'tcx>, parent: BlockId) {
645        match Language::block_kind(node.kind_id()) {
646            BlockKind::Func
647            | BlockKind::Class
648            | BlockKind::Enum
649            | BlockKind::Const
650            | BlockKind::Impl
651            | BlockKind::Field => self.build_block(node, parent, true),
652            _ => self.visit_children(node, parent),
653        }
654    }
655}
656
657pub fn build_llmcc_graph<'tcx, L: LanguageTrait>(
658    unit: CompileUnit<'tcx>,
659    unit_index: usize,
660) -> Result<UnitGraph, Box<dyn std::error::Error>> {
661    let root_hir = unit
662        .file_start_hir_id()
663        .ok_or("missing file start HIR id")?;
664    let mut builder = GraphBuilder::<L>::new(unit);
665    let root_node = unit.hir_node(root_hir);
666    builder.visit_node(root_node, BlockId::ROOT_PARENT);
667
668    let root_block = builder.root;
669    let root_block = root_block.ok_or("graph builder produced no root")?;
670    let edges = builder.build_edges(root_node);
671    Ok(UnitGraph::new(unit_index, root_block, edges))
672}