1mod references;
8mod symbols;
9mod visibility;
10
11use crate::index::rule_loader::{LanguageRules, ReferenceRule, ScopeContainerRule, SymbolRule};
12use crate::index::symbol::{Reference, ReferenceKind, Symbol, SymbolKind, Visibility};
13use ast_grep_core::tree_sitter::LanguageExt;
14use ast_grep_core::tree_sitter::StrDoc;
15use ast_grep_core::{Doc, Node};
16use ast_grep_language::SupportLang;
17use std::borrow::Cow;
18use std::collections::{HashMap, HashSet};
19use std::sync::LazyLock;
20
21pub type SgNode<'r> = Node<'r, StrDoc<SupportLang>>;
23
24static LOADED_RULES: LazyLock<Vec<LanguageRules>> =
26 LazyLock::new(crate::index::rule_loader::load_all_rules);
27
28pub struct AstGrepEngine {
30 extension_index: HashMap<&'static str, usize>,
32}
33
34impl AstGrepEngine {
35 pub fn new() -> Self {
37 let mut extension_index = HashMap::new();
38 for (i, lr) in LOADED_RULES.iter().enumerate() {
39 for &ext in lr.extensions {
40 extension_index.insert(ext, i);
41 }
42 }
43 Self { extension_index }
44 }
45
46 pub fn find_language(&self, ext: &str) -> Option<&LanguageRules> {
48 self.extension_index.get(ext).map(|&idx| &LOADED_RULES[idx])
49 }
50
51 pub fn supports_extension(&self, ext: &str) -> bool {
53 self.extension_index.contains_key(ext)
54 }
55
56 pub fn extract_symbols(
57 &self,
58 lang: &LanguageRules,
59 source: &str,
60 file_path: &str,
61 ) -> Vec<Symbol> {
62 let root = lang.lang.ast_grep(source);
63 self.extract_symbols_from_tree(lang, &root, source, file_path)
64 }
65
66 pub fn extract_symbols_from_tree(
68 &self,
69 lang: &LanguageRules,
70 root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
71 source: &str,
72 file_path: &str,
73 ) -> Vec<Symbol> {
74 let root_node = root.root();
75 let mut symbols = Vec::new();
76 let mut scope = Vec::new();
77 self.extract_symbols_recursive(
78 lang,
79 &root_node,
80 source,
81 file_path,
82 &mut scope,
83 false,
84 &mut symbols,
85 );
86 symbols
87 }
88
89 pub fn extract_references(
92 &self,
93 lang: &LanguageRules,
94 source: &str,
95 file_path: &str,
96 ) -> Vec<Reference> {
97 let root = lang.lang.ast_grep(source);
98 self.extract_references_from_tree(lang, &root, source, file_path)
99 }
100
101 pub fn extract_references_from_tree(
103 &self,
104 lang: &LanguageRules,
105 root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
106 source: &str,
107 file_path: &str,
108 ) -> Vec<Reference> {
109 let root_node = root.root();
110 let mut references = Vec::new();
111 let mut scope = Vec::new();
112 self.extract_references_recursive(
113 lang,
114 &root_node,
115 source,
116 file_path,
117 &mut scope,
118 &mut references,
119 );
120
121 let mut seen = HashSet::new();
123 references.retain(|r| {
124 seen.insert((
125 r.source_qualified_name.clone(),
126 r.target_name.clone(),
127 r.kind,
128 ))
129 });
130
131 references
132 }
133
134 #[allow(clippy::too_many_arguments)]
137 fn extract_symbols_recursive<D: Doc>(
138 &self,
139 lang: &LanguageRules,
140 node: &Node<'_, D>,
141 source: &str,
142 file_path: &str,
143 scope: &mut Vec<String>,
144 in_method_scope: bool,
145 symbols: &mut Vec<Symbol>,
146 ) where
147 D::Lang: ast_grep_core::Language,
148 {
149 let kind: Cow<'_, str> = node.kind();
150 let kind_str = kind.as_ref();
151
152 if lang.symbol_unwrap_set.contains(kind_str) {
154 for child in node.children() {
155 self.extract_symbols_recursive(
156 lang,
157 &child,
158 source,
159 file_path,
160 scope,
161 in_method_scope,
162 symbols,
163 );
164 }
165 return;
166 }
167
168 let handled_as_scope_container = lang.symbol_scope_index.contains_key(kind_str);
170 if let Some(&sc_idx) = lang.symbol_scope_index.get(kind_str) {
171 let sc = &lang.symbol_scope_containers[sc_idx];
172 if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
173 scope.push(scope_name);
174 let new_method_scope = sc.is_method_scope;
175
176 if let Some(body) = self.get_scope_body(sc, node) {
178 for child in body.children() {
179 self.extract_symbols_recursive(
180 lang,
181 &child,
182 source,
183 file_path,
184 scope,
185 new_method_scope,
186 symbols,
187 );
188 }
189 } else {
190 for child in node.children() {
192 self.extract_symbols_recursive(
193 lang,
194 &child,
195 source,
196 file_path,
197 scope,
198 new_method_scope,
199 symbols,
200 );
201 }
202 }
203 scope.pop();
204 }
207 }
208
209 if let Some(rule_indices) = lang.symbol_index.get(kind_str) {
211 for &rule_idx in rule_indices {
212 let rule = &lang.symbol_rules[rule_idx];
213
214 if let Some(ref special) = rule.special {
216 let multi = self
217 .handle_special_symbol_multi(lang, special, node, source, file_path, scope);
218 if !multi.is_empty() {
219 symbols.extend(multi);
220 return;
221 }
222 }
223
224 if let Some(sym) = self.extract_symbol_from_rule(
225 lang,
226 rule,
227 node,
228 source,
229 file_path,
230 scope,
231 in_method_scope,
232 ) {
233 let name = sym.name.clone();
234 let is_scope = rule.is_scope;
235 symbols.push(sym);
236
237 if is_scope && !handled_as_scope_container {
240 scope.push(name);
241 if let Some(body) = node.field("body") {
242 for child in body.children() {
243 self.extract_symbols_recursive(
244 lang,
245 &child,
246 source,
247 file_path,
248 scope,
249 in_method_scope,
250 symbols,
251 );
252 }
253 }
254 scope.pop();
255 }
256 return; }
258 }
259 }
260
261 if handled_as_scope_container {
265 return; }
267
268 for child in node.children() {
270 self.extract_symbols_recursive(
271 lang,
272 &child,
273 source,
274 file_path,
275 scope,
276 in_method_scope,
277 symbols,
278 );
279 }
280 }
281
282 #[allow(clippy::too_many_arguments)]
283 fn extract_symbol_from_rule<D: Doc>(
284 &self,
285 lang: &LanguageRules,
286 rule: &SymbolRule,
287 node: &Node<'_, D>,
288 source: &str,
289 file_path: &str,
290 scope: &[String],
291 in_method_scope: bool,
292 ) -> Option<Symbol>
293 where
294 D::Lang: ast_grep_core::Language,
295 {
296 if let Some(ref special) = rule.special {
298 let mut sym =
299 self.handle_special_symbol(lang, special, node, source, file_path, scope)?;
300 self.enrich_symbol_metadata(lang.name, node, &mut sym);
301 return Some(sym);
302 }
303
304 let name = self.get_node_field_text(node, &rule.name_field)?;
306 if name.is_empty() {
307 return None;
308 }
309
310 let base_kind = parse_symbol_kind(&rule.symbol_kind)?;
312 let is_test = self.detect_test(lang.name, node, source, file_path, &name);
313 let kind = if is_test {
314 SymbolKind::Test
315 } else if rule.method_when_scoped && in_method_scope {
316 SymbolKind::Method
317 } else {
318 base_kind
319 };
320
321 let visibility = self.detect_visibility(lang.name, node, source, &name);
322 let signature = self.extract_signature(lang.name, node, source);
323 let doc_comment = self.extract_doc_comment(lang.name, node, source);
324
325 let mut sym = build_symbol(
326 name,
327 kind,
328 signature,
329 visibility,
330 doc_comment,
331 file_path,
332 node.start_pos().line(),
333 node.end_pos().line(),
334 scope,
335 lang.scope_separator,
336 );
337 self.enrich_symbol_metadata(lang.name, node, &mut sym);
338 Some(sym)
339 }
340
341 fn extract_references_recursive<D: Doc>(
344 &self,
345 lang: &LanguageRules,
346 node: &Node<'_, D>,
347 source: &str,
348 file_path: &str,
349 scope: &mut Vec<String>,
350 references: &mut Vec<Reference>,
351 ) where
352 D::Lang: ast_grep_core::Language,
353 {
354 let kind: Cow<'_, str> = node.kind();
355 let kind_str = kind.as_ref();
356
357 if lang.reference_unwrap_set.contains(kind_str) {
359 for child in node.children() {
360 self.extract_references_recursive(
361 lang, &child, source, file_path, scope, references,
362 );
363 }
364 return;
365 }
366
367 if let Some(&sc_idx) = lang.reference_scope_index.get(kind_str) {
369 let sc = &lang.reference_scope_containers[sc_idx];
370 if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
371 if let Some(rule_indices) = lang.reference_index.get(kind_str) {
373 for &rule_idx in rule_indices {
374 let rule = &lang.reference_rules[rule_idx];
375 self.extract_reference_from_rule(
376 lang, rule, node, source, file_path, scope, references,
377 );
378 }
379 }
380
381 scope.push(scope_name);
382
383 if let Some(body) = self.get_scope_body(sc, node) {
384 for child in body.children() {
385 self.extract_references_recursive(
386 lang, &child, source, file_path, scope, references,
387 );
388 }
389 } else {
390 for child in node.children() {
391 self.extract_references_recursive(
392 lang, &child, source, file_path, scope, references,
393 );
394 }
395 }
396 scope.pop();
397 return;
398 }
399 }
400
401 if let Some(rule_indices) = lang.reference_index.get(kind_str) {
403 for &rule_idx in rule_indices {
404 let rule = &lang.reference_rules[rule_idx];
405 self.extract_reference_from_rule(
406 lang, rule, node, source, file_path, scope, references,
407 );
408 }
409 }
410
411 for child in node.children() {
413 self.extract_references_recursive(lang, &child, source, file_path, scope, references);
414 }
415 }
416
417 #[allow(clippy::too_many_arguments)]
418 fn extract_reference_from_rule<D: Doc>(
419 &self,
420 lang: &LanguageRules,
421 rule: &ReferenceRule,
422 node: &Node<'_, D>,
423 source: &str,
424 file_path: &str,
425 scope: &[String],
426 references: &mut Vec<Reference>,
427 ) where
428 D::Lang: ast_grep_core::Language,
429 {
430 if let Some(ref special) = rule.special {
432 self.handle_special_reference(
433 lang, special, node, source, file_path, scope, references,
434 );
435 return;
436 }
437
438 let ref_kind = match parse_reference_kind(&rule.reference_kind) {
439 Some(k) => k,
440 None => return,
441 };
442
443 let target_name = if let Some(ref field) = rule.name_field {
444 match self.get_node_field_text(node, field) {
445 Some(name) => name,
446 None => return,
447 }
448 } else {
449 let text = node.text();
451 text.trim().to_string()
452 };
453
454 if target_name.is_empty() {
455 return;
456 }
457
458 let source_qn = if scope.is_empty() {
459 file_path.to_string()
460 } else {
461 scope.join(lang.scope_separator)
462 };
463
464 push_ref(
465 references,
466 &source_qn,
467 target_name,
468 ref_kind,
469 file_path,
470 node.start_pos().line(),
471 );
472 }
473
474 pub(crate) fn get_node_field_text<D: Doc>(
477 &self,
478 node: &Node<'_, D>,
479 field_name: &str,
480 ) -> Option<String>
481 where
482 D::Lang: ast_grep_core::Language,
483 {
484 node.field(field_name).map(|n| n.text().to_string())
485 }
486
487 fn get_scope_name<D: Doc>(
488 &self,
489 lang: &LanguageRules,
490 sc: &ScopeContainerRule,
491 node: &Node<'_, D>,
492 source: &str,
493 ) -> Option<String>
494 where
495 D::Lang: ast_grep_core::Language,
496 {
497 if let Some(ref special) = sc.special {
498 return self.get_special_scope_name(lang, special, node, source);
499 }
500 self.get_node_field_text(node, &sc.name_field)
501 }
502
503 fn get_scope_body<'a, D: Doc>(
504 &self,
505 sc: &ScopeContainerRule,
506 node: &Node<'a, D>,
507 ) -> Option<Node<'a, D>>
508 where
509 D::Lang: ast_grep_core::Language,
510 {
511 if let Some(ref special) = sc.special {
512 return self.get_special_scope_body(special, node);
513 }
514 node.field(&sc.body_field)
515 }
516
517 fn get_special_scope_name<D: Doc>(
518 &self,
519 _lang: &LanguageRules,
520 special: &str,
521 node: &Node<'_, D>,
522 _source: &str,
523 ) -> Option<String>
524 where
525 D::Lang: ast_grep_core::Language,
526 {
527 match special {
528 "go_method_scope" => {
529 self.get_go_receiver_type(node)
531 }
532 "hcl_block_scope" => {
533 let mut parts = Vec::new();
535 for child in node.children() {
536 let ck = child.kind();
537 if ck.as_ref() == "identifier" && parts.is_empty() {
538 parts.push(child.text().to_string());
539 } else if ck.as_ref() == "string_lit" {
540 parts.push(child.text().to_string().trim_matches('"').to_string());
541 }
542 }
543 if parts.is_empty() {
544 None
545 } else {
546 Some(parts.join("."))
547 }
548 }
549 "kotlin_scope" => self.get_node_field_text(node, "name").or_else(|| {
550 for child in node.children() {
551 let ck = child.kind();
552 if ck.as_ref() == "type_identifier" || ck.as_ref() == "simple_identifier" {
553 return Some(child.text().to_string());
554 }
555 }
556 None
557 }),
558 "swift_class_scope" => {
559 self.get_node_field_text(node, "name").or_else(|| {
561 node.children()
562 .find(|c| {
563 let ck = c.kind();
564 ck.as_ref() == "type_identifier" || ck.as_ref() == "identifier"
565 })
566 .map(|c| c.text().to_string())
567 })
568 }
569 "cpp_namespace_scope" => {
570 self.get_node_field_text(node, "name").or_else(|| {
572 node.children()
573 .find(|c| {
574 let ck = c.kind();
575 ck.as_ref() == "namespace_identifier" || ck.as_ref() == "identifier"
576 })
577 .map(|c| c.text().to_string())
578 })
579 }
580 _ => None,
581 }
582 }
583
584 fn get_special_scope_body<'a, D: Doc>(
585 &self,
586 _special: &str,
587 node: &Node<'a, D>,
588 ) -> Option<Node<'a, D>>
589 where
590 D::Lang: ast_grep_core::Language,
591 {
592 node.field("body")
593 }
594}
595
596impl Default for AstGrepEngine {
597 fn default() -> Self {
598 Self::new()
599 }
600}
601
602#[allow(clippy::too_many_arguments)]
606fn build_symbol(
607 name: String,
608 kind: SymbolKind,
609 signature: String,
610 visibility: Visibility,
611 doc_comment: Option<String>,
612 file_path: &str,
613 line_start: usize,
614 line_end: usize,
615 scope: &[String],
616 scope_separator: &str,
617) -> Symbol {
618 Symbol {
619 qualified_name: build_qualified_name(scope, &name, scope_separator),
620 name,
621 kind,
622 signature,
623 visibility,
624 file_path: file_path.to_string(),
625 line_start,
626 line_end,
627 doc_comment,
628 parent: if scope.is_empty() {
629 None
630 } else {
631 Some(scope.join(scope_separator))
632 },
633 parameters: Vec::new(),
634 return_type: None,
635 is_async: false,
636 attributes: Vec::new(),
637 throws: Vec::new(),
638 generic_params: None,
639 is_abstract: false,
640 }
641}
642
643fn push_ref(
645 refs: &mut Vec<Reference>,
646 source_qn: &str,
647 target: String,
648 kind: ReferenceKind,
649 file_path: &str,
650 line: usize,
651) {
652 refs.push(Reference {
653 source_qualified_name: source_qn.to_string(),
654 target_name: target,
655 kind,
656 file_path: file_path.to_string(),
657 line,
658 });
659}
660
661fn build_qualified_name(scope: &[String], name: &str, separator: &str) -> String {
662 if scope.is_empty() {
663 name.to_string()
664 } else {
665 format!("{}{}{}", scope.join(separator), separator, name)
666 }
667}
668
669fn parse_symbol_kind(s: &str) -> Option<SymbolKind> {
670 match s {
671 "function" => Some(SymbolKind::Function),
672 "method" => Some(SymbolKind::Method),
673 "class" => Some(SymbolKind::Class),
674 "struct" => Some(SymbolKind::Struct),
675 "enum" => Some(SymbolKind::Enum),
676 "interface" => Some(SymbolKind::Interface),
677 "type" => Some(SymbolKind::Type),
678 "constant" => Some(SymbolKind::Constant),
679 "module" => Some(SymbolKind::Module),
680 "test" => Some(SymbolKind::Test),
681 "field" => Some(SymbolKind::Field),
682 "constructor" => Some(SymbolKind::Constructor),
683 _ => None,
684 }
685}
686
687fn parse_reference_kind(s: &str) -> Option<ReferenceKind> {
688 match s {
689 "import" => Some(ReferenceKind::Import),
690 "call" => Some(ReferenceKind::Call),
691 "inherits" => Some(ReferenceKind::Inherits),
692 "implements" => Some(ReferenceKind::Implements),
693 "type_usage" => Some(ReferenceKind::TypeUsage),
694 _ => None,
695 }
696}
697
698fn clean_block_doc_comment(text: &str) -> String {
699 let trimmed = text.trim_start_matches("/**").trim_end_matches("*/").trim();
700
701 let mut doc_lines = Vec::new();
702 for line in trimmed.lines() {
703 let line = line.trim();
704 let line = line
705 .strip_prefix("* ")
706 .or_else(|| line.strip_prefix('*'))
707 .unwrap_or(line);
708 let line = line.trim_end();
709 doc_lines.push(line.to_string());
711 }
712 doc_lines.join("\n").trim().to_string()
713}
714
715#[cfg(test)]
716#[path = "../tests/engine_symbols_tests.rs"]
717mod engine_symbols_tests;
718
719#[cfg(test)]
720#[path = "../tests/engine_references_tests.rs"]
721mod engine_references_tests;
722
723#[cfg(test)]
724#[path = "../tests/engine_cross_cutting_tests.rs"]
725mod engine_cross_cutting_tests;