1mod references;
8mod symbols;
9mod visibility;
10
11use crate::index::rule_loader::{LanguageRules, ReferenceRule, ScopeContainerRule, SymbolRule};
12use crate::index::symbol::{Reference, ReferenceKind, Symbol, SymbolKind, Visibility};
13use ast_grep_core::tree_sitter::LanguageExt;
14use ast_grep_core::tree_sitter::StrDoc;
15use ast_grep_core::{Doc, Node};
16use ast_grep_language::SupportLang;
17use std::borrow::Cow;
18use std::collections::{HashMap, HashSet};
19use std::sync::LazyLock;
20
21pub type SgNode<'r> = Node<'r, StrDoc<SupportLang>>;
23
24static LOADED_RULES: LazyLock<Vec<LanguageRules>> = LazyLock::new(|| {
26 crate::index::rule_loader::load_all_rules()
27 .expect("embedded YAML rule files must deserialize successfully")
28});
29
30pub struct AstGrepEngine {
32 extension_index: HashMap<&'static str, usize>,
34}
35
36impl AstGrepEngine {
37 pub fn new() -> Self {
39 let mut extension_index = HashMap::new();
40 for (i, lr) in LOADED_RULES.iter().enumerate() {
41 for &ext in lr.extensions {
42 extension_index.insert(ext, i);
43 }
44 }
45 Self { extension_index }
46 }
47
48 pub fn find_language(&self, ext: &str) -> Option<&LanguageRules> {
50 self.extension_index.get(ext).map(|&idx| &LOADED_RULES[idx])
51 }
52
53 pub fn supports_extension(&self, ext: &str) -> bool {
55 self.extension_index.contains_key(ext)
56 }
57
58 pub fn extract_symbols(
59 &self,
60 lang: &LanguageRules,
61 source: &str,
62 file_path: &str,
63 ) -> Vec<Symbol> {
64 let root = lang.lang.ast_grep(source);
65 self.extract_symbols_from_tree(lang, &root, source, file_path)
66 }
67
68 pub fn extract_symbols_from_tree(
70 &self,
71 lang: &LanguageRules,
72 root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
73 source: &str,
74 file_path: &str,
75 ) -> Vec<Symbol> {
76 let root_node = root.root();
77 let mut symbols = Vec::new();
78 let mut scope = Vec::new();
79 self.extract_symbols_recursive(
80 lang,
81 &root_node,
82 source,
83 file_path,
84 &mut scope,
85 false,
86 &mut symbols,
87 );
88 symbols
89 }
90
91 pub fn extract_references(
94 &self,
95 lang: &LanguageRules,
96 source: &str,
97 file_path: &str,
98 ) -> Vec<Reference> {
99 let root = lang.lang.ast_grep(source);
100 self.extract_references_from_tree(lang, &root, source, file_path)
101 }
102
103 pub fn extract_references_from_tree(
105 &self,
106 lang: &LanguageRules,
107 root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
108 source: &str,
109 file_path: &str,
110 ) -> Vec<Reference> {
111 let root_node = root.root();
112 let mut references = Vec::new();
113 let mut scope = Vec::new();
114 self.extract_references_recursive(
115 lang,
116 &root_node,
117 source,
118 file_path,
119 &mut scope,
120 &mut references,
121 );
122
123 let mut seen = HashSet::new();
125 references.retain(|r| {
126 seen.insert((
127 r.source_qualified_name.clone(),
128 r.target_name.clone(),
129 r.kind,
130 ))
131 });
132
133 references
134 }
135
136 #[allow(clippy::too_many_arguments)]
139 fn extract_symbols_recursive<D: Doc>(
140 &self,
141 lang: &LanguageRules,
142 node: &Node<'_, D>,
143 source: &str,
144 file_path: &str,
145 scope: &mut Vec<String>,
146 in_method_scope: bool,
147 symbols: &mut Vec<Symbol>,
148 ) where
149 D::Lang: ast_grep_core::Language,
150 {
151 let kind: Cow<'_, str> = node.kind();
152 let kind_str = kind.as_ref();
153
154 if lang.symbol_unwrap_set.contains(kind_str) {
156 for child in node.children() {
157 self.extract_symbols_recursive(
158 lang,
159 &child,
160 source,
161 file_path,
162 scope,
163 in_method_scope,
164 symbols,
165 );
166 }
167 return;
168 }
169
170 let handled_as_scope_container = lang.symbol_scope_index.contains_key(kind_str);
172 if let Some(&sc_idx) = lang.symbol_scope_index.get(kind_str) {
173 let sc = &lang.symbol_scope_containers[sc_idx];
174 if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
175 scope.push(scope_name);
176 let new_method_scope = sc.is_method_scope;
177
178 if let Some(body) = self.get_scope_body(sc, node) {
180 for child in body.children() {
181 self.extract_symbols_recursive(
182 lang,
183 &child,
184 source,
185 file_path,
186 scope,
187 new_method_scope,
188 symbols,
189 );
190 }
191 } else {
192 for child in node.children() {
194 self.extract_symbols_recursive(
195 lang,
196 &child,
197 source,
198 file_path,
199 scope,
200 new_method_scope,
201 symbols,
202 );
203 }
204 }
205 scope.pop();
206 }
209 }
210
211 if let Some(rule_indices) = lang.symbol_index.get(kind_str) {
213 for &rule_idx in rule_indices {
214 let rule = &lang.symbol_rules[rule_idx];
215
216 if let Some(ref special) = rule.special {
218 let multi = self
219 .handle_special_symbol_multi(lang, special, node, source, file_path, scope);
220 if !multi.is_empty() {
221 symbols.extend(multi);
222 return;
223 }
224 }
225
226 if let Some(sym) = self.extract_symbol_from_rule(
227 lang,
228 rule,
229 node,
230 source,
231 file_path,
232 scope,
233 in_method_scope,
234 ) {
235 let name = sym.name.clone();
236 let is_scope = rule.is_scope;
237 symbols.push(sym);
238
239 if is_scope && !handled_as_scope_container {
242 scope.push(name);
243 if let Some(body) = node.field("body") {
244 for child in body.children() {
245 self.extract_symbols_recursive(
246 lang,
247 &child,
248 source,
249 file_path,
250 scope,
251 in_method_scope,
252 symbols,
253 );
254 }
255 }
256 scope.pop();
257 }
258 return; }
260 }
261 }
262
263 if handled_as_scope_container {
267 return; }
269
270 for child in node.children() {
272 self.extract_symbols_recursive(
273 lang,
274 &child,
275 source,
276 file_path,
277 scope,
278 in_method_scope,
279 symbols,
280 );
281 }
282 }
283
284 #[allow(clippy::too_many_arguments)]
285 fn extract_symbol_from_rule<D: Doc>(
286 &self,
287 lang: &LanguageRules,
288 rule: &SymbolRule,
289 node: &Node<'_, D>,
290 source: &str,
291 file_path: &str,
292 scope: &[String],
293 in_method_scope: bool,
294 ) -> Option<Symbol>
295 where
296 D::Lang: ast_grep_core::Language,
297 {
298 if let Some(ref special) = rule.special {
300 let mut sym =
301 self.handle_special_symbol(lang, special, node, source, file_path, scope)?;
302 self.enrich_symbol_metadata(lang.name, node, &mut sym);
303 return Some(sym);
304 }
305
306 let name = self.get_node_field_text(node, &rule.name_field)?;
308 if name.is_empty() {
309 return None;
310 }
311
312 let base_kind = parse_symbol_kind(&rule.symbol_kind)?;
314 let is_test = self.detect_test(lang.name, node, source, file_path, &name);
315 let kind = if is_test {
316 SymbolKind::Test
317 } else if rule.method_when_scoped && in_method_scope {
318 SymbolKind::Method
319 } else {
320 base_kind
321 };
322
323 let visibility = self.detect_visibility(lang.name, node, source, &name);
324 let signature = self.extract_signature(lang.name, node, source);
325 let doc_comment = self.extract_doc_comment(lang.name, node, source);
326
327 let mut sym = build_symbol(
328 name,
329 kind,
330 signature,
331 visibility,
332 doc_comment,
333 file_path,
334 node.start_pos().line(),
335 node.end_pos().line(),
336 scope,
337 lang.scope_separator,
338 );
339 self.enrich_symbol_metadata(lang.name, node, &mut sym);
340 Some(sym)
341 }
342
343 fn extract_references_recursive<D: Doc>(
346 &self,
347 lang: &LanguageRules,
348 node: &Node<'_, D>,
349 source: &str,
350 file_path: &str,
351 scope: &mut Vec<String>,
352 references: &mut Vec<Reference>,
353 ) where
354 D::Lang: ast_grep_core::Language,
355 {
356 let kind: Cow<'_, str> = node.kind();
357 let kind_str = kind.as_ref();
358
359 if lang.reference_unwrap_set.contains(kind_str) {
361 for child in node.children() {
362 self.extract_references_recursive(
363 lang, &child, source, file_path, scope, references,
364 );
365 }
366 return;
367 }
368
369 if let Some(&sc_idx) = lang.reference_scope_index.get(kind_str) {
371 let sc = &lang.reference_scope_containers[sc_idx];
372 if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
373 if let Some(rule_indices) = lang.reference_index.get(kind_str) {
375 for &rule_idx in rule_indices {
376 let rule = &lang.reference_rules[rule_idx];
377 self.extract_reference_from_rule(
378 lang, rule, node, source, file_path, scope, references,
379 );
380 }
381 }
382
383 scope.push(scope_name);
384
385 if let Some(body) = self.get_scope_body(sc, node) {
386 for child in body.children() {
387 self.extract_references_recursive(
388 lang, &child, source, file_path, scope, references,
389 );
390 }
391 } else {
392 for child in node.children() {
393 self.extract_references_recursive(
394 lang, &child, source, file_path, scope, references,
395 );
396 }
397 }
398 scope.pop();
399 return;
400 }
401 }
402
403 if let Some(rule_indices) = lang.reference_index.get(kind_str) {
405 for &rule_idx in rule_indices {
406 let rule = &lang.reference_rules[rule_idx];
407 self.extract_reference_from_rule(
408 lang, rule, node, source, file_path, scope, references,
409 );
410 }
411 }
412
413 for child in node.children() {
415 self.extract_references_recursive(lang, &child, source, file_path, scope, references);
416 }
417 }
418
419 #[allow(clippy::too_many_arguments)]
420 fn extract_reference_from_rule<D: Doc>(
421 &self,
422 lang: &LanguageRules,
423 rule: &ReferenceRule,
424 node: &Node<'_, D>,
425 source: &str,
426 file_path: &str,
427 scope: &[String],
428 references: &mut Vec<Reference>,
429 ) where
430 D::Lang: ast_grep_core::Language,
431 {
432 if let Some(ref special) = rule.special {
434 self.handle_special_reference(
435 lang, special, node, source, file_path, scope, references,
436 );
437 return;
438 }
439
440 let ref_kind = match parse_reference_kind(&rule.reference_kind) {
441 Some(k) => k,
442 None => return,
443 };
444
445 let target_name = if let Some(ref field) = rule.name_field {
446 match self.get_node_field_text(node, field) {
447 Some(name) => name,
448 None => return,
449 }
450 } else {
451 let text = node.text();
453 text.trim().to_string()
454 };
455
456 if target_name.is_empty() {
457 return;
458 }
459
460 let source_qn = if scope.is_empty() {
461 file_path.to_string()
462 } else {
463 scope.join(lang.scope_separator)
464 };
465
466 push_ref(
467 references,
468 &source_qn,
469 target_name,
470 ref_kind,
471 file_path,
472 node.start_pos().line(),
473 );
474 }
475
476 pub(crate) fn get_node_field_text<D: Doc>(
479 &self,
480 node: &Node<'_, D>,
481 field_name: &str,
482 ) -> Option<String>
483 where
484 D::Lang: ast_grep_core::Language,
485 {
486 node.field(field_name).map(|n| n.text().to_string())
487 }
488
489 fn get_scope_name<D: Doc>(
490 &self,
491 lang: &LanguageRules,
492 sc: &ScopeContainerRule,
493 node: &Node<'_, D>,
494 source: &str,
495 ) -> Option<String>
496 where
497 D::Lang: ast_grep_core::Language,
498 {
499 if let Some(ref special) = sc.special {
500 return self.get_special_scope_name(lang, special, node, source);
501 }
502 self.get_node_field_text(node, &sc.name_field)
503 }
504
505 fn get_scope_body<'a, D: Doc>(
506 &self,
507 sc: &ScopeContainerRule,
508 node: &Node<'a, D>,
509 ) -> Option<Node<'a, D>>
510 where
511 D::Lang: ast_grep_core::Language,
512 {
513 if let Some(ref special) = sc.special {
514 return self.get_special_scope_body(special, node);
515 }
516 node.field(&sc.body_field)
517 }
518
519 fn get_special_scope_name<D: Doc>(
520 &self,
521 _lang: &LanguageRules,
522 special: &str,
523 node: &Node<'_, D>,
524 _source: &str,
525 ) -> Option<String>
526 where
527 D::Lang: ast_grep_core::Language,
528 {
529 match special {
530 "go_method_scope" => {
531 self.get_go_receiver_type(node)
533 }
534 "hcl_block_scope" => {
535 let mut parts = Vec::new();
537 for child in node.children() {
538 let ck = child.kind();
539 if ck.as_ref() == "identifier" && parts.is_empty() {
540 parts.push(child.text().to_string());
541 } else if ck.as_ref() == "string_lit" {
542 parts.push(child.text().to_string().trim_matches('"').to_string());
543 }
544 }
545 if parts.is_empty() {
546 None
547 } else {
548 Some(parts.join("."))
549 }
550 }
551 "kotlin_scope" => self.get_node_field_text(node, "name").or_else(|| {
552 for child in node.children() {
553 let ck = child.kind();
554 if ck.as_ref() == "type_identifier" || ck.as_ref() == "simple_identifier" {
555 return Some(child.text().to_string());
556 }
557 }
558 None
559 }),
560 "swift_class_scope" => {
561 self.get_node_field_text(node, "name").or_else(|| {
563 node.children()
564 .find(|c| {
565 let ck = c.kind();
566 ck.as_ref() == "type_identifier" || ck.as_ref() == "identifier"
567 })
568 .map(|c| c.text().to_string())
569 })
570 }
571 "cpp_namespace_scope" => {
572 self.get_node_field_text(node, "name").or_else(|| {
574 node.children()
575 .find(|c| {
576 let ck = c.kind();
577 ck.as_ref() == "namespace_identifier" || ck.as_ref() == "identifier"
578 })
579 .map(|c| c.text().to_string())
580 })
581 }
582 _ => None,
583 }
584 }
585
586 fn get_special_scope_body<'a, D: Doc>(
587 &self,
588 _special: &str,
589 node: &Node<'a, D>,
590 ) -> Option<Node<'a, D>>
591 where
592 D::Lang: ast_grep_core::Language,
593 {
594 node.field("body")
595 }
596}
597
598impl Default for AstGrepEngine {
599 fn default() -> Self {
600 Self::new()
601 }
602}
603
604#[allow(clippy::too_many_arguments)]
608fn build_symbol(
609 name: String,
610 kind: SymbolKind,
611 signature: String,
612 visibility: Visibility,
613 doc_comment: Option<String>,
614 file_path: &str,
615 line_start: usize,
616 line_end: usize,
617 scope: &[String],
618 scope_separator: &str,
619) -> Symbol {
620 Symbol {
621 qualified_name: build_qualified_name(scope, &name, scope_separator),
622 name,
623 kind,
624 signature,
625 visibility,
626 file_path: file_path.to_string(),
627 line_start,
628 line_end,
629 doc_comment,
630 parent: if scope.is_empty() {
631 None
632 } else {
633 Some(scope.join(scope_separator))
634 },
635 parameters: Vec::new(),
636 return_type: None,
637 is_async: false,
638 attributes: Vec::new(),
639 throws: Vec::new(),
640 generic_params: None,
641 is_abstract: false,
642 }
643}
644
645fn push_ref(
647 refs: &mut Vec<Reference>,
648 source_qn: &str,
649 target: String,
650 kind: ReferenceKind,
651 file_path: &str,
652 line: usize,
653) {
654 refs.push(Reference {
655 source_qualified_name: source_qn.to_string(),
656 target_name: target,
657 kind,
658 file_path: file_path.to_string(),
659 line,
660 });
661}
662
663fn build_qualified_name(scope: &[String], name: &str, separator: &str) -> String {
664 if scope.is_empty() {
665 name.to_string()
666 } else {
667 format!("{}{}{}", scope.join(separator), separator, name)
668 }
669}
670
671fn parse_symbol_kind(s: &str) -> Option<SymbolKind> {
672 match s {
673 "function" => Some(SymbolKind::Function),
674 "method" => Some(SymbolKind::Method),
675 "class" => Some(SymbolKind::Class),
676 "struct" => Some(SymbolKind::Struct),
677 "enum" => Some(SymbolKind::Enum),
678 "interface" => Some(SymbolKind::Interface),
679 "type" => Some(SymbolKind::Type),
680 "constant" => Some(SymbolKind::Constant),
681 "module" => Some(SymbolKind::Module),
682 "test" => Some(SymbolKind::Test),
683 "field" => Some(SymbolKind::Field),
684 "constructor" => Some(SymbolKind::Constructor),
685 _ => None,
686 }
687}
688
689fn parse_reference_kind(s: &str) -> Option<ReferenceKind> {
690 match s {
691 "import" => Some(ReferenceKind::Import),
692 "call" => Some(ReferenceKind::Call),
693 "inherits" => Some(ReferenceKind::Inherits),
694 "implements" => Some(ReferenceKind::Implements),
695 "type_usage" => Some(ReferenceKind::TypeUsage),
696 _ => None,
697 }
698}
699
700fn clean_block_doc_comment(text: &str) -> String {
701 let trimmed = text.trim_start_matches("/**").trim_end_matches("*/").trim();
702
703 let mut doc_lines = Vec::new();
704 for line in trimmed.lines() {
705 let line = line.trim();
706 let line = line
707 .strip_prefix("* ")
708 .or_else(|| line.strip_prefix('*'))
709 .unwrap_or(line);
710 let line = line.trim_end();
711 doc_lines.push(line.to_string());
713 }
714 doc_lines.join("\n").trim().to_string()
715}
716
717#[cfg(test)]
718#[path = "../tests/engine_symbols_tests.rs"]
719mod engine_symbols_tests;
720
721#[cfg(test)]
722#[path = "../tests/engine_references_tests.rs"]
723mod engine_references_tests;
724
725#[cfg(test)]
726#[path = "../tests/engine_cross_cutting_tests.rs"]
727mod engine_cross_cutting_tests;