Skip to main content

codemem_engine/index/engine/
mod.rs

1//! Unified AST extraction engine using ast-grep.
2//!
3//! Replaces per-language tree-sitter extractors with a single engine driven
4//! by YAML rules. Language-specific behavior (visibility, doc comments, etc.)
5//! is handled by shared helpers keyed on language name.
6
7mod references;
8mod symbols;
9mod visibility;
10
11use crate::index::rule_loader::{LanguageRules, ReferenceRule, ScopeContainerRule, SymbolRule};
12use crate::index::symbol::{Reference, ReferenceKind, Symbol, SymbolKind, Visibility};
13use ast_grep_core::tree_sitter::LanguageExt;
14use ast_grep_core::tree_sitter::StrDoc;
15use ast_grep_core::{Doc, Node};
16use ast_grep_language::SupportLang;
17use std::borrow::Cow;
18use std::collections::{HashMap, HashSet};
19use std::sync::LazyLock;
20
21/// Type alias for ast-grep nodes parameterized on SupportLang.
22pub type SgNode<'r> = Node<'r, StrDoc<SupportLang>>;
23
24/// One-time deserialized language rules, shared across all AstGrepEngine instances.
25static LOADED_RULES: LazyLock<Vec<LanguageRules>> = LazyLock::new(|| {
26    crate::index::rule_loader::load_all_rules()
27        .expect("embedded YAML rule files must deserialize successfully")
28});
29
30/// The unified extraction engine.
31pub struct AstGrepEngine {
32    /// L29: HashMap for O(1) extension lookup instead of linear scan.
33    extension_index: HashMap<&'static str, usize>,
34}
35
36impl AstGrepEngine {
37    /// Create a new engine referencing the globally cached language rules.
38    pub fn new() -> Self {
39        let mut extension_index = HashMap::new();
40        for (i, lr) in LOADED_RULES.iter().enumerate() {
41            for &ext in lr.extensions {
42                extension_index.insert(ext, i);
43            }
44        }
45        Self { extension_index }
46    }
47
48    /// Look up the rules for a given file extension.
49    pub fn find_language(&self, ext: &str) -> Option<&LanguageRules> {
50        self.extension_index.get(ext).map(|&idx| &LOADED_RULES[idx])
51    }
52
53    /// Check if a given extension is supported.
54    pub fn supports_extension(&self, ext: &str) -> bool {
55        self.extension_index.contains_key(ext)
56    }
57
58    pub fn extract_symbols(
59        &self,
60        lang: &LanguageRules,
61        source: &str,
62        file_path: &str,
63    ) -> Vec<Symbol> {
64        let root = lang.lang.ast_grep(source);
65        self.extract_symbols_from_tree(lang, &root, source, file_path)
66    }
67
68    /// C1: Extract symbols from a pre-parsed AST tree, avoiding re-parsing.
69    pub fn extract_symbols_from_tree(
70        &self,
71        lang: &LanguageRules,
72        root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
73        source: &str,
74        file_path: &str,
75    ) -> Vec<Symbol> {
76        let root_node = root.root();
77        let mut symbols = Vec::new();
78        let mut scope = Vec::new();
79        self.extract_symbols_recursive(
80            lang,
81            &root_node,
82            source,
83            file_path,
84            &mut scope,
85            false,
86            &mut symbols,
87        );
88        symbols
89    }
90
91    /// Parse source code and extract references, deduplicating by
92    /// (source_qualified_name, target_name, reference_kind).
93    pub fn extract_references(
94        &self,
95        lang: &LanguageRules,
96        source: &str,
97        file_path: &str,
98    ) -> Vec<Reference> {
99        let root = lang.lang.ast_grep(source);
100        self.extract_references_from_tree(lang, &root, source, file_path)
101    }
102
103    /// C1: Extract references from a pre-parsed AST tree, avoiding re-parsing.
104    pub fn extract_references_from_tree(
105        &self,
106        lang: &LanguageRules,
107        root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
108        source: &str,
109        file_path: &str,
110    ) -> Vec<Reference> {
111        let root_node = root.root();
112        let mut references = Vec::new();
113        let mut scope = Vec::new();
114        self.extract_references_recursive(
115            lang,
116            &root_node,
117            source,
118            file_path,
119            &mut scope,
120            &mut references,
121        );
122
123        // R3: Dedup references by (source_qualified_name, target_name, kind)
124        let mut seen = HashSet::new();
125        references.retain(|r| {
126            seen.insert((
127                r.source_qualified_name.clone(),
128                r.target_name.clone(),
129                r.kind,
130            ))
131        });
132
133        // R5: Filter out noise calls (builtins, stdlib methods)
134        references.retain(|r| {
135            if !matches!(r.kind, ReferenceKind::Call | ReferenceKind::Callback) {
136                return true; // only filter calls/callbacks, not imports/inherits
137            }
138            let simple = r
139                .target_name
140                .rsplit(lang.scope_separator)
141                .next()
142                .unwrap_or(&r.target_name);
143            !crate::index::blocklist::is_blocked_call(lang.name, simple)
144        });
145
146        references
147    }
148
149    // ── Symbol Extraction ─────────────────────────────────────────────
150
151    #[allow(clippy::too_many_arguments)]
152    fn extract_symbols_recursive<D: Doc>(
153        &self,
154        lang: &LanguageRules,
155        node: &Node<'_, D>,
156        source: &str,
157        file_path: &str,
158        scope: &mut Vec<String>,
159        in_method_scope: bool,
160        symbols: &mut Vec<Symbol>,
161    ) where
162        D::Lang: ast_grep_core::Language,
163    {
164        let kind: Cow<'_, str> = node.kind();
165        let kind_str = kind.as_ref();
166
167        // Check if this is an unwrap node (e.g., decorated_definition, export_statement)
168        if lang.symbol_unwrap_set.contains(kind_str) {
169            for child in node.children() {
170                self.extract_symbols_recursive(
171                    lang,
172                    &child,
173                    source,
174                    file_path,
175                    scope,
176                    in_method_scope,
177                    symbols,
178                );
179            }
180            return;
181        }
182
183        // Check if this is a scope container
184        let handled_as_scope_container = lang.symbol_scope_index.contains_key(kind_str);
185        if let Some(&sc_idx) = lang.symbol_scope_index.get(kind_str) {
186            let sc = &lang.symbol_scope_containers[sc_idx];
187            if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
188                scope.push(scope_name);
189                let new_method_scope = sc.is_method_scope;
190
191                // Recurse into the body
192                if let Some(body) = self.get_scope_body(sc, node) {
193                    for child in body.children() {
194                        self.extract_symbols_recursive(
195                            lang,
196                            &child,
197                            source,
198                            file_path,
199                            scope,
200                            new_method_scope,
201                            symbols,
202                        );
203                    }
204                } else {
205                    // No body field found, recurse into all children
206                    for child in node.children() {
207                        self.extract_symbols_recursive(
208                            lang,
209                            &child,
210                            source,
211                            file_path,
212                            scope,
213                            new_method_scope,
214                            symbols,
215                        );
216                    }
217                }
218                scope.pop();
219                // The scope container itself might also be a symbol
220                // (e.g., trait_item is both a scope container and an interface symbol)
221            }
222        }
223
224        // Check if this matches any symbol rules
225        if let Some(rule_indices) = lang.symbol_index.get(kind_str) {
226            for &rule_idx in rule_indices {
227                let rule = &lang.symbol_rules[rule_idx];
228
229                // Handle multi-symbol special cases (e.g. Go type/const/var declarations)
230                if let Some(ref special) = rule.special {
231                    let multi = self
232                        .handle_special_symbol_multi(lang, special, node, source, file_path, scope);
233                    if !multi.is_empty() {
234                        symbols.extend(multi);
235                        return;
236                    }
237                }
238
239                if let Some(sym) = self.extract_symbol_from_rule(
240                    lang,
241                    rule,
242                    node,
243                    source,
244                    file_path,
245                    scope,
246                    in_method_scope,
247                ) {
248                    let name = sym.name.clone();
249                    let is_scope = rule.is_scope;
250                    symbols.push(sym);
251
252                    // If this symbol creates a scope, recurse into it
253                    // Skip if already handled as a scope container above
254                    if is_scope && !handled_as_scope_container {
255                        scope.push(name);
256                        if let Some(body) = node.field("body") {
257                            for child in body.children() {
258                                self.extract_symbols_recursive(
259                                    lang,
260                                    &child,
261                                    source,
262                                    file_path,
263                                    scope,
264                                    in_method_scope,
265                                    symbols,
266                                );
267                            }
268                        }
269                        scope.pop();
270                    }
271                    return; // handled this node
272                }
273            }
274        }
275
276        // If this node was already handled as a scope container (which recursed),
277        // and it was also a symbol (handled above), we've already returned.
278        // If it was only a scope container (not a symbol), we already recursed.
279        if handled_as_scope_container {
280            return; // already recursed above
281        }
282
283        // Default: recurse into children
284        for child in node.children() {
285            self.extract_symbols_recursive(
286                lang,
287                &child,
288                source,
289                file_path,
290                scope,
291                in_method_scope,
292                symbols,
293            );
294        }
295    }
296
297    #[allow(clippy::too_many_arguments)]
298    fn extract_symbol_from_rule<D: Doc>(
299        &self,
300        lang: &LanguageRules,
301        rule: &SymbolRule,
302        node: &Node<'_, D>,
303        source: &str,
304        file_path: &str,
305        scope: &[String],
306        in_method_scope: bool,
307    ) -> Option<Symbol>
308    where
309        D::Lang: ast_grep_core::Language,
310    {
311        // Handle special cases first
312        if let Some(ref special) = rule.special {
313            let mut sym =
314                self.handle_special_symbol(lang, special, node, source, file_path, scope)?;
315            self.enrich_symbol_metadata(lang.name, node, &mut sym);
316            return Some(sym);
317        }
318
319        // Extract name
320        let name = self.get_node_field_text(node, &rule.name_field)?;
321        if name.is_empty() {
322            return None;
323        }
324
325        // Determine symbol kind
326        let base_kind = parse_symbol_kind(&rule.symbol_kind)?;
327        let is_test = self.detect_test(lang.name, node, source, file_path, &name);
328        let kind = if is_test {
329            SymbolKind::Test
330        } else if rule.method_when_scoped && in_method_scope {
331            SymbolKind::Method
332        } else {
333            base_kind
334        };
335
336        let visibility = self.detect_visibility(lang.name, node, source, &name);
337        let signature = self.extract_signature(lang.name, node, source);
338        let doc_comment = self.extract_doc_comment(lang.name, node, source);
339
340        let mut sym = build_symbol(
341            name,
342            kind,
343            signature,
344            visibility,
345            doc_comment,
346            file_path,
347            node.start_pos().line(),
348            node.end_pos().line(),
349            scope,
350            lang.scope_separator,
351        );
352        self.enrich_symbol_metadata(lang.name, node, &mut sym);
353        Some(sym)
354    }
355
356    // ── Reference Extraction ──────────────────────────────────────────
357
358    fn extract_references_recursive<D: Doc>(
359        &self,
360        lang: &LanguageRules,
361        node: &Node<'_, D>,
362        source: &str,
363        file_path: &str,
364        scope: &mut Vec<String>,
365        references: &mut Vec<Reference>,
366    ) where
367        D::Lang: ast_grep_core::Language,
368    {
369        let kind: Cow<'_, str> = node.kind();
370        let kind_str = kind.as_ref();
371
372        // Check unwrap nodes
373        if lang.reference_unwrap_set.contains(kind_str) {
374            for child in node.children() {
375                self.extract_references_recursive(
376                    lang, &child, source, file_path, scope, references,
377                );
378            }
379            return;
380        }
381
382        // Check scope containers
383        if let Some(&sc_idx) = lang.reference_scope_index.get(kind_str) {
384            let sc = &lang.reference_scope_containers[sc_idx];
385            if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
386                // Still extract references from this node itself before recursing
387                if let Some(rule_indices) = lang.reference_index.get(kind_str) {
388                    for &rule_idx in rule_indices {
389                        let rule = &lang.reference_rules[rule_idx];
390                        self.extract_reference_from_rule(
391                            lang, rule, node, source, file_path, scope, references,
392                        );
393                    }
394                }
395
396                scope.push(scope_name);
397
398                if let Some(body) = self.get_scope_body(sc, node) {
399                    for child in body.children() {
400                        self.extract_references_recursive(
401                            lang, &child, source, file_path, scope, references,
402                        );
403                    }
404                } else {
405                    for child in node.children() {
406                        self.extract_references_recursive(
407                            lang, &child, source, file_path, scope, references,
408                        );
409                    }
410                }
411                scope.pop();
412                return;
413            }
414        }
415
416        // Check reference rules
417        if let Some(rule_indices) = lang.reference_index.get(kind_str) {
418            for &rule_idx in rule_indices {
419                let rule = &lang.reference_rules[rule_idx];
420                self.extract_reference_from_rule(
421                    lang, rule, node, source, file_path, scope, references,
422                );
423            }
424        }
425
426        // Default recursion
427        for child in node.children() {
428            self.extract_references_recursive(lang, &child, source, file_path, scope, references);
429        }
430    }
431
432    #[allow(clippy::too_many_arguments)]
433    fn extract_reference_from_rule<D: Doc>(
434        &self,
435        lang: &LanguageRules,
436        rule: &ReferenceRule,
437        node: &Node<'_, D>,
438        source: &str,
439        file_path: &str,
440        scope: &[String],
441        references: &mut Vec<Reference>,
442    ) where
443        D::Lang: ast_grep_core::Language,
444    {
445        // Handle special cases
446        if let Some(ref special) = rule.special {
447            self.handle_special_reference(
448                lang, special, node, source, file_path, scope, references,
449            );
450            return;
451        }
452
453        let ref_kind = match parse_reference_kind(&rule.reference_kind) {
454            Some(k) => k,
455            None => return,
456        };
457
458        let target_name = if let Some(ref field) = rule.name_field {
459            match self.get_node_field_text(node, field) {
460                Some(name) => name,
461                None => return,
462            }
463        } else {
464            // Use full node text, trimmed
465            let text = node.text();
466            text.trim().to_string()
467        };
468
469        if target_name.is_empty() {
470            return;
471        }
472
473        let source_qn = if scope.is_empty() {
474            file_path.to_string()
475        } else {
476            scope.join(lang.scope_separator)
477        };
478
479        push_ref(
480            references,
481            &source_qn,
482            target_name,
483            ref_kind,
484            file_path,
485            node.start_pos().line(),
486        );
487    }
488
489    // ── Node Utility Helpers ──────────────────────────────────────────
490
491    pub(crate) fn get_node_field_text<D: Doc>(
492        &self,
493        node: &Node<'_, D>,
494        field_name: &str,
495    ) -> Option<String>
496    where
497        D::Lang: ast_grep_core::Language,
498    {
499        node.field(field_name).map(|n| n.text().to_string())
500    }
501
502    fn get_scope_name<D: Doc>(
503        &self,
504        lang: &LanguageRules,
505        sc: &ScopeContainerRule,
506        node: &Node<'_, D>,
507        source: &str,
508    ) -> Option<String>
509    where
510        D::Lang: ast_grep_core::Language,
511    {
512        if let Some(ref special) = sc.special {
513            return self.get_special_scope_name(lang, special, node, source);
514        }
515        self.get_node_field_text(node, &sc.name_field)
516    }
517
518    fn get_scope_body<'a, D: Doc>(
519        &self,
520        sc: &ScopeContainerRule,
521        node: &Node<'a, D>,
522    ) -> Option<Node<'a, D>>
523    where
524        D::Lang: ast_grep_core::Language,
525    {
526        if let Some(ref special) = sc.special {
527            return self.get_special_scope_body(special, node);
528        }
529        node.field(&sc.body_field)
530    }
531
532    fn get_special_scope_name<D: Doc>(
533        &self,
534        _lang: &LanguageRules,
535        special: &str,
536        node: &Node<'_, D>,
537        _source: &str,
538    ) -> Option<String>
539    where
540        D::Lang: ast_grep_core::Language,
541    {
542        match special {
543            "go_method_scope" => {
544                // For Go method declarations, the scope is Receiver.MethodName
545                self.get_go_receiver_type(node)
546            }
547            "hcl_block_scope" => {
548                // HCL block: combine block type and labels
549                let mut parts = Vec::new();
550                for child in node.children() {
551                    let ck = child.kind();
552                    if ck.as_ref() == "identifier" && parts.is_empty() {
553                        parts.push(child.text().to_string());
554                    } else if ck.as_ref() == "string_lit" {
555                        parts.push(child.text().to_string().trim_matches('"').to_string());
556                    }
557                }
558                if parts.is_empty() {
559                    None
560                } else {
561                    Some(parts.join("."))
562                }
563            }
564            "kotlin_scope" => self.get_node_field_text(node, "name").or_else(|| {
565                for child in node.children() {
566                    let ck = child.kind();
567                    if ck.as_ref() == "type_identifier" || ck.as_ref() == "simple_identifier" {
568                        return Some(child.text().to_string());
569                    }
570                }
571                None
572            }),
573            "swift_class_scope" => {
574                // Swift class/struct: find name via field or first type_identifier/identifier child
575                self.get_node_field_text(node, "name").or_else(|| {
576                    node.children()
577                        .find(|c| {
578                            let ck = c.kind();
579                            ck.as_ref() == "type_identifier" || ck.as_ref() == "identifier"
580                        })
581                        .map(|c| c.text().to_string())
582                })
583            }
584            "cpp_namespace_scope" => {
585                // C++ namespace: may use qualified_identifier or name field
586                self.get_node_field_text(node, "name").or_else(|| {
587                    node.children()
588                        .find(|c| {
589                            let ck = c.kind();
590                            ck.as_ref() == "namespace_identifier" || ck.as_ref() == "identifier"
591                        })
592                        .map(|c| c.text().to_string())
593                })
594            }
595            _ => None,
596        }
597    }
598
599    fn get_special_scope_body<'a, D: Doc>(
600        &self,
601        _special: &str,
602        node: &Node<'a, D>,
603    ) -> Option<Node<'a, D>>
604    where
605        D::Lang: ast_grep_core::Language,
606    {
607        node.field("body")
608    }
609}
610
611impl Default for AstGrepEngine {
612    fn default() -> Self {
613        Self::new()
614    }
615}
616
617// ── Free Functions ─────────────────────────────────────────────────────
618
619/// Build a Symbol struct with the common scope→parent logic.
620#[allow(clippy::too_many_arguments)]
621fn build_symbol(
622    name: String,
623    kind: SymbolKind,
624    signature: String,
625    visibility: Visibility,
626    doc_comment: Option<String>,
627    file_path: &str,
628    line_start: usize,
629    line_end: usize,
630    scope: &[String],
631    scope_separator: &str,
632) -> Symbol {
633    Symbol {
634        qualified_name: build_qualified_name(scope, &name, scope_separator),
635        name,
636        kind,
637        signature,
638        visibility,
639        file_path: file_path.to_string(),
640        line_start,
641        line_end,
642        doc_comment,
643        parent: if scope.is_empty() {
644            None
645        } else {
646            Some(scope.join(scope_separator))
647        },
648        parameters: Vec::new(),
649        return_type: None,
650        is_async: false,
651        attributes: Vec::new(),
652        throws: Vec::new(),
653        generic_params: None,
654        is_abstract: false,
655    }
656}
657
658/// Push a Reference onto a collection with the standard fields.
659fn push_ref(
660    refs: &mut Vec<Reference>,
661    source_qn: &str,
662    target: String,
663    kind: ReferenceKind,
664    file_path: &str,
665    line: usize,
666) {
667    refs.push(Reference {
668        source_qualified_name: source_qn.to_string(),
669        target_name: target,
670        kind,
671        file_path: file_path.to_string(),
672        line,
673    });
674}
675
676fn build_qualified_name(scope: &[String], name: &str, separator: &str) -> String {
677    if scope.is_empty() {
678        name.to_string()
679    } else {
680        format!("{}{}{}", scope.join(separator), separator, name)
681    }
682}
683
684fn parse_symbol_kind(s: &str) -> Option<SymbolKind> {
685    match s {
686        "function" => Some(SymbolKind::Function),
687        "method" => Some(SymbolKind::Method),
688        "class" => Some(SymbolKind::Class),
689        "struct" => Some(SymbolKind::Struct),
690        "enum" => Some(SymbolKind::Enum),
691        "interface" => Some(SymbolKind::Interface),
692        "type" => Some(SymbolKind::Type),
693        "constant" => Some(SymbolKind::Constant),
694        "module" => Some(SymbolKind::Module),
695        "test" => Some(SymbolKind::Test),
696        "field" => Some(SymbolKind::Field),
697        "constructor" => Some(SymbolKind::Constructor),
698        _ => None,
699    }
700}
701
702fn parse_reference_kind(s: &str) -> Option<ReferenceKind> {
703    match s {
704        "import" => Some(ReferenceKind::Import),
705        "call" => Some(ReferenceKind::Call),
706        "callback" => Some(ReferenceKind::Callback),
707        "inherits" => Some(ReferenceKind::Inherits),
708        "implements" => Some(ReferenceKind::Implements),
709        "type_usage" => Some(ReferenceKind::TypeUsage),
710        _ => None,
711    }
712}
713
714fn clean_block_doc_comment(text: &str) -> String {
715    let trimmed = text.trim_start_matches("/**").trim_end_matches("*/").trim();
716
717    let mut doc_lines = Vec::new();
718    for line in trimmed.lines() {
719        let line = line.trim();
720        let line = line
721            .strip_prefix("* ")
722            .or_else(|| line.strip_prefix('*'))
723            .unwrap_or(line);
724        let line = line.trim_end();
725        // Preserve internal blank lines (paragraph breaks)
726        doc_lines.push(line.to_string());
727    }
728    doc_lines.join("\n").trim().to_string()
729}
730
731#[cfg(test)]
732#[path = "../tests/engine_symbols_tests.rs"]
733mod engine_symbols_tests;
734
735#[cfg(test)]
736#[path = "../tests/engine_references_tests.rs"]
737mod engine_references_tests;
738
739#[cfg(test)]
740#[path = "../tests/engine_cross_cutting_tests.rs"]
741mod engine_cross_cutting_tests;