1use nom::branch::alt;
105use nom::combinator::map;
106use nom::combinator::opt;
107use nom::error::make_error;
108use nom::error::ErrorKind;
109use nom::multi::many0;
110use nom::Parser;
111use nom::{IResult, Needed};
112use pratt::Affix;
113use pratt::Associativity;
114use pratt::PrattError;
115use pratt::PrattParser;
116use pratt::Precedence;
117use proc_macro2::Group;
118use proc_macro2::Ident;
119use proc_macro2::Literal;
120use proc_macro2::Punct;
121use proc_macro2::Spacing;
122use proc_macro2::Span;
123use proc_macro2::TokenStream;
124use proc_macro2::TokenTree;
125use quote::quote;
126use quote::ToTokens;
127use quote::TokenStreamExt;
128use std::iter::{Cloned, Enumerate};
129use std::ops::Deref;
130use std::slice::Iter;
131use syn::punctuated::Punctuated;
132use syn::Token;
133
134#[proc_macro]
139pub fn rule(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
140 match expand_rule(tokens.into()) {
141 Ok(tokens) => tokens.into(),
142 Err(error) => error.to_compile_error().into(),
143 }
144}
145
146fn expand_rule(tokens: TokenStream) -> syn::Result<TokenStream> {
147 let i: Vec<TokenTree> = tokens.into_iter().collect();
148
149 let (i, terminals) = if let Ok((rest, (match_text, _, match_token, _))) =
151 (path, match_punct(','), path, match_punct(',')).parse(Input(&i))
152 {
153 (
154 rest,
155 Some(CustomTerminal {
156 match_text: match_text.1,
157 match_token: match_token.1,
158 }),
159 )
160 } else {
161 (Input(i.as_slice()), None)
162 };
163
164 let terminal = terminals.unwrap_or_else(|| CustomTerminal {
165 match_text: Path {
166 segments: vec![Ident::new("match_text", Span::call_site())],
167 },
168 match_token: Path {
169 segments: vec![Ident::new("match_token", Span::call_site())],
170 },
171 });
172
173 let rule = parse_rule(i.iter().cloned().collect())?;
174 rule.check_return_type()?;
175 Ok(rule.to_token_stream(&terminal))
176}
177
178#[derive(Debug, Clone)]
179struct Path {
180 segments: Vec<Ident>,
181}
182
183#[derive(Debug, Clone)]
184enum Rule {
185 MatchText(Span, Literal),
186 MatchToken(Span, Path),
187 ExternalFunction(Span, Path, Option<Group>),
188 Context(Span, Literal, Box<Rule>),
189 Peek(Span, Box<Rule>),
190 Not(Span, Box<Rule>),
191 Optional(Span, Box<Rule>),
192 Cut(Span, Box<Rule>),
193 Many0(Span, Box<Rule>),
194 Many1(Span, Box<Rule>),
195 Sequence(Span, Vec<Rule>),
196 Choice(Span, Vec<Rule>),
197}
198
199#[derive(Debug, Clone)]
200enum RuleElement {
201 MatchText(Literal),
202 MatchToken(Path),
203 ExternalFunction(Path, Option<Group>),
204 Context(Literal),
205 Peek,
206 Not,
207 Optional,
208 Cut,
209 Many0,
210 Many1,
211 Sequence,
212 Choice,
213 SubRule(Group),
214}
215
216#[derive(Debug, Clone)]
217struct WithSpan {
218 elem: RuleElement,
219 span: Span,
220}
221
222#[derive(Debug, Clone)]
223enum ReturnType {
224 Option(Box<ReturnType>),
225 Vec(Box<ReturnType>),
226 Unit,
227 Unknown,
228}
229
230struct CustomTerminal {
231 match_text: Path,
232 match_token: Path,
233}
234
235#[derive(Debug, Clone)]
236struct Input<'a>(&'a [TokenTree]);
237
238impl Deref for Input<'_> {
239 type Target = [TokenTree];
240
241 fn deref(&self) -> &Self::Target {
242 self.0
243 }
244}
245
246impl<'a> nom::Input for Input<'a> {
247 type Item = TokenTree;
248 type Iter = Cloned<Iter<'a, TokenTree>>;
249 type IterIndices = Enumerate<Self::Iter>;
250
251 fn input_len(&self) -> usize {
252 self.0.len()
253 }
254
255 fn take(&self, index: usize) -> Self {
256 Input(&self.0[0..index])
257 }
258
259 fn take_from(&self, index: usize) -> Self {
260 Input(&self.0[index..])
261 }
262
263 fn take_split(&self, index: usize) -> (Self, Self) {
264 let (prefix, suffix) = self.0.split_at(index);
265 (Input(suffix), Input(prefix))
266 }
267
268 fn position<P>(&self, predicate: P) -> Option<usize>
269 where
270 P: Fn(Self::Item) -> bool,
271 {
272 self.iter().position(|b| predicate(b.clone()))
273 }
274
275 fn iter_elements(&self) -> Self::Iter {
276 self.0.iter().cloned()
277 }
278
279 fn iter_indices(&self) -> Self::IterIndices {
280 self.iter_elements().enumerate()
281 }
282
283 fn slice_index(&self, count: usize) -> Result<usize, Needed> {
284 if self.len() >= count {
285 Ok(count)
286 } else {
287 Err(Needed::new(count - self.len()))
288 }
289 }
290}
291
292fn match_punct<'a>(punct: char) -> impl FnMut(Input<'a>) -> IResult<Input<'a>, TokenTree> {
293 move |i| match i.first().and_then(|token| match token {
294 TokenTree::Punct(p) if p.as_char() == punct => Some(token.clone()),
295 _ => None,
296 }) {
297 Some(token) => Ok((Input(&i.0[1..]), token)),
298 _ => Err(nom::Err::Error(make_error(i, ErrorKind::Satisfy))),
299 }
300}
301
302fn group(i: Input) -> IResult<Input, Group> {
303 match i.first().and_then(|token| match token {
304 TokenTree::Group(group) => Some(group.clone()),
305 _ => None,
306 }) {
307 Some(group) => Ok((Input(&i.0[1..]), group)),
308 _ => Err(nom::Err::Error(make_error(i, ErrorKind::Satisfy))),
309 }
310}
311
312fn literal(i: Input) -> IResult<Input, Literal> {
313 match i.first().and_then(|token| match token {
314 TokenTree::Literal(lit) => Some(lit.clone()),
315 _ => None,
316 }) {
317 Some(lit) => Ok((Input(&i.0[1..]), lit)),
318 _ => Err(nom::Err::Error(make_error(i, ErrorKind::Satisfy))),
319 }
320}
321
322fn ident(i: Input) -> IResult<Input, Ident> {
323 match i.first().and_then(|token| match token {
324 TokenTree::Ident(ident) => Some(ident.clone()),
325 _ => None,
326 }) {
327 Some(ident) => Ok((Input(&i.0[1..]), ident)),
328 _ => Err(nom::Err::Error(make_error(i, ErrorKind::Satisfy))),
329 }
330}
331
332fn path(i: Input) -> IResult<Input, (Span, Path)> {
333 map(
334 (ident, many0((match_punct(':'), match_punct(':'), ident))),
335 |(head, tail)| {
336 let mut segments = vec![head.clone()];
337 segments.extend(tail.into_iter().map(|(_, _, segment)| segment));
338 let span = segments
339 .iter()
340 .try_fold(head.span(), |span, seg| span.join(seg.span()))
341 .unwrap_or(Span::call_site());
342 (span, Path { segments })
343 },
344 )
345 .parse(i)
346}
347
348fn parse_rule(tokens: TokenStream) -> syn::Result<Rule> {
349 let i: Vec<TokenTree> = tokens.into_iter().collect();
350
351 let (i, elems) = many0(parse_rule_element)
352 .parse(Input(&i))
353 .map_err(nom_error_to_syn)?;
354 if !i.is_empty() {
355 let rest: TokenStream = i.iter().cloned().collect();
356 return Err(syn::Error::new_spanned(
357 rest,
358 "unable to parse the following rules",
359 ));
360 }
361
362 let mut iter = elems.into_iter().peekable();
363 let rule = RuleParser.parse(&mut iter).map_err(pratt_error_to_syn)?;
364 if iter.peek().is_some() {
365 let rest: Vec<_> = iter.collect();
366 return Err(syn::Error::new(
367 rest[0].span,
368 format!("unable to parse the following rules: {rest:?}"),
369 ));
370 }
371
372 Ok(rule)
373}
374
375fn nom_error_to_syn(error: nom::Err<nom::error::Error<Input<'_>>>) -> syn::Error {
376 match error {
377 nom::Err::Error(error) | nom::Err::Failure(error) => {
378 let tokens: TokenStream = error.input.iter().cloned().collect();
379 if tokens.is_empty() {
380 syn::Error::new(Span::call_site(), "unable to parse rule")
381 } else {
382 syn::Error::new_spanned(tokens, "unable to parse rule")
383 }
384 }
385 nom::Err::Incomplete(_) => syn::Error::new(Span::call_site(), "incomplete rule"),
386 }
387}
388
389fn pratt_error_to_syn(error: PrattError<WithSpan, syn::Error>) -> syn::Error {
390 match error {
391 PrattError::UserError(error) => error,
392 PrattError::EmptyInput => {
393 syn::Error::new(Span::call_site(), "expected more tokens for rule")
394 }
395 PrattError::UnexpectedNilfix(input) => {
396 syn::Error::new(input.span, "unable to parse the value")
397 }
398 PrattError::UnexpectedPrefix(input) => {
399 syn::Error::new(input.span, "unable to parse the prefix operator")
400 }
401 PrattError::UnexpectedInfix(input) => {
402 syn::Error::new(input.span, "unable to parse the binary operator")
403 }
404 PrattError::UnexpectedPostfix(input) => {
405 syn::Error::new(input.span, "unable to parse the postfix operator")
406 }
407 }
408}
409
410fn parse_rule_element(i: Input) -> IResult<Input, WithSpan> {
411 let function_call = |i| {
412 let (i, hashtag) = match_punct('#')(i)?;
413 let (i, (path_span, fn_path)) = path(i)?;
414 let (i, args) = opt(group).parse(i)?;
415 let span = hashtag.span().join(path_span).unwrap_or(Span::call_site());
416 let span = args
417 .as_ref()
418 .and_then(|args| args.span().join(span))
419 .unwrap_or(span);
420
421 Ok((
422 i,
423 WithSpan {
424 elem: RuleElement::ExternalFunction(fn_path, args),
425 span,
426 },
427 ))
428 };
429 let context = map((match_punct(':'), literal), |(colon, msg)| {
430 let span = colon.span().join(msg.span()).unwrap_or(Span::call_site());
431 WithSpan {
432 elem: RuleElement::Context(msg),
433 span,
434 }
435 });
436 alt((
437 map(match_punct('|'), |token| WithSpan {
438 span: token.span(),
439 elem: RuleElement::Choice,
440 }),
441 map(match_punct('*'), |token| WithSpan {
442 span: token.span(),
443 elem: RuleElement::Many0,
444 }),
445 map(match_punct('+'), |token| WithSpan {
446 span: token.span(),
447 elem: RuleElement::Many1,
448 }),
449 map(match_punct('?'), |token| WithSpan {
450 span: token.span(),
451 elem: RuleElement::Optional,
452 }),
453 map(match_punct('^'), |token| WithSpan {
454 span: token.span(),
455 elem: RuleElement::Cut,
456 }),
457 map(match_punct('&'), |token| WithSpan {
458 span: token.span(),
459 elem: RuleElement::Peek,
460 }),
461 map(match_punct('!'), |token| WithSpan {
462 span: token.span(),
463 elem: RuleElement::Not,
464 }),
465 map(match_punct('~'), |token| WithSpan {
466 span: token.span(),
467 elem: RuleElement::Sequence,
468 }),
469 map(literal, |lit| WithSpan {
470 span: lit.span(),
471 elem: RuleElement::MatchText(lit),
472 }),
473 map(path, |(span, p)| WithSpan {
474 span,
475 elem: RuleElement::MatchToken(p),
476 }),
477 map(group, |group| WithSpan {
478 span: group.span(),
479 elem: RuleElement::SubRule(group),
480 }),
481 function_call,
482 context,
483 ))
484 .parse(i)
485}
486
487struct RuleParser;
488
489impl<I: Iterator<Item = WithSpan>> PrattParser<I> for RuleParser {
490 type Error = syn::Error;
491 type Input = WithSpan;
492 type Output = Rule;
493
494 fn query(&mut self, elem: &WithSpan) -> Result<Affix, syn::Error> {
495 let affix = match elem.elem {
496 RuleElement::Choice => Affix::Infix(Precedence(1), Associativity::Left),
497 RuleElement::Context(_) => Affix::Postfix(Precedence(2)),
498 RuleElement::Sequence => Affix::Infix(Precedence(3), Associativity::Left),
499 RuleElement::Optional => Affix::Postfix(Precedence(4)),
500 RuleElement::Many1 => Affix::Postfix(Precedence(4)),
501 RuleElement::Many0 => Affix::Postfix(Precedence(4)),
502 RuleElement::Cut => Affix::Prefix(Precedence(5)),
503 RuleElement::Peek => Affix::Prefix(Precedence(5)),
504 RuleElement::Not => Affix::Prefix(Precedence(5)),
505 _ => Affix::Nilfix,
506 };
507 Ok(affix)
508 }
509
510 fn primary(&mut self, elem: WithSpan) -> Result<Rule, syn::Error> {
511 let rule = match elem.elem {
512 RuleElement::SubRule(group) => {
513 if group.stream().is_empty() {
514 return Err(syn::Error::new(
515 group.span(),
516 "expected more tokens for rule",
517 ));
518 }
519 parse_rule(group.stream())?
520 }
521 RuleElement::MatchText(text) => Rule::MatchText(elem.span, text),
522 RuleElement::MatchToken(token) => Rule::MatchToken(elem.span, token),
523 RuleElement::ExternalFunction(func, args) => {
524 Rule::ExternalFunction(elem.span, func, args)
525 }
526 _ => unreachable!(),
527 };
528 Ok(rule)
529 }
530
531 fn infix(&mut self, lhs: Rule, elem: WithSpan, rhs: Rule) -> Result<Rule, syn::Error> {
532 let rule = match elem.elem {
533 RuleElement::Sequence => match lhs {
534 Rule::Sequence(span, mut seq) => {
535 let span = span
536 .join(elem.span)
537 .unwrap_or(Span::call_site())
538 .join(rhs.span())
539 .unwrap_or(Span::call_site());
540 seq.push(rhs);
541 Rule::Sequence(span, seq)
542 }
543 lhs => {
544 let span = lhs.span().join(rhs.span()).unwrap_or(Span::call_site());
545 Rule::Sequence(span, vec![lhs, rhs])
546 }
547 },
548 RuleElement::Choice => match lhs {
549 Rule::Choice(span, mut choices) => {
550 let span = span
551 .join(elem.span)
552 .unwrap_or(Span::call_site())
553 .join(rhs.span())
554 .unwrap_or(Span::call_site());
555 choices.push(rhs);
556 Rule::Choice(span, choices)
557 }
558 lhs => {
559 let span = lhs.span().join(rhs.span()).unwrap_or(Span::call_site());
560 Rule::Choice(span, vec![lhs, rhs])
561 }
562 },
563 _ => unreachable!(),
564 };
565 Ok(rule)
566 }
567
568 fn prefix(&mut self, elem: WithSpan, rhs: Rule) -> Result<Rule, syn::Error> {
569 let rule = match elem.elem {
570 RuleElement::Cut => {
571 let span = elem.span.join(rhs.span()).unwrap_or(Span::call_site());
572 Rule::Cut(span, Box::new(rhs))
573 }
574 RuleElement::Peek => {
575 let span = elem.span.join(rhs.span()).unwrap_or(Span::call_site());
576 Rule::Peek(span, Box::new(rhs))
577 }
578 RuleElement::Not => {
579 let span = elem.span.join(rhs.span()).unwrap_or(Span::call_site());
580 Rule::Not(span, Box::new(rhs))
581 }
582 _ => unreachable!(),
583 };
584 Ok(rule)
585 }
586
587 fn postfix(&mut self, lhs: Rule, elem: WithSpan) -> Result<Rule, syn::Error> {
588 let rule = match elem.elem {
589 RuleElement::Optional => {
590 let span = lhs.span().join(elem.span).unwrap_or(Span::call_site());
591 Rule::Optional(span, Box::new(lhs))
592 }
593 RuleElement::Many0 => {
594 let span = lhs.span().join(elem.span).unwrap_or(Span::call_site());
595 Rule::Many0(span, Box::new(lhs))
596 }
597 RuleElement::Many1 => {
598 let span = lhs.span().join(elem.span).unwrap_or(Span::call_site());
599 Rule::Many1(span, Box::new(lhs))
600 }
601 RuleElement::Context(msg) => {
602 let span = lhs.span().join(elem.span).unwrap_or(Span::call_site());
603 Rule::Context(span, msg, Box::new(lhs))
604 }
605 _ => unreachable!(),
606 };
607 Ok(rule)
608 }
609}
610
611impl std::fmt::Display for ReturnType {
612 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
613 match self {
614 ReturnType::Option(ty) => write!(f, "Option<{}>", ty),
615 ReturnType::Vec(ty) => write!(f, "Vec<{}>", ty),
616 ReturnType::Unit => write!(f, "()"),
617 ReturnType::Unknown => write!(f, "_"),
618 }
619 }
620}
621
622impl PartialEq for ReturnType {
623 fn eq(&self, other: &ReturnType) -> bool {
624 match (self, other) {
625 (ReturnType::Option(lhs), ReturnType::Option(rhs)) => lhs == rhs,
626 (ReturnType::Vec(lhs), ReturnType::Vec(rhs)) => lhs == rhs,
627 (ReturnType::Unit, ReturnType::Unit) => true,
628 (ReturnType::Unknown, _) => true,
629 (_, ReturnType::Unknown) => true,
630 _ => false,
631 }
632 }
633}
634
635impl Rule {
636 fn check_return_type(&self) -> syn::Result<ReturnType> {
637 let ty = match self {
638 Rule::MatchText(_, _) | Rule::MatchToken(_, _) | Rule::ExternalFunction(_, _, _) => {
639 ReturnType::Unknown
640 }
641 Rule::Context(_, _, rule) | Rule::Peek(_, rule) => rule.check_return_type()?,
642 Rule::Not(_, _) => ReturnType::Unit,
643 Rule::Optional(_, rule) => ReturnType::Option(Box::new(rule.check_return_type()?)),
644 Rule::Cut(_, rule) => rule.check_return_type()?,
645 Rule::Many0(_, rule) | Rule::Many1(_, rule) => {
646 ReturnType::Vec(Box::new(rule.check_return_type()?))
647 }
648 Rule::Sequence(_, rules) => {
649 for rule in rules {
650 rule.check_return_type()?;
651 }
652 ReturnType::Vec(Box::new(ReturnType::Unknown))
653 }
654 Rule::Choice(_, rules) => {
655 for slice in rules.windows(2) {
656 match (slice[0].check_return_type()?, slice[1].check_return_type()?) {
657 (ReturnType::Option(_), _) => {
658 return Err(syn::Error::new(
659 slice[0].span(),
660 "optional shouldn't be in a choice because it will shortcut the following branches",
661 ));
662 }
663 (a, b) if a != b => {
664 return Err(syn::Error::new(
665 slice[0]
666 .span()
667 .join(slice[1].span())
668 .unwrap_or(Span::call_site()),
669 format!("type mismatched between {a:} and {b:}"),
670 ));
671 }
672 _ => (),
673 }
674 }
675 ReturnType::Vec(Box::new(rules[0].check_return_type()?))
676 }
677 };
678
679 Ok(ty)
680 }
681
682 fn span(&self) -> Span {
683 match self {
684 Rule::MatchText(span, _)
685 | Rule::MatchToken(span, _)
686 | Rule::ExternalFunction(span, _, _)
687 | Rule::Context(span, _, _)
688 | Rule::Peek(span, _)
689 | Rule::Not(span, _)
690 | Rule::Optional(span, _)
691 | Rule::Cut(span, _)
692 | Rule::Many0(span, _)
693 | Rule::Many1(span, _)
694 | Rule::Sequence(span, _)
695 | Rule::Choice(span, _) => *span,
696 }
697 }
698
699 fn to_tokens(&self, terminal: &CustomTerminal, tokens: &mut TokenStream) {
700 let token = match self {
701 Rule::MatchText(_, text) => {
702 let match_text = &terminal.match_text;
703 quote! { #match_text (#text) }
704 }
705 Rule::MatchToken(_, token) => {
706 let match_token = &terminal.match_token;
707 quote! { #match_token (#token) }
708 }
709 Rule::ExternalFunction(_, name, arg) => {
710 quote! { #name #arg }
711 }
712 Rule::Context(_, msg, rule) => {
713 let rule = rule.to_token_stream(terminal);
714 quote! { nom::error::context(#msg, #rule) }
715 }
716 Rule::Peek(_, rule) => {
717 let rule = rule.to_token_stream(terminal);
718 quote! { nom::combinator::peek(#rule) }
719 }
720 Rule::Not(_, rule) => {
721 let rule = rule.to_token_stream(terminal);
722 quote! { nom::combinator::not(#rule) }
723 }
724 Rule::Optional(_, rule) => {
725 let rule = rule.to_token_stream(terminal);
726 quote! { nom::combinator::opt(#rule) }
727 }
728 Rule::Cut(_, rule) => {
729 let rule = rule.to_token_stream(terminal);
730 quote! { nom::combinator::cut(#rule) }
731 }
732 Rule::Many0(_, rule) => {
733 let rule = rule.to_token_stream(terminal);
734 quote! { nom::multi::many0(#rule) }
735 }
736 Rule::Many1(_, rule) => {
737 let rule = rule.to_token_stream(terminal);
738 quote! { nom::multi::many1(#rule) }
739 }
740 Rule::Sequence(_, rules) => {
741 let list: Punctuated<TokenStream, Token![,]> = rules
742 .iter()
743 .map(|rule| rule.to_token_stream(terminal))
744 .collect();
745 quote! { (#list) }
746 }
747 Rule::Choice(_, rules) => {
748 let list: Punctuated<TokenStream, Token![,]> = rules
749 .iter()
750 .map(|rule| rule.to_token_stream(terminal))
751 .collect();
752 quote! { nom::branch::alt((#list)) }
753 }
754 };
755
756 tokens.extend(token);
757 }
758
759 fn to_token_stream(&self, terminal: &CustomTerminal) -> TokenStream {
760 let mut tokens = TokenStream::new();
761 self.to_tokens(terminal, &mut tokens);
762 tokens
763 }
764}
765
766impl ToTokens for Path {
767 fn to_tokens(&self, tokens: &mut TokenStream) {
768 for (i, segment) in self.segments.iter().enumerate() {
769 if i > 0 {
770 tokens.append(Punct::new(':', Spacing::Joint));
772 tokens.append(Punct::new(':', Spacing::Alone));
773 }
774 segment.to_tokens(tokens);
775 }
776 }
777}