1mod references;
8mod symbols;
9mod visibility;
10
11use crate::index::rule_loader::{LanguageRules, ReferenceRule, ScopeContainerRule, SymbolRule};
12use crate::index::symbol::{Reference, ReferenceKind, Symbol, SymbolKind, Visibility};
13use ast_grep_core::tree_sitter::LanguageExt;
14use ast_grep_core::tree_sitter::StrDoc;
15use ast_grep_core::{Doc, Node};
16use ast_grep_language::SupportLang;
17use std::borrow::Cow;
18use std::collections::{HashMap, HashSet};
19use std::sync::LazyLock;
20
21pub type SgNode<'r> = Node<'r, StrDoc<SupportLang>>;
23
24static LOADED_RULES: LazyLock<Vec<LanguageRules>> = LazyLock::new(|| {
26 crate::index::rule_loader::load_all_rules()
27 .expect("embedded YAML rule files must deserialize successfully")
28});
29
30pub struct AstGrepEngine {
32 extension_index: HashMap<&'static str, usize>,
34}
35
36impl AstGrepEngine {
37 pub fn new() -> Self {
39 let mut extension_index = HashMap::new();
40 for (i, lr) in LOADED_RULES.iter().enumerate() {
41 for &ext in lr.extensions {
42 extension_index.insert(ext, i);
43 }
44 }
45 Self { extension_index }
46 }
47
48 pub fn find_language(&self, ext: &str) -> Option<&LanguageRules> {
50 self.extension_index.get(ext).map(|&idx| &LOADED_RULES[idx])
51 }
52
53 pub fn supports_extension(&self, ext: &str) -> bool {
55 self.extension_index.contains_key(ext)
56 }
57
58 pub fn extract_symbols(
59 &self,
60 lang: &LanguageRules,
61 source: &str,
62 file_path: &str,
63 ) -> Vec<Symbol> {
64 let root = lang.lang.ast_grep(source);
65 self.extract_symbols_from_tree(lang, &root, source, file_path)
66 }
67
68 pub fn extract_symbols_from_tree(
70 &self,
71 lang: &LanguageRules,
72 root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
73 source: &str,
74 file_path: &str,
75 ) -> Vec<Symbol> {
76 let root_node = root.root();
77 let mut symbols = Vec::new();
78 let mut scope = Vec::new();
79 self.extract_symbols_recursive(
80 lang,
81 &root_node,
82 source,
83 file_path,
84 &mut scope,
85 false,
86 &mut symbols,
87 );
88 symbols
89 }
90
91 pub fn extract_references(
94 &self,
95 lang: &LanguageRules,
96 source: &str,
97 file_path: &str,
98 ) -> Vec<Reference> {
99 let root = lang.lang.ast_grep(source);
100 self.extract_references_from_tree(lang, &root, source, file_path)
101 }
102
103 pub fn extract_references_from_tree(
105 &self,
106 lang: &LanguageRules,
107 root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
108 source: &str,
109 file_path: &str,
110 ) -> Vec<Reference> {
111 let root_node = root.root();
112 let mut references = Vec::new();
113 let mut scope = Vec::new();
114 self.extract_references_recursive(
115 lang,
116 &root_node,
117 source,
118 file_path,
119 &mut scope,
120 &mut references,
121 );
122
123 let mut seen = HashSet::new();
125 references.retain(|r| {
126 seen.insert((
127 r.source_qualified_name.clone(),
128 r.target_name.clone(),
129 r.kind,
130 ))
131 });
132
133 references.retain(|r| {
135 if !matches!(r.kind, ReferenceKind::Call | ReferenceKind::Callback) {
136 return true; }
138 let simple = r
139 .target_name
140 .rsplit(lang.scope_separator)
141 .next()
142 .unwrap_or(&r.target_name);
143 !crate::index::blocklist::is_blocked_call(lang.name, simple)
144 });
145
146 references
147 }
148
149 #[allow(clippy::too_many_arguments)]
152 fn extract_symbols_recursive<D: Doc>(
153 &self,
154 lang: &LanguageRules,
155 node: &Node<'_, D>,
156 source: &str,
157 file_path: &str,
158 scope: &mut Vec<String>,
159 in_method_scope: bool,
160 symbols: &mut Vec<Symbol>,
161 ) where
162 D::Lang: ast_grep_core::Language,
163 {
164 let kind: Cow<'_, str> = node.kind();
165 let kind_str = kind.as_ref();
166
167 if lang.symbol_unwrap_set.contains(kind_str) {
169 for child in node.children() {
170 self.extract_symbols_recursive(
171 lang,
172 &child,
173 source,
174 file_path,
175 scope,
176 in_method_scope,
177 symbols,
178 );
179 }
180 return;
181 }
182
183 let handled_as_scope_container = lang.symbol_scope_index.contains_key(kind_str);
185 if let Some(&sc_idx) = lang.symbol_scope_index.get(kind_str) {
186 let sc = &lang.symbol_scope_containers[sc_idx];
187 if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
188 scope.push(scope_name);
189 let new_method_scope = sc.is_method_scope;
190
191 if let Some(body) = self.get_scope_body(sc, node) {
193 for child in body.children() {
194 self.extract_symbols_recursive(
195 lang,
196 &child,
197 source,
198 file_path,
199 scope,
200 new_method_scope,
201 symbols,
202 );
203 }
204 } else {
205 for child in node.children() {
207 self.extract_symbols_recursive(
208 lang,
209 &child,
210 source,
211 file_path,
212 scope,
213 new_method_scope,
214 symbols,
215 );
216 }
217 }
218 scope.pop();
219 }
222 }
223
224 if let Some(rule_indices) = lang.symbol_index.get(kind_str) {
226 for &rule_idx in rule_indices {
227 let rule = &lang.symbol_rules[rule_idx];
228
229 if let Some(ref special) = rule.special {
231 let multi = self
232 .handle_special_symbol_multi(lang, special, node, source, file_path, scope);
233 if !multi.is_empty() {
234 symbols.extend(multi);
235 return;
236 }
237 }
238
239 if let Some(sym) = self.extract_symbol_from_rule(
240 lang,
241 rule,
242 node,
243 source,
244 file_path,
245 scope,
246 in_method_scope,
247 ) {
248 let name = sym.name.clone();
249 let is_scope = rule.is_scope;
250 symbols.push(sym);
251
252 if is_scope && !handled_as_scope_container {
255 scope.push(name);
256 if let Some(body) = node.field("body") {
257 for child in body.children() {
258 self.extract_symbols_recursive(
259 lang,
260 &child,
261 source,
262 file_path,
263 scope,
264 in_method_scope,
265 symbols,
266 );
267 }
268 }
269 scope.pop();
270 }
271 return; }
273 }
274 }
275
276 if handled_as_scope_container {
280 return; }
282
283 for child in node.children() {
285 self.extract_symbols_recursive(
286 lang,
287 &child,
288 source,
289 file_path,
290 scope,
291 in_method_scope,
292 symbols,
293 );
294 }
295 }
296
297 #[allow(clippy::too_many_arguments)]
298 fn extract_symbol_from_rule<D: Doc>(
299 &self,
300 lang: &LanguageRules,
301 rule: &SymbolRule,
302 node: &Node<'_, D>,
303 source: &str,
304 file_path: &str,
305 scope: &[String],
306 in_method_scope: bool,
307 ) -> Option<Symbol>
308 where
309 D::Lang: ast_grep_core::Language,
310 {
311 if let Some(ref special) = rule.special {
313 let mut sym =
314 self.handle_special_symbol(lang, special, node, source, file_path, scope)?;
315 self.enrich_symbol_metadata(lang.name, node, &mut sym);
316 return Some(sym);
317 }
318
319 let name = self.get_node_field_text(node, &rule.name_field)?;
321 if name.is_empty() {
322 return None;
323 }
324
325 let base_kind = parse_symbol_kind(&rule.symbol_kind)?;
327 let is_test = self.detect_test(lang.name, node, source, file_path, &name);
328 let kind = if is_test {
329 SymbolKind::Test
330 } else if rule.method_when_scoped && in_method_scope {
331 SymbolKind::Method
332 } else {
333 base_kind
334 };
335
336 let visibility = self.detect_visibility(lang.name, node, source, &name);
337 let signature = self.extract_signature(lang.name, node, source);
338 let doc_comment = self.extract_doc_comment(lang.name, node, source);
339
340 let mut sym = build_symbol(
341 name,
342 kind,
343 signature,
344 visibility,
345 doc_comment,
346 file_path,
347 node.start_pos().line(),
348 node.end_pos().line(),
349 scope,
350 lang.scope_separator,
351 );
352 self.enrich_symbol_metadata(lang.name, node, &mut sym);
353 Some(sym)
354 }
355
356 fn extract_references_recursive<D: Doc>(
359 &self,
360 lang: &LanguageRules,
361 node: &Node<'_, D>,
362 source: &str,
363 file_path: &str,
364 scope: &mut Vec<String>,
365 references: &mut Vec<Reference>,
366 ) where
367 D::Lang: ast_grep_core::Language,
368 {
369 let kind: Cow<'_, str> = node.kind();
370 let kind_str = kind.as_ref();
371
372 if lang.reference_unwrap_set.contains(kind_str) {
374 for child in node.children() {
375 self.extract_references_recursive(
376 lang, &child, source, file_path, scope, references,
377 );
378 }
379 return;
380 }
381
382 if let Some(&sc_idx) = lang.reference_scope_index.get(kind_str) {
384 let sc = &lang.reference_scope_containers[sc_idx];
385 if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
386 if let Some(rule_indices) = lang.reference_index.get(kind_str) {
388 for &rule_idx in rule_indices {
389 let rule = &lang.reference_rules[rule_idx];
390 self.extract_reference_from_rule(
391 lang, rule, node, source, file_path, scope, references,
392 );
393 }
394 }
395
396 scope.push(scope_name);
397
398 if let Some(body) = self.get_scope_body(sc, node) {
399 for child in body.children() {
400 self.extract_references_recursive(
401 lang, &child, source, file_path, scope, references,
402 );
403 }
404 } else {
405 for child in node.children() {
406 self.extract_references_recursive(
407 lang, &child, source, file_path, scope, references,
408 );
409 }
410 }
411 scope.pop();
412 return;
413 }
414 }
415
416 if let Some(rule_indices) = lang.reference_index.get(kind_str) {
418 for &rule_idx in rule_indices {
419 let rule = &lang.reference_rules[rule_idx];
420 self.extract_reference_from_rule(
421 lang, rule, node, source, file_path, scope, references,
422 );
423 }
424 }
425
426 for child in node.children() {
428 self.extract_references_recursive(lang, &child, source, file_path, scope, references);
429 }
430 }
431
432 #[allow(clippy::too_many_arguments)]
433 fn extract_reference_from_rule<D: Doc>(
434 &self,
435 lang: &LanguageRules,
436 rule: &ReferenceRule,
437 node: &Node<'_, D>,
438 source: &str,
439 file_path: &str,
440 scope: &[String],
441 references: &mut Vec<Reference>,
442 ) where
443 D::Lang: ast_grep_core::Language,
444 {
445 if let Some(ref special) = rule.special {
447 self.handle_special_reference(
448 lang, special, node, source, file_path, scope, references,
449 );
450 return;
451 }
452
453 let ref_kind = match parse_reference_kind(&rule.reference_kind) {
454 Some(k) => k,
455 None => return,
456 };
457
458 let target_name = if let Some(ref field) = rule.name_field {
459 match self.get_node_field_text(node, field) {
460 Some(name) => name,
461 None => return,
462 }
463 } else {
464 let text = node.text();
466 text.trim().to_string()
467 };
468
469 if target_name.is_empty() {
470 return;
471 }
472
473 let source_qn = if scope.is_empty() {
474 file_path.to_string()
475 } else {
476 scope.join(lang.scope_separator)
477 };
478
479 push_ref(
480 references,
481 &source_qn,
482 target_name,
483 ref_kind,
484 file_path,
485 node.start_pos().line(),
486 );
487 }
488
489 pub(crate) fn get_node_field_text<D: Doc>(
492 &self,
493 node: &Node<'_, D>,
494 field_name: &str,
495 ) -> Option<String>
496 where
497 D::Lang: ast_grep_core::Language,
498 {
499 node.field(field_name).map(|n| n.text().to_string())
500 }
501
502 fn get_scope_name<D: Doc>(
503 &self,
504 lang: &LanguageRules,
505 sc: &ScopeContainerRule,
506 node: &Node<'_, D>,
507 source: &str,
508 ) -> Option<String>
509 where
510 D::Lang: ast_grep_core::Language,
511 {
512 if let Some(ref special) = sc.special {
513 return self.get_special_scope_name(lang, special, node, source);
514 }
515 self.get_node_field_text(node, &sc.name_field)
516 }
517
518 fn get_scope_body<'a, D: Doc>(
519 &self,
520 sc: &ScopeContainerRule,
521 node: &Node<'a, D>,
522 ) -> Option<Node<'a, D>>
523 where
524 D::Lang: ast_grep_core::Language,
525 {
526 if let Some(ref special) = sc.special {
527 return self.get_special_scope_body(special, node);
528 }
529 node.field(&sc.body_field)
530 }
531
532 fn get_special_scope_name<D: Doc>(
533 &self,
534 _lang: &LanguageRules,
535 special: &str,
536 node: &Node<'_, D>,
537 _source: &str,
538 ) -> Option<String>
539 where
540 D::Lang: ast_grep_core::Language,
541 {
542 match special {
543 "go_method_scope" => {
544 self.get_go_receiver_type(node)
546 }
547 "hcl_block_scope" => {
548 let mut parts = Vec::new();
550 for child in node.children() {
551 let ck = child.kind();
552 if ck.as_ref() == "identifier" && parts.is_empty() {
553 parts.push(child.text().to_string());
554 } else if ck.as_ref() == "string_lit" {
555 parts.push(child.text().to_string().trim_matches('"').to_string());
556 }
557 }
558 if parts.is_empty() {
559 None
560 } else {
561 Some(parts.join("."))
562 }
563 }
564 "kotlin_scope" => self.get_node_field_text(node, "name").or_else(|| {
565 for child in node.children() {
566 let ck = child.kind();
567 if ck.as_ref() == "type_identifier" || ck.as_ref() == "simple_identifier" {
568 return Some(child.text().to_string());
569 }
570 }
571 None
572 }),
573 "swift_class_scope" => {
574 self.get_node_field_text(node, "name").or_else(|| {
576 node.children()
577 .find(|c| {
578 let ck = c.kind();
579 ck.as_ref() == "type_identifier" || ck.as_ref() == "identifier"
580 })
581 .map(|c| c.text().to_string())
582 })
583 }
584 "cpp_namespace_scope" => {
585 self.get_node_field_text(node, "name").or_else(|| {
587 node.children()
588 .find(|c| {
589 let ck = c.kind();
590 ck.as_ref() == "namespace_identifier" || ck.as_ref() == "identifier"
591 })
592 .map(|c| c.text().to_string())
593 })
594 }
595 _ => None,
596 }
597 }
598
599 fn get_special_scope_body<'a, D: Doc>(
600 &self,
601 _special: &str,
602 node: &Node<'a, D>,
603 ) -> Option<Node<'a, D>>
604 where
605 D::Lang: ast_grep_core::Language,
606 {
607 node.field("body")
608 }
609}
610
611impl Default for AstGrepEngine {
612 fn default() -> Self {
613 Self::new()
614 }
615}
616
617#[allow(clippy::too_many_arguments)]
621fn build_symbol(
622 name: String,
623 kind: SymbolKind,
624 signature: String,
625 visibility: Visibility,
626 doc_comment: Option<String>,
627 file_path: &str,
628 line_start: usize,
629 line_end: usize,
630 scope: &[String],
631 scope_separator: &str,
632) -> Symbol {
633 Symbol {
634 qualified_name: build_qualified_name(scope, &name, scope_separator),
635 name,
636 kind,
637 signature,
638 visibility,
639 file_path: file_path.to_string(),
640 line_start,
641 line_end,
642 doc_comment,
643 parent: if scope.is_empty() {
644 None
645 } else {
646 Some(scope.join(scope_separator))
647 },
648 parameters: Vec::new(),
649 return_type: None,
650 is_async: false,
651 attributes: Vec::new(),
652 throws: Vec::new(),
653 generic_params: None,
654 is_abstract: false,
655 }
656}
657
658fn push_ref(
660 refs: &mut Vec<Reference>,
661 source_qn: &str,
662 target: String,
663 kind: ReferenceKind,
664 file_path: &str,
665 line: usize,
666) {
667 refs.push(Reference {
668 source_qualified_name: source_qn.to_string(),
669 target_name: target,
670 kind,
671 file_path: file_path.to_string(),
672 line,
673 });
674}
675
676fn build_qualified_name(scope: &[String], name: &str, separator: &str) -> String {
677 if scope.is_empty() {
678 name.to_string()
679 } else {
680 format!("{}{}{}", scope.join(separator), separator, name)
681 }
682}
683
684fn parse_symbol_kind(s: &str) -> Option<SymbolKind> {
685 match s {
686 "function" => Some(SymbolKind::Function),
687 "method" => Some(SymbolKind::Method),
688 "class" => Some(SymbolKind::Class),
689 "struct" => Some(SymbolKind::Struct),
690 "enum" => Some(SymbolKind::Enum),
691 "interface" => Some(SymbolKind::Interface),
692 "type" => Some(SymbolKind::Type),
693 "constant" => Some(SymbolKind::Constant),
694 "module" => Some(SymbolKind::Module),
695 "test" => Some(SymbolKind::Test),
696 "field" => Some(SymbolKind::Field),
697 "constructor" => Some(SymbolKind::Constructor),
698 _ => None,
699 }
700}
701
702fn parse_reference_kind(s: &str) -> Option<ReferenceKind> {
703 match s {
704 "import" => Some(ReferenceKind::Import),
705 "call" => Some(ReferenceKind::Call),
706 "callback" => Some(ReferenceKind::Callback),
707 "inherits" => Some(ReferenceKind::Inherits),
708 "implements" => Some(ReferenceKind::Implements),
709 "type_usage" => Some(ReferenceKind::TypeUsage),
710 _ => None,
711 }
712}
713
714fn clean_block_doc_comment(text: &str) -> String {
715 let trimmed = text.trim_start_matches("/**").trim_end_matches("*/").trim();
716
717 let mut doc_lines = Vec::new();
718 for line in trimmed.lines() {
719 let line = line.trim();
720 let line = line
721 .strip_prefix("* ")
722 .or_else(|| line.strip_prefix('*'))
723 .unwrap_or(line);
724 let line = line.trim_end();
725 doc_lines.push(line.to_string());
727 }
728 doc_lines.join("\n").trim().to_string()
729}
730
731#[cfg(test)]
732#[path = "../tests/engine_symbols_tests.rs"]
733mod engine_symbols_tests;
734
735#[cfg(test)]
736#[path = "../tests/engine_references_tests.rs"]
737mod engine_references_tests;
738
739#[cfg(test)]
740#[path = "../tests/engine_cross_cutting_tests.rs"]
741mod engine_cross_cutting_tests;