1mod deserialize_env;
2mod nth_child;
3mod range;
4pub mod referent_rule;
5mod relational_rule;
6mod selector;
7mod stop_by;
8
9pub use deserialize_env::DeserializeEnv;
10pub use relational_rule::Relation;
11use selector::{parse_selector, SelectorError};
12pub use stop_by::StopBy;
13
14use crate::maybe::Maybe;
15use nth_child::{NthChild, NthChildError, SerializableNthChild};
16use range::{RangeMatcher, RangeMatcherError, SerializableRange};
17use referent_rule::{ReferentRule, ReferentRuleError};
18use relational_rule::{Follows, Has, Inside, Precedes};
19
20use ast_grep_core::language::Language;
21use ast_grep_core::matcher::{KindMatcher, RegexMatcher, RegexMatcherError};
22use ast_grep_core::meta_var::MetaVarEnv;
23use ast_grep_core::{ops as o, Doc, Node};
24use ast_grep_core::{MatchStrictness, Matcher, Pattern, PatternError};
25
26use bit_set::BitSet;
27use schemars::JsonSchema;
28use serde::{Deserialize, Serialize};
29use std::borrow::Cow;
30use std::collections::HashSet;
31use thiserror::Error;
32
33#[derive(Serialize, Deserialize, Clone, Default, JsonSchema)]
43#[serde(deny_unknown_fields)]
44pub struct SerializableRule {
45 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
50 pub pattern: Maybe<PatternStyle>,
51 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
53 pub kind: Maybe<String>,
54 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
56 pub regex: Maybe<String>,
57 #[serde(default, skip_serializing_if = "Maybe::is_absent", rename = "nthChild")]
60 pub nth_child: Maybe<SerializableNthChild>,
61 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
64 pub range: Maybe<SerializableRange>,
65
66 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
70 pub inside: Maybe<Box<Relation>>,
71 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
74 pub has: Maybe<Box<Relation>>,
75 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
78 pub precedes: Maybe<Box<Relation>>,
79 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
82 pub follows: Maybe<Box<Relation>>,
83 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
87 pub all: Maybe<Vec<SerializableRule>>,
88 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
91 pub any: Maybe<Vec<SerializableRule>>,
92 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
93 pub not: Maybe<Box<SerializableRule>>,
95 #[serde(default, skip_serializing_if = "Maybe::is_absent")]
97 pub matches: Maybe<String>,
98}
99
100struct Categorized {
101 pub atomic: AtomicRule,
102 pub relational: RelationalRule,
103 pub composite: CompositeRule,
104}
105
106impl SerializableRule {
107 fn categorized(self) -> Categorized {
108 Categorized {
109 atomic: AtomicRule {
110 pattern: self.pattern.into(),
111 kind: self.kind.into(),
112 regex: self.regex.into(),
113 nth_child: self.nth_child.into(),
114 range: self.range.into(),
115 },
116 relational: RelationalRule {
117 inside: self.inside.into(),
118 has: self.has.into(),
119 precedes: self.precedes.into(),
120 follows: self.follows.into(),
121 },
122 composite: CompositeRule {
123 all: self.all.into(),
124 any: self.any.into(),
125 not: self.not.into(),
126 matches: self.matches.into(),
127 },
128 }
129 }
130}
131
132pub struct AtomicRule {
133 pub pattern: Option<PatternStyle>,
134 pub kind: Option<String>,
135 pub regex: Option<String>,
136 pub nth_child: Option<SerializableNthChild>,
137 pub range: Option<SerializableRange>,
138}
139#[derive(Serialize, Deserialize, Clone, JsonSchema)]
140#[serde(rename_all = "camelCase")]
141pub enum Strictness {
142 Cst,
144 Smart,
146 Ast,
148 Relaxed,
150 Signature,
152 Template,
154}
155
156impl From<MatchStrictness> for Strictness {
157 fn from(value: MatchStrictness) -> Self {
158 use MatchStrictness as M;
159 use Strictness as S;
160 match value {
161 M::Cst => S::Cst,
162 M::Smart => S::Smart,
163 M::Ast => S::Ast,
164 M::Relaxed => S::Relaxed,
165 M::Signature => S::Signature,
166 M::Template => S::Template,
167 }
168 }
169}
170
171impl From<Strictness> for MatchStrictness {
172 fn from(value: Strictness) -> Self {
173 use MatchStrictness as M;
174 use Strictness as S;
175 match value {
176 S::Cst => M::Cst,
177 S::Smart => M::Smart,
178 S::Ast => M::Ast,
179 S::Relaxed => M::Relaxed,
180 S::Signature => M::Signature,
181 S::Template => M::Template,
182 }
183 }
184}
185
186#[derive(Serialize, Deserialize, Clone, JsonSchema)]
189#[serde(untagged)]
190pub enum PatternStyle {
191 Str(String),
192 Contextual {
193 context: String,
195 selector: Option<String>,
197 strictness: Option<Strictness>,
199 },
200}
201
202pub struct RelationalRule {
203 pub inside: Option<Box<Relation>>,
204 pub has: Option<Box<Relation>>,
205 pub precedes: Option<Box<Relation>>,
206 pub follows: Option<Box<Relation>>,
207}
208
209pub struct CompositeRule {
210 pub all: Option<Vec<SerializableRule>>,
211 pub any: Option<Vec<SerializableRule>>,
212 pub not: Option<Box<SerializableRule>>,
213 pub matches: Option<String>,
214}
215
216pub enum Rule {
217 Pattern(Pattern),
219 Kind(KindMatcher),
220 Regex(RegexMatcher),
221 NthChild(NthChild),
222 Range(RangeMatcher),
223 Inside(Box<Inside>),
225 Has(Box<Has>),
226 Precedes(Box<Precedes>),
227 Follows(Box<Follows>),
228 All(o::All<Rule>),
230 Any(o::Any<Rule>),
231 Not(Box<o::Not<Rule>>),
232 Matches(ReferentRule),
233}
234impl Rule {
235 pub(crate) fn check_cyclic(&self, id: &str) -> bool {
237 match self {
238 Rule::All(all) => all.inner().iter().any(|r| r.check_cyclic(id)),
239 Rule::Any(any) => any.inner().iter().any(|r| r.check_cyclic(id)),
240 Rule::Not(not) => not.inner().check_cyclic(id),
241 Rule::Matches(m) => m.rule_id == id,
242 _ => false,
243 }
244 }
245
246 pub fn defined_vars(&self) -> HashSet<&str> {
247 match self {
248 Rule::Pattern(p) => p.defined_vars(),
249 Rule::Kind(_) => HashSet::new(),
250 Rule::Regex(_) => HashSet::new(),
251 Rule::NthChild(n) => n.defined_vars(),
252 Rule::Range(_) => HashSet::new(),
253 Rule::Has(c) => c.defined_vars(),
254 Rule::Inside(p) => p.defined_vars(),
255 Rule::Precedes(f) => f.defined_vars(),
256 Rule::Follows(f) => f.defined_vars(),
257 Rule::All(sub) => sub.inner().iter().flat_map(|r| r.defined_vars()).collect(),
258 Rule::Any(sub) => sub.inner().iter().flat_map(|r| r.defined_vars()).collect(),
259 Rule::Not(sub) => sub.inner().defined_vars(),
260 Rule::Matches(_r) => HashSet::new(),
262 }
263 }
264
265 pub fn verify_util(&self) -> Result<(), RuleSerializeError> {
267 match self {
268 Rule::Pattern(_) => Ok(()),
269 Rule::Kind(_) => Ok(()),
270 Rule::Regex(_) => Ok(()),
271 Rule::NthChild(n) => n.verify_util(),
272 Rule::Range(_) => Ok(()),
273 Rule::Has(c) => c.verify_util(),
274 Rule::Inside(p) => p.verify_util(),
275 Rule::Precedes(f) => f.verify_util(),
276 Rule::Follows(f) => f.verify_util(),
277 Rule::All(sub) => sub.inner().iter().try_for_each(|r| r.verify_util()),
278 Rule::Any(sub) => sub.inner().iter().try_for_each(|r| r.verify_util()),
279 Rule::Not(sub) => sub.inner().verify_util(),
280 Rule::Matches(r) => Ok(r.verify_util()?),
281 }
282 }
283}
284
285impl Matcher for Rule {
286 fn match_node_with_env<'tree, D: Doc>(
287 &self,
288 node: Node<'tree, D>,
289 env: &mut Cow<MetaVarEnv<'tree, D>>,
290 ) -> Option<Node<'tree, D>> {
291 use Rule::*;
292 match self {
293 Pattern(pattern) => pattern.match_node_with_env(node, env),
295 Kind(kind) => kind.match_node_with_env(node, env),
296 Regex(regex) => regex.match_node_with_env(node, env),
297 NthChild(nth_child) => nth_child.match_node_with_env(node, env),
298 Range(range) => range.match_node_with_env(node, env),
299 Inside(parent) => match_and_add_label(&**parent, node, env),
301 Has(child) => match_and_add_label(&**child, node, env),
302 Precedes(latter) => match_and_add_label(&**latter, node, env),
303 Follows(former) => match_and_add_label(&**former, node, env),
304 All(all) => all.match_node_with_env(node, env),
306 Any(any) => any.match_node_with_env(node, env),
307 Not(not) => not.match_node_with_env(node, env),
308 Matches(rule) => rule.match_node_with_env(node, env),
309 }
310 }
311
312 fn potential_kinds(&self) -> Option<BitSet> {
313 use Rule::*;
314 match self {
315 Pattern(pattern) => pattern.potential_kinds(),
317 Kind(kind) => kind.potential_kinds(),
318 Regex(regex) => regex.potential_kinds(),
319 NthChild(nth_child) => nth_child.potential_kinds(),
320 Range(range) => range.potential_kinds(),
321 Inside(parent) => parent.potential_kinds(),
323 Has(child) => child.potential_kinds(),
324 Precedes(latter) => latter.potential_kinds(),
325 Follows(former) => former.potential_kinds(),
326 All(all) => all.potential_kinds(),
328 Any(any) => any.potential_kinds(),
329 Not(not) => not.potential_kinds(),
330 Matches(rule) => rule.potential_kinds(),
331 }
332 }
333}
334
335impl Default for Rule {
338 fn default() -> Self {
339 Self::Any(o::Any::new(std::iter::empty()))
340 }
341}
342
343fn match_and_add_label<'tree, D: Doc, M: Matcher>(
344 inner: &M,
345 node: Node<'tree, D>,
346 env: &mut Cow<MetaVarEnv<'tree, D>>,
347) -> Option<Node<'tree, D>> {
348 let matched = inner.match_node_with_env(node, env)?;
349 env.to_mut().add_label("secondary", matched.clone());
350 Some(matched)
351}
352
353#[derive(Debug, Error)]
354pub enum RuleSerializeError {
355 #[error("Rule must have one positive matcher.")]
356 MissPositiveMatcher,
357 #[error("Rule contains invalid kind matcher.")]
358 InvalidKind(#[from] SelectorError),
359 #[error("Rule contains invalid pattern matcher.")]
360 InvalidPattern(#[from] PatternError),
361 #[error("Rule contains invalid nthChild.")]
362 NthChild(#[from] NthChildError),
363 #[error("Rule contains invalid regex matcher.")]
364 WrongRegex(#[from] RegexMatcherError),
365 #[error("Rule contains invalid matches reference.")]
366 MatchesReference(#[from] ReferentRuleError),
367 #[error("Rule contains invalid range matcher.")]
368 InvalidRange(#[from] RangeMatcherError),
369 #[error("field is only supported in has/inside.")]
370 FieldNotSupported,
371 #[error("Relational rule contains invalid field {0}.")]
372 InvalidField(String),
373}
374
375pub fn deserialize_rule<L: Language>(
377 serialized: SerializableRule,
378 env: &DeserializeEnv<L>,
379) -> Result<Rule, RuleSerializeError> {
380 let mut rules = Vec::with_capacity(1);
381 use Rule as R;
382 let categorized = serialized.categorized();
383 deserialze_atomic_rule(categorized.atomic, &mut rules, env)?;
386 deserialze_composite_rule(categorized.composite, &mut rules, env)?;
387 deserialize_relational_rule(categorized.relational, &mut rules, env)?;
388
389 if rules.is_empty() {
390 Err(RuleSerializeError::MissPositiveMatcher)
391 } else if rules.len() == 1 {
392 Ok(rules.pop().expect("should not be empty"))
393 } else {
394 Ok(R::All(o::All::new(rules)))
395 }
396}
397
398fn deserialze_composite_rule<L: Language>(
399 composite: CompositeRule,
400 rules: &mut Vec<Rule>,
401 env: &DeserializeEnv<L>,
402) -> Result<(), RuleSerializeError> {
403 use Rule as R;
404 let convert_rules = |rules: Vec<SerializableRule>| -> Result<_, RuleSerializeError> {
405 let mut inner = Vec::with_capacity(rules.len());
406 for rule in rules {
407 inner.push(deserialize_rule(rule, env)?);
408 }
409 Ok(inner)
410 };
411 if let Some(all) = composite.all {
412 rules.push(R::All(o::All::new(convert_rules(all)?)));
413 }
414 if let Some(any) = composite.any {
415 rules.push(R::Any(o::Any::new(convert_rules(any)?)));
416 }
417 if let Some(not) = composite.not {
418 let not = o::Not::new(deserialize_rule(*not, env)?);
419 rules.push(R::Not(Box::new(not)));
420 }
421 if let Some(id) = composite.matches {
422 let matches = ReferentRule::try_new(id, &env.registration)?;
423 rules.push(R::Matches(matches));
424 }
425 Ok(())
426}
427
428fn deserialize_relational_rule<L: Language>(
429 relational: RelationalRule,
430 rules: &mut Vec<Rule>,
431 env: &DeserializeEnv<L>,
432) -> Result<(), RuleSerializeError> {
433 use Rule as R;
434 if let Some(inside) = relational.inside {
436 rules.push(R::Inside(Box::new(Inside::try_new(*inside, env)?)));
437 }
438 if let Some(has) = relational.has {
439 rules.push(R::Has(Box::new(Has::try_new(*has, env)?)));
440 }
441 if let Some(precedes) = relational.precedes {
442 rules.push(R::Precedes(Box::new(Precedes::try_new(*precedes, env)?)));
443 }
444 if let Some(follows) = relational.follows {
445 rules.push(R::Follows(Box::new(Follows::try_new(*follows, env)?)));
446 }
447 Ok(())
448}
449
450fn deserialze_atomic_rule<L: Language>(
451 atomic: AtomicRule,
452 rules: &mut Vec<Rule>,
453 env: &DeserializeEnv<L>,
454) -> Result<(), RuleSerializeError> {
455 use Rule as R;
456 if let Some(pattern) = atomic.pattern {
457 rules.push(match pattern {
458 PatternStyle::Str(pat) => R::Pattern(Pattern::try_new(&pat, env.lang.clone())?),
459 PatternStyle::Contextual {
460 context,
461 selector,
462 strictness,
463 } => {
464 let pattern = if let Some(selector) = selector {
465 Pattern::contextual(&context, &selector, env.lang.clone())?
466 } else {
467 Pattern::try_new(&context, env.lang.clone())?
468 };
469 let pattern = if let Some(strictness) = strictness {
470 pattern.with_strictness(strictness.into())
471 } else {
472 pattern
473 };
474 R::Pattern(pattern)
475 }
476 });
477 }
478 if let Some(kind) = atomic.kind {
479 let rule = parse_selector(&kind, env.lang.clone())?;
480 rules.push(rule);
481 }
482 if let Some(regex) = atomic.regex {
483 rules.push(R::Regex(RegexMatcher::try_new(®ex)?));
484 }
485 if let Some(nth_child) = atomic.nth_child {
486 rules.push(R::NthChild(NthChild::try_new(nth_child, env)?));
487 }
488 if let Some(range) = atomic.range {
489 rules.push(R::Range(RangeMatcher::try_new(range.start, range.end)?));
490 }
491 Ok(())
492}
493
494#[cfg(test)]
495mod test {
496 use super::*;
497 use crate::from_str;
498 use crate::test::TypeScript;
499 use ast_grep_core::tree_sitter::LanguageExt;
500 use PatternStyle::*;
501
502 #[test]
503 fn test_pattern() {
504 let src = r"
505pattern: Test
506";
507 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
508 assert!(rule.pattern.is_present());
509 let src = r"
510pattern:
511 context: class $C { set $B() {} }
512 selector: method_definition
513";
514 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
515 assert!(matches!(rule.pattern, Maybe::Present(Contextual { .. }),));
516 }
517
518 #[test]
519 fn test_augmentation() {
520 let src = r"
521pattern: class A {}
522inside:
523 pattern: function() {}
524";
525 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
526 assert!(rule.inside.is_present());
527 assert!(rule.pattern.is_present());
528 }
529
530 #[test]
531 fn test_multi_augmentation() {
532 let src = r"
533pattern: class A {}
534inside:
535 pattern: function() {}
536has:
537 pattern: Some()
538";
539 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
540 assert!(rule.inside.is_present());
541 assert!(rule.has.is_present());
542 assert!(rule.follows.is_absent());
543 assert!(rule.precedes.is_absent());
544 assert!(rule.pattern.is_present());
545 }
546
547 #[test]
548 fn test_maybe_not() {
549 let src = "not: 123";
550 let ret: Result<SerializableRule, _> = from_str(src);
551 assert!(ret.is_err());
552 let src = "not:";
553 let ret: Result<SerializableRule, _> = from_str(src);
554 assert!(ret.is_err());
555 }
556
557 #[test]
558 fn test_nested_augmentation() {
559 let src = r"
560pattern: class A {}
561inside:
562 pattern: function() {}
563 inside:
564 pattern:
565 context: Some()
566 selector: ss
567";
568 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
569 assert!(rule.inside.is_present());
570 let inside = rule.inside.unwrap();
571 assert!(inside.rule.pattern.is_present());
572 assert!(inside.rule.inside.unwrap().rule.pattern.is_present());
573 }
574
575 #[test]
576 fn test_precedes_follows() {
577 let src = r"
578pattern: class A {}
579precedes:
580 pattern: function() {}
581follows:
582 pattern:
583 context: Some()
584 selector: ss
585";
586 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
587 assert!(rule.precedes.is_present());
588 assert!(rule.follows.is_present());
589 let follows = rule.follows.unwrap();
590 assert!(follows.rule.pattern.is_present());
591 assert!(follows.rule.pattern.is_present());
592 }
593
594 #[test]
595 fn test_deserialize_rule() {
596 let src = r"
597pattern: class A {}
598kind: class_declaration
599";
600 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
601 let env = DeserializeEnv::new(TypeScript::Tsx);
602 let rule = deserialize_rule(rule, &env).expect("should deserialize");
603 let root = TypeScript::Tsx.ast_grep("class A {}");
604 assert!(root.root().find(rule).is_some());
605 }
606
607 #[test]
608 fn test_deserialize_order() {
609 let src = r"
610pattern: class A {}
611inside:
612 kind: class
613";
614 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
615 let env = DeserializeEnv::new(TypeScript::Tsx);
616 let rule = deserialize_rule(rule, &env).expect("should deserialize");
617 assert!(matches!(rule, Rule::All(_)));
618 }
619
620 #[test]
621 fn test_defined_vars() {
622 let src = r"
623pattern: var $A = 123
624inside:
625 pattern: var $B = 456
626";
627 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
628 let env = DeserializeEnv::new(TypeScript::Tsx);
629 let rule = deserialize_rule(rule, &env).expect("should deserialize");
630 assert_eq!(rule.defined_vars(), ["A", "B"].into_iter().collect());
631 }
632
633 #[test]
634 fn test_issue_1164() {
635 let src = r"
636 kind: statement_block
637 has:
638 pattern: this.$A = promise()
639 stopBy: end";
640 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
641 let env = DeserializeEnv::new(TypeScript::Tsx);
642 let rule = deserialize_rule(rule, &env).expect("should deserialize");
643 let root = TypeScript::Tsx.ast_grep(
644 "if (a) {
645 this.a = b;
646 this.d = promise()
647 }",
648 );
649 assert!(root.root().find(rule).is_some());
650 }
651
652 #[test]
653 fn test_issue_1225() {
654 let src = r"
655 kind: statement_block
656 has:
657 pattern: $A
658 regex: const";
659 let rule: SerializableRule = from_str(src).expect("cannot parse rule");
660 let env = DeserializeEnv::new(TypeScript::Tsx);
661 let rule = deserialize_rule(rule, &env).expect("should deserialize");
662 let root = TypeScript::Tsx.ast_grep(
663 "{
664 let x = 1;
665 const z = 9;
666 }",
667 );
668 assert!(root.root().find(rule).is_some());
669 }
670}