Skip to main content

codemem_engine/index/engine/
mod.rs

1//! Unified AST extraction engine using ast-grep.
2//!
3//! Replaces per-language tree-sitter extractors with a single engine driven
4//! by YAML rules. Language-specific behavior (visibility, doc comments, etc.)
5//! is handled by shared helpers keyed on language name.
6
7mod references;
8mod symbols;
9mod visibility;
10
11use crate::index::rule_loader::{LanguageRules, ReferenceRule, ScopeContainerRule, SymbolRule};
12use crate::index::symbol::{Reference, ReferenceKind, Symbol, SymbolKind, Visibility};
13use ast_grep_core::tree_sitter::LanguageExt;
14use ast_grep_core::tree_sitter::StrDoc;
15use ast_grep_core::{Doc, Node};
16use ast_grep_language::SupportLang;
17use std::borrow::Cow;
18use std::collections::{HashMap, HashSet};
19use std::sync::LazyLock;
20
21/// Type alias for ast-grep nodes parameterized on SupportLang.
22pub type SgNode<'r> = Node<'r, StrDoc<SupportLang>>;
23
24/// One-time deserialized language rules, shared across all AstGrepEngine instances.
25static LOADED_RULES: LazyLock<Vec<LanguageRules>> =
26    LazyLock::new(crate::index::rule_loader::load_all_rules);
27
28/// The unified extraction engine.
29pub struct AstGrepEngine {
30    /// L29: HashMap for O(1) extension lookup instead of linear scan.
31    extension_index: HashMap<&'static str, usize>,
32}
33
34impl AstGrepEngine {
35    /// Create a new engine referencing the globally cached language rules.
36    pub fn new() -> Self {
37        let mut extension_index = HashMap::new();
38        for (i, lr) in LOADED_RULES.iter().enumerate() {
39            for &ext in lr.extensions {
40                extension_index.insert(ext, i);
41            }
42        }
43        Self { extension_index }
44    }
45
46    /// Look up the rules for a given file extension.
47    pub fn find_language(&self, ext: &str) -> Option<&LanguageRules> {
48        self.extension_index.get(ext).map(|&idx| &LOADED_RULES[idx])
49    }
50
51    /// List all file extensions we can handle.
52    pub fn supported_extensions(&self) -> Vec<&str> {
53        LOADED_RULES
54            .iter()
55            .flat_map(|lr| lr.extensions.iter().copied())
56            .collect()
57    }
58
59    /// Check if a given extension is supported.
60    pub fn supports_extension(&self, ext: &str) -> bool {
61        self.extension_index.contains_key(ext)
62    }
63
64    /// Get the language name for a given extension.
65    pub fn language_name(&self, ext: &str) -> Option<&str> {
66        self.find_language(ext).map(|lr| lr.name)
67    }
68
69    pub fn extract_symbols(
70        &self,
71        lang: &LanguageRules,
72        source: &str,
73        file_path: &str,
74    ) -> Vec<Symbol> {
75        let root = lang.lang.ast_grep(source);
76        self.extract_symbols_from_tree(lang, &root, source, file_path)
77    }
78
79    /// C1: Extract symbols from a pre-parsed AST tree, avoiding re-parsing.
80    pub fn extract_symbols_from_tree(
81        &self,
82        lang: &LanguageRules,
83        root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
84        source: &str,
85        file_path: &str,
86    ) -> Vec<Symbol> {
87        let root_node = root.root();
88        let mut symbols = Vec::new();
89        let mut scope = Vec::new();
90        self.extract_symbols_recursive(
91            lang,
92            &root_node,
93            source,
94            file_path,
95            &mut scope,
96            false,
97            &mut symbols,
98        );
99        symbols
100    }
101
102    /// Parse source code and extract references, deduplicating by
103    /// (source_qualified_name, target_name, reference_kind).
104    pub fn extract_references(
105        &self,
106        lang: &LanguageRules,
107        source: &str,
108        file_path: &str,
109    ) -> Vec<Reference> {
110        let root = lang.lang.ast_grep(source);
111        self.extract_references_from_tree(lang, &root, source, file_path)
112    }
113
114    /// C1: Extract references from a pre-parsed AST tree, avoiding re-parsing.
115    pub fn extract_references_from_tree(
116        &self,
117        lang: &LanguageRules,
118        root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
119        source: &str,
120        file_path: &str,
121    ) -> Vec<Reference> {
122        let root_node = root.root();
123        let mut references = Vec::new();
124        let mut scope = Vec::new();
125        self.extract_references_recursive(
126            lang,
127            &root_node,
128            source,
129            file_path,
130            &mut scope,
131            &mut references,
132        );
133
134        // R3: Dedup references by (source_qualified_name, target_name, kind)
135        let mut seen = HashSet::new();
136        references.retain(|r| {
137            seen.insert((
138                r.source_qualified_name.clone(),
139                r.target_name.clone(),
140                r.kind,
141            ))
142        });
143
144        references
145    }
146
147    // ── Symbol Extraction ─────────────────────────────────────────────
148
149    #[allow(clippy::too_many_arguments)]
150    fn extract_symbols_recursive<D: Doc>(
151        &self,
152        lang: &LanguageRules,
153        node: &Node<'_, D>,
154        source: &str,
155        file_path: &str,
156        scope: &mut Vec<String>,
157        in_method_scope: bool,
158        symbols: &mut Vec<Symbol>,
159    ) where
160        D::Lang: ast_grep_core::Language,
161    {
162        let kind: Cow<'_, str> = node.kind();
163        let kind_str = kind.as_ref();
164
165        // Check if this is an unwrap node (e.g., decorated_definition, export_statement)
166        if lang.symbol_unwrap_set.contains(kind_str) {
167            for child in node.children() {
168                self.extract_symbols_recursive(
169                    lang,
170                    &child,
171                    source,
172                    file_path,
173                    scope,
174                    in_method_scope,
175                    symbols,
176                );
177            }
178            return;
179        }
180
181        // Check if this is a scope container
182        let handled_as_scope_container = lang.symbol_scope_index.contains_key(kind_str);
183        if let Some(&sc_idx) = lang.symbol_scope_index.get(kind_str) {
184            let sc = &lang.symbol_scope_containers[sc_idx];
185            if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
186                scope.push(scope_name);
187                let new_method_scope = sc.is_method_scope;
188
189                // Recurse into the body
190                if let Some(body) = self.get_scope_body(sc, node) {
191                    for child in body.children() {
192                        self.extract_symbols_recursive(
193                            lang,
194                            &child,
195                            source,
196                            file_path,
197                            scope,
198                            new_method_scope,
199                            symbols,
200                        );
201                    }
202                } else {
203                    // No body field found, recurse into all children
204                    for child in node.children() {
205                        self.extract_symbols_recursive(
206                            lang,
207                            &child,
208                            source,
209                            file_path,
210                            scope,
211                            new_method_scope,
212                            symbols,
213                        );
214                    }
215                }
216                scope.pop();
217                // The scope container itself might also be a symbol
218                // (e.g., trait_item is both a scope container and an interface symbol)
219            }
220        }
221
222        // Check if this matches any symbol rules
223        if let Some(rule_indices) = lang.symbol_index.get(kind_str) {
224            for &rule_idx in rule_indices {
225                let rule = &lang.symbol_rules[rule_idx];
226
227                // Handle multi-symbol special cases (e.g. Go type/const/var declarations)
228                if let Some(ref special) = rule.special {
229                    let multi = self
230                        .handle_special_symbol_multi(lang, special, node, source, file_path, scope);
231                    if !multi.is_empty() {
232                        symbols.extend(multi);
233                        return;
234                    }
235                }
236
237                if let Some(sym) = self.extract_symbol_from_rule(
238                    lang,
239                    rule,
240                    node,
241                    source,
242                    file_path,
243                    scope,
244                    in_method_scope,
245                ) {
246                    let name = sym.name.clone();
247                    let is_scope = rule.is_scope;
248                    symbols.push(sym);
249
250                    // If this symbol creates a scope, recurse into it
251                    // Skip if already handled as a scope container above
252                    if is_scope && !handled_as_scope_container {
253                        scope.push(name);
254                        if let Some(body) = node.field("body") {
255                            for child in body.children() {
256                                self.extract_symbols_recursive(
257                                    lang,
258                                    &child,
259                                    source,
260                                    file_path,
261                                    scope,
262                                    in_method_scope,
263                                    symbols,
264                                );
265                            }
266                        }
267                        scope.pop();
268                    }
269                    return; // handled this node
270                }
271            }
272        }
273
274        // If this node was already handled as a scope container (which recursed),
275        // and it was also a symbol (handled above), we've already returned.
276        // If it was only a scope container (not a symbol), we already recursed.
277        if handled_as_scope_container {
278            return; // already recursed above
279        }
280
281        // Default: recurse into children
282        for child in node.children() {
283            self.extract_symbols_recursive(
284                lang,
285                &child,
286                source,
287                file_path,
288                scope,
289                in_method_scope,
290                symbols,
291            );
292        }
293    }
294
295    #[allow(clippy::too_many_arguments)]
296    fn extract_symbol_from_rule<D: Doc>(
297        &self,
298        lang: &LanguageRules,
299        rule: &SymbolRule,
300        node: &Node<'_, D>,
301        source: &str,
302        file_path: &str,
303        scope: &[String],
304        in_method_scope: bool,
305    ) -> Option<Symbol>
306    where
307        D::Lang: ast_grep_core::Language,
308    {
309        // Handle special cases first
310        if let Some(ref special) = rule.special {
311            let mut sym =
312                self.handle_special_symbol(lang, special, node, source, file_path, scope)?;
313            self.enrich_symbol_metadata(lang.name, node, &mut sym);
314            return Some(sym);
315        }
316
317        // Extract name
318        let name = self.get_node_field_text(node, &rule.name_field)?;
319        if name.is_empty() {
320            return None;
321        }
322
323        // Determine symbol kind
324        let base_kind = parse_symbol_kind(&rule.symbol_kind)?;
325        let is_test = self.detect_test(lang.name, node, source, file_path, &name);
326        let kind = if is_test {
327            SymbolKind::Test
328        } else if rule.method_when_scoped && in_method_scope {
329            SymbolKind::Method
330        } else {
331            base_kind
332        };
333
334        let visibility = self.detect_visibility(lang.name, node, source, &name);
335        let signature = self.extract_signature(lang.name, node, source);
336        let doc_comment = self.extract_doc_comment(lang.name, node, source);
337
338        let mut sym = build_symbol(
339            name,
340            kind,
341            signature,
342            visibility,
343            doc_comment,
344            file_path,
345            node.start_pos().line(),
346            node.end_pos().line(),
347            scope,
348            lang.scope_separator,
349        );
350        self.enrich_symbol_metadata(lang.name, node, &mut sym);
351        Some(sym)
352    }
353
354    // ── Reference Extraction ──────────────────────────────────────────
355
356    fn extract_references_recursive<D: Doc>(
357        &self,
358        lang: &LanguageRules,
359        node: &Node<'_, D>,
360        source: &str,
361        file_path: &str,
362        scope: &mut Vec<String>,
363        references: &mut Vec<Reference>,
364    ) where
365        D::Lang: ast_grep_core::Language,
366    {
367        let kind: Cow<'_, str> = node.kind();
368        let kind_str = kind.as_ref();
369
370        // Check unwrap nodes
371        if lang.reference_unwrap_set.contains(kind_str) {
372            for child in node.children() {
373                self.extract_references_recursive(
374                    lang, &child, source, file_path, scope, references,
375                );
376            }
377            return;
378        }
379
380        // Check scope containers
381        if let Some(&sc_idx) = lang.reference_scope_index.get(kind_str) {
382            let sc = &lang.reference_scope_containers[sc_idx];
383            if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
384                // Still extract references from this node itself before recursing
385                if let Some(rule_indices) = lang.reference_index.get(kind_str) {
386                    for &rule_idx in rule_indices {
387                        let rule = &lang.reference_rules[rule_idx];
388                        self.extract_reference_from_rule(
389                            lang, rule, node, source, file_path, scope, references,
390                        );
391                    }
392                }
393
394                scope.push(scope_name);
395
396                if let Some(body) = self.get_scope_body(sc, node) {
397                    for child in body.children() {
398                        self.extract_references_recursive(
399                            lang, &child, source, file_path, scope, references,
400                        );
401                    }
402                } else {
403                    for child in node.children() {
404                        self.extract_references_recursive(
405                            lang, &child, source, file_path, scope, references,
406                        );
407                    }
408                }
409                scope.pop();
410                return;
411            }
412        }
413
414        // Check reference rules
415        if let Some(rule_indices) = lang.reference_index.get(kind_str) {
416            for &rule_idx in rule_indices {
417                let rule = &lang.reference_rules[rule_idx];
418                self.extract_reference_from_rule(
419                    lang, rule, node, source, file_path, scope, references,
420                );
421            }
422        }
423
424        // Default recursion
425        for child in node.children() {
426            self.extract_references_recursive(lang, &child, source, file_path, scope, references);
427        }
428    }
429
430    #[allow(clippy::too_many_arguments)]
431    fn extract_reference_from_rule<D: Doc>(
432        &self,
433        lang: &LanguageRules,
434        rule: &ReferenceRule,
435        node: &Node<'_, D>,
436        source: &str,
437        file_path: &str,
438        scope: &[String],
439        references: &mut Vec<Reference>,
440    ) where
441        D::Lang: ast_grep_core::Language,
442    {
443        // Handle special cases
444        if let Some(ref special) = rule.special {
445            self.handle_special_reference(
446                lang, special, node, source, file_path, scope, references,
447            );
448            return;
449        }
450
451        let ref_kind = match parse_reference_kind(&rule.reference_kind) {
452            Some(k) => k,
453            None => return,
454        };
455
456        let target_name = if let Some(ref field) = rule.name_field {
457            match self.get_node_field_text(node, field) {
458                Some(name) => name,
459                None => return,
460            }
461        } else {
462            // Use full node text, trimmed
463            let text = node.text();
464            text.trim().to_string()
465        };
466
467        if target_name.is_empty() {
468            return;
469        }
470
471        let source_qn = if scope.is_empty() {
472            file_path.to_string()
473        } else {
474            scope.join(lang.scope_separator)
475        };
476
477        push_ref(
478            references,
479            &source_qn,
480            target_name,
481            ref_kind,
482            file_path,
483            node.start_pos().line(),
484        );
485    }
486
487    // ── Node Utility Helpers ──────────────────────────────────────────
488
489    pub(crate) fn get_node_field_text<D: Doc>(
490        &self,
491        node: &Node<'_, D>,
492        field_name: &str,
493    ) -> Option<String>
494    where
495        D::Lang: ast_grep_core::Language,
496    {
497        node.field(field_name).map(|n| n.text().to_string())
498    }
499
500    fn get_scope_name<D: Doc>(
501        &self,
502        lang: &LanguageRules,
503        sc: &ScopeContainerRule,
504        node: &Node<'_, D>,
505        source: &str,
506    ) -> Option<String>
507    where
508        D::Lang: ast_grep_core::Language,
509    {
510        if let Some(ref special) = sc.special {
511            return self.get_special_scope_name(lang, special, node, source);
512        }
513        self.get_node_field_text(node, &sc.name_field)
514    }
515
516    fn get_scope_body<'a, D: Doc>(
517        &self,
518        sc: &ScopeContainerRule,
519        node: &Node<'a, D>,
520    ) -> Option<Node<'a, D>>
521    where
522        D::Lang: ast_grep_core::Language,
523    {
524        if let Some(ref special) = sc.special {
525            return self.get_special_scope_body(special, node);
526        }
527        node.field(&sc.body_field)
528    }
529
530    fn get_special_scope_name<D: Doc>(
531        &self,
532        _lang: &LanguageRules,
533        special: &str,
534        node: &Node<'_, D>,
535        _source: &str,
536    ) -> Option<String>
537    where
538        D::Lang: ast_grep_core::Language,
539    {
540        match special {
541            "go_method_scope" => {
542                // For Go method declarations, the scope is Receiver.MethodName
543                self.get_go_receiver_type(node)
544            }
545            "hcl_block_scope" => {
546                // HCL block: combine block type and labels
547                let mut parts = Vec::new();
548                for child in node.children() {
549                    let ck = child.kind();
550                    if ck.as_ref() == "identifier" && parts.is_empty() {
551                        parts.push(child.text().to_string());
552                    } else if ck.as_ref() == "string_lit" {
553                        parts.push(child.text().to_string().trim_matches('"').to_string());
554                    }
555                }
556                if parts.is_empty() {
557                    None
558                } else {
559                    Some(parts.join("."))
560                }
561            }
562            "kotlin_scope" => self.get_node_field_text(node, "name").or_else(|| {
563                for child in node.children() {
564                    let ck = child.kind();
565                    if ck.as_ref() == "type_identifier" || ck.as_ref() == "simple_identifier" {
566                        return Some(child.text().to_string());
567                    }
568                }
569                None
570            }),
571            "swift_class_scope" => {
572                // Swift class/struct: find name via field or first type_identifier/identifier child
573                self.get_node_field_text(node, "name").or_else(|| {
574                    node.children()
575                        .find(|c| {
576                            let ck = c.kind();
577                            ck.as_ref() == "type_identifier" || ck.as_ref() == "identifier"
578                        })
579                        .map(|c| c.text().to_string())
580                })
581            }
582            "cpp_namespace_scope" => {
583                // C++ namespace: may use qualified_identifier or name field
584                self.get_node_field_text(node, "name").or_else(|| {
585                    node.children()
586                        .find(|c| {
587                            let ck = c.kind();
588                            ck.as_ref() == "namespace_identifier" || ck.as_ref() == "identifier"
589                        })
590                        .map(|c| c.text().to_string())
591                })
592            }
593            _ => None,
594        }
595    }
596
597    fn get_special_scope_body<'a, D: Doc>(
598        &self,
599        _special: &str,
600        node: &Node<'a, D>,
601    ) -> Option<Node<'a, D>>
602    where
603        D::Lang: ast_grep_core::Language,
604    {
605        node.field("body")
606    }
607}
608
609impl Default for AstGrepEngine {
610    fn default() -> Self {
611        Self::new()
612    }
613}
614
615// ── Free Functions ─────────────────────────────────────────────────────
616
617/// Build a Symbol struct with the common scope→parent logic.
618#[allow(clippy::too_many_arguments)]
619fn build_symbol(
620    name: String,
621    kind: SymbolKind,
622    signature: String,
623    visibility: Visibility,
624    doc_comment: Option<String>,
625    file_path: &str,
626    line_start: usize,
627    line_end: usize,
628    scope: &[String],
629    scope_separator: &str,
630) -> Symbol {
631    Symbol {
632        qualified_name: build_qualified_name(scope, &name, scope_separator),
633        name,
634        kind,
635        signature,
636        visibility,
637        file_path: file_path.to_string(),
638        line_start,
639        line_end,
640        doc_comment,
641        parent: if scope.is_empty() {
642            None
643        } else {
644            Some(scope.join(scope_separator))
645        },
646        parameters: Vec::new(),
647        return_type: None,
648        is_async: false,
649        attributes: Vec::new(),
650        throws: Vec::new(),
651        generic_params: None,
652        is_abstract: false,
653    }
654}
655
656/// Push a Reference onto a collection with the standard fields.
657fn push_ref(
658    refs: &mut Vec<Reference>,
659    source_qn: &str,
660    target: String,
661    kind: ReferenceKind,
662    file_path: &str,
663    line: usize,
664) {
665    refs.push(Reference {
666        source_qualified_name: source_qn.to_string(),
667        target_name: target,
668        kind,
669        file_path: file_path.to_string(),
670        line,
671    });
672}
673
674fn build_qualified_name(scope: &[String], name: &str, separator: &str) -> String {
675    if scope.is_empty() {
676        name.to_string()
677    } else {
678        format!("{}{}{}", scope.join(separator), separator, name)
679    }
680}
681
682fn parse_symbol_kind(s: &str) -> Option<SymbolKind> {
683    match s {
684        "function" => Some(SymbolKind::Function),
685        "method" => Some(SymbolKind::Method),
686        "class" => Some(SymbolKind::Class),
687        "struct" => Some(SymbolKind::Struct),
688        "enum" => Some(SymbolKind::Enum),
689        "interface" => Some(SymbolKind::Interface),
690        "type" => Some(SymbolKind::Type),
691        "constant" => Some(SymbolKind::Constant),
692        "module" => Some(SymbolKind::Module),
693        "test" => Some(SymbolKind::Test),
694        "field" => Some(SymbolKind::Field),
695        "property" => Some(SymbolKind::Property),
696        "constructor" => Some(SymbolKind::Constructor),
697        "enum_variant" => Some(SymbolKind::EnumVariant),
698        "macro" => Some(SymbolKind::Macro),
699        "decorator" => Some(SymbolKind::Decorator),
700        _ => None,
701    }
702}
703
704fn parse_reference_kind(s: &str) -> Option<ReferenceKind> {
705    match s {
706        "import" => Some(ReferenceKind::Import),
707        "call" => Some(ReferenceKind::Call),
708        "inherits" => Some(ReferenceKind::Inherits),
709        "implements" => Some(ReferenceKind::Implements),
710        "type_usage" => Some(ReferenceKind::TypeUsage),
711        _ => None,
712    }
713}
714
715fn clean_block_doc_comment(text: &str) -> String {
716    let trimmed = text.trim_start_matches("/**").trim_end_matches("*/").trim();
717
718    let mut doc_lines = Vec::new();
719    for line in trimmed.lines() {
720        let line = line.trim();
721        let line = line
722            .strip_prefix("* ")
723            .or_else(|| line.strip_prefix('*'))
724            .unwrap_or(line);
725        let line = line.trim_end();
726        // Preserve internal blank lines (paragraph breaks)
727        doc_lines.push(line.to_string());
728    }
729    doc_lines.join("\n").trim().to_string()
730}
731
732#[cfg(test)]
733#[path = "../tests/engine_tests.rs"]
734mod tests;