Skip to main content

codemem_engine/index/engine/
mod.rs

1//! Unified AST extraction engine using ast-grep.
2//!
3//! Replaces per-language tree-sitter extractors with a single engine driven
4//! by YAML rules. Language-specific behavior (visibility, doc comments, etc.)
5//! is handled by shared helpers keyed on language name.
6
7mod references;
8mod symbols;
9mod visibility;
10
11use crate::index::rule_loader::{LanguageRules, ReferenceRule, ScopeContainerRule, SymbolRule};
12use crate::index::symbol::{Reference, ReferenceKind, Symbol, SymbolKind, Visibility};
13use ast_grep_core::tree_sitter::LanguageExt;
14use ast_grep_core::tree_sitter::StrDoc;
15use ast_grep_core::{Doc, Node};
16use ast_grep_language::SupportLang;
17use std::borrow::Cow;
18use std::collections::{HashMap, HashSet};
19use std::sync::LazyLock;
20
21/// Type alias for ast-grep nodes parameterized on SupportLang.
22pub type SgNode<'r> = Node<'r, StrDoc<SupportLang>>;
23
24/// One-time deserialized language rules, shared across all AstGrepEngine instances.
25static LOADED_RULES: LazyLock<Vec<LanguageRules>> = LazyLock::new(|| {
26    crate::index::rule_loader::load_all_rules()
27        .expect("embedded YAML rule files must deserialize successfully")
28});
29
30/// The unified extraction engine.
31pub struct AstGrepEngine {
32    /// L29: HashMap for O(1) extension lookup instead of linear scan.
33    extension_index: HashMap<&'static str, usize>,
34}
35
36impl AstGrepEngine {
37    /// Create a new engine referencing the globally cached language rules.
38    pub fn new() -> Self {
39        let mut extension_index = HashMap::new();
40        for (i, lr) in LOADED_RULES.iter().enumerate() {
41            for &ext in lr.extensions {
42                extension_index.insert(ext, i);
43            }
44        }
45        Self { extension_index }
46    }
47
48    /// Look up the rules for a given file extension.
49    pub fn find_language(&self, ext: &str) -> Option<&LanguageRules> {
50        self.extension_index.get(ext).map(|&idx| &LOADED_RULES[idx])
51    }
52
53    /// Check if a given extension is supported.
54    pub fn supports_extension(&self, ext: &str) -> bool {
55        self.extension_index.contains_key(ext)
56    }
57
58    pub fn extract_symbols(
59        &self,
60        lang: &LanguageRules,
61        source: &str,
62        file_path: &str,
63    ) -> Vec<Symbol> {
64        let root = lang.lang.ast_grep(source);
65        self.extract_symbols_from_tree(lang, &root, source, file_path)
66    }
67
68    /// C1: Extract symbols from a pre-parsed AST tree, avoiding re-parsing.
69    pub fn extract_symbols_from_tree(
70        &self,
71        lang: &LanguageRules,
72        root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
73        source: &str,
74        file_path: &str,
75    ) -> Vec<Symbol> {
76        let root_node = root.root();
77        let mut symbols = Vec::new();
78        let mut scope = Vec::new();
79        self.extract_symbols_recursive(
80            lang,
81            &root_node,
82            source,
83            file_path,
84            &mut scope,
85            false,
86            &mut symbols,
87        );
88        symbols
89    }
90
91    /// Parse source code and extract references, deduplicating by
92    /// (source_qualified_name, target_name, reference_kind).
93    pub fn extract_references(
94        &self,
95        lang: &LanguageRules,
96        source: &str,
97        file_path: &str,
98    ) -> Vec<Reference> {
99        let root = lang.lang.ast_grep(source);
100        self.extract_references_from_tree(lang, &root, source, file_path)
101    }
102
103    /// C1: Extract references from a pre-parsed AST tree, avoiding re-parsing.
104    pub fn extract_references_from_tree(
105        &self,
106        lang: &LanguageRules,
107        root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
108        source: &str,
109        file_path: &str,
110    ) -> Vec<Reference> {
111        let root_node = root.root();
112        let mut references = Vec::new();
113        let mut scope = Vec::new();
114        self.extract_references_recursive(
115            lang,
116            &root_node,
117            source,
118            file_path,
119            &mut scope,
120            &mut references,
121        );
122
123        // R3: Dedup references by (source_qualified_name, target_name, kind)
124        let mut seen = HashSet::new();
125        references.retain(|r| {
126            seen.insert((
127                r.source_qualified_name.clone(),
128                r.target_name.clone(),
129                r.kind,
130            ))
131        });
132
133        references
134    }
135
136    // ── Symbol Extraction ─────────────────────────────────────────────
137
138    #[allow(clippy::too_many_arguments)]
139    fn extract_symbols_recursive<D: Doc>(
140        &self,
141        lang: &LanguageRules,
142        node: &Node<'_, D>,
143        source: &str,
144        file_path: &str,
145        scope: &mut Vec<String>,
146        in_method_scope: bool,
147        symbols: &mut Vec<Symbol>,
148    ) where
149        D::Lang: ast_grep_core::Language,
150    {
151        let kind: Cow<'_, str> = node.kind();
152        let kind_str = kind.as_ref();
153
154        // Check if this is an unwrap node (e.g., decorated_definition, export_statement)
155        if lang.symbol_unwrap_set.contains(kind_str) {
156            for child in node.children() {
157                self.extract_symbols_recursive(
158                    lang,
159                    &child,
160                    source,
161                    file_path,
162                    scope,
163                    in_method_scope,
164                    symbols,
165                );
166            }
167            return;
168        }
169
170        // Check if this is a scope container
171        let handled_as_scope_container = lang.symbol_scope_index.contains_key(kind_str);
172        if let Some(&sc_idx) = lang.symbol_scope_index.get(kind_str) {
173            let sc = &lang.symbol_scope_containers[sc_idx];
174            if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
175                scope.push(scope_name);
176                let new_method_scope = sc.is_method_scope;
177
178                // Recurse into the body
179                if let Some(body) = self.get_scope_body(sc, node) {
180                    for child in body.children() {
181                        self.extract_symbols_recursive(
182                            lang,
183                            &child,
184                            source,
185                            file_path,
186                            scope,
187                            new_method_scope,
188                            symbols,
189                        );
190                    }
191                } else {
192                    // No body field found, recurse into all children
193                    for child in node.children() {
194                        self.extract_symbols_recursive(
195                            lang,
196                            &child,
197                            source,
198                            file_path,
199                            scope,
200                            new_method_scope,
201                            symbols,
202                        );
203                    }
204                }
205                scope.pop();
206                // The scope container itself might also be a symbol
207                // (e.g., trait_item is both a scope container and an interface symbol)
208            }
209        }
210
211        // Check if this matches any symbol rules
212        if let Some(rule_indices) = lang.symbol_index.get(kind_str) {
213            for &rule_idx in rule_indices {
214                let rule = &lang.symbol_rules[rule_idx];
215
216                // Handle multi-symbol special cases (e.g. Go type/const/var declarations)
217                if let Some(ref special) = rule.special {
218                    let multi = self
219                        .handle_special_symbol_multi(lang, special, node, source, file_path, scope);
220                    if !multi.is_empty() {
221                        symbols.extend(multi);
222                        return;
223                    }
224                }
225
226                if let Some(sym) = self.extract_symbol_from_rule(
227                    lang,
228                    rule,
229                    node,
230                    source,
231                    file_path,
232                    scope,
233                    in_method_scope,
234                ) {
235                    let name = sym.name.clone();
236                    let is_scope = rule.is_scope;
237                    symbols.push(sym);
238
239                    // If this symbol creates a scope, recurse into it
240                    // Skip if already handled as a scope container above
241                    if is_scope && !handled_as_scope_container {
242                        scope.push(name);
243                        if let Some(body) = node.field("body") {
244                            for child in body.children() {
245                                self.extract_symbols_recursive(
246                                    lang,
247                                    &child,
248                                    source,
249                                    file_path,
250                                    scope,
251                                    in_method_scope,
252                                    symbols,
253                                );
254                            }
255                        }
256                        scope.pop();
257                    }
258                    return; // handled this node
259                }
260            }
261        }
262
263        // If this node was already handled as a scope container (which recursed),
264        // and it was also a symbol (handled above), we've already returned.
265        // If it was only a scope container (not a symbol), we already recursed.
266        if handled_as_scope_container {
267            return; // already recursed above
268        }
269
270        // Default: recurse into children
271        for child in node.children() {
272            self.extract_symbols_recursive(
273                lang,
274                &child,
275                source,
276                file_path,
277                scope,
278                in_method_scope,
279                symbols,
280            );
281        }
282    }
283
284    #[allow(clippy::too_many_arguments)]
285    fn extract_symbol_from_rule<D: Doc>(
286        &self,
287        lang: &LanguageRules,
288        rule: &SymbolRule,
289        node: &Node<'_, D>,
290        source: &str,
291        file_path: &str,
292        scope: &[String],
293        in_method_scope: bool,
294    ) -> Option<Symbol>
295    where
296        D::Lang: ast_grep_core::Language,
297    {
298        // Handle special cases first
299        if let Some(ref special) = rule.special {
300            let mut sym =
301                self.handle_special_symbol(lang, special, node, source, file_path, scope)?;
302            self.enrich_symbol_metadata(lang.name, node, &mut sym);
303            return Some(sym);
304        }
305
306        // Extract name
307        let name = self.get_node_field_text(node, &rule.name_field)?;
308        if name.is_empty() {
309            return None;
310        }
311
312        // Determine symbol kind
313        let base_kind = parse_symbol_kind(&rule.symbol_kind)?;
314        let is_test = self.detect_test(lang.name, node, source, file_path, &name);
315        let kind = if is_test {
316            SymbolKind::Test
317        } else if rule.method_when_scoped && in_method_scope {
318            SymbolKind::Method
319        } else {
320            base_kind
321        };
322
323        let visibility = self.detect_visibility(lang.name, node, source, &name);
324        let signature = self.extract_signature(lang.name, node, source);
325        let doc_comment = self.extract_doc_comment(lang.name, node, source);
326
327        let mut sym = build_symbol(
328            name,
329            kind,
330            signature,
331            visibility,
332            doc_comment,
333            file_path,
334            node.start_pos().line(),
335            node.end_pos().line(),
336            scope,
337            lang.scope_separator,
338        );
339        self.enrich_symbol_metadata(lang.name, node, &mut sym);
340        Some(sym)
341    }
342
343    // ── Reference Extraction ──────────────────────────────────────────
344
345    fn extract_references_recursive<D: Doc>(
346        &self,
347        lang: &LanguageRules,
348        node: &Node<'_, D>,
349        source: &str,
350        file_path: &str,
351        scope: &mut Vec<String>,
352        references: &mut Vec<Reference>,
353    ) where
354        D::Lang: ast_grep_core::Language,
355    {
356        let kind: Cow<'_, str> = node.kind();
357        let kind_str = kind.as_ref();
358
359        // Check unwrap nodes
360        if lang.reference_unwrap_set.contains(kind_str) {
361            for child in node.children() {
362                self.extract_references_recursive(
363                    lang, &child, source, file_path, scope, references,
364                );
365            }
366            return;
367        }
368
369        // Check scope containers
370        if let Some(&sc_idx) = lang.reference_scope_index.get(kind_str) {
371            let sc = &lang.reference_scope_containers[sc_idx];
372            if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
373                // Still extract references from this node itself before recursing
374                if let Some(rule_indices) = lang.reference_index.get(kind_str) {
375                    for &rule_idx in rule_indices {
376                        let rule = &lang.reference_rules[rule_idx];
377                        self.extract_reference_from_rule(
378                            lang, rule, node, source, file_path, scope, references,
379                        );
380                    }
381                }
382
383                scope.push(scope_name);
384
385                if let Some(body) = self.get_scope_body(sc, node) {
386                    for child in body.children() {
387                        self.extract_references_recursive(
388                            lang, &child, source, file_path, scope, references,
389                        );
390                    }
391                } else {
392                    for child in node.children() {
393                        self.extract_references_recursive(
394                            lang, &child, source, file_path, scope, references,
395                        );
396                    }
397                }
398                scope.pop();
399                return;
400            }
401        }
402
403        // Check reference rules
404        if let Some(rule_indices) = lang.reference_index.get(kind_str) {
405            for &rule_idx in rule_indices {
406                let rule = &lang.reference_rules[rule_idx];
407                self.extract_reference_from_rule(
408                    lang, rule, node, source, file_path, scope, references,
409                );
410            }
411        }
412
413        // Default recursion
414        for child in node.children() {
415            self.extract_references_recursive(lang, &child, source, file_path, scope, references);
416        }
417    }
418
419    #[allow(clippy::too_many_arguments)]
420    fn extract_reference_from_rule<D: Doc>(
421        &self,
422        lang: &LanguageRules,
423        rule: &ReferenceRule,
424        node: &Node<'_, D>,
425        source: &str,
426        file_path: &str,
427        scope: &[String],
428        references: &mut Vec<Reference>,
429    ) where
430        D::Lang: ast_grep_core::Language,
431    {
432        // Handle special cases
433        if let Some(ref special) = rule.special {
434            self.handle_special_reference(
435                lang, special, node, source, file_path, scope, references,
436            );
437            return;
438        }
439
440        let ref_kind = match parse_reference_kind(&rule.reference_kind) {
441            Some(k) => k,
442            None => return,
443        };
444
445        let target_name = if let Some(ref field) = rule.name_field {
446            match self.get_node_field_text(node, field) {
447                Some(name) => name,
448                None => return,
449            }
450        } else {
451            // Use full node text, trimmed
452            let text = node.text();
453            text.trim().to_string()
454        };
455
456        if target_name.is_empty() {
457            return;
458        }
459
460        let source_qn = if scope.is_empty() {
461            file_path.to_string()
462        } else {
463            scope.join(lang.scope_separator)
464        };
465
466        push_ref(
467            references,
468            &source_qn,
469            target_name,
470            ref_kind,
471            file_path,
472            node.start_pos().line(),
473        );
474    }
475
476    // ── Node Utility Helpers ──────────────────────────────────────────
477
478    pub(crate) fn get_node_field_text<D: Doc>(
479        &self,
480        node: &Node<'_, D>,
481        field_name: &str,
482    ) -> Option<String>
483    where
484        D::Lang: ast_grep_core::Language,
485    {
486        node.field(field_name).map(|n| n.text().to_string())
487    }
488
489    fn get_scope_name<D: Doc>(
490        &self,
491        lang: &LanguageRules,
492        sc: &ScopeContainerRule,
493        node: &Node<'_, D>,
494        source: &str,
495    ) -> Option<String>
496    where
497        D::Lang: ast_grep_core::Language,
498    {
499        if let Some(ref special) = sc.special {
500            return self.get_special_scope_name(lang, special, node, source);
501        }
502        self.get_node_field_text(node, &sc.name_field)
503    }
504
505    fn get_scope_body<'a, D: Doc>(
506        &self,
507        sc: &ScopeContainerRule,
508        node: &Node<'a, D>,
509    ) -> Option<Node<'a, D>>
510    where
511        D::Lang: ast_grep_core::Language,
512    {
513        if let Some(ref special) = sc.special {
514            return self.get_special_scope_body(special, node);
515        }
516        node.field(&sc.body_field)
517    }
518
519    fn get_special_scope_name<D: Doc>(
520        &self,
521        _lang: &LanguageRules,
522        special: &str,
523        node: &Node<'_, D>,
524        _source: &str,
525    ) -> Option<String>
526    where
527        D::Lang: ast_grep_core::Language,
528    {
529        match special {
530            "go_method_scope" => {
531                // For Go method declarations, the scope is Receiver.MethodName
532                self.get_go_receiver_type(node)
533            }
534            "hcl_block_scope" => {
535                // HCL block: combine block type and labels
536                let mut parts = Vec::new();
537                for child in node.children() {
538                    let ck = child.kind();
539                    if ck.as_ref() == "identifier" && parts.is_empty() {
540                        parts.push(child.text().to_string());
541                    } else if ck.as_ref() == "string_lit" {
542                        parts.push(child.text().to_string().trim_matches('"').to_string());
543                    }
544                }
545                if parts.is_empty() {
546                    None
547                } else {
548                    Some(parts.join("."))
549                }
550            }
551            "kotlin_scope" => self.get_node_field_text(node, "name").or_else(|| {
552                for child in node.children() {
553                    let ck = child.kind();
554                    if ck.as_ref() == "type_identifier" || ck.as_ref() == "simple_identifier" {
555                        return Some(child.text().to_string());
556                    }
557                }
558                None
559            }),
560            "swift_class_scope" => {
561                // Swift class/struct: find name via field or first type_identifier/identifier child
562                self.get_node_field_text(node, "name").or_else(|| {
563                    node.children()
564                        .find(|c| {
565                            let ck = c.kind();
566                            ck.as_ref() == "type_identifier" || ck.as_ref() == "identifier"
567                        })
568                        .map(|c| c.text().to_string())
569                })
570            }
571            "cpp_namespace_scope" => {
572                // C++ namespace: may use qualified_identifier or name field
573                self.get_node_field_text(node, "name").or_else(|| {
574                    node.children()
575                        .find(|c| {
576                            let ck = c.kind();
577                            ck.as_ref() == "namespace_identifier" || ck.as_ref() == "identifier"
578                        })
579                        .map(|c| c.text().to_string())
580                })
581            }
582            _ => None,
583        }
584    }
585
586    fn get_special_scope_body<'a, D: Doc>(
587        &self,
588        _special: &str,
589        node: &Node<'a, D>,
590    ) -> Option<Node<'a, D>>
591    where
592        D::Lang: ast_grep_core::Language,
593    {
594        node.field("body")
595    }
596}
597
598impl Default for AstGrepEngine {
599    fn default() -> Self {
600        Self::new()
601    }
602}
603
604// ── Free Functions ─────────────────────────────────────────────────────
605
606/// Build a Symbol struct with the common scope→parent logic.
607#[allow(clippy::too_many_arguments)]
608fn build_symbol(
609    name: String,
610    kind: SymbolKind,
611    signature: String,
612    visibility: Visibility,
613    doc_comment: Option<String>,
614    file_path: &str,
615    line_start: usize,
616    line_end: usize,
617    scope: &[String],
618    scope_separator: &str,
619) -> Symbol {
620    Symbol {
621        qualified_name: build_qualified_name(scope, &name, scope_separator),
622        name,
623        kind,
624        signature,
625        visibility,
626        file_path: file_path.to_string(),
627        line_start,
628        line_end,
629        doc_comment,
630        parent: if scope.is_empty() {
631            None
632        } else {
633            Some(scope.join(scope_separator))
634        },
635        parameters: Vec::new(),
636        return_type: None,
637        is_async: false,
638        attributes: Vec::new(),
639        throws: Vec::new(),
640        generic_params: None,
641        is_abstract: false,
642    }
643}
644
645/// Push a Reference onto a collection with the standard fields.
646fn push_ref(
647    refs: &mut Vec<Reference>,
648    source_qn: &str,
649    target: String,
650    kind: ReferenceKind,
651    file_path: &str,
652    line: usize,
653) {
654    refs.push(Reference {
655        source_qualified_name: source_qn.to_string(),
656        target_name: target,
657        kind,
658        file_path: file_path.to_string(),
659        line,
660    });
661}
662
663fn build_qualified_name(scope: &[String], name: &str, separator: &str) -> String {
664    if scope.is_empty() {
665        name.to_string()
666    } else {
667        format!("{}{}{}", scope.join(separator), separator, name)
668    }
669}
670
671fn parse_symbol_kind(s: &str) -> Option<SymbolKind> {
672    match s {
673        "function" => Some(SymbolKind::Function),
674        "method" => Some(SymbolKind::Method),
675        "class" => Some(SymbolKind::Class),
676        "struct" => Some(SymbolKind::Struct),
677        "enum" => Some(SymbolKind::Enum),
678        "interface" => Some(SymbolKind::Interface),
679        "type" => Some(SymbolKind::Type),
680        "constant" => Some(SymbolKind::Constant),
681        "module" => Some(SymbolKind::Module),
682        "test" => Some(SymbolKind::Test),
683        "field" => Some(SymbolKind::Field),
684        "constructor" => Some(SymbolKind::Constructor),
685        _ => None,
686    }
687}
688
689fn parse_reference_kind(s: &str) -> Option<ReferenceKind> {
690    match s {
691        "import" => Some(ReferenceKind::Import),
692        "call" => Some(ReferenceKind::Call),
693        "inherits" => Some(ReferenceKind::Inherits),
694        "implements" => Some(ReferenceKind::Implements),
695        "type_usage" => Some(ReferenceKind::TypeUsage),
696        _ => None,
697    }
698}
699
700fn clean_block_doc_comment(text: &str) -> String {
701    let trimmed = text.trim_start_matches("/**").trim_end_matches("*/").trim();
702
703    let mut doc_lines = Vec::new();
704    for line in trimmed.lines() {
705        let line = line.trim();
706        let line = line
707            .strip_prefix("* ")
708            .or_else(|| line.strip_prefix('*'))
709            .unwrap_or(line);
710        let line = line.trim_end();
711        // Preserve internal blank lines (paragraph breaks)
712        doc_lines.push(line.to_string());
713    }
714    doc_lines.join("\n").trim().to_string()
715}
716
717#[cfg(test)]
718#[path = "../tests/engine_symbols_tests.rs"]
719mod engine_symbols_tests;
720
721#[cfg(test)]
722#[path = "../tests/engine_references_tests.rs"]
723mod engine_references_tests;
724
725#[cfg(test)]
726#[path = "../tests/engine_cross_cutting_tests.rs"]
727mod engine_cross_cutting_tests;