1mod references;
8mod symbols;
9mod visibility;
10
11use crate::index::rule_loader::{LanguageRules, ReferenceRule, ScopeContainerRule, SymbolRule};
12use crate::index::symbol::{Reference, ReferenceKind, Symbol, SymbolKind, Visibility};
13use ast_grep_core::tree_sitter::LanguageExt;
14use ast_grep_core::tree_sitter::StrDoc;
15use ast_grep_core::{Doc, Node};
16use ast_grep_language::SupportLang;
17use std::borrow::Cow;
18use std::collections::{HashMap, HashSet};
19use std::sync::LazyLock;
20
21pub type SgNode<'r> = Node<'r, StrDoc<SupportLang>>;
23
24static LOADED_RULES: LazyLock<Vec<LanguageRules>> =
26 LazyLock::new(crate::index::rule_loader::load_all_rules);
27
28pub struct AstGrepEngine {
30 extension_index: HashMap<&'static str, usize>,
32}
33
34impl AstGrepEngine {
35 pub fn new() -> Self {
37 let mut extension_index = HashMap::new();
38 for (i, lr) in LOADED_RULES.iter().enumerate() {
39 for &ext in lr.extensions {
40 extension_index.insert(ext, i);
41 }
42 }
43 Self { extension_index }
44 }
45
46 pub fn find_language(&self, ext: &str) -> Option<&LanguageRules> {
48 self.extension_index.get(ext).map(|&idx| &LOADED_RULES[idx])
49 }
50
51 pub fn supported_extensions(&self) -> Vec<&str> {
53 LOADED_RULES
54 .iter()
55 .flat_map(|lr| lr.extensions.iter().copied())
56 .collect()
57 }
58
59 pub fn supports_extension(&self, ext: &str) -> bool {
61 self.extension_index.contains_key(ext)
62 }
63
64 pub fn language_name(&self, ext: &str) -> Option<&str> {
66 self.find_language(ext).map(|lr| lr.name)
67 }
68
69 pub fn extract_symbols(
70 &self,
71 lang: &LanguageRules,
72 source: &str,
73 file_path: &str,
74 ) -> Vec<Symbol> {
75 let root = lang.lang.ast_grep(source);
76 self.extract_symbols_from_tree(lang, &root, source, file_path)
77 }
78
79 pub fn extract_symbols_from_tree(
81 &self,
82 lang: &LanguageRules,
83 root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
84 source: &str,
85 file_path: &str,
86 ) -> Vec<Symbol> {
87 let root_node = root.root();
88 let mut symbols = Vec::new();
89 let mut scope = Vec::new();
90 self.extract_symbols_recursive(
91 lang,
92 &root_node,
93 source,
94 file_path,
95 &mut scope,
96 false,
97 &mut symbols,
98 );
99 symbols
100 }
101
102 pub fn extract_references(
105 &self,
106 lang: &LanguageRules,
107 source: &str,
108 file_path: &str,
109 ) -> Vec<Reference> {
110 let root = lang.lang.ast_grep(source);
111 self.extract_references_from_tree(lang, &root, source, file_path)
112 }
113
114 pub fn extract_references_from_tree(
116 &self,
117 lang: &LanguageRules,
118 root: &ast_grep_core::AstGrep<StrDoc<SupportLang>>,
119 source: &str,
120 file_path: &str,
121 ) -> Vec<Reference> {
122 let root_node = root.root();
123 let mut references = Vec::new();
124 let mut scope = Vec::new();
125 self.extract_references_recursive(
126 lang,
127 &root_node,
128 source,
129 file_path,
130 &mut scope,
131 &mut references,
132 );
133
134 let mut seen = HashSet::new();
136 references.retain(|r| {
137 seen.insert((
138 r.source_qualified_name.clone(),
139 r.target_name.clone(),
140 r.kind,
141 ))
142 });
143
144 references
145 }
146
147 #[allow(clippy::too_many_arguments)]
150 fn extract_symbols_recursive<D: Doc>(
151 &self,
152 lang: &LanguageRules,
153 node: &Node<'_, D>,
154 source: &str,
155 file_path: &str,
156 scope: &mut Vec<String>,
157 in_method_scope: bool,
158 symbols: &mut Vec<Symbol>,
159 ) where
160 D::Lang: ast_grep_core::Language,
161 {
162 let kind: Cow<'_, str> = node.kind();
163 let kind_str = kind.as_ref();
164
165 if lang.symbol_unwrap_set.contains(kind_str) {
167 for child in node.children() {
168 self.extract_symbols_recursive(
169 lang,
170 &child,
171 source,
172 file_path,
173 scope,
174 in_method_scope,
175 symbols,
176 );
177 }
178 return;
179 }
180
181 let handled_as_scope_container = lang.symbol_scope_index.contains_key(kind_str);
183 if let Some(&sc_idx) = lang.symbol_scope_index.get(kind_str) {
184 let sc = &lang.symbol_scope_containers[sc_idx];
185 if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
186 scope.push(scope_name);
187 let new_method_scope = sc.is_method_scope;
188
189 if let Some(body) = self.get_scope_body(sc, node) {
191 for child in body.children() {
192 self.extract_symbols_recursive(
193 lang,
194 &child,
195 source,
196 file_path,
197 scope,
198 new_method_scope,
199 symbols,
200 );
201 }
202 } else {
203 for child in node.children() {
205 self.extract_symbols_recursive(
206 lang,
207 &child,
208 source,
209 file_path,
210 scope,
211 new_method_scope,
212 symbols,
213 );
214 }
215 }
216 scope.pop();
217 }
220 }
221
222 if let Some(rule_indices) = lang.symbol_index.get(kind_str) {
224 for &rule_idx in rule_indices {
225 let rule = &lang.symbol_rules[rule_idx];
226
227 if let Some(ref special) = rule.special {
229 let multi = self
230 .handle_special_symbol_multi(lang, special, node, source, file_path, scope);
231 if !multi.is_empty() {
232 symbols.extend(multi);
233 return;
234 }
235 }
236
237 if let Some(sym) = self.extract_symbol_from_rule(
238 lang,
239 rule,
240 node,
241 source,
242 file_path,
243 scope,
244 in_method_scope,
245 ) {
246 let name = sym.name.clone();
247 let is_scope = rule.is_scope;
248 symbols.push(sym);
249
250 if is_scope && !handled_as_scope_container {
253 scope.push(name);
254 if let Some(body) = node.field("body") {
255 for child in body.children() {
256 self.extract_symbols_recursive(
257 lang,
258 &child,
259 source,
260 file_path,
261 scope,
262 in_method_scope,
263 symbols,
264 );
265 }
266 }
267 scope.pop();
268 }
269 return; }
271 }
272 }
273
274 if handled_as_scope_container {
278 return; }
280
281 for child in node.children() {
283 self.extract_symbols_recursive(
284 lang,
285 &child,
286 source,
287 file_path,
288 scope,
289 in_method_scope,
290 symbols,
291 );
292 }
293 }
294
295 #[allow(clippy::too_many_arguments)]
296 fn extract_symbol_from_rule<D: Doc>(
297 &self,
298 lang: &LanguageRules,
299 rule: &SymbolRule,
300 node: &Node<'_, D>,
301 source: &str,
302 file_path: &str,
303 scope: &[String],
304 in_method_scope: bool,
305 ) -> Option<Symbol>
306 where
307 D::Lang: ast_grep_core::Language,
308 {
309 if let Some(ref special) = rule.special {
311 let mut sym =
312 self.handle_special_symbol(lang, special, node, source, file_path, scope)?;
313 self.enrich_symbol_metadata(lang.name, node, &mut sym);
314 return Some(sym);
315 }
316
317 let name = self.get_node_field_text(node, &rule.name_field)?;
319 if name.is_empty() {
320 return None;
321 }
322
323 let base_kind = parse_symbol_kind(&rule.symbol_kind)?;
325 let is_test = self.detect_test(lang.name, node, source, file_path, &name);
326 let kind = if is_test {
327 SymbolKind::Test
328 } else if rule.method_when_scoped && in_method_scope {
329 SymbolKind::Method
330 } else {
331 base_kind
332 };
333
334 let visibility = self.detect_visibility(lang.name, node, source, &name);
335 let signature = self.extract_signature(lang.name, node, source);
336 let doc_comment = self.extract_doc_comment(lang.name, node, source);
337
338 let mut sym = build_symbol(
339 name,
340 kind,
341 signature,
342 visibility,
343 doc_comment,
344 file_path,
345 node.start_pos().line(),
346 node.end_pos().line(),
347 scope,
348 lang.scope_separator,
349 );
350 self.enrich_symbol_metadata(lang.name, node, &mut sym);
351 Some(sym)
352 }
353
354 fn extract_references_recursive<D: Doc>(
357 &self,
358 lang: &LanguageRules,
359 node: &Node<'_, D>,
360 source: &str,
361 file_path: &str,
362 scope: &mut Vec<String>,
363 references: &mut Vec<Reference>,
364 ) where
365 D::Lang: ast_grep_core::Language,
366 {
367 let kind: Cow<'_, str> = node.kind();
368 let kind_str = kind.as_ref();
369
370 if lang.reference_unwrap_set.contains(kind_str) {
372 for child in node.children() {
373 self.extract_references_recursive(
374 lang, &child, source, file_path, scope, references,
375 );
376 }
377 return;
378 }
379
380 if let Some(&sc_idx) = lang.reference_scope_index.get(kind_str) {
382 let sc = &lang.reference_scope_containers[sc_idx];
383 if let Some(scope_name) = self.get_scope_name(lang, sc, node, source) {
384 if let Some(rule_indices) = lang.reference_index.get(kind_str) {
386 for &rule_idx in rule_indices {
387 let rule = &lang.reference_rules[rule_idx];
388 self.extract_reference_from_rule(
389 lang, rule, node, source, file_path, scope, references,
390 );
391 }
392 }
393
394 scope.push(scope_name);
395
396 if let Some(body) = self.get_scope_body(sc, node) {
397 for child in body.children() {
398 self.extract_references_recursive(
399 lang, &child, source, file_path, scope, references,
400 );
401 }
402 } else {
403 for child in node.children() {
404 self.extract_references_recursive(
405 lang, &child, source, file_path, scope, references,
406 );
407 }
408 }
409 scope.pop();
410 return;
411 }
412 }
413
414 if let Some(rule_indices) = lang.reference_index.get(kind_str) {
416 for &rule_idx in rule_indices {
417 let rule = &lang.reference_rules[rule_idx];
418 self.extract_reference_from_rule(
419 lang, rule, node, source, file_path, scope, references,
420 );
421 }
422 }
423
424 for child in node.children() {
426 self.extract_references_recursive(lang, &child, source, file_path, scope, references);
427 }
428 }
429
430 #[allow(clippy::too_many_arguments)]
431 fn extract_reference_from_rule<D: Doc>(
432 &self,
433 lang: &LanguageRules,
434 rule: &ReferenceRule,
435 node: &Node<'_, D>,
436 source: &str,
437 file_path: &str,
438 scope: &[String],
439 references: &mut Vec<Reference>,
440 ) where
441 D::Lang: ast_grep_core::Language,
442 {
443 if let Some(ref special) = rule.special {
445 self.handle_special_reference(
446 lang, special, node, source, file_path, scope, references,
447 );
448 return;
449 }
450
451 let ref_kind = match parse_reference_kind(&rule.reference_kind) {
452 Some(k) => k,
453 None => return,
454 };
455
456 let target_name = if let Some(ref field) = rule.name_field {
457 match self.get_node_field_text(node, field) {
458 Some(name) => name,
459 None => return,
460 }
461 } else {
462 let text = node.text();
464 text.trim().to_string()
465 };
466
467 if target_name.is_empty() {
468 return;
469 }
470
471 let source_qn = if scope.is_empty() {
472 file_path.to_string()
473 } else {
474 scope.join(lang.scope_separator)
475 };
476
477 push_ref(
478 references,
479 &source_qn,
480 target_name,
481 ref_kind,
482 file_path,
483 node.start_pos().line(),
484 );
485 }
486
487 pub(crate) fn get_node_field_text<D: Doc>(
490 &self,
491 node: &Node<'_, D>,
492 field_name: &str,
493 ) -> Option<String>
494 where
495 D::Lang: ast_grep_core::Language,
496 {
497 node.field(field_name).map(|n| n.text().to_string())
498 }
499
500 fn get_scope_name<D: Doc>(
501 &self,
502 lang: &LanguageRules,
503 sc: &ScopeContainerRule,
504 node: &Node<'_, D>,
505 source: &str,
506 ) -> Option<String>
507 where
508 D::Lang: ast_grep_core::Language,
509 {
510 if let Some(ref special) = sc.special {
511 return self.get_special_scope_name(lang, special, node, source);
512 }
513 self.get_node_field_text(node, &sc.name_field)
514 }
515
516 fn get_scope_body<'a, D: Doc>(
517 &self,
518 sc: &ScopeContainerRule,
519 node: &Node<'a, D>,
520 ) -> Option<Node<'a, D>>
521 where
522 D::Lang: ast_grep_core::Language,
523 {
524 if let Some(ref special) = sc.special {
525 return self.get_special_scope_body(special, node);
526 }
527 node.field(&sc.body_field)
528 }
529
530 fn get_special_scope_name<D: Doc>(
531 &self,
532 _lang: &LanguageRules,
533 special: &str,
534 node: &Node<'_, D>,
535 _source: &str,
536 ) -> Option<String>
537 where
538 D::Lang: ast_grep_core::Language,
539 {
540 match special {
541 "go_method_scope" => {
542 self.get_go_receiver_type(node)
544 }
545 "hcl_block_scope" => {
546 let mut parts = Vec::new();
548 for child in node.children() {
549 let ck = child.kind();
550 if ck.as_ref() == "identifier" && parts.is_empty() {
551 parts.push(child.text().to_string());
552 } else if ck.as_ref() == "string_lit" {
553 parts.push(child.text().to_string().trim_matches('"').to_string());
554 }
555 }
556 if parts.is_empty() {
557 None
558 } else {
559 Some(parts.join("."))
560 }
561 }
562 "kotlin_scope" => self.get_node_field_text(node, "name").or_else(|| {
563 for child in node.children() {
564 let ck = child.kind();
565 if ck.as_ref() == "type_identifier" || ck.as_ref() == "simple_identifier" {
566 return Some(child.text().to_string());
567 }
568 }
569 None
570 }),
571 "swift_class_scope" => {
572 self.get_node_field_text(node, "name").or_else(|| {
574 node.children()
575 .find(|c| {
576 let ck = c.kind();
577 ck.as_ref() == "type_identifier" || ck.as_ref() == "identifier"
578 })
579 .map(|c| c.text().to_string())
580 })
581 }
582 "cpp_namespace_scope" => {
583 self.get_node_field_text(node, "name").or_else(|| {
585 node.children()
586 .find(|c| {
587 let ck = c.kind();
588 ck.as_ref() == "namespace_identifier" || ck.as_ref() == "identifier"
589 })
590 .map(|c| c.text().to_string())
591 })
592 }
593 _ => None,
594 }
595 }
596
597 fn get_special_scope_body<'a, D: Doc>(
598 &self,
599 _special: &str,
600 node: &Node<'a, D>,
601 ) -> Option<Node<'a, D>>
602 where
603 D::Lang: ast_grep_core::Language,
604 {
605 node.field("body")
606 }
607}
608
609impl Default for AstGrepEngine {
610 fn default() -> Self {
611 Self::new()
612 }
613}
614
615#[allow(clippy::too_many_arguments)]
619fn build_symbol(
620 name: String,
621 kind: SymbolKind,
622 signature: String,
623 visibility: Visibility,
624 doc_comment: Option<String>,
625 file_path: &str,
626 line_start: usize,
627 line_end: usize,
628 scope: &[String],
629 scope_separator: &str,
630) -> Symbol {
631 Symbol {
632 qualified_name: build_qualified_name(scope, &name, scope_separator),
633 name,
634 kind,
635 signature,
636 visibility,
637 file_path: file_path.to_string(),
638 line_start,
639 line_end,
640 doc_comment,
641 parent: if scope.is_empty() {
642 None
643 } else {
644 Some(scope.join(scope_separator))
645 },
646 parameters: Vec::new(),
647 return_type: None,
648 is_async: false,
649 attributes: Vec::new(),
650 throws: Vec::new(),
651 generic_params: None,
652 is_abstract: false,
653 }
654}
655
656fn push_ref(
658 refs: &mut Vec<Reference>,
659 source_qn: &str,
660 target: String,
661 kind: ReferenceKind,
662 file_path: &str,
663 line: usize,
664) {
665 refs.push(Reference {
666 source_qualified_name: source_qn.to_string(),
667 target_name: target,
668 kind,
669 file_path: file_path.to_string(),
670 line,
671 });
672}
673
674fn build_qualified_name(scope: &[String], name: &str, separator: &str) -> String {
675 if scope.is_empty() {
676 name.to_string()
677 } else {
678 format!("{}{}{}", scope.join(separator), separator, name)
679 }
680}
681
682fn parse_symbol_kind(s: &str) -> Option<SymbolKind> {
683 match s {
684 "function" => Some(SymbolKind::Function),
685 "method" => Some(SymbolKind::Method),
686 "class" => Some(SymbolKind::Class),
687 "struct" => Some(SymbolKind::Struct),
688 "enum" => Some(SymbolKind::Enum),
689 "interface" => Some(SymbolKind::Interface),
690 "type" => Some(SymbolKind::Type),
691 "constant" => Some(SymbolKind::Constant),
692 "module" => Some(SymbolKind::Module),
693 "test" => Some(SymbolKind::Test),
694 "field" => Some(SymbolKind::Field),
695 "property" => Some(SymbolKind::Property),
696 "constructor" => Some(SymbolKind::Constructor),
697 "enum_variant" => Some(SymbolKind::EnumVariant),
698 "macro" => Some(SymbolKind::Macro),
699 "decorator" => Some(SymbolKind::Decorator),
700 _ => None,
701 }
702}
703
704fn parse_reference_kind(s: &str) -> Option<ReferenceKind> {
705 match s {
706 "import" => Some(ReferenceKind::Import),
707 "call" => Some(ReferenceKind::Call),
708 "inherits" => Some(ReferenceKind::Inherits),
709 "implements" => Some(ReferenceKind::Implements),
710 "type_usage" => Some(ReferenceKind::TypeUsage),
711 _ => None,
712 }
713}
714
715fn clean_block_doc_comment(text: &str) -> String {
716 let trimmed = text.trim_start_matches("/**").trim_end_matches("*/").trim();
717
718 let mut doc_lines = Vec::new();
719 for line in trimmed.lines() {
720 let line = line.trim();
721 let line = line
722 .strip_prefix("* ")
723 .or_else(|| line.strip_prefix('*'))
724 .unwrap_or(line);
725 let line = line.trim_end();
726 doc_lines.push(line.to_string());
728 }
729 doc_lines.join("\n").trim().to_string()
730}
731
732#[cfg(test)]
733#[path = "../tests/engine_tests.rs"]
734mod tests;