llmcc_core/
context.rs

1use rayon::prelude::*;
2use std::cell::{Cell, RefCell};
3use std::collections::{BTreeMap, HashMap};
4use std::ops::Deref;
5use std::path::Path;
6use tree_sitter::Tree;
7
8use crate::block::{Arena as BlockArena, BasicBlock, BlockId, BlockKind};
9use crate::block_rel::BlockRelationMap;
10use crate::file::File;
11use crate::interner::{InternPool, InternedStr};
12use crate::ir::{Arena, HirId, HirNode};
13use crate::lang_def::LanguageTrait;
14use crate::symbol::{Scope, SymId, Symbol};
15
16#[derive(Debug, Copy, Clone)]
17pub struct CompileUnit<'tcx> {
18    pub cc: &'tcx CompileCtxt<'tcx>,
19    pub index: usize,
20}
21
22impl<'tcx> CompileUnit<'tcx> {
23    pub fn file(&self) -> &'tcx File {
24        &self.cc.files[self.index]
25    }
26
27    pub fn tree(&self) -> &'tcx Tree {
28        self.cc.trees[self.index].as_ref().unwrap()
29    }
30
31    /// Access the shared string interner.
32    pub fn interner(&self) -> &InternPool {
33        &self.cc.interner
34    }
35
36    /// Intern a string and return its symbol.
37    pub fn intern_str<S>(&self, value: S) -> InternedStr
38    where
39        S: AsRef<str>,
40    {
41        self.cc.interner.intern(value)
42    }
43
44    /// Resolve an interned symbol into an owned string.
45    pub fn resolve_interned_owned(&self, symbol: InternedStr) -> Option<String> {
46        self.cc.interner.resolve_owned(symbol)
47    }
48
49    pub fn reserve_hir_id(&self) -> HirId {
50        self.cc.reserve_hir_id()
51    }
52
53    pub fn reserve_block_id(&self) -> BlockId {
54        self.cc.reserve_block_id()
55    }
56
57    pub fn register_file_start(&self) -> HirId {
58        let start = self.cc.current_hir_id();
59        self.cc.set_file_start(self.index, start);
60        start
61    }
62
63    pub fn file_start_hir_id(&self) -> Option<HirId> {
64        self.cc.file_start(self.index)
65    }
66
67    pub fn file_path(&self) -> Option<&str> {
68        self.cc.file_path(self.index)
69    }
70
71    /// Get text from the file between start and end byte positions
72    pub fn get_text(&self, start: usize, end: usize) -> String {
73        self.file().get_text(start, end)
74    }
75
76    /// Get a HIR node by ID, returning None if not found
77    pub fn opt_hir_node(self, id: HirId) -> Option<HirNode<'tcx>> {
78        self.cc
79            .hir_map
80            .borrow()
81            .get(&id)
82            .map(|parented| parented.node)
83    }
84
85    /// Get a HIR node by ID, panicking if not found
86    pub fn hir_node(self, id: HirId) -> HirNode<'tcx> {
87        self.opt_hir_node(id)
88            .unwrap_or_else(|| panic!("hir node not found {}", id))
89    }
90
91    /// Get a HIR node by ID, returning None if not found
92    pub fn opt_bb(self, id: BlockId) -> Option<BasicBlock<'tcx>> {
93        self.cc
94            .block_map
95            .borrow()
96            .get(&id)
97            .map(|parented| parented.block.clone())
98    }
99
100    /// Get a HIR node by ID, panicking if not found
101    pub fn bb(self, id: BlockId) -> BasicBlock<'tcx> {
102        self.opt_bb(id)
103            .unwrap_or_else(|| panic!("basic block not found: {}", id))
104    }
105
106    /// Get the parent of a HIR node
107    pub fn parent_node(self, id: HirId) -> Option<HirId> {
108        self.cc
109            .hir_map
110            .borrow()
111            .get(&id)
112            .and_then(|parented| parented.parent())
113    }
114
115    /// Get an existing scope or None if it doesn't exist
116    pub fn opt_get_scope(self, owner: HirId) -> Option<&'tcx Scope<'tcx>> {
117        self.cc.scope_map.borrow().get(&owner).copied()
118    }
119
120    pub fn opt_get_symbol(self, owner: SymId) -> Option<&'tcx Symbol> {
121        self.cc.symbol_map.borrow().get(&owner).copied()
122    }
123
124    /// Get an existing scope or None if it doesn't exist
125    pub fn get_scope(self, owner: HirId) -> &'tcx Scope<'tcx> {
126        self.cc.scope_map.borrow().get(&owner).copied().unwrap()
127    }
128
129    /// Find an existing scope or create a new one
130    pub fn alloc_scope(self, owner: HirId) -> &'tcx Scope<'tcx> {
131        self.cc.alloc_scope(owner)
132    }
133
134    /// Add a HIR node to the map
135    pub fn insert_hir_node(self, id: HirId, node: HirNode<'tcx>) {
136        let parented = ParentedNode::new(node);
137        self.cc.hir_map.borrow_mut().insert(id, parented);
138    }
139
140    /// Get all child nodes of a given parent
141    pub fn children_of(self, parent: HirId) -> Vec<(HirId, HirNode<'tcx>)> {
142        let Some(parent_node) = self.opt_hir_node(parent) else {
143            return Vec::new();
144        };
145        parent_node
146            .children()
147            .iter()
148            .map(|&child_id| (child_id, self.hir_node(child_id)))
149            .collect()
150    }
151
152    /// Walk up the parent chain to find an ancestor of a specific type
153    pub fn find_ancestor<F>(self, mut current: HirId, predicate: F) -> Option<HirId>
154    where
155        F: Fn(&HirNode<'tcx>) -> bool,
156    {
157        while let Some(parent_id) = self.parent_node(current) {
158            if let Some(parent_node) = self.opt_hir_node(parent_id) {
159                if predicate(&parent_node) {
160                    return Some(parent_id);
161                }
162                current = parent_id;
163            } else {
164                break;
165            }
166        }
167        None
168    }
169
170    pub fn add_unresolved_symbol(&self, symbol: &'tcx Symbol) {
171        self.cc.unresolve_symbols.borrow_mut().push(symbol);
172    }
173
174    pub fn insert_block(&self, id: BlockId, block: BasicBlock<'tcx>, parent: BlockId) {
175        let parented = ParentedBlock::new(parent, block.clone());
176        self.cc.block_map.borrow_mut().insert(id, parented);
177
178        // Register the block in the index maps
179        let block_kind = block.kind();
180        let block_name = block
181            .base()
182            .and_then(|base| base.opt_get_name())
183            .map(|s| s.to_string());
184
185        self.cc
186            .block_indexes
187            .borrow_mut()
188            .insert_block(id, block_name, block_kind, self.index);
189    }
190}
191
192impl<'tcx> Deref for CompileUnit<'tcx> {
193    type Target = CompileCtxt<'tcx>;
194
195    #[inline(always)]
196    fn deref(&self) -> &Self::Target {
197        self.cc
198    }
199}
200
201#[derive(Debug, Clone)]
202pub struct ParentedNode<'tcx> {
203    pub node: HirNode<'tcx>,
204}
205
206impl<'tcx> ParentedNode<'tcx> {
207    pub fn new(node: HirNode<'tcx>) -> Self {
208        Self { node }
209    }
210
211    /// Get a reference to the wrapped node
212    pub fn node(&self) -> &HirNode<'tcx> {
213        &self.node
214    }
215
216    /// Get the parent ID
217    pub fn parent(&self) -> Option<HirId> {
218        self.node.parent()
219    }
220}
221
222#[derive(Debug, Clone)]
223pub struct ParentedBlock<'tcx> {
224    pub parent: BlockId,
225    pub block: BasicBlock<'tcx>,
226}
227
228impl<'tcx> ParentedBlock<'tcx> {
229    pub fn new(parent: BlockId, block: BasicBlock<'tcx>) -> Self {
230        Self { parent, block }
231    }
232
233    /// Get a reference to the wrapped node
234    pub fn block(&self) -> &BasicBlock<'tcx> {
235        &self.block
236    }
237
238    /// Get the parent ID
239    pub fn parent(&self) -> BlockId {
240        self.parent
241    }
242}
243
244/// BlockIndexMaps provides efficient lookup of blocks by various indices.
245///
246/// Best practices for usage:
247/// - block_name_index: Use when you want to find blocks by name (multiple blocks can share the same name)
248/// - unit_index_index: Use when you want all blocks in a specific unit
249/// - block_kind_index: Use when you want all blocks of a specific kind (e.g., all functions)
250/// - block_id_index: Use for O(1) lookup of block metadata by BlockId
251///
252/// Important: The "name" field is optional since Root blocks and some other blocks may not have names.
253///
254/// Rationale for data structure choices:
255/// - BTreeMap is used for name and unit indexes for better iteration and range queries
256/// - HashMap is used for kind index since BlockKind doesn't implement Ord
257/// - HashMap is used for block_id_index (direct lookup by BlockId) for O(1) access
258/// - Vec is used for values to handle multiple blocks with the same index (same name/kind/unit)
259#[derive(Debug, Default, Clone)]
260pub struct BlockIndexMaps {
261    /// block_name -> Vec<(unit_index, block_kind, block_id)>
262    /// Multiple blocks can share the same name across units or within the same unit
263    pub block_name_index: BTreeMap<String, Vec<(usize, BlockKind, BlockId)>>,
264
265    /// unit_index -> Vec<(block_name, block_kind, block_id)>
266    /// Allows retrieval of all blocks in a specific compilation unit
267    pub unit_index_map: BTreeMap<usize, Vec<(Option<String>, BlockKind, BlockId)>>,
268
269    /// block_kind -> Vec<(unit_index, block_name, block_id)>
270    /// Allows retrieval of all blocks of a specific kind across all units
271    pub block_kind_index: HashMap<BlockKind, Vec<(usize, Option<String>, BlockId)>>,
272
273    /// block_id -> (unit_index, block_name, block_kind)
274    /// Direct O(1) lookup of block metadata by ID
275    pub block_id_index: HashMap<BlockId, (usize, Option<String>, BlockKind)>,
276}
277
278impl BlockIndexMaps {
279    /// Create a new empty BlockIndexMaps
280    pub fn new() -> Self {
281        Self::default()
282    }
283
284    /// Register a new block in all indexes
285    ///
286    /// # Arguments
287    /// - `block_id`: The unique block identifier
288    /// - `block_name`: Optional name of the block (None for unnamed blocks)
289    /// - `block_kind`: The kind of block (Func, Class, Stmt, etc.)
290    /// - `unit_index`: The compilation unit index this block belongs to
291    pub fn insert_block(
292        &mut self,
293        block_id: BlockId,
294        block_name: Option<String>,
295        block_kind: BlockKind,
296        unit_index: usize,
297    ) {
298        // Insert into block_id_index for O(1) lookups
299        self.block_id_index
300            .insert(block_id, (unit_index, block_name.clone(), block_kind));
301
302        // Insert into block_name_index (if name exists)
303        if let Some(ref name) = block_name {
304            self.block_name_index
305                .entry(name.clone())
306                .or_default()
307                .push((unit_index, block_kind, block_id));
308        }
309
310        // Insert into unit_index_map
311        self.unit_index_map.entry(unit_index).or_default().push((
312            block_name.clone(),
313            block_kind,
314            block_id,
315        ));
316
317        // Insert into block_kind_index
318        self.block_kind_index
319            .entry(block_kind)
320            .or_default()
321            .push((unit_index, block_name, block_id));
322    }
323
324    /// Find all blocks with a given name (may return multiple blocks)
325    ///
326    /// Returns a vector of (unit_index, block_kind, block_id) tuples
327    pub fn find_by_name(&self, name: &str) -> Vec<(usize, BlockKind, BlockId)> {
328        self.block_name_index.get(name).cloned().unwrap_or_default()
329    }
330
331    /// Find all blocks in a specific unit
332    ///
333    /// Returns a vector of (block_name, block_kind, block_id) tuples
334    pub fn find_by_unit(&self, unit_index: usize) -> Vec<(Option<String>, BlockKind, BlockId)> {
335        self.unit_index_map
336            .get(&unit_index)
337            .cloned()
338            .unwrap_or_default()
339    }
340
341    /// Find all blocks of a specific kind across all units
342    ///
343    /// Returns a vector of (unit_index, block_name, block_id) tuples
344    pub fn find_by_kind(&self, block_kind: BlockKind) -> Vec<(usize, Option<String>, BlockId)> {
345        self.block_kind_index
346            .get(&block_kind)
347            .cloned()
348            .unwrap_or_default()
349    }
350
351    /// Find all blocks of a specific kind in a specific unit
352    ///
353    /// Returns a vector of block_ids
354    pub fn find_by_kind_and_unit(&self, block_kind: BlockKind, unit_index: usize) -> Vec<BlockId> {
355        let by_kind = self.find_by_kind(block_kind);
356        by_kind
357            .into_iter()
358            .filter(|(unit, _, _)| *unit == unit_index)
359            .map(|(_, _, block_id)| block_id)
360            .collect()
361    }
362
363    /// Look up block metadata by BlockId for O(1) access
364    ///
365    /// Returns (unit_index, block_name, block_kind) if found
366    pub fn get_block_info(&self, block_id: BlockId) -> Option<(usize, Option<String>, BlockKind)> {
367        self.block_id_index.get(&block_id).cloned()
368    }
369
370    /// Get total number of blocks indexed
371    pub fn block_count(&self) -> usize {
372        self.block_id_index.len()
373    }
374
375    /// Get the number of unique block names
376    pub fn unique_names_count(&self) -> usize {
377        self.block_name_index.len()
378    }
379
380    /// Check if a block with the given ID exists
381    pub fn contains_block(&self, block_id: BlockId) -> bool {
382        self.block_id_index.contains_key(&block_id)
383    }
384
385    /// Clear all indexes
386    pub fn clear(&mut self) {
387        self.block_name_index.clear();
388        self.unit_index_map.clear();
389        self.block_kind_index.clear();
390        self.block_id_index.clear();
391    }
392}
393
394#[derive(Debug, Default)]
395pub struct CompileCtxt<'tcx> {
396    pub arena: Arena<'tcx>,
397    pub interner: InternPool,
398    pub files: Vec<File>,
399    pub trees: Vec<Option<Tree>>,
400    pub hir_next_id: Cell<u32>,
401    pub hir_start_ids: RefCell<Vec<Option<HirId>>>,
402
403    // HirId -> ParentedNode
404    pub hir_map: RefCell<HashMap<HirId, ParentedNode<'tcx>>>,
405    // HirId -> &Scope (scopes owned by this HIR node)
406    pub scope_map: RefCell<HashMap<HirId, &'tcx Scope<'tcx>>>,
407    // SymId -> &Symbol
408    pub symbol_map: RefCell<HashMap<SymId, &'tcx Symbol>>,
409
410    pub block_arena: BlockArena<'tcx>,
411    pub block_next_id: Cell<u32>,
412    // BlockId -> ParentedBlock
413    pub block_map: RefCell<HashMap<BlockId, ParentedBlock<'tcx>>>,
414    pub unresolve_symbols: RefCell<Vec<&'tcx Symbol>>,
415    pub related_map: BlockRelationMap,
416
417    /// Index maps for efficient block lookups by name, kind, unit, and id
418    pub block_indexes: RefCell<BlockIndexMaps>,
419}
420
421impl<'tcx> CompileCtxt<'tcx> {
422    /// Create a new CompileCtxt from source code
423    pub fn from_sources<L: LanguageTrait>(sources: &[Vec<u8>]) -> Self {
424        let files: Vec<File> = sources
425            .iter()
426            .map(|src| File::new_source(src.clone()))
427            .collect();
428        let trees = sources.par_iter().map(|src| L::parse(src)).collect();
429        let count = files.len();
430        Self {
431            arena: Arena::default(),
432            interner: InternPool::default(),
433            files,
434            trees,
435            hir_next_id: Cell::new(0),
436            hir_start_ids: RefCell::new(vec![None; count]),
437            hir_map: RefCell::new(HashMap::new()),
438            scope_map: RefCell::new(HashMap::new()),
439            symbol_map: RefCell::new(HashMap::new()),
440            block_arena: BlockArena::default(),
441            block_next_id: Cell::new(0),
442            block_map: RefCell::new(HashMap::new()),
443            unresolve_symbols: RefCell::new(Vec::new()),
444            related_map: BlockRelationMap::default(),
445            block_indexes: RefCell::new(BlockIndexMaps::new()),
446        }
447    }
448
449    /// Create a new CompileCtxt from files
450    pub fn from_files<L: LanguageTrait>(paths: &[String]) -> std::io::Result<Self> {
451        let mut files = Vec::new();
452        for path in paths {
453            files.push(File::new_file(path.clone())?);
454        }
455
456        let trees: Vec<_> = files
457            .par_iter()
458            .map(|file| L::parse(file.content()))
459            .collect();
460
461        let count = files.len();
462        Ok(Self {
463            arena: Arena::default(),
464            interner: InternPool::default(),
465            files,
466            trees,
467            hir_next_id: Cell::new(0),
468            hir_start_ids: RefCell::new(vec![None; count]),
469            hir_map: RefCell::new(HashMap::new()),
470            scope_map: RefCell::new(HashMap::new()),
471            symbol_map: RefCell::new(HashMap::new()),
472            block_arena: BlockArena::default(),
473            block_next_id: Cell::new(0),
474            block_map: RefCell::new(HashMap::new()),
475            unresolve_symbols: RefCell::new(Vec::new()),
476            related_map: BlockRelationMap::default(),
477            block_indexes: RefCell::new(BlockIndexMaps::new()),
478        })
479    }
480
481    /// Create a new CompileCtxt from a directory, recursively finding all *.rs files
482    pub fn from_dir<P: AsRef<Path>, L: LanguageTrait>(dir: P) -> std::io::Result<Self> {
483        let mut files = Vec::new();
484
485        let walker = ignore::WalkBuilder::new(dir.as_ref())
486            .standard_filters(true)
487            .build();
488
489        for entry in walker {
490            let entry: ignore::DirEntry = entry
491                .map_err(|e| std::io::Error::other(format!("Failed to walk directory: {}", e)))?;
492            let path = entry.path();
493
494            if path.extension().and_then(|ext| ext.to_str()) == Some("rs") {
495                if let Ok(file) = File::new_file(path.to_string_lossy().to_string()) {
496                    files.push(file);
497                }
498            } else if path.extension().and_then(|ext| ext.to_str()) == Some("py") {
499                if let Ok(file) = File::new_file(path.to_string_lossy().to_string()) {
500                    files.push(file);
501                }
502            }
503        }
504
505        let trees: Vec<_> = files
506            .par_iter()
507            .map(|file| L::parse(file.content()))
508            .collect();
509
510        let count = files.len();
511        Ok(Self {
512            arena: Arena::default(),
513            interner: InternPool::default(),
514            files,
515            trees,
516            hir_next_id: Cell::new(0),
517            hir_start_ids: RefCell::new(vec![None; count]),
518            hir_map: RefCell::new(HashMap::new()),
519            scope_map: RefCell::new(HashMap::new()),
520            symbol_map: RefCell::new(HashMap::new()),
521            block_arena: BlockArena::default(),
522            block_next_id: Cell::new(0),
523            block_map: RefCell::new(HashMap::new()),
524            unresolve_symbols: RefCell::new(Vec::new()),
525            related_map: BlockRelationMap::default(),
526            block_indexes: RefCell::new(BlockIndexMaps::new()),
527        })
528    }
529
530    /// Create a context that references this CompileCtxt for a specific file index
531    pub fn compile_unit(&'tcx self, index: usize) -> CompileUnit<'tcx> {
532        CompileUnit { cc: self, index }
533    }
534
535    pub fn create_globals(&'tcx self) -> &'tcx Scope<'tcx> {
536        self.alloc_scope(HirId(0))
537    }
538
539    pub fn get_scope(&'tcx self, owner: HirId) -> &'tcx Scope<'tcx> {
540        self.scope_map.borrow().get(&owner).unwrap()
541    }
542
543    pub fn opt_get_symbol(&'tcx self, owner: SymId) -> Option<&'tcx Symbol> {
544        self.symbol_map.borrow().get(&owner).cloned()
545    }
546
547    /// Find the primary symbol associated with a block ID
548    pub fn find_symbol_by_block_id(&'tcx self, block_id: BlockId) -> Option<&'tcx Symbol> {
549        self.symbol_map
550            .borrow()
551            .values()
552            .find(|symbol| symbol.block_id() == Some(block_id))
553            .copied()
554    }
555
556    pub fn alloc_scope(&'tcx self, owner: HirId) -> &'tcx Scope<'tcx> {
557        if let Some(existing) = self.scope_map.borrow().get(&owner) {
558            return existing;
559        }
560
561        let scope = self.arena.alloc(Scope::new(owner));
562        self.scope_map.borrow_mut().insert(owner, scope);
563        scope
564    }
565
566    pub fn reserve_hir_id(&self) -> HirId {
567        let id = self.hir_next_id.get();
568        self.hir_next_id.set(id + 1);
569        HirId(id)
570    }
571
572    pub fn reserve_block_id(&self) -> BlockId {
573        let id = self.block_next_id.get();
574        self.block_next_id.set(id + 1);
575        BlockId::new(id)
576    }
577
578    pub fn current_hir_id(&self) -> HirId {
579        HirId(self.hir_next_id.get())
580    }
581
582    pub fn set_file_start(&self, index: usize, start: HirId) {
583        let mut starts = self.hir_start_ids.borrow_mut();
584        if index < starts.len() && starts[index].is_none() {
585            starts[index] = Some(start);
586        }
587    }
588
589    pub fn file_start(&self, index: usize) -> Option<HirId> {
590        self.hir_start_ids.borrow().get(index).and_then(|opt| *opt)
591    }
592
593    pub fn file_path(&self, index: usize) -> Option<&str> {
594        self.files.get(index).and_then(|file| file.path())
595    }
596
597    /// Get all file paths from the compilation context
598    pub fn get_files(&self) -> Vec<String> {
599        self.files
600            .iter()
601            .filter_map(|f| f.path().map(|p| p.to_string()))
602            .collect()
603    }
604
605    /// Clear all maps (useful for testing)
606    #[cfg(test)]
607    pub fn clear(&self) {
608        self.hir_map.borrow_mut().clear();
609        self.scope_map.borrow_mut().clear();
610    }
611}