#![allow(clippy::doc_lazy_continuation)]
use super::{
nth_child::{NthChild, NthChildError},
relational_rule::{Follows, Has, Inside},
Rule,
};
use ast_grep_core::{
matcher::{KindMatcher, KindMatcherError},
ops, Language,
};
use thiserror::Error;
#[derive(Debug, Clone, PartialEq)]
enum Token<'a> {
Identifier(&'a str),
Combinator(char),
ClassDot,
PseudoColon,
LeftParen,
RightParen,
Comma,
}
pub fn parse_selector<L: Language>(source: &str, lang: L) -> Result<Rule, SelectorError> {
let mut input = Input::new(source, lang);
let ret = try_parse_selector(&mut input)?;
if !input.is_empty() {
return Err(SelectorError::UnexpectedToken);
}
Ok(ret)
}
fn try_parse_selector<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
let mut rules = vec![];
while !input.is_empty() {
let complex_selector = parse_complex_selector(input)?;
rules.push(complex_selector);
if let Some(Token::Comma) = input.peek()? {
input.next()?; } else if !input.is_empty() {
break;
}
}
Ok(Rule::Any(ops::Any::new(rules)))
}
fn parse_complex_selector<'a, L: Language>(
input: &mut Input<'a, L>,
) -> Result<Rule, SelectorError> {
let mut rule = parse_compound_selector(input)?;
while let Some(combinator) = try_parse_combinator(input)? {
let next_rule = parse_compound_selector(input)?;
match combinator {
'>' => {
rule = Rule::All(ops::All::new([
next_rule,
Rule::Inside(Box::new(Inside::rule(rule))),
]));
}
'+' => {
rule = Rule::All(ops::All::new([
next_rule,
Rule::Follows(Box::new(Follows::rule(rule))),
]));
}
'~' => {
rule = Rule::All(ops::All::new([
next_rule,
Rule::Follows(Box::new(Follows::rule_descent(rule))),
]));
}
' ' => {
rule = Rule::All(ops::All::new([
next_rule,
Rule::Inside(Box::new(Inside::rule_descent(rule))),
]));
}
_ => {
return Err(SelectorError::IllegalCharacter(combinator));
}
}
}
Ok(rule)
}
fn try_parse_combinator<'a, L: Language>(
input: &mut Input<'a, L>,
) -> Result<Option<char>, SelectorError> {
let Some(Token::Combinator(c)) = input.peek()? else {
return Ok(None);
};
let c = *c;
input.next()?; Ok(Some(c))
}
fn parse_compound_selector<'a, L: Language>(
input: &mut Input<'a, L>,
) -> Result<Rule, SelectorError> {
let mut rules = vec![];
if let Some(rule) = try_parse_type_selector(input)? {
rules.push(rule);
}
while let Some(subclass_rule) = try_parse_subclass_selector(input)? {
rules.push(subclass_rule);
}
if rules.is_empty() {
return Err(SelectorError::MissingSelector);
}
Ok(Rule::All(ops::All::new(rules)))
}
fn try_parse_type_selector<'a, L: Language>(
input: &mut Input<'a, L>,
) -> Result<Option<Rule>, SelectorError> {
let Some(Token::Identifier(ident)) = input.peek()? else {
return Ok(None);
};
let ident = *ident;
let lang = input.language.clone();
input.next()?;
let matcher = KindMatcher::try_new(ident, lang)?;
Ok(Some(Rule::Kind(matcher)))
}
fn try_parse_subclass_selector<'a, L: Language>(
input: &mut Input<'a, L>,
) -> Result<Option<Rule>, SelectorError> {
if let Some(Token::ClassDot) = input.peek()? {
return Err(SelectorError::Unsupported("class-selector"));
}
if let Some(Token::PseudoColon) = input.peek()? {
return try_parse_pseudo_class_selector(input).map(Some);
}
Ok(None)
}
fn try_parse_pseudo_class_selector<'a, L: Language>(
input: &mut Input<'a, L>,
) -> Result<Rule, SelectorError> {
input.next()?; let Some(Token::Identifier(name)) = input.next()? else {
return Err(SelectorError::UnexpectedToken);
};
let Some(Token::LeftParen) = input.next()? else {
return Err(SelectorError::ExpectedLeftParen);
};
let rule = match name {
"has" => parse_has_argument(input)?,
"not" => parse_not_argument(input)?,
"is" => try_parse_selector(input)?,
"nth-child" => parse_nth_child_argument(input, false)?,
"nth-last-child" => parse_nth_child_argument(input, true)?,
_ => return Err(SelectorError::UnknownPseudoClass(name.to_string())),
};
let Some(Token::RightParen) = input.next()? else {
return Err(SelectorError::ExpectedRightParen);
};
Ok(rule)
}
fn parse_has_argument<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
let has_direct_child = if let Some(Token::Combinator('>')) = input.peek()? {
input.next()?; true
} else {
false
};
let inner_rule = parse_complex_selector(input)?;
let has = if has_direct_child {
Has::rule(inner_rule)
} else {
Has::rule_descent(inner_rule)
};
Ok(Rule::Has(Box::new(has)))
}
fn parse_not_argument<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
let inner_rule = parse_complex_selector(input)?;
Ok(Rule::Not(Box::new(ops::Not::new(inner_rule))))
}
fn parse_nth_child_argument<'a, L: Language>(
input: &mut Input<'a, L>,
reverse: bool,
) -> Result<Rule, SelectorError> {
let text = input.extract_an_plus_b();
let mut nth_child = NthChild::try_parse(text, reverse)?;
if let Some(Token::Identifier("of")) = input.peek()? {
input.next()?; input.consume_whitespace();
nth_child = nth_child.of_rule(parse_complex_selector(input)?);
}
Ok(Rule::NthChild(nth_child))
}
#[derive(Debug, Error)]
pub enum SelectorError {
#[error("Illegal character {0} encountered")]
IllegalCharacter(char),
#[error("Unexpected token")]
UnexpectedToken,
#[error("Missing Selector")]
MissingSelector,
#[error("Invalid Kind")]
InvalidKind(#[from] KindMatcherError),
#[error("{0} is not supported yet")]
Unsupported(&'static str),
#[error("Expected '(' after pseudo-class")]
ExpectedLeftParen,
#[error("Expected ')' to close pseudo-class")]
ExpectedRightParen,
#[error("Unknown pseudo-class '{0}'")]
UnknownPseudoClass(String),
#[error("Invalid nth-child")]
InvalidNthChild(#[from] NthChildError),
}
struct Input<'a, L: Language> {
source: &'a str,
lookahead: Option<Token<'a>>,
language: L,
}
impl<'a, L: Language> Input<'a, L> {
fn new(source: &'a str, language: L) -> Self {
Self {
source: source.trim(),
lookahead: None,
language,
}
}
fn is_empty(&self) -> bool {
self.source.is_empty() && self.lookahead.is_none()
}
fn consume_whitespace(&mut self) {
self.source = self.source.trim_start();
}
fn extract_an_plus_b(&mut self) -> &'a str {
debug_assert!(self.lookahead.is_none());
let len = self
.source
.find(|c: char| !matches!(c, '0'..='9' | 'n' | 'N' | '+' | '-' | ' '))
.unwrap_or(self.source.len());
let text = self.source[..len].trim();
self.source = &self.source[len..];
self.consume_whitespace();
text
}
fn do_next(&mut self) -> Result<Option<Token<'a>>, SelectorError> {
if self.source.is_empty() {
return Ok(None);
}
let (next_token, step, need_trim) = match self.source.as_bytes()[0] as char {
' ' => {
let len = self
.source
.find(|c: char| !c.is_whitespace())
.unwrap_or(self.source.len());
if self.source.len() > len
&& matches!(
self.source.as_bytes()[len] as char,
'+' | '~' | '>' | ')' | ','
)
{
self.consume_whitespace();
return self.do_next(); }
(Token::Combinator(' '), len, true)
}
c @ ('+' | '~' | '>') => (Token::Combinator(c), 1, true),
'.' => (Token::ClassDot, 1, false),
':' => (Token::PseudoColon, 1, false),
'(' => (Token::LeftParen, 1, true),
')' => (Token::RightParen, 1, false),
',' => (Token::Comma, 1, true),
'a'..='z' | 'A'..='Z' | '_' | '-' => {
let len = self
.source
.find(|c| !matches!(c, 'a'..='z' | 'A'..='Z' | '_' | '-' | '0'..='9'))
.unwrap_or(self.source.len());
let ident = &self.source[..len];
(Token::Identifier(ident), len, false)
}
c => {
return Err(SelectorError::IllegalCharacter(c));
}
};
self.source = &self.source[step..];
if need_trim {
self.consume_whitespace();
}
Ok(Some(next_token))
}
fn next(&mut self) -> Result<Option<Token<'a>>, SelectorError> {
if let Some(token) = self.lookahead.take() {
Ok(Some(token))
} else {
self.do_next()
}
}
fn peek(&mut self) -> Result<&Option<Token<'a>>, SelectorError> {
if self.lookahead.is_some() {
return Ok(&self.lookahead);
}
let next_token = self.do_next()?;
self.lookahead = next_token;
Ok(&self.lookahead)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::test::TypeScript as TS;
use ast_grep_core::tree_sitter::LanguageExt;
fn input_to_tokens(input: &str) -> Result<Vec<Token<'_>>, SelectorError> {
let mut input = Input::new(input, TS::Tsx);
let mut tokens = Vec::new();
while let Some(token) = input.next()? {
tokens.push(token);
}
Ok(tokens)
}
#[test]
fn test_valid_tokens() -> Result<(), SelectorError> {
let tokens = input_to_tokens("call_expression + statement > .body :has, identifier")?;
let expected = vec![
Token::Identifier("call_expression"),
Token::Combinator('+'),
Token::Identifier("statement"),
Token::Combinator('>'),
Token::ClassDot,
Token::Identifier("body"),
Token::Combinator(' '),
Token::PseudoColon,
Token::Identifier("has"),
Token::Comma,
Token::Identifier("identifier"),
];
assert_eq!(tokens, expected);
let tokens =
input_to_tokens(" call_expression + statement > .body :has, identifier ")?;
assert_eq!(tokens, expected);
Ok(())
}
#[test]
fn test_illegal_character() {
let mut input = Input::new("call_expression $ statement", TS::Tsx);
assert_eq!(
input.next().unwrap(),
Some(Token::Identifier("call_expression"))
);
assert_eq!(input.next().unwrap(), Some(Token::Combinator(' ')));
assert!(matches!(
input.next(),
Err(SelectorError::IllegalCharacter('$'))
));
}
#[test]
fn test_edge_cases() -> Result<(), SelectorError> {
let mut input = Input::new("", TS::Tsx);
assert_eq!(input.next()?, None);
let mut input = Input::new(" call_expression ", TS::Tsx);
assert_eq!(input.next()?, Some(Token::Identifier("call_expression")));
assert_eq!(input.next()?, None);
let mut input = Input::new("call_expression$statement", TS::Tsx);
assert_eq!(input.next()?, Some(Token::Identifier("call_expression")));
assert!(matches!(
input.next(),
Err(SelectorError::IllegalCharacter('$'))
));
let mut input = Input::new("thisisaverylongidentifier", TS::Tsx);
assert_eq!(
input.next()?,
Some(Token::Identifier("thisisaverylongidentifier"))
);
assert_eq!(input.next()?, None);
Ok(())
}
#[test]
fn test_parse_selector() -> Result<(), SelectorError> {
let selector = "call_expression > identifier";
let rule = parse_selector(selector, TS::Tsx)?;
let root = TS::Tsx.ast_grep("test(123)");
let ident = root.root().find(&rule).expect("Should find identifier");
assert_eq!(ident.kind(), "identifier");
assert_eq!(ident.text(), "test");
let rule = parse_selector("call_expression > number", TS::Tsx)?;
assert!(root.root().find(&rule).is_none());
let rule = parse_selector("call_expression number", TS::Tsx)?;
let number = root.root().find(&rule).expect("Should find number");
assert_eq!(number.text(), "123");
Ok(())
}
#[test]
fn test_identifier_with_number() -> Result<(), SelectorError> {
let tokens = input_to_tokens("atx_h1_marker")?;
let expected = vec![Token::Identifier("atx_h1_marker")];
assert_eq!(tokens, expected);
Ok(())
}
#[test]
fn test_has_tokens() -> Result<(), SelectorError> {
let tokens = input_to_tokens("A:has(> B)")?;
let expected = vec![
Token::Identifier("A"),
Token::PseudoColon,
Token::Identifier("has"),
Token::LeftParen,
Token::Combinator('>'),
Token::Identifier("B"),
Token::RightParen,
];
assert_eq!(tokens, expected);
Ok(())
}
#[test]
fn test_has_selector() -> Result<(), SelectorError> {
let rule = parse_selector("function_declaration:has(return_statement)", TS::Tsx)?;
let root = TS::Tsx.ast_grep("function foo() { return 1 }");
let found = root.root().find(&rule).expect("should find");
assert_eq!(found.kind(), "function_declaration");
let root = TS::Tsx.ast_grep("function foo() { let x = 1 }");
assert!(root.root().find(&rule).is_none());
Ok(())
}
#[test]
fn test_has_direct_child_selector() -> Result<(), SelectorError> {
let rule = parse_selector("expression_statement:has(> call_expression)", TS::Tsx)?;
let root = TS::Tsx.ast_grep("foo()");
let found = root.root().find(&rule).expect("should find");
assert_eq!(found.kind(), "expression_statement");
Ok(())
}
#[test]
fn test_has_with_whitespace() -> Result<(), SelectorError> {
let rule = parse_selector("function_declaration:has( return_statement )", TS::Tsx)?;
let root = TS::Tsx.ast_grep("function foo() { return 1 }");
assert!(root.root().find(&rule).is_some());
Ok(())
}
#[test]
fn test_has_error_cases() {
let result = parse_selector("expression_statement:first-child(identifier)", TS::Tsx);
assert!(matches!(result, Err(SelectorError::UnknownPseudoClass(_))));
let result = parse_selector("expression_statement:has identifier", TS::Tsx);
assert!(matches!(result, Err(SelectorError::ExpectedLeftParen)));
let result = parse_selector("expression_statement:has(identifier", TS::Tsx);
assert!(matches!(result, Err(SelectorError::ExpectedRightParen)));
}
#[test]
fn test_not_selector() -> Result<(), SelectorError> {
let rule = parse_selector("identifier:not(number)", TS::Tsx)?;
let root = TS::Tsx.ast_grep("test(123)");
let found = root.root().find(&rule).expect("should find");
assert_eq!(found.kind(), "identifier");
assert_eq!(found.text(), "test");
Ok(())
}
#[test]
fn test_not_selector_excludes() -> Result<(), SelectorError> {
let rule = parse_selector("number:not(number)", TS::Tsx)?;
let root = TS::Tsx.ast_grep("test(123)");
assert!(root.root().find(&rule).is_none());
Ok(())
}
#[test]
fn test_is_selector() -> Result<(), SelectorError> {
let rule = parse_selector(":is(identifier, number)", TS::Tsx)?;
let root = TS::Tsx.ast_grep("test(123)");
let matches: Vec<_> = root.root().find_all(&rule).collect();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].text(), "test");
assert_eq!(matches[1].text(), "123");
Ok(())
}
#[test]
fn test_is_selector_in_combinator() -> Result<(), SelectorError> {
let rule = parse_selector("call_expression > :is(identifier, number)", TS::Tsx)?;
let root = TS::Tsx.ast_grep("test(123)");
let matches: Vec<_> = root.root().find_all(&rule).collect();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].text(), "test");
Ok(())
}
#[test]
fn test_nth_child_selector() -> Result<(), SelectorError> {
let rule = parse_selector("array > number:nth-child(2n+1)", TS::Tsx)?;
let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
let matches: Vec<_> = root.root().find_all(&rule).collect();
assert_eq!(matches.len(), 3);
assert_eq!(matches[0].text(), "1");
assert_eq!(matches[1].text(), "3");
assert_eq!(matches[2].text(), "5");
Ok(())
}
#[test]
fn test_nth_child_selector_with_whitespace() -> Result<(), SelectorError> {
let rule = parse_selector("array > number:nth-child( 2n + 1 )", TS::Tsx)?;
let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
let matches: Vec<_> = root.root().find_all(&rule).collect();
assert_eq!(matches.len(), 3);
Ok(())
}
#[test]
fn test_nth_child_negative_an_plus_b() -> Result<(), SelectorError> {
let rule = parse_selector("array > number:nth-child(-n + 3)", TS::Tsx)?;
let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
let matches: Vec<_> = root.root().find_all(&rule).collect();
assert_eq!(matches.len(), 3);
assert_eq!(matches[0].text(), "1");
assert_eq!(matches[1].text(), "2");
assert_eq!(matches[2].text(), "3");
Ok(())
}
#[test]
fn test_nth_child_of_selector() -> Result<(), SelectorError> {
let rule = parse_selector("array > :nth-child(1 of number)", TS::Tsx)?;
let root = TS::Tsx.ast_grep("[a, 1, 2, 3]");
let matches: Vec<_> = root.root().find_all(&rule).collect();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].text(), "1");
Ok(())
}
#[test]
fn test_nth_child_of_complex_selector() -> Result<(), SelectorError> {
let rule = parse_selector("array > :nth-child(2n+1 of number)", TS::Tsx)?;
let root = TS::Tsx.ast_grep("[a, 1, 2, 3]");
let matches: Vec<_> = root.root().find_all(&rule).collect();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].text(), "1");
assert_eq!(matches[1].text(), "3");
Ok(())
}
#[test]
fn test_nth_last_child_selector() -> Result<(), SelectorError> {
let rule = parse_selector("array > number:nth-last-child(1)", TS::Tsx)?;
let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
let matches: Vec<_> = root.root().find_all(&rule).collect();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].text(), "5");
Ok(())
}
}