Skip to main content

codemem_engine/index/engine/
mod.rs

1//! Unified AST extraction engine using ast-grep.
2//!
3//! Replaces per-language tree-sitter extractors with a single engine driven
4//! by YAML rules. Language-specific behavior (visibility, doc comments, etc.)
5//! is handled by shared helpers keyed on language name.
6
7mod references;
8mod symbols;
9mod visibility;
10
11use crate::index::rule_loader::{LanguageRules, ReferenceRule, ScopeContainerRule, SymbolRule};
12use crate::index::symbol::{Reference, ReferenceKind, Symbol, SymbolKind, Visibility};
13use ast_grep_core::tree_sitter::LanguageExt;
14use ast_grep_core::tree_sitter::StrDoc;
15use ast_grep_core::{Doc, Node};
16use ast_grep_language::SupportLang;
17use std::borrow::Cow;
18use std::collections::{HashMap, HashSet};
19use std::sync::LazyLock;
20
21/// Type alias for ast-grep nodes parameterized on SupportLang.
22pub type SgNode<'r> = Node<'r, StrDoc<SupportLang>>;
23
24/// One-time deserialized language rules, shared across all AstGrepEngine instances.
25static LOADED_RULES: LazyLock<Vec<LanguageRules>> =
26    LazyLock::new(crate::index::rule_loader::load_all_rules);
27
28/// The unified extraction engine.
29pub struct AstGrepEngine {
30    /// L29: HashMap for O(1) extension lookup instead of linear scan.
31    extension_index: HashMap<&'static str, usize>,
32}
33
34impl AstGrepEngine {
35    /// Create a new engine referencing the globally cached language rules.
36    pub fn new() -> Self {
37        let mut extension_index = HashMap::new();
38        for (i, lr) in LOADED_RULES.iter().enumerate() {
39            for &ext in lr.extensions {
40                extension_index.insert(ext, i);
41            }
42        }
43        Self { extension_index }
44    }
45
46    /// Look up the rules for a given file extension.
47    pub fn find_language(&self, ext: &str) -> Option<&LanguageRules> {
48        self.extension_index.get(ext).map(|&idx| &LOADED_RULES[idx])
49    }
50
51    /// Check if a given extension is supported.
52    pub fn supports_extension(&self, ext: &str) -> bool {
53        self.extension_index.contains_key(ext)
54    }
55
56    pub fn extract_symbols(
57        &self,
58        lang: &LanguageRules,
59        source: &str,
60        file_path: &str,
61    ) -> Vec<Symbol> {
62        let root = lang.lang.ast_grep(source);
63        self.extract_symbols_from_tree(lang, &root, source, file_path)
64    }
65
66    /// C1: Extract symbols from a pre-parsed AST tree, avoiding re-parsing.
67    pub fn extract_symbols_from_tree(
68        &self,
69        lang: &LanguageRules,
70        root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
71        source: &str,
72        file_path: &str,
73    ) -> Vec<Symbol> {
74        let root_node = root.root();
75        let mut symbols = Vec::new();
76        let mut scope = Vec::new();
77        self.extract_symbols_recursive(
78            lang,
79            &root_node,
80            source,
81            file_path,
82            &mut scope,
83            false,
84            &mut symbols,
85        );
86        symbols
87    }
88
89    /// Parse source code and extract references, deduplicating by
90    /// (source_qualified_name, target_name, reference_kind).
91    pub fn extract_references(
92        &self,
93        lang: &LanguageRules,
94        source: &str,
95        file_path: &str,
96    ) -> Vec<Reference> {
97        let root = lang.lang.ast_grep(source);
98        self.extract_references_from_tree(lang, &root, source, file_path)
99    }
100
101    /// C1: Extract references from a pre-parsed AST tree, avoiding re-parsing.
102    pub fn extract_references_from_tree(
103        &self,
104        lang: &LanguageRules,
105        root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
106        source: &str,
107        file_path: &str,
108    ) -> Vec<Reference> {
109        let root_node = root.root();
110        let mut references = Vec::new();
111        let mut scope = Vec::new();
112        self.extract_references_recursive(
113            lang,
114            &root_node,
115            source,
116            file_path,
117            &mut scope,
118            &mut references,
119        );
120
121        // R3: Dedup references by (source_qualified_name, target_name, kind)
122        let mut seen = HashSet::new();
123        references.retain(|r| {
124            seen.insert((
125                r.source_qualified_name.clone(),
126                r.target_name.clone(),
127                r.kind,
128            ))
129        });
130
131        references
132    }
133
134    // ── Symbol Extraction ─────────────────────────────────────────────
135
136    #[allow(clippy::too_many_arguments)]
137    fn extract_symbols_recursive<D: Doc>(
138        &self,
139        lang: &LanguageRules,
140        node: &Node<'_, D>,
141        source: &str,
142        file_path: &str,
143        scope: &mut Vec<String>,
144        in_method_scope: bool,
145        symbols: &mut Vec<Symbol>,
146    ) where
147        D::Lang: ast_grep_core::Language,
148    {
149        let kind: Cow<'_, str> = node.kind();
150        let kind_str = kind.as_ref();
151
152        // Check if this is an unwrap node (e.g., decorated_definition, export_statement)
153        if lang.symbol_unwrap_set.contains(kind_str) {
154            for child in node.children() {
155                self.extract_symbols_recursive(
156                    lang,
157                    &child,
158                    source,
159                    file_path,
160                    scope,
161                    in_method_scope,
162                    symbols,
163                );
164            }
165            return;
166        }
167
168        // Check if this is a scope container
169        let handled_as_scope_container = lang.symbol_scope_index.contains_key(kind_str);
170        if let Some(&sc_idx) = lang.symbol_scope_index.get(kind_str) {
171            let sc = &lang.symbol_scope_containers[sc_idx];
172            if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
173                scope.push(scope_name);
174                let new_method_scope = sc.is_method_scope;
175
176                // Recurse into the body
177                if let Some(body) = self.get_scope_body(sc, node) {
178                    for child in body.children() {
179                        self.extract_symbols_recursive(
180                            lang,
181                            &child,
182                            source,
183                            file_path,
184                            scope,
185                            new_method_scope,
186                            symbols,
187                        );
188                    }
189                } else {
190                    // No body field found, recurse into all children
191                    for child in node.children() {
192                        self.extract_symbols_recursive(
193                            lang,
194                            &child,
195                            source,
196                            file_path,
197                            scope,
198                            new_method_scope,
199                            symbols,
200                        );
201                    }
202                }
203                scope.pop();
204                // The scope container itself might also be a symbol
205                // (e.g., trait_item is both a scope container and an interface symbol)
206            }
207        }
208
209        // Check if this matches any symbol rules
210        if let Some(rule_indices) = lang.symbol_index.get(kind_str) {
211            for &rule_idx in rule_indices {
212                let rule = &lang.symbol_rules[rule_idx];
213
214                // Handle multi-symbol special cases (e.g. Go type/const/var declarations)
215                if let Some(ref special) = rule.special {
216                    let multi = self
217                        .handle_special_symbol_multi(lang, special, node, source, file_path, scope);
218                    if !multi.is_empty() {
219                        symbols.extend(multi);
220                        return;
221                    }
222                }
223
224                if let Some(sym) = self.extract_symbol_from_rule(
225                    lang,
226                    rule,
227                    node,
228                    source,
229                    file_path,
230                    scope,
231                    in_method_scope,
232                ) {
233                    let name = sym.name.clone();
234                    let is_scope = rule.is_scope;
235                    symbols.push(sym);
236
237                    // If this symbol creates a scope, recurse into it
238                    // Skip if already handled as a scope container above
239                    if is_scope && !handled_as_scope_container {
240                        scope.push(name);
241                        if let Some(body) = node.field("body") {
242                            for child in body.children() {
243                                self.extract_symbols_recursive(
244                                    lang,
245                                    &child,
246                                    source,
247                                    file_path,
248                                    scope,
249                                    in_method_scope,
250                                    symbols,
251                                );
252                            }
253                        }
254                        scope.pop();
255                    }
256                    return; // handled this node
257                }
258            }
259        }
260
261        // If this node was already handled as a scope container (which recursed),
262        // and it was also a symbol (handled above), we've already returned.
263        // If it was only a scope container (not a symbol), we already recursed.
264        if handled_as_scope_container {
265            return; // already recursed above
266        }
267
268        // Default: recurse into children
269        for child in node.children() {
270            self.extract_symbols_recursive(
271                lang,
272                &child,
273                source,
274                file_path,
275                scope,
276                in_method_scope,
277                symbols,
278            );
279        }
280    }
281
282    #[allow(clippy::too_many_arguments)]
283    fn extract_symbol_from_rule<D: Doc>(
284        &self,
285        lang: &LanguageRules,
286        rule: &SymbolRule,
287        node: &Node<'_, D>,
288        source: &str,
289        file_path: &str,
290        scope: &[String],
291        in_method_scope: bool,
292    ) -> Option<Symbol>
293    where
294        D::Lang: ast_grep_core::Language,
295    {
296        // Handle special cases first
297        if let Some(ref special) = rule.special {
298            let mut sym =
299                self.handle_special_symbol(lang, special, node, source, file_path, scope)?;
300            self.enrich_symbol_metadata(lang.name, node, &mut sym);
301            return Some(sym);
302        }
303
304        // Extract name
305        let name = self.get_node_field_text(node, &rule.name_field)?;
306        if name.is_empty() {
307            return None;
308        }
309
310        // Determine symbol kind
311        let base_kind = parse_symbol_kind(&rule.symbol_kind)?;
312        let is_test = self.detect_test(lang.name, node, source, file_path, &name);
313        let kind = if is_test {
314            SymbolKind::Test
315        } else if rule.method_when_scoped && in_method_scope {
316            SymbolKind::Method
317        } else {
318            base_kind
319        };
320
321        let visibility = self.detect_visibility(lang.name, node, source, &name);
322        let signature = self.extract_signature(lang.name, node, source);
323        let doc_comment = self.extract_doc_comment(lang.name, node, source);
324
325        let mut sym = build_symbol(
326            name,
327            kind,
328            signature,
329            visibility,
330            doc_comment,
331            file_path,
332            node.start_pos().line(),
333            node.end_pos().line(),
334            scope,
335            lang.scope_separator,
336        );
337        self.enrich_symbol_metadata(lang.name, node, &mut sym);
338        Some(sym)
339    }
340
341    // ── Reference Extraction ──────────────────────────────────────────
342
343    fn extract_references_recursive<D: Doc>(
344        &self,
345        lang: &LanguageRules,
346        node: &Node<'_, D>,
347        source: &str,
348        file_path: &str,
349        scope: &mut Vec<String>,
350        references: &mut Vec<Reference>,
351    ) where
352        D::Lang: ast_grep_core::Language,
353    {
354        let kind: Cow<'_, str> = node.kind();
355        let kind_str = kind.as_ref();
356
357        // Check unwrap nodes
358        if lang.reference_unwrap_set.contains(kind_str) {
359            for child in node.children() {
360                self.extract_references_recursive(
361                    lang, &child, source, file_path, scope, references,
362                );
363            }
364            return;
365        }
366
367        // Check scope containers
368        if let Some(&sc_idx) = lang.reference_scope_index.get(kind_str) {
369            let sc = &lang.reference_scope_containers[sc_idx];
370            if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
371                // Still extract references from this node itself before recursing
372                if let Some(rule_indices) = lang.reference_index.get(kind_str) {
373                    for &rule_idx in rule_indices {
374                        let rule = &lang.reference_rules[rule_idx];
375                        self.extract_reference_from_rule(
376                            lang, rule, node, source, file_path, scope, references,
377                        );
378                    }
379                }
380
381                scope.push(scope_name);
382
383                if let Some(body) = self.get_scope_body(sc, node) {
384                    for child in body.children() {
385                        self.extract_references_recursive(
386                            lang, &child, source, file_path, scope, references,
387                        );
388                    }
389                } else {
390                    for child in node.children() {
391                        self.extract_references_recursive(
392                            lang, &child, source, file_path, scope, references,
393                        );
394                    }
395                }
396                scope.pop();
397                return;
398            }
399        }
400
401        // Check reference rules
402        if let Some(rule_indices) = lang.reference_index.get(kind_str) {
403            for &rule_idx in rule_indices {
404                let rule = &lang.reference_rules[rule_idx];
405                self.extract_reference_from_rule(
406                    lang, rule, node, source, file_path, scope, references,
407                );
408            }
409        }
410
411        // Default recursion
412        for child in node.children() {
413            self.extract_references_recursive(lang, &child, source, file_path, scope, references);
414        }
415    }
416
417    #[allow(clippy::too_many_arguments)]
418    fn extract_reference_from_rule<D: Doc>(
419        &self,
420        lang: &LanguageRules,
421        rule: &ReferenceRule,
422        node: &Node<'_, D>,
423        source: &str,
424        file_path: &str,
425        scope: &[String],
426        references: &mut Vec<Reference>,
427    ) where
428        D::Lang: ast_grep_core::Language,
429    {
430        // Handle special cases
431        if let Some(ref special) = rule.special {
432            self.handle_special_reference(
433                lang, special, node, source, file_path, scope, references,
434            );
435            return;
436        }
437
438        let ref_kind = match parse_reference_kind(&rule.reference_kind) {
439            Some(k) => k,
440            None => return,
441        };
442
443        let target_name = if let Some(ref field) = rule.name_field {
444            match self.get_node_field_text(node, field) {
445                Some(name) => name,
446                None => return,
447            }
448        } else {
449            // Use full node text, trimmed
450            let text = node.text();
451            text.trim().to_string()
452        };
453
454        if target_name.is_empty() {
455            return;
456        }
457
458        let source_qn = if scope.is_empty() {
459            file_path.to_string()
460        } else {
461            scope.join(lang.scope_separator)
462        };
463
464        push_ref(
465            references,
466            &source_qn,
467            target_name,
468            ref_kind,
469            file_path,
470            node.start_pos().line(),
471        );
472    }
473
474    // ── Node Utility Helpers ──────────────────────────────────────────
475
476    pub(crate) fn get_node_field_text<D: Doc>(
477        &self,
478        node: &Node<'_, D>,
479        field_name: &str,
480    ) -> Option<String>
481    where
482        D::Lang: ast_grep_core::Language,
483    {
484        node.field(field_name).map(|n| n.text().to_string())
485    }
486
487    fn get_scope_name<D: Doc>(
488        &self,
489        lang: &LanguageRules,
490        sc: &ScopeContainerRule,
491        node: &Node<'_, D>,
492        source: &str,
493    ) -> Option<String>
494    where
495        D::Lang: ast_grep_core::Language,
496    {
497        if let Some(ref special) = sc.special {
498            return self.get_special_scope_name(lang, special, node, source);
499        }
500        self.get_node_field_text(node, &sc.name_field)
501    }
502
503    fn get_scope_body<'a, D: Doc>(
504        &self,
505        sc: &ScopeContainerRule,
506        node: &Node<'a, D>,
507    ) -> Option<Node<'a, D>>
508    where
509        D::Lang: ast_grep_core::Language,
510    {
511        if let Some(ref special) = sc.special {
512            return self.get_special_scope_body(special, node);
513        }
514        node.field(&sc.body_field)
515    }
516
517    fn get_special_scope_name<D: Doc>(
518        &self,
519        _lang: &LanguageRules,
520        special: &str,
521        node: &Node<'_, D>,
522        _source: &str,
523    ) -> Option<String>
524    where
525        D::Lang: ast_grep_core::Language,
526    {
527        match special {
528            "go_method_scope" => {
529                // For Go method declarations, the scope is Receiver.MethodName
530                self.get_go_receiver_type(node)
531            }
532            "hcl_block_scope" => {
533                // HCL block: combine block type and labels
534                let mut parts = Vec::new();
535                for child in node.children() {
536                    let ck = child.kind();
537                    if ck.as_ref() == "identifier" && parts.is_empty() {
538                        parts.push(child.text().to_string());
539                    } else if ck.as_ref() == "string_lit" {
540                        parts.push(child.text().to_string().trim_matches('"').to_string());
541                    }
542                }
543                if parts.is_empty() {
544                    None
545                } else {
546                    Some(parts.join("."))
547                }
548            }
549            "kotlin_scope" => self.get_node_field_text(node, "name").or_else(|| {
550                for child in node.children() {
551                    let ck = child.kind();
552                    if ck.as_ref() == "type_identifier" || ck.as_ref() == "simple_identifier" {
553                        return Some(child.text().to_string());
554                    }
555                }
556                None
557            }),
558            "swift_class_scope" => {
559                // Swift class/struct: find name via field or first type_identifier/identifier child
560                self.get_node_field_text(node, "name").or_else(|| {
561                    node.children()
562                        .find(|c| {
563                            let ck = c.kind();
564                            ck.as_ref() == "type_identifier" || ck.as_ref() == "identifier"
565                        })
566                        .map(|c| c.text().to_string())
567                })
568            }
569            "cpp_namespace_scope" => {
570                // C++ namespace: may use qualified_identifier or name field
571                self.get_node_field_text(node, "name").or_else(|| {
572                    node.children()
573                        .find(|c| {
574                            let ck = c.kind();
575                            ck.as_ref() == "namespace_identifier" || ck.as_ref() == "identifier"
576                        })
577                        .map(|c| c.text().to_string())
578                })
579            }
580            _ => None,
581        }
582    }
583
584    fn get_special_scope_body<'a, D: Doc>(
585        &self,
586        _special: &str,
587        node: &Node<'a, D>,
588    ) -> Option<Node<'a, D>>
589    where
590        D::Lang: ast_grep_core::Language,
591    {
592        node.field("body")
593    }
594}
595
596impl Default for AstGrepEngine {
597    fn default() -> Self {
598        Self::new()
599    }
600}
601
602// ── Free Functions ─────────────────────────────────────────────────────
603
604/// Build a Symbol struct with the common scope→parent logic.
605#[allow(clippy::too_many_arguments)]
606fn build_symbol(
607    name: String,
608    kind: SymbolKind,
609    signature: String,
610    visibility: Visibility,
611    doc_comment: Option<String>,
612    file_path: &str,
613    line_start: usize,
614    line_end: usize,
615    scope: &[String],
616    scope_separator: &str,
617) -> Symbol {
618    Symbol {
619        qualified_name: build_qualified_name(scope, &name, scope_separator),
620        name,
621        kind,
622        signature,
623        visibility,
624        file_path: file_path.to_string(),
625        line_start,
626        line_end,
627        doc_comment,
628        parent: if scope.is_empty() {
629            None
630        } else {
631            Some(scope.join(scope_separator))
632        },
633        parameters: Vec::new(),
634        return_type: None,
635        is_async: false,
636        attributes: Vec::new(),
637        throws: Vec::new(),
638        generic_params: None,
639        is_abstract: false,
640    }
641}
642
643/// Push a Reference onto a collection with the standard fields.
644fn push_ref(
645    refs: &mut Vec<Reference>,
646    source_qn: &str,
647    target: String,
648    kind: ReferenceKind,
649    file_path: &str,
650    line: usize,
651) {
652    refs.push(Reference {
653        source_qualified_name: source_qn.to_string(),
654        target_name: target,
655        kind,
656        file_path: file_path.to_string(),
657        line,
658    });
659}
660
661fn build_qualified_name(scope: &[String], name: &str, separator: &str) -> String {
662    if scope.is_empty() {
663        name.to_string()
664    } else {
665        format!("{}{}{}", scope.join(separator), separator, name)
666    }
667}
668
669fn parse_symbol_kind(s: &str) -> Option<SymbolKind> {
670    match s {
671        "function" => Some(SymbolKind::Function),
672        "method" => Some(SymbolKind::Method),
673        "class" => Some(SymbolKind::Class),
674        "struct" => Some(SymbolKind::Struct),
675        "enum" => Some(SymbolKind::Enum),
676        "interface" => Some(SymbolKind::Interface),
677        "type" => Some(SymbolKind::Type),
678        "constant" => Some(SymbolKind::Constant),
679        "module" => Some(SymbolKind::Module),
680        "test" => Some(SymbolKind::Test),
681        "field" => Some(SymbolKind::Field),
682        "constructor" => Some(SymbolKind::Constructor),
683        _ => None,
684    }
685}
686
687fn parse_reference_kind(s: &str) -> Option<ReferenceKind> {
688    match s {
689        "import" => Some(ReferenceKind::Import),
690        "call" => Some(ReferenceKind::Call),
691        "inherits" => Some(ReferenceKind::Inherits),
692        "implements" => Some(ReferenceKind::Implements),
693        "type_usage" => Some(ReferenceKind::TypeUsage),
694        _ => None,
695    }
696}
697
698fn clean_block_doc_comment(text: &str) -> String {
699    let trimmed = text.trim_start_matches("/**").trim_end_matches("*/").trim();
700
701    let mut doc_lines = Vec::new();
702    for line in trimmed.lines() {
703        let line = line.trim();
704        let line = line
705            .strip_prefix("* ")
706            .or_else(|| line.strip_prefix('*'))
707            .unwrap_or(line);
708        let line = line.trim_end();
709        // Preserve internal blank lines (paragraph breaks)
710        doc_lines.push(line.to_string());
711    }
712    doc_lines.join("\n").trim().to_string()
713}
714
715#[cfg(test)]
716#[path = "../tests/engine_symbols_tests.rs"]
717mod engine_symbols_tests;
718
719#[cfg(test)]
720#[path = "../tests/engine_references_tests.rs"]
721mod engine_references_tests;
722
723#[cfg(test)]
724#[path = "../tests/engine_cross_cutting_tests.rs"]
725mod engine_cross_cutting_tests;