Skip to main content

ast_grep_config/rule/
selector.rs

1#![allow(clippy::doc_lazy_continuation)]
2/// a css selector parser for tree-sitter kind
3///
4/// Example selector
5/// * `call_expression > identifier`
6/// is equvalent to
7/// ```yaml
8/// kind: identifier
9/// inside:
10///   kind: call_expression
11/// ```
12/// * `call_expression identifier`
13/// is equvalent to
14/// ```yaml
15/// kind: identifier
16/// inside:
17///   kind: call_expression
18///   stopBy: end
19/// ```
20/** Grammar for selector
21
22<selector-list> = <complex-selector>#
23
24<complex-selector> = <compound-selector> [ <combinator> <compound-selector> ]*
25
26<compound-selector> = [ <type-selector>? <subclass-selector>* ]!
27
28<combinator> = '>' | '+' | '~' | ' '
29
30<type-selector> = <ident-token>
31
32<subclass-selector> = <class-selector> | <pseudo-class-selector>
33
34<class-selector> = '.' <ident-token>
35
36<pseudo-class-selector> = ':' <ident-token> [ '(' <selector-list> ')' ]?
37*/
38use super::{
39  nth_child::{NthChild, NthChildError},
40  relational_rule::{Follows, Has, Inside},
41  Rule,
42};
43use ast_grep_core::{
44  matcher::{KindMatcher, KindMatcherError},
45  ops, Language,
46};
47use thiserror::Error;
48
49// Inspired by CSS Selector, see
50// https://www.w3.org/TR/selectors-4/#grammar
51/// Token types for the lexer
52#[derive(Debug, Clone, PartialEq)]
53enum Token<'a> {
54  Identifier(&'a str),
55  /// + ~ > or space ` `
56  Combinator(char),
57  /// .
58  ClassDot,
59  /// :
60  PseudoColon,
61  /// (
62  LeftParen,
63  /// )
64  RightParen,
65  /// ,
66  Comma,
67}
68
69pub fn parse_selector<L: Language>(source: &str, lang: L) -> Result<Rule, SelectorError> {
70  let mut input = Input::new(source, lang);
71  let ret = try_parse_selector(&mut input)?;
72  if !input.is_empty() {
73    return Err(SelectorError::UnexpectedToken);
74  }
75  Ok(ret)
76}
77
78/// <selector-list> = <complex-selector>#
79fn try_parse_selector<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
80  let mut rules = vec![];
81  while !input.is_empty() {
82    let complex_selector = parse_complex_selector(input)?;
83    rules.push(complex_selector);
84    if let Some(Token::Comma) = input.peek()? {
85      input.next()?; // consume the comma
86    } else if !input.is_empty() {
87      break;
88    }
89  }
90  Ok(Rule::Any(ops::Any::new(rules)))
91}
92
93/// <complex-selector> = <compound-selector> [ <combinator> <compound-selector> ]*
94fn parse_complex_selector<'a, L: Language>(
95  input: &mut Input<'a, L>,
96) -> Result<Rule, SelectorError> {
97  let mut rule = parse_compound_selector(input)?;
98  while let Some(combinator) = try_parse_combinator(input)? {
99    let next_rule = parse_compound_selector(input)?;
100    match combinator {
101      '>' => {
102        rule = Rule::All(ops::All::new([
103          next_rule,
104          Rule::Inside(Box::new(Inside::rule(rule))),
105        ]));
106      }
107      '+' => {
108        rule = Rule::All(ops::All::new([
109          next_rule,
110          Rule::Follows(Box::new(Follows::rule(rule))),
111        ]));
112      }
113      '~' => {
114        rule = Rule::All(ops::All::new([
115          next_rule,
116          Rule::Follows(Box::new(Follows::rule_descent(rule))),
117        ]));
118      }
119      ' ' => {
120        // space combinator means any descendant
121        rule = Rule::All(ops::All::new([
122          next_rule,
123          Rule::Inside(Box::new(Inside::rule_descent(rule))),
124        ]));
125      }
126      _ => {
127        return Err(SelectorError::IllegalCharacter(combinator));
128      }
129    }
130  }
131  Ok(rule)
132}
133
134/// <combinator> = '>' | '+' | '~' | ' '
135fn try_parse_combinator<'a, L: Language>(
136  input: &mut Input<'a, L>,
137) -> Result<Option<char>, SelectorError> {
138  let Some(Token::Combinator(c)) = input.peek()? else {
139    return Ok(None);
140  };
141  let c = *c;
142  input.next()?; // consume the combinator
143  Ok(Some(c))
144}
145
146/// <compound-selector> = [ <type-selector>? <subclass-selector>* ]!
147fn parse_compound_selector<'a, L: Language>(
148  input: &mut Input<'a, L>,
149) -> Result<Rule, SelectorError> {
150  let mut rules = vec![];
151  if let Some(rule) = try_parse_type_selector(input)? {
152    rules.push(rule);
153  }
154  while let Some(subclass_rule) = try_parse_subclass_selector(input)? {
155    rules.push(subclass_rule);
156  }
157  if rules.is_empty() {
158    return Err(SelectorError::MissingSelector);
159  }
160  Ok(Rule::All(ops::All::new(rules)))
161}
162
163fn try_parse_type_selector<'a, L: Language>(
164  input: &mut Input<'a, L>,
165) -> Result<Option<Rule>, SelectorError> {
166  let Some(Token::Identifier(ident)) = input.peek()? else {
167    return Ok(None);
168  };
169  let ident = *ident;
170  let lang = input.language.clone();
171  input.next()?;
172  let matcher = KindMatcher::try_new(ident, lang)?;
173  Ok(Some(Rule::Kind(matcher)))
174}
175
176/// <subclass-selector> = <class-selector> | <pseudo-class-selector>
177fn try_parse_subclass_selector<'a, L: Language>(
178  input: &mut Input<'a, L>,
179) -> Result<Option<Rule>, SelectorError> {
180  if let Some(Token::ClassDot) = input.peek()? {
181    return Err(SelectorError::Unsupported("class-selector"));
182  }
183  if let Some(Token::PseudoColon) = input.peek()? {
184    return try_parse_pseudo_class_selector(input).map(Some);
185  }
186  Ok(None)
187}
188
189/// <pseudo-class-selector> = ':' <ident-token> [ '(' <selector-list> ')' ]?
190fn try_parse_pseudo_class_selector<'a, L: Language>(
191  input: &mut Input<'a, L>,
192) -> Result<Rule, SelectorError> {
193  input.next()?; // consume ':'
194  let Some(Token::Identifier(name)) = input.next()? else {
195    return Err(SelectorError::UnexpectedToken);
196  };
197  // handle open left (
198  let Some(Token::LeftParen) = input.next()? else {
199    return Err(SelectorError::ExpectedLeftParen);
200  };
201  // prase inner argument according to the pseudo class name
202  let rule = match name {
203    "has" => parse_has_argument(input)?,
204    "not" => parse_not_argument(input)?,
205    // :is() accepts a list of selectors as `matches-any`, reuse try_parse_selector
206    "is" => try_parse_selector(input)?,
207    "nth-child" => parse_nth_child_argument(input, false)?,
208    "nth-last-child" => parse_nth_child_argument(input, true)?,
209    _ => return Err(SelectorError::UnknownPseudoClass(name.to_string())),
210  };
211  // handle closing )
212  let Some(Token::RightParen) = input.next()? else {
213    return Err(SelectorError::ExpectedRightParen);
214  };
215  Ok(rule)
216}
217
218/// [<combinator>]? <complex-selector>
219fn parse_has_argument<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
220  // Leading '>' means direct child (stopBy: neighbor), otherwise descendant (stopBy: end)
221  let has_direct_child = if let Some(Token::Combinator('>')) = input.peek()? {
222    input.next()?; // consume '>'
223    true
224  } else {
225    false
226  };
227  let inner_rule = parse_complex_selector(input)?;
228  let has = if has_direct_child {
229    Has::rule(inner_rule)
230  } else {
231    Has::rule_descent(inner_rule)
232  };
233  Ok(Rule::Has(Box::new(has)))
234}
235
236/// <complex-selector>
237fn parse_not_argument<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
238  let inner_rule = parse_complex_selector(input)?;
239  Ok(Rule::Not(Box::new(ops::Not::new(inner_rule))))
240}
241
242/// <an+b> ['of' <complex-selector>]?
243fn parse_nth_child_argument<'a, L: Language>(
244  input: &mut Input<'a, L>,
245  reverse: bool,
246) -> Result<Rule, SelectorError> {
247  let text = input.extract_an_plus_b();
248  let mut nth_child = NthChild::try_parse(text, reverse)?;
249  if let Some(Token::Identifier("of")) = input.peek()? {
250    input.next()?; // consume 'of'
251    input.consume_whitespace();
252    nth_child = nth_child.of_rule(parse_complex_selector(input)?);
253  }
254  Ok(Rule::NthChild(nth_child))
255}
256
257#[derive(Debug, Error)]
258pub enum SelectorError {
259  #[error("Illegal character {0} encountered")]
260  IllegalCharacter(char),
261  #[error("Unexpected token")]
262  UnexpectedToken,
263  #[error("Missing Selector")]
264  MissingSelector,
265  #[error("Invalid Kind")]
266  InvalidKind(#[from] KindMatcherError),
267  #[error("{0} is not supported yet")]
268  Unsupported(&'static str),
269  #[error("Expected '(' after pseudo-class")]
270  ExpectedLeftParen,
271  #[error("Expected ')' to close pseudo-class")]
272  ExpectedRightParen,
273  #[error("Unknown pseudo-class '{0}'")]
274  UnknownPseudoClass(String),
275  #[error("Invalid nth-child")]
276  InvalidNthChild(#[from] NthChildError),
277}
278
279struct Input<'a, L: Language> {
280  source: &'a str,
281  lookahead: Option<Token<'a>>,
282  language: L,
283}
284
285impl<'a, L: Language> Input<'a, L> {
286  fn new(source: &'a str, language: L) -> Self {
287    Self {
288      source: source.trim(),
289      lookahead: None,
290      language,
291    }
292  }
293
294  fn is_empty(&self) -> bool {
295    self.source.is_empty() && self.lookahead.is_none()
296  }
297
298  fn consume_whitespace(&mut self) {
299    self.source = self.source.trim_start();
300  }
301
302  /// Extract raw An+B text from source, consuming `[0-9nN+- \t]`.
303  fn extract_an_plus_b(&mut self) -> &'a str {
304    debug_assert!(self.lookahead.is_none());
305    let len = self
306      .source
307      .find(|c: char| !matches!(c, '0'..='9' | 'n' | 'N' | '+' | '-' | ' '))
308      .unwrap_or(self.source.len());
309    let text = self.source[..len].trim();
310    self.source = &self.source[len..];
311    self.consume_whitespace();
312    text
313  }
314
315  fn do_next(&mut self) -> Result<Option<Token<'a>>, SelectorError> {
316    if self.source.is_empty() {
317      return Ok(None);
318    }
319    let (next_token, step, need_trim) = match self.source.as_bytes()[0] as char {
320      ' ' => {
321        let len = self
322          .source
323          .find(|c: char| !c.is_whitespace())
324          .unwrap_or(self.source.len());
325        if self.source.len() > len
326          && matches!(
327            self.source.as_bytes()[len] as char,
328            '+' | '~' | '>' | ')' | ','
329          )
330        {
331          self.consume_whitespace();
332          return self.do_next(); // skip whitespace
333        }
334        (Token::Combinator(' '), len, true)
335      }
336      c @ ('+' | '~' | '>') => (Token::Combinator(c), 1, true),
337      '.' => (Token::ClassDot, 1, false),
338      ':' => (Token::PseudoColon, 1, false),
339      '(' => (Token::LeftParen, 1, true),
340      ')' => (Token::RightParen, 1, false),
341      ',' => (Token::Comma, 1, true),
342      'a'..='z' | 'A'..='Z' | '_' | '-' => {
343        let len = self
344          .source
345          .find(|c| !matches!(c, 'a'..='z' | 'A'..='Z' | '_' | '-' | '0'..='9'))
346          .unwrap_or(self.source.len());
347        let ident = &self.source[..len];
348        (Token::Identifier(ident), len, false)
349      }
350      c => {
351        return Err(SelectorError::IllegalCharacter(c));
352      }
353    };
354    self.source = &self.source[step..];
355    if need_trim {
356      self.consume_whitespace();
357    }
358    Ok(Some(next_token))
359  }
360
361  fn next(&mut self) -> Result<Option<Token<'a>>, SelectorError> {
362    if let Some(token) = self.lookahead.take() {
363      Ok(Some(token))
364    } else {
365      self.do_next()
366    }
367  }
368
369  fn peek(&mut self) -> Result<&Option<Token<'a>>, SelectorError> {
370    if self.lookahead.is_some() {
371      return Ok(&self.lookahead);
372    }
373    let next_token = self.do_next()?;
374    self.lookahead = next_token;
375    Ok(&self.lookahead)
376  }
377}
378
379#[cfg(test)]
380mod test {
381  use super::*;
382  use crate::test::TypeScript as TS;
383  use ast_grep_core::tree_sitter::LanguageExt;
384
385  fn input_to_tokens(input: &str) -> Result<Vec<Token<'_>>, SelectorError> {
386    let mut input = Input::new(input, TS::Tsx);
387    let mut tokens = Vec::new();
388    while let Some(token) = input.next()? {
389      tokens.push(token);
390    }
391    Ok(tokens)
392  }
393
394  #[test]
395  fn test_valid_tokens() -> Result<(), SelectorError> {
396    let tokens = input_to_tokens("call_expression + statement > .body :has, identifier")?;
397    let expected = vec![
398      Token::Identifier("call_expression"),
399      Token::Combinator('+'),
400      Token::Identifier("statement"),
401      Token::Combinator('>'),
402      Token::ClassDot,
403      Token::Identifier("body"),
404      Token::Combinator(' '),
405      Token::PseudoColon,
406      Token::Identifier("has"),
407      Token::Comma,
408      Token::Identifier("identifier"),
409    ];
410    assert_eq!(tokens, expected);
411    // Test with extra whitespace
412    let tokens =
413      input_to_tokens("  call_expression   +   statement  >   .body    :has,    identifier  ")?;
414    assert_eq!(tokens, expected);
415    Ok(())
416  }
417
418  #[test]
419  fn test_illegal_character() {
420    let mut input = Input::new("call_expression $ statement", TS::Tsx);
421
422    assert_eq!(
423      input.next().unwrap(),
424      Some(Token::Identifier("call_expression"))
425    );
426    assert_eq!(input.next().unwrap(), Some(Token::Combinator(' ')));
427    assert!(matches!(
428      input.next(),
429      Err(SelectorError::IllegalCharacter('$'))
430    ));
431  }
432
433  #[test]
434  fn test_edge_cases() -> Result<(), SelectorError> {
435    // Empty string
436    let mut input = Input::new("", TS::Tsx);
437    assert_eq!(input.next()?, None);
438
439    // Leading and trailing whitespaces
440    let mut input = Input::new("   call_expression   ", TS::Tsx);
441    assert_eq!(input.next()?, Some(Token::Identifier("call_expression")));
442    assert_eq!(input.next()?, None);
443
444    // Mixed valid and invalid characters
445    let mut input = Input::new("call_expression$statement", TS::Tsx);
446    assert_eq!(input.next()?, Some(Token::Identifier("call_expression")));
447    assert!(matches!(
448      input.next(),
449      Err(SelectorError::IllegalCharacter('$'))
450    ));
451
452    // Long sequence of identifiers
453    let mut input = Input::new("thisisaverylongidentifier", TS::Tsx);
454    assert_eq!(
455      input.next()?,
456      Some(Token::Identifier("thisisaverylongidentifier"))
457    );
458    assert_eq!(input.next()?, None);
459    Ok(())
460  }
461
462  #[test]
463  fn test_parse_selector() -> Result<(), SelectorError> {
464    let selector = "call_expression > identifier";
465    let rule = parse_selector(selector, TS::Tsx)?;
466    let root = TS::Tsx.ast_grep("test(123)");
467    let ident = root.root().find(&rule).expect("Should find identifier");
468    assert_eq!(ident.kind(), "identifier");
469    assert_eq!(ident.text(), "test");
470    let rule = parse_selector("call_expression > number", TS::Tsx)?;
471    assert!(root.root().find(&rule).is_none());
472    let rule = parse_selector("call_expression number", TS::Tsx)?;
473    let number = root.root().find(&rule).expect("Should find number");
474    assert_eq!(number.text(), "123");
475    Ok(())
476  }
477
478  // see issue #2387
479  #[test]
480  fn test_identifier_with_number() -> Result<(), SelectorError> {
481    let tokens = input_to_tokens("atx_h1_marker")?;
482    let expected = vec![Token::Identifier("atx_h1_marker")];
483    assert_eq!(tokens, expected);
484    Ok(())
485  }
486
487  #[test]
488  fn test_has_tokens() -> Result<(), SelectorError> {
489    let tokens = input_to_tokens("A:has(> B)")?;
490    let expected = vec![
491      Token::Identifier("A"),
492      Token::PseudoColon,
493      Token::Identifier("has"),
494      Token::LeftParen,
495      Token::Combinator('>'),
496      Token::Identifier("B"),
497      Token::RightParen,
498    ];
499    assert_eq!(tokens, expected);
500    Ok(())
501  }
502
503  #[test]
504  fn test_has_selector() -> Result<(), SelectorError> {
505    // function_declaration:has(return_statement) - descendant search
506    let rule = parse_selector("function_declaration:has(return_statement)", TS::Tsx)?;
507    let root = TS::Tsx.ast_grep("function foo() { return 1 }");
508    let found = root.root().find(&rule).expect("should find");
509    assert_eq!(found.kind(), "function_declaration");
510
511    // Should not match when descendant is absent
512    let root = TS::Tsx.ast_grep("function foo() { let x = 1 }");
513    assert!(root.root().find(&rule).is_none());
514    Ok(())
515  }
516
517  #[test]
518  fn test_has_direct_child_selector() -> Result<(), SelectorError> {
519    // expression_statement:has(> call_expression) - direct child only
520    let rule = parse_selector("expression_statement:has(> call_expression)", TS::Tsx)?;
521    let root = TS::Tsx.ast_grep("foo()");
522    let found = root.root().find(&rule).expect("should find");
523    assert_eq!(found.kind(), "expression_statement");
524    Ok(())
525  }
526
527  #[test]
528  fn test_has_with_whitespace() -> Result<(), SelectorError> {
529    // whitespace inside :has() should work
530    let rule = parse_selector("function_declaration:has( return_statement )", TS::Tsx)?;
531    let root = TS::Tsx.ast_grep("function foo() { return 1 }");
532    assert!(root.root().find(&rule).is_some());
533    Ok(())
534  }
535
536  #[test]
537  fn test_has_error_cases() {
538    // Unknown pseudo-class
539    let result = parse_selector("expression_statement:first-child(identifier)", TS::Tsx);
540    assert!(matches!(result, Err(SelectorError::UnknownPseudoClass(_))));
541
542    // Missing left paren
543    let result = parse_selector("expression_statement:has identifier", TS::Tsx);
544    assert!(matches!(result, Err(SelectorError::ExpectedLeftParen)));
545
546    // Missing right paren
547    let result = parse_selector("expression_statement:has(identifier", TS::Tsx);
548    assert!(matches!(result, Err(SelectorError::ExpectedRightParen)));
549  }
550
551  #[test]
552  fn test_not_selector() -> Result<(), SelectorError> {
553    // identifier:not(number) - match identifiers that are not numbers
554    let rule = parse_selector("identifier:not(number)", TS::Tsx)?;
555    let root = TS::Tsx.ast_grep("test(123)");
556    let found = root.root().find(&rule).expect("should find");
557    assert_eq!(found.kind(), "identifier");
558    assert_eq!(found.text(), "test");
559    Ok(())
560  }
561
562  #[test]
563  fn test_not_selector_excludes() -> Result<(), SelectorError> {
564    // number:not(number) - should match nothing
565    let rule = parse_selector("number:not(number)", TS::Tsx)?;
566    let root = TS::Tsx.ast_grep("test(123)");
567    assert!(root.root().find(&rule).is_none());
568    Ok(())
569  }
570
571  #[test]
572  fn test_is_selector() -> Result<(), SelectorError> {
573    // :is(identifier, number) - matches any of the listed kinds
574    let rule = parse_selector(":is(identifier, number)", TS::Tsx)?;
575    let root = TS::Tsx.ast_grep("test(123)");
576    let matches: Vec<_> = root.root().find_all(&rule).collect();
577    assert_eq!(matches.len(), 2);
578    assert_eq!(matches[0].text(), "test");
579    assert_eq!(matches[1].text(), "123");
580    Ok(())
581  }
582
583  #[test]
584  fn test_is_selector_in_combinator() -> Result<(), SelectorError> {
585    // call_expression > :is(identifier, number) - composing :is deeper in tree
586    let rule = parse_selector("call_expression > :is(identifier, number)", TS::Tsx)?;
587    let root = TS::Tsx.ast_grep("test(123)");
588    let matches: Vec<_> = root.root().find_all(&rule).collect();
589    assert_eq!(matches.len(), 1);
590    assert_eq!(matches[0].text(), "test");
591    Ok(())
592  }
593
594  #[test]
595  fn test_nth_child_selector() -> Result<(), SelectorError> {
596    // array > number:nth-child(2n+1) - match odd-positioned numbers in array
597    let rule = parse_selector("array > number:nth-child(2n+1)", TS::Tsx)?;
598    let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
599    let matches: Vec<_> = root.root().find_all(&rule).collect();
600    assert_eq!(matches.len(), 3);
601    assert_eq!(matches[0].text(), "1");
602    assert_eq!(matches[1].text(), "3");
603    assert_eq!(matches[2].text(), "5");
604    Ok(())
605  }
606
607  #[test]
608  fn test_nth_child_selector_with_whitespace() -> Result<(), SelectorError> {
609    let rule = parse_selector("array > number:nth-child( 2n + 1 )", TS::Tsx)?;
610    let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
611    let matches: Vec<_> = root.root().find_all(&rule).collect();
612    assert_eq!(matches.len(), 3);
613    Ok(())
614  }
615
616  #[test]
617  fn test_nth_child_negative_an_plus_b() -> Result<(), SelectorError> {
618    // :nth-child(-n + 3) - first 3 children
619    let rule = parse_selector("array > number:nth-child(-n + 3)", TS::Tsx)?;
620    let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
621    let matches: Vec<_> = root.root().find_all(&rule).collect();
622    assert_eq!(matches.len(), 3);
623    assert_eq!(matches[0].text(), "1");
624    assert_eq!(matches[1].text(), "2");
625    assert_eq!(matches[2].text(), "3");
626    Ok(())
627  }
628
629  #[test]
630  fn test_nth_child_of_selector() -> Result<(), SelectorError> {
631    // :nth-child(1 of number) - first number child among siblings
632    let rule = parse_selector("array > :nth-child(1 of number)", TS::Tsx)?;
633    let root = TS::Tsx.ast_grep("[a, 1, 2, 3]");
634    let matches: Vec<_> = root.root().find_all(&rule).collect();
635    assert_eq!(matches.len(), 1);
636    assert_eq!(matches[0].text(), "1");
637    Ok(())
638  }
639
640  #[test]
641  fn test_nth_child_of_complex_selector() -> Result<(), SelectorError> {
642    // :nth-child(2n+1 of number) - odd-positioned numbers only
643    let rule = parse_selector("array > :nth-child(2n+1 of number)", TS::Tsx)?;
644    let root = TS::Tsx.ast_grep("[a, 1, 2, 3]");
645    let matches: Vec<_> = root.root().find_all(&rule).collect();
646    assert_eq!(matches.len(), 2);
647    assert_eq!(matches[0].text(), "1");
648    assert_eq!(matches[1].text(), "3");
649    Ok(())
650  }
651
652  #[test]
653  fn test_nth_last_child_selector() -> Result<(), SelectorError> {
654    let rule = parse_selector("array > number:nth-last-child(1)", TS::Tsx)?;
655    let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
656    let matches: Vec<_> = root.root().find_all(&rule).collect();
657    assert_eq!(matches.len(), 1);
658    assert_eq!(matches[0].text(), "5");
659    Ok(())
660  }
661}