1#![allow(clippy::doc_lazy_continuation)]
2use super::{
39 Rule,
40 nth_child::{NthChild, NthChildError},
41 relational_rule::{Follows, Has, Inside},
42};
43use ast_grep_core::{
44 Language,
45 matcher::{KindMatcher, KindMatcherError},
46 ops,
47};
48use thiserror::Error;
49
50#[derive(Debug, Clone, PartialEq)]
54enum Token<'a> {
55 Identifier(&'a str),
56 Combinator(char),
58 ClassDot,
60 PseudoColon,
62 LeftParen,
64 RightParen,
66 Comma,
68}
69
70pub fn parse_selector<L: Language>(source: &str, lang: L) -> Result<Rule, SelectorError> {
71 let mut input = Input::new(source, lang);
72 let ret = try_parse_selector(&mut input)?;
73 if !input.is_empty() {
74 return Err(SelectorError::UnexpectedToken);
75 }
76 Ok(ret)
77}
78
79fn try_parse_selector<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
81 let mut rules = vec![];
82 while !input.is_empty() {
83 let complex_selector = parse_complex_selector(input)?;
84 rules.push(complex_selector);
85 if let Some(Token::Comma) = input.peek()? {
86 input.next()?; } else if !input.is_empty() {
88 break;
89 }
90 }
91 Ok(Rule::Any(ops::Any::new(rules)))
92}
93
94fn parse_complex_selector<'a, L: Language>(
96 input: &mut Input<'a, L>,
97) -> Result<Rule, SelectorError> {
98 let mut rule = parse_compound_selector(input)?;
99 while let Some(combinator) = try_parse_combinator(input)? {
100 let next_rule = parse_compound_selector(input)?;
101 match combinator {
102 '>' => {
103 rule = Rule::All(ops::All::new([
104 next_rule,
105 Rule::Inside(Box::new(Inside::rule(rule))),
106 ]));
107 }
108 '+' => {
109 rule = Rule::All(ops::All::new([
110 next_rule,
111 Rule::Follows(Box::new(Follows::rule(rule))),
112 ]));
113 }
114 '~' => {
115 rule = Rule::All(ops::All::new([
116 next_rule,
117 Rule::Follows(Box::new(Follows::rule_descent(rule))),
118 ]));
119 }
120 ' ' => {
121 rule = Rule::All(ops::All::new([
123 next_rule,
124 Rule::Inside(Box::new(Inside::rule_descent(rule))),
125 ]));
126 }
127 _ => {
128 return Err(SelectorError::IllegalCharacter(combinator));
129 }
130 }
131 }
132 Ok(rule)
133}
134
135fn try_parse_combinator<'a, L: Language>(
137 input: &mut Input<'a, L>,
138) -> Result<Option<char>, SelectorError> {
139 let Some(Token::Combinator(c)) = input.peek()? else {
140 return Ok(None);
141 };
142 let c = *c;
143 input.next()?; Ok(Some(c))
145}
146
147fn parse_compound_selector<'a, L: Language>(
149 input: &mut Input<'a, L>,
150) -> Result<Rule, SelectorError> {
151 let mut rules = vec![];
152 if let Some(rule) = try_parse_type_selector(input)? {
153 rules.push(rule);
154 }
155 while let Some(subclass_rule) = try_parse_subclass_selector(input)? {
156 rules.push(subclass_rule);
157 }
158 if rules.is_empty() {
159 return Err(SelectorError::MissingSelector);
160 }
161 Ok(Rule::All(ops::All::new(rules)))
162}
163
164fn try_parse_type_selector<'a, L: Language>(
165 input: &mut Input<'a, L>,
166) -> Result<Option<Rule>, SelectorError> {
167 let Some(Token::Identifier(ident)) = input.peek()? else {
168 return Ok(None);
169 };
170 let ident = *ident;
171 let lang = input.language.clone();
172 input.next()?;
173 let matcher = KindMatcher::try_new(ident, lang)?;
174 Ok(Some(Rule::Kind(matcher)))
175}
176
177fn try_parse_subclass_selector<'a, L: Language>(
179 input: &mut Input<'a, L>,
180) -> Result<Option<Rule>, SelectorError> {
181 if let Some(Token::ClassDot) = input.peek()? {
182 return Err(SelectorError::Unsupported("class-selector"));
183 }
184 if let Some(Token::PseudoColon) = input.peek()? {
185 return try_parse_pseudo_class_selector(input).map(Some);
186 }
187 Ok(None)
188}
189
190fn try_parse_pseudo_class_selector<'a, L: Language>(
192 input: &mut Input<'a, L>,
193) -> Result<Rule, SelectorError> {
194 input.next()?; let Some(Token::Identifier(name)) = input.next()? else {
196 return Err(SelectorError::UnexpectedToken);
197 };
198 let Some(Token::LeftParen) = input.next()? else {
200 return Err(SelectorError::ExpectedLeftParen);
201 };
202 let rule = match name {
204 "has" => parse_has_argument(input)?,
205 "not" => parse_not_argument(input)?,
206 "is" => try_parse_selector(input)?,
208 "nth-child" => parse_nth_child_argument(input, false)?,
209 "nth-last-child" => parse_nth_child_argument(input, true)?,
210 _ => return Err(SelectorError::UnknownPseudoClass(name.to_string())),
211 };
212 let Some(Token::RightParen) = input.next()? else {
214 return Err(SelectorError::ExpectedRightParen);
215 };
216 Ok(rule)
217}
218
219fn parse_has_argument<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
221 let has_direct_child = if let Some(Token::Combinator('>')) = input.peek()? {
223 input.next()?; true
225 } else {
226 false
227 };
228 let inner_rule = parse_complex_selector(input)?;
229 let has = if has_direct_child {
230 Has::rule(inner_rule)
231 } else {
232 Has::rule_descent(inner_rule)
233 };
234 Ok(Rule::Has(Box::new(has)))
235}
236
237fn parse_not_argument<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
239 let inner_rule = parse_complex_selector(input)?;
240 Ok(Rule::Not(Box::new(ops::Not::new(inner_rule))))
241}
242
243fn parse_nth_child_argument<'a, L: Language>(
245 input: &mut Input<'a, L>,
246 reverse: bool,
247) -> Result<Rule, SelectorError> {
248 let text = input.extract_an_plus_b();
249 let mut nth_child = NthChild::try_parse(text, reverse)?;
250 if let Some(Token::Identifier("of")) = input.peek()? {
251 input.next()?; input.consume_whitespace();
253 nth_child = nth_child.of_rule(parse_complex_selector(input)?);
254 }
255 Ok(Rule::NthChild(nth_child))
256}
257
258#[derive(Debug, Error)]
259pub enum SelectorError {
260 #[error("Illegal character {0} encountered")]
261 IllegalCharacter(char),
262 #[error("Unexpected token")]
263 UnexpectedToken,
264 #[error("Missing Selector")]
265 MissingSelector,
266 #[error("Invalid Kind")]
267 InvalidKind(#[from] KindMatcherError),
268 #[error("{0} is not supported yet")]
269 Unsupported(&'static str),
270 #[error("Expected '(' after pseudo-class")]
271 ExpectedLeftParen,
272 #[error("Expected ')' to close pseudo-class")]
273 ExpectedRightParen,
274 #[error("Unknown pseudo-class '{0}'")]
275 UnknownPseudoClass(String),
276 #[error("Invalid nth-child")]
277 InvalidNthChild(#[from] NthChildError),
278}
279
280struct Input<'a, L: Language> {
281 source: &'a str,
282 lookahead: Option<Token<'a>>,
283 language: L,
284}
285
286impl<'a, L: Language> Input<'a, L> {
287 fn new(source: &'a str, language: L) -> Self {
288 Self {
289 source: source.trim(),
290 lookahead: None,
291 language,
292 }
293 }
294
295 fn is_empty(&self) -> bool {
296 self.source.is_empty() && self.lookahead.is_none()
297 }
298
299 fn consume_whitespace(&mut self) {
300 self.source = self.source.trim_start();
301 }
302
303 fn extract_an_plus_b(&mut self) -> &'a str {
305 debug_assert!(self.lookahead.is_none());
306 let len = self
307 .source
308 .find(|c: char| !matches!(c, '0'..='9' | 'n' | 'N' | '+' | '-' | ' '))
309 .unwrap_or(self.source.len());
310 let text = self.source[..len].trim();
311 self.source = &self.source[len..];
312 self.consume_whitespace();
313 text
314 }
315
316 fn do_next(&mut self) -> Result<Option<Token<'a>>, SelectorError> {
317 if self.source.is_empty() {
318 return Ok(None);
319 }
320 let (next_token, step, need_trim) = match self.source.as_bytes()[0] as char {
321 ' ' => {
322 let len = self
323 .source
324 .find(|c: char| !c.is_whitespace())
325 .unwrap_or(self.source.len());
326 if self.source.len() > len
327 && matches!(
328 self.source.as_bytes()[len] as char,
329 '+' | '~' | '>' | ')' | ','
330 )
331 {
332 self.consume_whitespace();
333 return self.do_next(); }
335 (Token::Combinator(' '), len, true)
336 }
337 c @ ('+' | '~' | '>') => (Token::Combinator(c), 1, true),
338 '.' => (Token::ClassDot, 1, false),
339 ':' => (Token::PseudoColon, 1, false),
340 '(' => (Token::LeftParen, 1, true),
341 ')' => (Token::RightParen, 1, false),
342 ',' => (Token::Comma, 1, true),
343 'a'..='z' | 'A'..='Z' | '_' | '-' => {
344 let len = self
345 .source
346 .find(|c| !matches!(c, 'a'..='z' | 'A'..='Z' | '_' | '-' | '0'..='9'))
347 .unwrap_or(self.source.len());
348 let ident = &self.source[..len];
349 (Token::Identifier(ident), len, false)
350 }
351 c => {
352 return Err(SelectorError::IllegalCharacter(c));
353 }
354 };
355 self.source = &self.source[step..];
356 if need_trim {
357 self.consume_whitespace();
358 }
359 Ok(Some(next_token))
360 }
361
362 fn next(&mut self) -> Result<Option<Token<'a>>, SelectorError> {
363 if let Some(token) = self.lookahead.take() {
364 Ok(Some(token))
365 } else {
366 self.do_next()
367 }
368 }
369
370 fn peek(&mut self) -> Result<&Option<Token<'a>>, SelectorError> {
371 if self.lookahead.is_some() {
372 return Ok(&self.lookahead);
373 }
374 let next_token = self.do_next()?;
375 self.lookahead = next_token;
376 Ok(&self.lookahead)
377 }
378}
379
380#[cfg(test)]
381mod test {
382 use super::*;
383 use crate::test::TypeScript as TS;
384 use ast_grep_core::tree_sitter::LanguageExt;
385
386 fn input_to_tokens(input: &str) -> Result<Vec<Token<'_>>, SelectorError> {
387 let mut input = Input::new(input, TS::Tsx);
388 let mut tokens = Vec::new();
389 while let Some(token) = input.next()? {
390 tokens.push(token);
391 }
392 Ok(tokens)
393 }
394
395 #[test]
396 fn test_valid_tokens() -> Result<(), SelectorError> {
397 let tokens = input_to_tokens("call_expression + statement > .body :has, identifier")?;
398 let expected = vec![
399 Token::Identifier("call_expression"),
400 Token::Combinator('+'),
401 Token::Identifier("statement"),
402 Token::Combinator('>'),
403 Token::ClassDot,
404 Token::Identifier("body"),
405 Token::Combinator(' '),
406 Token::PseudoColon,
407 Token::Identifier("has"),
408 Token::Comma,
409 Token::Identifier("identifier"),
410 ];
411 assert_eq!(tokens, expected);
412 let tokens =
414 input_to_tokens(" call_expression + statement > .body :has, identifier ")?;
415 assert_eq!(tokens, expected);
416 Ok(())
417 }
418
419 #[test]
420 fn test_illegal_character() {
421 let mut input = Input::new("call_expression $ statement", TS::Tsx);
422
423 assert_eq!(
424 input.next().unwrap(),
425 Some(Token::Identifier("call_expression"))
426 );
427 assert_eq!(input.next().unwrap(), Some(Token::Combinator(' ')));
428 assert!(matches!(
429 input.next(),
430 Err(SelectorError::IllegalCharacter('$'))
431 ));
432 }
433
434 #[test]
435 fn test_edge_cases() -> Result<(), SelectorError> {
436 let mut input = Input::new("", TS::Tsx);
438 assert_eq!(input.next()?, None);
439
440 let mut input = Input::new(" call_expression ", TS::Tsx);
442 assert_eq!(input.next()?, Some(Token::Identifier("call_expression")));
443 assert_eq!(input.next()?, None);
444
445 let mut input = Input::new("call_expression$statement", TS::Tsx);
447 assert_eq!(input.next()?, Some(Token::Identifier("call_expression")));
448 assert!(matches!(
449 input.next(),
450 Err(SelectorError::IllegalCharacter('$'))
451 ));
452
453 let mut input = Input::new("thisisaverylongidentifier", TS::Tsx);
455 assert_eq!(
456 input.next()?,
457 Some(Token::Identifier("thisisaverylongidentifier"))
458 );
459 assert_eq!(input.next()?, None);
460 Ok(())
461 }
462
463 #[test]
464 fn test_parse_selector() -> Result<(), SelectorError> {
465 let selector = "call_expression > identifier";
466 let rule = parse_selector(selector, TS::Tsx)?;
467 let root = TS::Tsx.ast_grep("test(123)");
468 let ident = root.root().find(&rule).expect("Should find identifier");
469 assert_eq!(ident.kind(), "identifier");
470 assert_eq!(ident.text(), "test");
471 let rule = parse_selector("call_expression > number", TS::Tsx)?;
472 assert!(root.root().find(&rule).is_none());
473 let rule = parse_selector("call_expression number", TS::Tsx)?;
474 let number = root.root().find(&rule).expect("Should find number");
475 assert_eq!(number.text(), "123");
476 Ok(())
477 }
478
479 #[test]
481 fn test_identifier_with_number() -> Result<(), SelectorError> {
482 let tokens = input_to_tokens("atx_h1_marker")?;
483 let expected = vec![Token::Identifier("atx_h1_marker")];
484 assert_eq!(tokens, expected);
485 Ok(())
486 }
487
488 #[test]
489 fn test_has_tokens() -> Result<(), SelectorError> {
490 let tokens = input_to_tokens("A:has(> B)")?;
491 let expected = vec![
492 Token::Identifier("A"),
493 Token::PseudoColon,
494 Token::Identifier("has"),
495 Token::LeftParen,
496 Token::Combinator('>'),
497 Token::Identifier("B"),
498 Token::RightParen,
499 ];
500 assert_eq!(tokens, expected);
501 Ok(())
502 }
503
504 #[test]
505 fn test_has_selector() -> Result<(), SelectorError> {
506 let rule = parse_selector("function_declaration:has(return_statement)", TS::Tsx)?;
508 let root = TS::Tsx.ast_grep("function foo() { return 1 }");
509 let found = root.root().find(&rule).expect("should find");
510 assert_eq!(found.kind(), "function_declaration");
511
512 let root = TS::Tsx.ast_grep("function foo() { let x = 1 }");
514 assert!(root.root().find(&rule).is_none());
515 Ok(())
516 }
517
518 #[test]
519 fn test_has_direct_child_selector() -> Result<(), SelectorError> {
520 let rule = parse_selector("expression_statement:has(> call_expression)", TS::Tsx)?;
522 let root = TS::Tsx.ast_grep("foo()");
523 let found = root.root().find(&rule).expect("should find");
524 assert_eq!(found.kind(), "expression_statement");
525 Ok(())
526 }
527
528 #[test]
529 fn test_has_with_whitespace() -> Result<(), SelectorError> {
530 let rule = parse_selector("function_declaration:has( return_statement )", TS::Tsx)?;
532 let root = TS::Tsx.ast_grep("function foo() { return 1 }");
533 assert!(root.root().find(&rule).is_some());
534 Ok(())
535 }
536
537 #[test]
538 fn test_has_error_cases() {
539 let result = parse_selector("expression_statement:first-child(identifier)", TS::Tsx);
541 assert!(matches!(result, Err(SelectorError::UnknownPseudoClass(_))));
542
543 let result = parse_selector("expression_statement:has identifier", TS::Tsx);
545 assert!(matches!(result, Err(SelectorError::ExpectedLeftParen)));
546
547 let result = parse_selector("expression_statement:has(identifier", TS::Tsx);
549 assert!(matches!(result, Err(SelectorError::ExpectedRightParen)));
550 }
551
552 #[test]
553 fn test_not_selector() -> Result<(), SelectorError> {
554 let rule = parse_selector("identifier:not(number)", TS::Tsx)?;
556 let root = TS::Tsx.ast_grep("test(123)");
557 let found = root.root().find(&rule).expect("should find");
558 assert_eq!(found.kind(), "identifier");
559 assert_eq!(found.text(), "test");
560 Ok(())
561 }
562
563 #[test]
564 fn test_not_selector_excludes() -> Result<(), SelectorError> {
565 let rule = parse_selector("number:not(number)", TS::Tsx)?;
567 let root = TS::Tsx.ast_grep("test(123)");
568 assert!(root.root().find(&rule).is_none());
569 Ok(())
570 }
571
572 #[test]
573 fn test_is_selector() -> Result<(), SelectorError> {
574 let rule = parse_selector(":is(identifier, number)", TS::Tsx)?;
576 let root = TS::Tsx.ast_grep("test(123)");
577 let matches: Vec<_> = root.root().find_all(&rule).collect();
578 assert_eq!(matches.len(), 2);
579 assert_eq!(matches[0].text(), "test");
580 assert_eq!(matches[1].text(), "123");
581 Ok(())
582 }
583
584 #[test]
585 fn test_is_selector_in_combinator() -> Result<(), SelectorError> {
586 let rule = parse_selector("call_expression > :is(identifier, number)", TS::Tsx)?;
588 let root = TS::Tsx.ast_grep("test(123)");
589 let matches: Vec<_> = root.root().find_all(&rule).collect();
590 assert_eq!(matches.len(), 1);
591 assert_eq!(matches[0].text(), "test");
592 Ok(())
593 }
594
595 #[test]
596 fn test_nth_child_selector() -> Result<(), SelectorError> {
597 let rule = parse_selector("array > number:nth-child(2n+1)", TS::Tsx)?;
599 let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
600 let matches: Vec<_> = root.root().find_all(&rule).collect();
601 assert_eq!(matches.len(), 3);
602 assert_eq!(matches[0].text(), "1");
603 assert_eq!(matches[1].text(), "3");
604 assert_eq!(matches[2].text(), "5");
605 Ok(())
606 }
607
608 #[test]
609 fn test_nth_child_selector_with_whitespace() -> Result<(), SelectorError> {
610 let rule = parse_selector("array > number:nth-child( 2n + 1 )", TS::Tsx)?;
611 let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
612 let matches: Vec<_> = root.root().find_all(&rule).collect();
613 assert_eq!(matches.len(), 3);
614 Ok(())
615 }
616
617 #[test]
618 fn test_nth_child_negative_an_plus_b() -> Result<(), SelectorError> {
619 let rule = parse_selector("array > number:nth-child(-n + 3)", TS::Tsx)?;
621 let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
622 let matches: Vec<_> = root.root().find_all(&rule).collect();
623 assert_eq!(matches.len(), 3);
624 assert_eq!(matches[0].text(), "1");
625 assert_eq!(matches[1].text(), "2");
626 assert_eq!(matches[2].text(), "3");
627 Ok(())
628 }
629
630 #[test]
631 fn test_nth_child_of_selector() -> Result<(), SelectorError> {
632 let rule = parse_selector("array > :nth-child(1 of number)", TS::Tsx)?;
634 let root = TS::Tsx.ast_grep("[a, 1, 2, 3]");
635 let matches: Vec<_> = root.root().find_all(&rule).collect();
636 assert_eq!(matches.len(), 1);
637 assert_eq!(matches[0].text(), "1");
638 Ok(())
639 }
640
641 #[test]
642 fn test_nth_child_of_complex_selector() -> Result<(), SelectorError> {
643 let rule = parse_selector("array > :nth-child(2n+1 of number)", TS::Tsx)?;
645 let root = TS::Tsx.ast_grep("[a, 1, 2, 3]");
646 let matches: Vec<_> = root.root().find_all(&rule).collect();
647 assert_eq!(matches.len(), 2);
648 assert_eq!(matches[0].text(), "1");
649 assert_eq!(matches[1].text(), "3");
650 Ok(())
651 }
652
653 #[test]
654 fn test_nth_last_child_selector() -> Result<(), SelectorError> {
655 let rule = parse_selector("array > number:nth-last-child(1)", TS::Tsx)?;
656 let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
657 let matches: Vec<_> = root.root().find_all(&rule).collect();
658 assert_eq!(matches.len(), 1);
659 assert_eq!(matches[0].text(), "5");
660 Ok(())
661 }
662}