Skip to main content

ast_grep_config/rule/
selector.rs

1#![allow(clippy::doc_lazy_continuation)]
2/// a css selector parser for tree-sitter kind
3///
4/// Example selector
5/// * `call_expression > identifier`
6/// is equvalent to
7/// ```yaml
8/// kind: identifier
9/// inside:
10///   kind: call_expression
11/// ```
12/// * `call_expression identifier`
13/// is equvalent to
14/// ```yaml
15/// kind: identifier
16/// inside:
17///   kind: call_expression
18///   stopBy: end
19/// ```
20/** Grammar for selector
21
22<selector-list> = <complex-selector>#
23
24<complex-selector> = <compound-selector> [ <combinator> <compound-selector> ]*
25
26<compound-selector> = [ <type-selector>? <subclass-selector>* ]!
27
28<combinator> = '>' | '+' | '~' | ' '
29
30<type-selector> = <ident-token>
31
32<subclass-selector> = <class-selector> | <pseudo-class-selector>
33
34<class-selector> = '.' <ident-token>
35
36<pseudo-class-selector> = ':' <ident-token> [ '(' <selector-list> ')' ]?
37*/
38use super::{
39  Rule,
40  nth_child::{NthChild, NthChildError},
41  relational_rule::{Follows, Has, Inside},
42};
43use ast_grep_core::{
44  Language,
45  matcher::{KindMatcher, KindMatcherError},
46  ops,
47};
48use thiserror::Error;
49
50// Inspired by CSS Selector, see
51// https://www.w3.org/TR/selectors-4/#grammar
52/// Token types for the lexer
53#[derive(Debug, Clone, PartialEq)]
54enum Token<'a> {
55  Identifier(&'a str),
56  /// + ~ > or space ` `
57  Combinator(char),
58  /// .
59  ClassDot,
60  /// :
61  PseudoColon,
62  /// (
63  LeftParen,
64  /// )
65  RightParen,
66  /// ,
67  Comma,
68}
69
70pub fn parse_selector<L: Language>(source: &str, lang: L) -> Result<Rule, SelectorError> {
71  let mut input = Input::new(source, lang);
72  let ret = try_parse_selector(&mut input)?;
73  if !input.is_empty() {
74    return Err(SelectorError::UnexpectedToken);
75  }
76  Ok(ret)
77}
78
79/// <selector-list> = <complex-selector>#
80fn try_parse_selector<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
81  let mut rules = vec![];
82  while !input.is_empty() {
83    let complex_selector = parse_complex_selector(input)?;
84    rules.push(complex_selector);
85    if let Some(Token::Comma) = input.peek()? {
86      input.next()?; // consume the comma
87    } else if !input.is_empty() {
88      break;
89    }
90  }
91  Ok(Rule::Any(ops::Any::new(rules)))
92}
93
94/// <complex-selector> = <compound-selector> [ <combinator> <compound-selector> ]*
95fn parse_complex_selector<'a, L: Language>(
96  input: &mut Input<'a, L>,
97) -> Result<Rule, SelectorError> {
98  let mut rule = parse_compound_selector(input)?;
99  while let Some(combinator) = try_parse_combinator(input)? {
100    let next_rule = parse_compound_selector(input)?;
101    match combinator {
102      '>' => {
103        rule = Rule::All(ops::All::new([
104          next_rule,
105          Rule::Inside(Box::new(Inside::rule(rule))),
106        ]));
107      }
108      '+' => {
109        rule = Rule::All(ops::All::new([
110          next_rule,
111          Rule::Follows(Box::new(Follows::rule(rule))),
112        ]));
113      }
114      '~' => {
115        rule = Rule::All(ops::All::new([
116          next_rule,
117          Rule::Follows(Box::new(Follows::rule_descent(rule))),
118        ]));
119      }
120      ' ' => {
121        // space combinator means any descendant
122        rule = Rule::All(ops::All::new([
123          next_rule,
124          Rule::Inside(Box::new(Inside::rule_descent(rule))),
125        ]));
126      }
127      _ => {
128        return Err(SelectorError::IllegalCharacter(combinator));
129      }
130    }
131  }
132  Ok(rule)
133}
134
135/// <combinator> = '>' | '+' | '~' | ' '
136fn try_parse_combinator<'a, L: Language>(
137  input: &mut Input<'a, L>,
138) -> Result<Option<char>, SelectorError> {
139  let Some(Token::Combinator(c)) = input.peek()? else {
140    return Ok(None);
141  };
142  let c = *c;
143  input.next()?; // consume the combinator
144  Ok(Some(c))
145}
146
147/// <compound-selector> = [ <type-selector>? <subclass-selector>* ]!
148fn parse_compound_selector<'a, L: Language>(
149  input: &mut Input<'a, L>,
150) -> Result<Rule, SelectorError> {
151  let mut rules = vec![];
152  if let Some(rule) = try_parse_type_selector(input)? {
153    rules.push(rule);
154  }
155  while let Some(subclass_rule) = try_parse_subclass_selector(input)? {
156    rules.push(subclass_rule);
157  }
158  if rules.is_empty() {
159    return Err(SelectorError::MissingSelector);
160  }
161  Ok(Rule::All(ops::All::new(rules)))
162}
163
164fn try_parse_type_selector<'a, L: Language>(
165  input: &mut Input<'a, L>,
166) -> Result<Option<Rule>, SelectorError> {
167  let Some(Token::Identifier(ident)) = input.peek()? else {
168    return Ok(None);
169  };
170  let ident = *ident;
171  let lang = input.language.clone();
172  input.next()?;
173  let matcher = KindMatcher::try_new(ident, lang)?;
174  Ok(Some(Rule::Kind(matcher)))
175}
176
177/// <subclass-selector> = <class-selector> | <pseudo-class-selector>
178fn try_parse_subclass_selector<'a, L: Language>(
179  input: &mut Input<'a, L>,
180) -> Result<Option<Rule>, SelectorError> {
181  if let Some(Token::ClassDot) = input.peek()? {
182    return Err(SelectorError::Unsupported("class-selector"));
183  }
184  if let Some(Token::PseudoColon) = input.peek()? {
185    return try_parse_pseudo_class_selector(input).map(Some);
186  }
187  Ok(None)
188}
189
190/// <pseudo-class-selector> = ':' <ident-token> [ '(' <selector-list> ')' ]?
191fn try_parse_pseudo_class_selector<'a, L: Language>(
192  input: &mut Input<'a, L>,
193) -> Result<Rule, SelectorError> {
194  input.next()?; // consume ':'
195  let Some(Token::Identifier(name)) = input.next()? else {
196    return Err(SelectorError::UnexpectedToken);
197  };
198  // handle open left (
199  let Some(Token::LeftParen) = input.next()? else {
200    return Err(SelectorError::ExpectedLeftParen);
201  };
202  // prase inner argument according to the pseudo class name
203  let rule = match name {
204    "has" => parse_has_argument(input)?,
205    "not" => parse_not_argument(input)?,
206    // :is() accepts a list of selectors as `matches-any`, reuse try_parse_selector
207    "is" => try_parse_selector(input)?,
208    "nth-child" => parse_nth_child_argument(input, false)?,
209    "nth-last-child" => parse_nth_child_argument(input, true)?,
210    _ => return Err(SelectorError::UnknownPseudoClass(name.to_string())),
211  };
212  // handle closing )
213  let Some(Token::RightParen) = input.next()? else {
214    return Err(SelectorError::ExpectedRightParen);
215  };
216  Ok(rule)
217}
218
219/// [<combinator>]? <complex-selector>
220fn parse_has_argument<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
221  // Leading '>' means direct child (stopBy: neighbor), otherwise descendant (stopBy: end)
222  let has_direct_child = if let Some(Token::Combinator('>')) = input.peek()? {
223    input.next()?; // consume '>'
224    true
225  } else {
226    false
227  };
228  let inner_rule = parse_complex_selector(input)?;
229  let has = if has_direct_child {
230    Has::rule(inner_rule)
231  } else {
232    Has::rule_descent(inner_rule)
233  };
234  Ok(Rule::Has(Box::new(has)))
235}
236
237/// <complex-selector>
238fn parse_not_argument<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
239  let inner_rule = parse_complex_selector(input)?;
240  Ok(Rule::Not(Box::new(ops::Not::new(inner_rule))))
241}
242
243/// <an+b> ['of' <complex-selector>]?
244fn parse_nth_child_argument<'a, L: Language>(
245  input: &mut Input<'a, L>,
246  reverse: bool,
247) -> Result<Rule, SelectorError> {
248  let text = input.extract_an_plus_b();
249  let mut nth_child = NthChild::try_parse(text, reverse)?;
250  if let Some(Token::Identifier("of")) = input.peek()? {
251    input.next()?; // consume 'of'
252    input.consume_whitespace();
253    nth_child = nth_child.of_rule(parse_complex_selector(input)?);
254  }
255  Ok(Rule::NthChild(nth_child))
256}
257
258#[derive(Debug, Error)]
259pub enum SelectorError {
260  #[error("Illegal character {0} encountered")]
261  IllegalCharacter(char),
262  #[error("Unexpected token")]
263  UnexpectedToken,
264  #[error("Missing Selector")]
265  MissingSelector,
266  #[error("Invalid Kind")]
267  InvalidKind(#[from] KindMatcherError),
268  #[error("{0} is not supported yet")]
269  Unsupported(&'static str),
270  #[error("Expected '(' after pseudo-class")]
271  ExpectedLeftParen,
272  #[error("Expected ')' to close pseudo-class")]
273  ExpectedRightParen,
274  #[error("Unknown pseudo-class '{0}'")]
275  UnknownPseudoClass(String),
276  #[error("Invalid nth-child")]
277  InvalidNthChild(#[from] NthChildError),
278}
279
280struct Input<'a, L: Language> {
281  source: &'a str,
282  lookahead: Option<Token<'a>>,
283  language: L,
284}
285
286impl<'a, L: Language> Input<'a, L> {
287  fn new(source: &'a str, language: L) -> Self {
288    Self {
289      source: source.trim(),
290      lookahead: None,
291      language,
292    }
293  }
294
295  fn is_empty(&self) -> bool {
296    self.source.is_empty() && self.lookahead.is_none()
297  }
298
299  fn consume_whitespace(&mut self) {
300    self.source = self.source.trim_start();
301  }
302
303  /// Extract raw An+B text from source, consuming `[0-9nN+- \t]`.
304  fn extract_an_plus_b(&mut self) -> &'a str {
305    debug_assert!(self.lookahead.is_none());
306    let len = self
307      .source
308      .find(|c: char| !matches!(c, '0'..='9' | 'n' | 'N' | '+' | '-' | ' '))
309      .unwrap_or(self.source.len());
310    let text = self.source[..len].trim();
311    self.source = &self.source[len..];
312    self.consume_whitespace();
313    text
314  }
315
316  fn do_next(&mut self) -> Result<Option<Token<'a>>, SelectorError> {
317    if self.source.is_empty() {
318      return Ok(None);
319    }
320    let (next_token, step, need_trim) = match self.source.as_bytes()[0] as char {
321      ' ' => {
322        let len = self
323          .source
324          .find(|c: char| !c.is_whitespace())
325          .unwrap_or(self.source.len());
326        if self.source.len() > len
327          && matches!(
328            self.source.as_bytes()[len] as char,
329            '+' | '~' | '>' | ')' | ','
330          )
331        {
332          self.consume_whitespace();
333          return self.do_next(); // skip whitespace
334        }
335        (Token::Combinator(' '), len, true)
336      }
337      c @ ('+' | '~' | '>') => (Token::Combinator(c), 1, true),
338      '.' => (Token::ClassDot, 1, false),
339      ':' => (Token::PseudoColon, 1, false),
340      '(' => (Token::LeftParen, 1, true),
341      ')' => (Token::RightParen, 1, false),
342      ',' => (Token::Comma, 1, true),
343      'a'..='z' | 'A'..='Z' | '_' | '-' => {
344        let len = self
345          .source
346          .find(|c| !matches!(c, 'a'..='z' | 'A'..='Z' | '_' | '-' | '0'..='9'))
347          .unwrap_or(self.source.len());
348        let ident = &self.source[..len];
349        (Token::Identifier(ident), len, false)
350      }
351      c => {
352        return Err(SelectorError::IllegalCharacter(c));
353      }
354    };
355    self.source = &self.source[step..];
356    if need_trim {
357      self.consume_whitespace();
358    }
359    Ok(Some(next_token))
360  }
361
362  fn next(&mut self) -> Result<Option<Token<'a>>, SelectorError> {
363    if let Some(token) = self.lookahead.take() {
364      Ok(Some(token))
365    } else {
366      self.do_next()
367    }
368  }
369
370  fn peek(&mut self) -> Result<&Option<Token<'a>>, SelectorError> {
371    if self.lookahead.is_some() {
372      return Ok(&self.lookahead);
373    }
374    let next_token = self.do_next()?;
375    self.lookahead = next_token;
376    Ok(&self.lookahead)
377  }
378}
379
380#[cfg(test)]
381mod test {
382  use super::*;
383  use crate::test::TypeScript as TS;
384  use ast_grep_core::tree_sitter::LanguageExt;
385
386  fn input_to_tokens(input: &str) -> Result<Vec<Token<'_>>, SelectorError> {
387    let mut input = Input::new(input, TS::Tsx);
388    let mut tokens = Vec::new();
389    while let Some(token) = input.next()? {
390      tokens.push(token);
391    }
392    Ok(tokens)
393  }
394
395  #[test]
396  fn test_valid_tokens() -> Result<(), SelectorError> {
397    let tokens = input_to_tokens("call_expression + statement > .body :has, identifier")?;
398    let expected = vec![
399      Token::Identifier("call_expression"),
400      Token::Combinator('+'),
401      Token::Identifier("statement"),
402      Token::Combinator('>'),
403      Token::ClassDot,
404      Token::Identifier("body"),
405      Token::Combinator(' '),
406      Token::PseudoColon,
407      Token::Identifier("has"),
408      Token::Comma,
409      Token::Identifier("identifier"),
410    ];
411    assert_eq!(tokens, expected);
412    // Test with extra whitespace
413    let tokens =
414      input_to_tokens("  call_expression   +   statement  >   .body    :has,    identifier  ")?;
415    assert_eq!(tokens, expected);
416    Ok(())
417  }
418
419  #[test]
420  fn test_illegal_character() {
421    let mut input = Input::new("call_expression $ statement", TS::Tsx);
422
423    assert_eq!(
424      input.next().unwrap(),
425      Some(Token::Identifier("call_expression"))
426    );
427    assert_eq!(input.next().unwrap(), Some(Token::Combinator(' ')));
428    assert!(matches!(
429      input.next(),
430      Err(SelectorError::IllegalCharacter('$'))
431    ));
432  }
433
434  #[test]
435  fn test_edge_cases() -> Result<(), SelectorError> {
436    // Empty string
437    let mut input = Input::new("", TS::Tsx);
438    assert_eq!(input.next()?, None);
439
440    // Leading and trailing whitespaces
441    let mut input = Input::new("   call_expression   ", TS::Tsx);
442    assert_eq!(input.next()?, Some(Token::Identifier("call_expression")));
443    assert_eq!(input.next()?, None);
444
445    // Mixed valid and invalid characters
446    let mut input = Input::new("call_expression$statement", TS::Tsx);
447    assert_eq!(input.next()?, Some(Token::Identifier("call_expression")));
448    assert!(matches!(
449      input.next(),
450      Err(SelectorError::IllegalCharacter('$'))
451    ));
452
453    // Long sequence of identifiers
454    let mut input = Input::new("thisisaverylongidentifier", TS::Tsx);
455    assert_eq!(
456      input.next()?,
457      Some(Token::Identifier("thisisaverylongidentifier"))
458    );
459    assert_eq!(input.next()?, None);
460    Ok(())
461  }
462
463  #[test]
464  fn test_parse_selector() -> Result<(), SelectorError> {
465    let selector = "call_expression > identifier";
466    let rule = parse_selector(selector, TS::Tsx)?;
467    let root = TS::Tsx.ast_grep("test(123)");
468    let ident = root.root().find(&rule).expect("Should find identifier");
469    assert_eq!(ident.kind(), "identifier");
470    assert_eq!(ident.text(), "test");
471    let rule = parse_selector("call_expression > number", TS::Tsx)?;
472    assert!(root.root().find(&rule).is_none());
473    let rule = parse_selector("call_expression number", TS::Tsx)?;
474    let number = root.root().find(&rule).expect("Should find number");
475    assert_eq!(number.text(), "123");
476    Ok(())
477  }
478
479  // see issue #2387
480  #[test]
481  fn test_identifier_with_number() -> Result<(), SelectorError> {
482    let tokens = input_to_tokens("atx_h1_marker")?;
483    let expected = vec![Token::Identifier("atx_h1_marker")];
484    assert_eq!(tokens, expected);
485    Ok(())
486  }
487
488  #[test]
489  fn test_has_tokens() -> Result<(), SelectorError> {
490    let tokens = input_to_tokens("A:has(> B)")?;
491    let expected = vec![
492      Token::Identifier("A"),
493      Token::PseudoColon,
494      Token::Identifier("has"),
495      Token::LeftParen,
496      Token::Combinator('>'),
497      Token::Identifier("B"),
498      Token::RightParen,
499    ];
500    assert_eq!(tokens, expected);
501    Ok(())
502  }
503
504  #[test]
505  fn test_has_selector() -> Result<(), SelectorError> {
506    // function_declaration:has(return_statement) - descendant search
507    let rule = parse_selector("function_declaration:has(return_statement)", TS::Tsx)?;
508    let root = TS::Tsx.ast_grep("function foo() { return 1 }");
509    let found = root.root().find(&rule).expect("should find");
510    assert_eq!(found.kind(), "function_declaration");
511
512    // Should not match when descendant is absent
513    let root = TS::Tsx.ast_grep("function foo() { let x = 1 }");
514    assert!(root.root().find(&rule).is_none());
515    Ok(())
516  }
517
518  #[test]
519  fn test_has_direct_child_selector() -> Result<(), SelectorError> {
520    // expression_statement:has(> call_expression) - direct child only
521    let rule = parse_selector("expression_statement:has(> call_expression)", TS::Tsx)?;
522    let root = TS::Tsx.ast_grep("foo()");
523    let found = root.root().find(&rule).expect("should find");
524    assert_eq!(found.kind(), "expression_statement");
525    Ok(())
526  }
527
528  #[test]
529  fn test_has_with_whitespace() -> Result<(), SelectorError> {
530    // whitespace inside :has() should work
531    let rule = parse_selector("function_declaration:has( return_statement )", TS::Tsx)?;
532    let root = TS::Tsx.ast_grep("function foo() { return 1 }");
533    assert!(root.root().find(&rule).is_some());
534    Ok(())
535  }
536
537  #[test]
538  fn test_has_error_cases() {
539    // Unknown pseudo-class
540    let result = parse_selector("expression_statement:first-child(identifier)", TS::Tsx);
541    assert!(matches!(result, Err(SelectorError::UnknownPseudoClass(_))));
542
543    // Missing left paren
544    let result = parse_selector("expression_statement:has identifier", TS::Tsx);
545    assert!(matches!(result, Err(SelectorError::ExpectedLeftParen)));
546
547    // Missing right paren
548    let result = parse_selector("expression_statement:has(identifier", TS::Tsx);
549    assert!(matches!(result, Err(SelectorError::ExpectedRightParen)));
550  }
551
552  #[test]
553  fn test_not_selector() -> Result<(), SelectorError> {
554    // identifier:not(number) - match identifiers that are not numbers
555    let rule = parse_selector("identifier:not(number)", TS::Tsx)?;
556    let root = TS::Tsx.ast_grep("test(123)");
557    let found = root.root().find(&rule).expect("should find");
558    assert_eq!(found.kind(), "identifier");
559    assert_eq!(found.text(), "test");
560    Ok(())
561  }
562
563  #[test]
564  fn test_not_selector_excludes() -> Result<(), SelectorError> {
565    // number:not(number) - should match nothing
566    let rule = parse_selector("number:not(number)", TS::Tsx)?;
567    let root = TS::Tsx.ast_grep("test(123)");
568    assert!(root.root().find(&rule).is_none());
569    Ok(())
570  }
571
572  #[test]
573  fn test_is_selector() -> Result<(), SelectorError> {
574    // :is(identifier, number) - matches any of the listed kinds
575    let rule = parse_selector(":is(identifier, number)", TS::Tsx)?;
576    let root = TS::Tsx.ast_grep("test(123)");
577    let matches: Vec<_> = root.root().find_all(&rule).collect();
578    assert_eq!(matches.len(), 2);
579    assert_eq!(matches[0].text(), "test");
580    assert_eq!(matches[1].text(), "123");
581    Ok(())
582  }
583
584  #[test]
585  fn test_is_selector_in_combinator() -> Result<(), SelectorError> {
586    // call_expression > :is(identifier, number) - composing :is deeper in tree
587    let rule = parse_selector("call_expression > :is(identifier, number)", TS::Tsx)?;
588    let root = TS::Tsx.ast_grep("test(123)");
589    let matches: Vec<_> = root.root().find_all(&rule).collect();
590    assert_eq!(matches.len(), 1);
591    assert_eq!(matches[0].text(), "test");
592    Ok(())
593  }
594
595  #[test]
596  fn test_nth_child_selector() -> Result<(), SelectorError> {
597    // array > number:nth-child(2n+1) - match odd-positioned numbers in array
598    let rule = parse_selector("array > number:nth-child(2n+1)", TS::Tsx)?;
599    let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
600    let matches: Vec<_> = root.root().find_all(&rule).collect();
601    assert_eq!(matches.len(), 3);
602    assert_eq!(matches[0].text(), "1");
603    assert_eq!(matches[1].text(), "3");
604    assert_eq!(matches[2].text(), "5");
605    Ok(())
606  }
607
608  #[test]
609  fn test_nth_child_selector_with_whitespace() -> Result<(), SelectorError> {
610    let rule = parse_selector("array > number:nth-child( 2n + 1 )", TS::Tsx)?;
611    let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
612    let matches: Vec<_> = root.root().find_all(&rule).collect();
613    assert_eq!(matches.len(), 3);
614    Ok(())
615  }
616
617  #[test]
618  fn test_nth_child_negative_an_plus_b() -> Result<(), SelectorError> {
619    // :nth-child(-n + 3) - first 3 children
620    let rule = parse_selector("array > number:nth-child(-n + 3)", TS::Tsx)?;
621    let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
622    let matches: Vec<_> = root.root().find_all(&rule).collect();
623    assert_eq!(matches.len(), 3);
624    assert_eq!(matches[0].text(), "1");
625    assert_eq!(matches[1].text(), "2");
626    assert_eq!(matches[2].text(), "3");
627    Ok(())
628  }
629
630  #[test]
631  fn test_nth_child_of_selector() -> Result<(), SelectorError> {
632    // :nth-child(1 of number) - first number child among siblings
633    let rule = parse_selector("array > :nth-child(1 of number)", TS::Tsx)?;
634    let root = TS::Tsx.ast_grep("[a, 1, 2, 3]");
635    let matches: Vec<_> = root.root().find_all(&rule).collect();
636    assert_eq!(matches.len(), 1);
637    assert_eq!(matches[0].text(), "1");
638    Ok(())
639  }
640
641  #[test]
642  fn test_nth_child_of_complex_selector() -> Result<(), SelectorError> {
643    // :nth-child(2n+1 of number) - odd-positioned numbers only
644    let rule = parse_selector("array > :nth-child(2n+1 of number)", TS::Tsx)?;
645    let root = TS::Tsx.ast_grep("[a, 1, 2, 3]");
646    let matches: Vec<_> = root.root().find_all(&rule).collect();
647    assert_eq!(matches.len(), 2);
648    assert_eq!(matches[0].text(), "1");
649    assert_eq!(matches[1].text(), "3");
650    Ok(())
651  }
652
653  #[test]
654  fn test_nth_last_child_selector() -> Result<(), SelectorError> {
655    let rule = parse_selector("array > number:nth-last-child(1)", TS::Tsx)?;
656    let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
657    let matches: Vec<_> = root.root().find_all(&rule).collect();
658    assert_eq!(matches.len(), 1);
659    assert_eq!(matches[0].text(), "5");
660    Ok(())
661  }
662}