1#![allow(clippy::doc_lazy_continuation)]
2use super::{
39 nth_child::{NthChild, NthChildError},
40 relational_rule::{Follows, Has, Inside},
41 Rule,
42};
43use ast_grep_core::{
44 matcher::{KindMatcher, KindMatcherError},
45 ops, Language,
46};
47use thiserror::Error;
48
49#[derive(Debug, Clone, PartialEq)]
53enum Token<'a> {
54 Identifier(&'a str),
55 Combinator(char),
57 ClassDot,
59 PseudoColon,
61 LeftParen,
63 RightParen,
65 Comma,
67}
68
69pub fn parse_selector<L: Language>(source: &str, lang: L) -> Result<Rule, SelectorError> {
70 let mut input = Input::new(source, lang);
71 let ret = try_parse_selector(&mut input)?;
72 if !input.is_empty() {
73 return Err(SelectorError::UnexpectedToken);
74 }
75 Ok(ret)
76}
77
78fn try_parse_selector<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
80 let mut rules = vec![];
81 while !input.is_empty() {
82 let complex_selector = parse_complex_selector(input)?;
83 rules.push(complex_selector);
84 if let Some(Token::Comma) = input.peek()? {
85 input.next()?; } else if !input.is_empty() {
87 break;
88 }
89 }
90 Ok(Rule::Any(ops::Any::new(rules)))
91}
92
93fn parse_complex_selector<'a, L: Language>(
95 input: &mut Input<'a, L>,
96) -> Result<Rule, SelectorError> {
97 let mut rule = parse_compound_selector(input)?;
98 while let Some(combinator) = try_parse_combinator(input)? {
99 let next_rule = parse_compound_selector(input)?;
100 match combinator {
101 '>' => {
102 rule = Rule::All(ops::All::new([
103 next_rule,
104 Rule::Inside(Box::new(Inside::rule(rule))),
105 ]));
106 }
107 '+' => {
108 rule = Rule::All(ops::All::new([
109 next_rule,
110 Rule::Follows(Box::new(Follows::rule(rule))),
111 ]));
112 }
113 '~' => {
114 rule = Rule::All(ops::All::new([
115 next_rule,
116 Rule::Follows(Box::new(Follows::rule_descent(rule))),
117 ]));
118 }
119 ' ' => {
120 rule = Rule::All(ops::All::new([
122 next_rule,
123 Rule::Inside(Box::new(Inside::rule_descent(rule))),
124 ]));
125 }
126 _ => {
127 return Err(SelectorError::IllegalCharacter(combinator));
128 }
129 }
130 }
131 Ok(rule)
132}
133
134fn try_parse_combinator<'a, L: Language>(
136 input: &mut Input<'a, L>,
137) -> Result<Option<char>, SelectorError> {
138 let Some(Token::Combinator(c)) = input.peek()? else {
139 return Ok(None);
140 };
141 let c = *c;
142 input.next()?; Ok(Some(c))
144}
145
146fn parse_compound_selector<'a, L: Language>(
148 input: &mut Input<'a, L>,
149) -> Result<Rule, SelectorError> {
150 let mut rules = vec![];
151 if let Some(rule) = try_parse_type_selector(input)? {
152 rules.push(rule);
153 }
154 while let Some(subclass_rule) = try_parse_subclass_selector(input)? {
155 rules.push(subclass_rule);
156 }
157 if rules.is_empty() {
158 return Err(SelectorError::MissingSelector);
159 }
160 Ok(Rule::All(ops::All::new(rules)))
161}
162
163fn try_parse_type_selector<'a, L: Language>(
164 input: &mut Input<'a, L>,
165) -> Result<Option<Rule>, SelectorError> {
166 let Some(Token::Identifier(ident)) = input.peek()? else {
167 return Ok(None);
168 };
169 let ident = *ident;
170 let lang = input.language.clone();
171 input.next()?;
172 let matcher = KindMatcher::try_new(ident, lang)?;
173 Ok(Some(Rule::Kind(matcher)))
174}
175
176fn try_parse_subclass_selector<'a, L: Language>(
178 input: &mut Input<'a, L>,
179) -> Result<Option<Rule>, SelectorError> {
180 if let Some(Token::ClassDot) = input.peek()? {
181 return Err(SelectorError::Unsupported("class-selector"));
182 }
183 if let Some(Token::PseudoColon) = input.peek()? {
184 return try_parse_pseudo_class_selector(input).map(Some);
185 }
186 Ok(None)
187}
188
189fn try_parse_pseudo_class_selector<'a, L: Language>(
191 input: &mut Input<'a, L>,
192) -> Result<Rule, SelectorError> {
193 input.next()?; let Some(Token::Identifier(name)) = input.next()? else {
195 return Err(SelectorError::UnexpectedToken);
196 };
197 let Some(Token::LeftParen) = input.next()? else {
199 return Err(SelectorError::ExpectedLeftParen);
200 };
201 let rule = match name {
203 "has" => parse_has_argument(input)?,
204 "not" => parse_not_argument(input)?,
205 "is" => try_parse_selector(input)?,
207 "nth-child" => parse_nth_child_argument(input, false)?,
208 "nth-last-child" => parse_nth_child_argument(input, true)?,
209 _ => return Err(SelectorError::UnknownPseudoClass(name.to_string())),
210 };
211 let Some(Token::RightParen) = input.next()? else {
213 return Err(SelectorError::ExpectedRightParen);
214 };
215 Ok(rule)
216}
217
218fn parse_has_argument<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
220 let has_direct_child = if let Some(Token::Combinator('>')) = input.peek()? {
222 input.next()?; true
224 } else {
225 false
226 };
227 let inner_rule = parse_complex_selector(input)?;
228 let has = if has_direct_child {
229 Has::rule(inner_rule)
230 } else {
231 Has::rule_descent(inner_rule)
232 };
233 Ok(Rule::Has(Box::new(has)))
234}
235
236fn parse_not_argument<'a, L: Language>(input: &mut Input<'a, L>) -> Result<Rule, SelectorError> {
238 let inner_rule = parse_complex_selector(input)?;
239 Ok(Rule::Not(Box::new(ops::Not::new(inner_rule))))
240}
241
242fn parse_nth_child_argument<'a, L: Language>(
244 input: &mut Input<'a, L>,
245 reverse: bool,
246) -> Result<Rule, SelectorError> {
247 let text = input.extract_an_plus_b();
248 let mut nth_child = NthChild::try_parse(text, reverse)?;
249 if let Some(Token::Identifier("of")) = input.peek()? {
250 input.next()?; input.consume_whitespace();
252 nth_child = nth_child.of_rule(parse_complex_selector(input)?);
253 }
254 Ok(Rule::NthChild(nth_child))
255}
256
257#[derive(Debug, Error)]
258pub enum SelectorError {
259 #[error("Illegal character {0} encountered")]
260 IllegalCharacter(char),
261 #[error("Unexpected token")]
262 UnexpectedToken,
263 #[error("Missing Selector")]
264 MissingSelector,
265 #[error("Invalid Kind")]
266 InvalidKind(#[from] KindMatcherError),
267 #[error("{0} is not supported yet")]
268 Unsupported(&'static str),
269 #[error("Expected '(' after pseudo-class")]
270 ExpectedLeftParen,
271 #[error("Expected ')' to close pseudo-class")]
272 ExpectedRightParen,
273 #[error("Unknown pseudo-class '{0}'")]
274 UnknownPseudoClass(String),
275 #[error("Invalid nth-child")]
276 InvalidNthChild(#[from] NthChildError),
277}
278
279struct Input<'a, L: Language> {
280 source: &'a str,
281 lookahead: Option<Token<'a>>,
282 language: L,
283}
284
285impl<'a, L: Language> Input<'a, L> {
286 fn new(source: &'a str, language: L) -> Self {
287 Self {
288 source: source.trim(),
289 lookahead: None,
290 language,
291 }
292 }
293
294 fn is_empty(&self) -> bool {
295 self.source.is_empty() && self.lookahead.is_none()
296 }
297
298 fn consume_whitespace(&mut self) {
299 self.source = self.source.trim_start();
300 }
301
302 fn extract_an_plus_b(&mut self) -> &'a str {
304 debug_assert!(self.lookahead.is_none());
305 let len = self
306 .source
307 .find(|c: char| !matches!(c, '0'..='9' | 'n' | 'N' | '+' | '-' | ' '))
308 .unwrap_or(self.source.len());
309 let text = self.source[..len].trim();
310 self.source = &self.source[len..];
311 self.consume_whitespace();
312 text
313 }
314
315 fn do_next(&mut self) -> Result<Option<Token<'a>>, SelectorError> {
316 if self.source.is_empty() {
317 return Ok(None);
318 }
319 let (next_token, step, need_trim) = match self.source.as_bytes()[0] as char {
320 ' ' => {
321 let len = self
322 .source
323 .find(|c: char| !c.is_whitespace())
324 .unwrap_or(self.source.len());
325 if self.source.len() > len
326 && matches!(
327 self.source.as_bytes()[len] as char,
328 '+' | '~' | '>' | ')' | ','
329 )
330 {
331 self.consume_whitespace();
332 return self.do_next(); }
334 (Token::Combinator(' '), len, true)
335 }
336 c @ ('+' | '~' | '>') => (Token::Combinator(c), 1, true),
337 '.' => (Token::ClassDot, 1, false),
338 ':' => (Token::PseudoColon, 1, false),
339 '(' => (Token::LeftParen, 1, true),
340 ')' => (Token::RightParen, 1, false),
341 ',' => (Token::Comma, 1, true),
342 'a'..='z' | 'A'..='Z' | '_' | '-' => {
343 let len = self
344 .source
345 .find(|c| !matches!(c, 'a'..='z' | 'A'..='Z' | '_' | '-' | '0'..='9'))
346 .unwrap_or(self.source.len());
347 let ident = &self.source[..len];
348 (Token::Identifier(ident), len, false)
349 }
350 c => {
351 return Err(SelectorError::IllegalCharacter(c));
352 }
353 };
354 self.source = &self.source[step..];
355 if need_trim {
356 self.consume_whitespace();
357 }
358 Ok(Some(next_token))
359 }
360
361 fn next(&mut self) -> Result<Option<Token<'a>>, SelectorError> {
362 if let Some(token) = self.lookahead.take() {
363 Ok(Some(token))
364 } else {
365 self.do_next()
366 }
367 }
368
369 fn peek(&mut self) -> Result<&Option<Token<'a>>, SelectorError> {
370 if self.lookahead.is_some() {
371 return Ok(&self.lookahead);
372 }
373 let next_token = self.do_next()?;
374 self.lookahead = next_token;
375 Ok(&self.lookahead)
376 }
377}
378
379#[cfg(test)]
380mod test {
381 use super::*;
382 use crate::test::TypeScript as TS;
383 use ast_grep_core::tree_sitter::LanguageExt;
384
385 fn input_to_tokens(input: &str) -> Result<Vec<Token<'_>>, SelectorError> {
386 let mut input = Input::new(input, TS::Tsx);
387 let mut tokens = Vec::new();
388 while let Some(token) = input.next()? {
389 tokens.push(token);
390 }
391 Ok(tokens)
392 }
393
394 #[test]
395 fn test_valid_tokens() -> Result<(), SelectorError> {
396 let tokens = input_to_tokens("call_expression + statement > .body :has, identifier")?;
397 let expected = vec![
398 Token::Identifier("call_expression"),
399 Token::Combinator('+'),
400 Token::Identifier("statement"),
401 Token::Combinator('>'),
402 Token::ClassDot,
403 Token::Identifier("body"),
404 Token::Combinator(' '),
405 Token::PseudoColon,
406 Token::Identifier("has"),
407 Token::Comma,
408 Token::Identifier("identifier"),
409 ];
410 assert_eq!(tokens, expected);
411 let tokens =
413 input_to_tokens(" call_expression + statement > .body :has, identifier ")?;
414 assert_eq!(tokens, expected);
415 Ok(())
416 }
417
418 #[test]
419 fn test_illegal_character() {
420 let mut input = Input::new("call_expression $ statement", TS::Tsx);
421
422 assert_eq!(
423 input.next().unwrap(),
424 Some(Token::Identifier("call_expression"))
425 );
426 assert_eq!(input.next().unwrap(), Some(Token::Combinator(' ')));
427 assert!(matches!(
428 input.next(),
429 Err(SelectorError::IllegalCharacter('$'))
430 ));
431 }
432
433 #[test]
434 fn test_edge_cases() -> Result<(), SelectorError> {
435 let mut input = Input::new("", TS::Tsx);
437 assert_eq!(input.next()?, None);
438
439 let mut input = Input::new(" call_expression ", TS::Tsx);
441 assert_eq!(input.next()?, Some(Token::Identifier("call_expression")));
442 assert_eq!(input.next()?, None);
443
444 let mut input = Input::new("call_expression$statement", TS::Tsx);
446 assert_eq!(input.next()?, Some(Token::Identifier("call_expression")));
447 assert!(matches!(
448 input.next(),
449 Err(SelectorError::IllegalCharacter('$'))
450 ));
451
452 let mut input = Input::new("thisisaverylongidentifier", TS::Tsx);
454 assert_eq!(
455 input.next()?,
456 Some(Token::Identifier("thisisaverylongidentifier"))
457 );
458 assert_eq!(input.next()?, None);
459 Ok(())
460 }
461
462 #[test]
463 fn test_parse_selector() -> Result<(), SelectorError> {
464 let selector = "call_expression > identifier";
465 let rule = parse_selector(selector, TS::Tsx)?;
466 let root = TS::Tsx.ast_grep("test(123)");
467 let ident = root.root().find(&rule).expect("Should find identifier");
468 assert_eq!(ident.kind(), "identifier");
469 assert_eq!(ident.text(), "test");
470 let rule = parse_selector("call_expression > number", TS::Tsx)?;
471 assert!(root.root().find(&rule).is_none());
472 let rule = parse_selector("call_expression number", TS::Tsx)?;
473 let number = root.root().find(&rule).expect("Should find number");
474 assert_eq!(number.text(), "123");
475 Ok(())
476 }
477
478 #[test]
480 fn test_identifier_with_number() -> Result<(), SelectorError> {
481 let tokens = input_to_tokens("atx_h1_marker")?;
482 let expected = vec![Token::Identifier("atx_h1_marker")];
483 assert_eq!(tokens, expected);
484 Ok(())
485 }
486
487 #[test]
488 fn test_has_tokens() -> Result<(), SelectorError> {
489 let tokens = input_to_tokens("A:has(> B)")?;
490 let expected = vec![
491 Token::Identifier("A"),
492 Token::PseudoColon,
493 Token::Identifier("has"),
494 Token::LeftParen,
495 Token::Combinator('>'),
496 Token::Identifier("B"),
497 Token::RightParen,
498 ];
499 assert_eq!(tokens, expected);
500 Ok(())
501 }
502
503 #[test]
504 fn test_has_selector() -> Result<(), SelectorError> {
505 let rule = parse_selector("function_declaration:has(return_statement)", TS::Tsx)?;
507 let root = TS::Tsx.ast_grep("function foo() { return 1 }");
508 let found = root.root().find(&rule).expect("should find");
509 assert_eq!(found.kind(), "function_declaration");
510
511 let root = TS::Tsx.ast_grep("function foo() { let x = 1 }");
513 assert!(root.root().find(&rule).is_none());
514 Ok(())
515 }
516
517 #[test]
518 fn test_has_direct_child_selector() -> Result<(), SelectorError> {
519 let rule = parse_selector("expression_statement:has(> call_expression)", TS::Tsx)?;
521 let root = TS::Tsx.ast_grep("foo()");
522 let found = root.root().find(&rule).expect("should find");
523 assert_eq!(found.kind(), "expression_statement");
524 Ok(())
525 }
526
527 #[test]
528 fn test_has_with_whitespace() -> Result<(), SelectorError> {
529 let rule = parse_selector("function_declaration:has( return_statement )", TS::Tsx)?;
531 let root = TS::Tsx.ast_grep("function foo() { return 1 }");
532 assert!(root.root().find(&rule).is_some());
533 Ok(())
534 }
535
536 #[test]
537 fn test_has_error_cases() {
538 let result = parse_selector("expression_statement:first-child(identifier)", TS::Tsx);
540 assert!(matches!(result, Err(SelectorError::UnknownPseudoClass(_))));
541
542 let result = parse_selector("expression_statement:has identifier", TS::Tsx);
544 assert!(matches!(result, Err(SelectorError::ExpectedLeftParen)));
545
546 let result = parse_selector("expression_statement:has(identifier", TS::Tsx);
548 assert!(matches!(result, Err(SelectorError::ExpectedRightParen)));
549 }
550
551 #[test]
552 fn test_not_selector() -> Result<(), SelectorError> {
553 let rule = parse_selector("identifier:not(number)", TS::Tsx)?;
555 let root = TS::Tsx.ast_grep("test(123)");
556 let found = root.root().find(&rule).expect("should find");
557 assert_eq!(found.kind(), "identifier");
558 assert_eq!(found.text(), "test");
559 Ok(())
560 }
561
562 #[test]
563 fn test_not_selector_excludes() -> Result<(), SelectorError> {
564 let rule = parse_selector("number:not(number)", TS::Tsx)?;
566 let root = TS::Tsx.ast_grep("test(123)");
567 assert!(root.root().find(&rule).is_none());
568 Ok(())
569 }
570
571 #[test]
572 fn test_is_selector() -> Result<(), SelectorError> {
573 let rule = parse_selector(":is(identifier, number)", TS::Tsx)?;
575 let root = TS::Tsx.ast_grep("test(123)");
576 let matches: Vec<_> = root.root().find_all(&rule).collect();
577 assert_eq!(matches.len(), 2);
578 assert_eq!(matches[0].text(), "test");
579 assert_eq!(matches[1].text(), "123");
580 Ok(())
581 }
582
583 #[test]
584 fn test_is_selector_in_combinator() -> Result<(), SelectorError> {
585 let rule = parse_selector("call_expression > :is(identifier, number)", TS::Tsx)?;
587 let root = TS::Tsx.ast_grep("test(123)");
588 let matches: Vec<_> = root.root().find_all(&rule).collect();
589 assert_eq!(matches.len(), 1);
590 assert_eq!(matches[0].text(), "test");
591 Ok(())
592 }
593
594 #[test]
595 fn test_nth_child_selector() -> Result<(), SelectorError> {
596 let rule = parse_selector("array > number:nth-child(2n+1)", TS::Tsx)?;
598 let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
599 let matches: Vec<_> = root.root().find_all(&rule).collect();
600 assert_eq!(matches.len(), 3);
601 assert_eq!(matches[0].text(), "1");
602 assert_eq!(matches[1].text(), "3");
603 assert_eq!(matches[2].text(), "5");
604 Ok(())
605 }
606
607 #[test]
608 fn test_nth_child_selector_with_whitespace() -> Result<(), SelectorError> {
609 let rule = parse_selector("array > number:nth-child( 2n + 1 )", TS::Tsx)?;
610 let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
611 let matches: Vec<_> = root.root().find_all(&rule).collect();
612 assert_eq!(matches.len(), 3);
613 Ok(())
614 }
615
616 #[test]
617 fn test_nth_child_negative_an_plus_b() -> Result<(), SelectorError> {
618 let rule = parse_selector("array > number:nth-child(-n + 3)", TS::Tsx)?;
620 let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
621 let matches: Vec<_> = root.root().find_all(&rule).collect();
622 assert_eq!(matches.len(), 3);
623 assert_eq!(matches[0].text(), "1");
624 assert_eq!(matches[1].text(), "2");
625 assert_eq!(matches[2].text(), "3");
626 Ok(())
627 }
628
629 #[test]
630 fn test_nth_child_of_selector() -> Result<(), SelectorError> {
631 let rule = parse_selector("array > :nth-child(1 of number)", TS::Tsx)?;
633 let root = TS::Tsx.ast_grep("[a, 1, 2, 3]");
634 let matches: Vec<_> = root.root().find_all(&rule).collect();
635 assert_eq!(matches.len(), 1);
636 assert_eq!(matches[0].text(), "1");
637 Ok(())
638 }
639
640 #[test]
641 fn test_nth_child_of_complex_selector() -> Result<(), SelectorError> {
642 let rule = parse_selector("array > :nth-child(2n+1 of number)", TS::Tsx)?;
644 let root = TS::Tsx.ast_grep("[a, 1, 2, 3]");
645 let matches: Vec<_> = root.root().find_all(&rule).collect();
646 assert_eq!(matches.len(), 2);
647 assert_eq!(matches[0].text(), "1");
648 assert_eq!(matches[1].text(), "3");
649 Ok(())
650 }
651
652 #[test]
653 fn test_nth_last_child_selector() -> Result<(), SelectorError> {
654 let rule = parse_selector("array > number:nth-last-child(1)", TS::Tsx)?;
655 let root = TS::Tsx.ast_grep("[1, 2, 3, 4, 5]");
656 let matches: Vec<_> = root.root().find_all(&rule).collect();
657 assert_eq!(matches.len(), 1);
658 assert_eq!(matches[0].text(), "5");
659 Ok(())
660 }
661}