syn_match/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use syn::Expr;
7use syn::Ident;
8use syn::Result;
9use syn::Token;
10use syn::parse::Parse;
11use syn::parse::ParseStream;
12use syn::parse_macro_input;
13use syn::punctuated::Punctuated;
14
15enum PathPattern {
16  Path {
17    segments: Vec<SegmentPattern>,
18    fully_qualified: bool,
19  },
20  Wildcard,
21}
22
23enum SegmentPattern {
24  Required(SegmentMatcher),
25  Optional(SegmentMatcher),
26  Binding {
27    span: Span,
28    name: Ident,
29    multi: bool,
30  },
31}
32
33struct SegmentMatcher {
34  ident: Ident,
35  args: Option<ArgumentPattern>,
36}
37
38enum ArgumentPattern {
39  AngleBracketed(Vec<GenericArgumentPattern>),
40}
41
42enum GenericArgumentPattern {
43  Argument(Box<SegmentMatcher>),
44  Wildcard,
45  Binding(Span, Ident),
46}
47
48struct PathMatchArms {
49  path: Expr,
50  arms: Vec<MatchArm>,
51}
52
53struct MatchArm {
54  patterns: Punctuated<PathPattern, Token![|]>,
55  _fat_arrow: Token![=>],
56  body: Expr,
57  _comma: Option<Token![,]>,
58}
59
60impl Parse for PathMatchArms {
61  fn parse(input: ParseStream) -> Result<Self> {
62    let path = input.parse()?;
63    input.parse::<Token![,]>()?;
64
65    let mut arms = Vec::new();
66    while !input.is_empty() {
67      arms.push(input.parse()?);
68    }
69
70    Ok(PathMatchArms { path, arms })
71  }
72}
73
74impl Parse for MatchArm {
75  fn parse(input: ParseStream) -> Result<Self> {
76    let patterns = Punctuated::parse_separated_nonempty(input)?;
77    let _fat_arrow = input.parse()?;
78    let body = input.parse()?;
79    let _comma = input.parse().ok();
80
81    Ok(MatchArm {
82      patterns,
83      _fat_arrow,
84      body,
85      _comma,
86    })
87  }
88}
89
90impl Parse for PathPattern {
91  fn parse(input: ParseStream) -> Result<Self> {
92    if input.peek(Token![_]) {
93      input.parse::<Token![_]>()?;
94      return Ok(PathPattern::Wildcard);
95    }
96
97    let mut segments = Vec::new();
98    let mut fully_qualified = false;
99
100    if input.peek(Token![::]) {
101      input.parse::<Token![::]>()?;
102      fully_qualified = true;
103    }
104
105    loop {
106      if input.peek(Token![$]) {
107        input.parse::<Token![$]>()?;
108        let mut name: Ident = input.parse()?;
109        let multi = if input.peek(Token![+]) {
110          input.parse::<Token![+]>()?;
111          true
112        } else {
113          false
114        };
115
116        let span = name.span();
117        name.set_span(Span::call_site());
118
119        segments.push(SegmentPattern::Binding { span, name, multi });
120
121        if input.peek(Token![::]) {
122          input.parse::<Token![::]>()?;
123        } else {
124          break;
125        }
126        continue;
127      }
128
129      let ident: Ident = input.parse()?;
130
131      let args = if input.peek(Token![<]) {
132        Some(input.parse()?)
133      } else {
134        None
135      };
136
137      let optional = if input.peek(Token![?]) {
138        input.parse::<Token![?]>()?;
139        true
140      } else {
141        false
142      };
143
144      let matcher = SegmentMatcher { ident, args };
145
146      segments.push(if optional {
147        SegmentPattern::Optional(matcher)
148      } else {
149        SegmentPattern::Required(matcher)
150      });
151
152      if input.peek(Token![::]) {
153        input.parse::<Token![::]>()?;
154      } else {
155        break;
156      }
157    }
158
159    Ok(PathPattern::Path {
160      segments,
161      fully_qualified,
162    })
163  }
164}
165
166impl Parse for ArgumentPattern {
167  fn parse(input: ParseStream) -> Result<Self> {
168    input.parse::<Token![<]>()?;
169    let mut args = Vec::new();
170
171    loop {
172      if input.peek(Token![_]) {
173        input.parse::<Token![_]>()?;
174        args.push(GenericArgumentPattern::Wildcard);
175      } else if input.peek(Token![$]) {
176        input.parse::<Token![$]>()?;
177        let mut name: Ident = input.parse()?;
178        let span = name.span();
179        name.set_span(Span::call_site());
180        args.push(GenericArgumentPattern::Binding(span, name));
181      } else {
182        let ident: Ident = input.parse()?;
183
184        let matcher = SegmentMatcher {
185          ident,
186          args: if input.peek(Token![<]) {
187            Some(input.parse()?)
188          } else {
189            None
190          },
191        };
192        args.push(GenericArgumentPattern::Argument(Box::new(matcher)));
193      }
194
195      if input.peek(Token![,]) {
196        input.parse::<Token![,]>()?;
197      } else {
198        break;
199      }
200    }
201
202    input.parse::<Token![>]>()?;
203    Ok(ArgumentPattern::AngleBracketed(args))
204  }
205}
206
207struct PathMatcher {
208  path_expr: Expr,
209  fully_qualified_check: proc_macro2::TokenStream,
210  length_check: proc_macro2::TokenStream,
211  segment_checks: Vec<proc_macro2::TokenStream>,
212  binding_names: Vec<(Span, Ident)>,
213}
214
215fn generate_path_matcher(
216  path_expr: &Expr,
217  patterns: &Punctuated<PathPattern, Token![|]>,
218) -> Vec<PathMatcher> {
219  let mut out = Vec::new();
220
221  for pattern in patterns {
222    match pattern {
223      PathPattern::Wildcard => {}
224      PathPattern::Path {
225        segments,
226        fully_qualified,
227      } => {
228        let mut binding_names = Vec::new();
229        for seg in segments {
230          match seg {
231            SegmentPattern::Binding { span, name, .. } => {
232              binding_names.push((span.clone(), name.clone()));
233            }
234            SegmentPattern::Required(matcher)
235            | SegmentPattern::Optional(matcher) => {
236              fn handle_segment_matcher(
237                binding_names: &mut Vec<(Span, Ident)>,
238                matcher: &SegmentMatcher,
239              ) {
240                if let Some(ArgumentPattern::AngleBracketed(args)) =
241                  &matcher.args
242                {
243                  for arg in args {
244                    match arg {
245                      GenericArgumentPattern::Binding(span, name) => {
246                        binding_names.push((span.clone(), name.clone()));
247                      }
248                      GenericArgumentPattern::Argument(arg) => {
249                        handle_segment_matcher(binding_names, &*arg);
250                      }
251                      GenericArgumentPattern::Wildcard => {}
252                    }
253                  }
254                }
255              }
256
257              handle_segment_matcher(&mut binding_names, matcher);
258            }
259          }
260        }
261
262        let mut required_segments = Vec::new();
263        let mut optional_segments = Vec::new();
264        let mut has_multi_binding = false;
265
266        for seg in segments {
267          match seg {
268            SegmentPattern::Required(matcher) => {
269              required_segments.push(matcher)
270            }
271            SegmentPattern::Optional(matcher) => {
272              optional_segments.push(matcher)
273            }
274            SegmentPattern::Binding { multi, .. } => {
275              if *multi {
276                has_multi_binding = true;
277              }
278            }
279          }
280        }
281
282        let min_len = segments
283          .iter()
284          .filter(|s| {
285            matches!(
286              s,
287              SegmentPattern::Required(_)
288                | SegmentPattern::Binding { multi: false, .. }
289            )
290          })
291          .count();
292
293        let max_len = if has_multi_binding {
294          None
295        } else {
296          Some(segments.len())
297        };
298
299        let mut segment_checks = Vec::new();
300
301        for (seg_idx, seg) in segments.iter().enumerate() {
302          match seg {
303            SegmentPattern::Binding { name, multi, .. } => {
304              if *multi {
305                let required_after = segments[seg_idx + 1..]
306                  .iter()
307                  .filter(|s| {
308                    matches!(
309                      s,
310                      SegmentPattern::Required(_)
311                        | SegmentPattern::Binding { multi: false, .. }
312                    )
313                  })
314                  .count();
315
316                let check = quote! {
317                  let __end_idx = __segments.len() - #required_after;
318                  if __idx > __end_idx {
319                    break false;
320                  }
321                  #name = Some(__segments.iter().skip(__idx).take(__end_idx - __idx).cloned().collect::<syn::punctuated::Punctuated<_, syn::Token![::]>>());
322                  __idx = __end_idx;
323                };
324                segment_checks.push(check);
325              } else {
326                let check = quote! {
327                  if __idx >= __segments.len() {
328                    break false;
329                  }
330                  #name = Some(&__segments[__idx]);
331                  __idx += 1;
332                };
333                segment_checks.push(check);
334              }
335              continue;
336            }
337            _ => {}
338          }
339
340          let (matcher, is_optional) = match seg {
341            SegmentPattern::Required(m) => (m, false),
342            SegmentPattern::Optional(m) => (m, true),
343            SegmentPattern::Binding { .. } => unreachable!(),
344          };
345
346          fn handle_segment_matcher(
347            matcher: &SegmentMatcher,
348            is_optional: bool,
349          ) -> proc_macro2::TokenStream {
350            let seg_ident = &matcher.ident;
351            let seg_ident_str = seg_ident.to_string();
352
353            let name_check = quote! {
354                __seg.ident == #seg_ident_str
355            };
356
357            let args_check = if let Some(args) = &matcher.args {
358              match args {
359                ArgumentPattern::AngleBracketed(arg_patterns) => {
360                  let mut arg_checks = Vec::new();
361                  for (arg_idx, arg_pattern) in arg_patterns.iter().enumerate()
362                  {
363                    match arg_pattern {
364                      GenericArgumentPattern::Wildcard => {}
365                      GenericArgumentPattern::Argument(segment_matcher) => {
366                        fn generate_nested_arg_check(
367                          segment_matcher: &SegmentMatcher,
368                          arg_var: &str,
369                          depth: usize,
370                        ) -> proc_macro2::TokenStream {
371                          let arg_ident = &segment_matcher.ident;
372                          let arg_ident_str = arg_ident.to_string();
373                          let arg_var_ident =
374                            Ident::new(arg_var, Span::call_site());
375
376                          if let Some(ArgumentPattern::AngleBracketed(
377                            nested_arg_patterns,
378                          )) = &segment_matcher.args
379                          {
380                            let mut nested_arg_checks = Vec::new();
381                            for (nested_arg_idx, nested_arg_pattern) in
382                              nested_arg_patterns.iter().enumerate()
383                            {
384                              match nested_arg_pattern {
385                                GenericArgumentPattern::Binding(
386                                  _span,
387                                  name,
388                                ) => {
389                                  let nested_arg_var =
390                                    format!("__nested_arg_{}", depth);
391                                  let nested_arg_var_ident = Ident::new(
392                                    &nested_arg_var,
393                                    Span::call_site(),
394                                  );
395                                  nested_arg_checks.push(quote! {
396                                    if let Some(#nested_arg_var_ident) = __nested_args.get(#nested_arg_idx) {
397                                      #name = Some(#nested_arg_var_ident);
398                                    } else {
399                                      break false;
400                                    }
401                                  });
402                                }
403                                GenericArgumentPattern::Wildcard => {}
404                                GenericArgumentPattern::Argument(
405                                  inner_segment_matcher,
406                                ) => {
407                                  let nested_arg_var =
408                                    format!("__nested_arg_{}", depth);
409                                  let inner_check = generate_nested_arg_check(
410                                    inner_segment_matcher,
411                                    &nested_arg_var,
412                                    depth + 1,
413                                  );
414                                  let nested_arg_var_ident = Ident::new(
415                                    &nested_arg_var,
416                                    Span::call_site(),
417                                  );
418                                  nested_arg_checks.push(quote! {
419                                    if let Some(#nested_arg_var_ident) = __nested_args.get(#nested_arg_idx) {
420                                      #inner_check
421                                    } else {
422                                      break false;
423                                    }
424                                  });
425                                }
426                              }
427                            }
428
429                            let nested_arg_count = nested_arg_patterns.len();
430                            quote! {
431                              if
432                                let syn::GenericArgument::Type(syn::Type::Path(__nested_type_path)) = #arg_var_ident
433                                  && let Some(__nested_seg) = __nested_type_path.path.segments.last()
434                                  && __nested_seg.ident == #arg_ident_str
435                                  && let syn::PathArguments::AngleBracketed(__nested_angle_args) = &__nested_seg.arguments
436                              {
437                                let __nested_args = &__nested_angle_args.args;
438                                if __nested_args.len() != #nested_arg_count {
439                                  break false;
440                                }
441                                #(#nested_arg_checks)*
442                              } else {
443                                break false;
444                              }
445                            }
446                          } else {
447                            quote! {
448                              if
449                                let syn::GenericArgument::Type(syn::Type::Path(__nested_type_path)) = #arg_var_ident
450                                  && let Some(__nested_seg) = __nested_type_path.path.segments.last()
451                                  && __nested_seg.ident == #arg_ident_str
452                              {
453                                //
454                              } else {
455                                break false;
456                              }
457                            }
458                          }
459                        }
460
461                        let nested_check = generate_nested_arg_check(
462                          segment_matcher,
463                          "__arg",
464                          0,
465                        );
466
467                        arg_checks.push(quote! {
468                          if let Some(__arg) = __args.get(#arg_idx) {
469                            #nested_check
470                          } else {
471                            break false;
472                          }
473                        });
474                      }
475                      GenericArgumentPattern::Binding(_span, name) => {
476                        arg_checks.push(quote! {
477                          if let Some(__arg) = __args.get(#arg_idx) {
478                            #name = Some(__arg);
479                          } else {
480                            break false;
481                          }
482                        });
483                      }
484                    }
485                  }
486
487                  let arg_count = arg_patterns.len();
488                  quote! {
489                    if let syn::PathArguments::AngleBracketed(__angle_args) = &__seg.arguments {
490                      let __args = &__angle_args.args;
491                      if __args.len() != #arg_count {
492                        break false;
493                      }
494                      #(#arg_checks)*
495                    } else {
496                      break false;
497                    }
498                  }
499                }
500              }
501            } else {
502              quote!()
503            };
504
505            let check = if is_optional {
506              quote! {
507                if __idx < __segments.len() {
508                  let __seg = &__segments[__idx];
509                  if #name_check {
510                    #args_check
511                    __idx += 1;
512                  }
513                }
514              }
515            } else {
516              quote! {
517                if __idx >= __segments.len() {
518                  break false;
519                }
520                let __seg = &__segments[__idx];
521                if !(#name_check) {
522                  break false;
523                }
524                #args_check
525                __idx += 1;
526              }
527            };
528
529            check
530          }
531
532          segment_checks.push(handle_segment_matcher(matcher, is_optional));
533        }
534
535        let fully_qualified_check = if *fully_qualified {
536          quote! {
537            if !__path.leading_colon.is_some() {
538              __matched = false;
539            }
540          }
541        } else {
542          quote!()
543        };
544
545        let length_check = match max_len {
546          Some(max) if min_len == max => {
547            quote! {
548              if __segments.len() != #min_len {
549                __matched = false;
550              }
551            }
552          }
553          Some(max) => {
554            quote! {
555              if __segments.len() < #min_len || __segments.len() > #max {
556                __matched = false;
557              }
558            }
559          }
560          None => {
561            quote! {
562              if __segments.len() < #min_len {
563                __matched = false;
564              }
565            }
566          }
567        };
568
569        out.push(PathMatcher {
570          path_expr: path_expr.clone(),
571          fully_qualified_check,
572          length_check,
573          segment_checks,
574          binding_names,
575        })
576      }
577    }
578  }
579
580  out
581}
582
583#[proc_macro]
584pub fn path_match(input: TokenStream) -> TokenStream {
585  let PathMatchArms { path, arms } = parse_macro_input!(input as PathMatchArms);
586
587  let wildcard_arm = arms.last().filter(|arm| {
588    arm
589      .patterns
590      .first()
591      .is_some_and(|pattern| matches!(pattern, PathPattern::Wildcard))
592  });
593
594  if wildcard_arm.is_none() {
595    return syn::Error::new(
596      Span::call_site(),
597      "path_match! requires a wildcard arm `_ => ...` as the last arm",
598    )
599    .to_compile_error()
600    .into();
601  }
602
603  let wildcard_body = &wildcard_arm.unwrap().body;
604  let non_wildcard_arms = &arms[..arms.len() - 1];
605
606  for arm in non_wildcard_arms {
607    if arm
608      .patterns
609      .iter()
610      .any(|pattern| matches!(pattern, PathPattern::Wildcard))
611    {
612      return syn::Error::new(
613        Span::call_site(),
614        "wildcard pattern `_` must be the last arm",
615      )
616      .to_compile_error()
617      .into();
618    }
619  }
620
621  let mut match_checks = Vec::new();
622
623  for arm in non_wildcard_arms {
624    let path_matchers = generate_path_matcher(&path, &arm.patterns);
625    for PathMatcher {
626      path_expr,
627      fully_qualified_check,
628      length_check,
629      segment_checks,
630      binding_names,
631    } in path_matchers
632    {
633      match_checks.push((
634        path_expr,
635        fully_qualified_check,
636        length_check,
637        segment_checks,
638        binding_names,
639        &arm.body,
640      ));
641    }
642  }
643
644  let arms_code = match_checks.into_iter().map(
645    |(path_expr, fq_check, len_check, seg_checks, binding_names, body)| {
646      let spanless_binding_names = binding_names
647        .iter()
648        .map(|(_, name)| name.clone())
649        .collect::<Vec<_>>();
650      let binding_extractions =
651        binding_names.into_iter().map(|(span, name)| {
652          let mut name_in_some = name.clone();
653          name_in_some.set_span(span);
654
655          quote! {
656            let #name_in_some = #name.unwrap();
657          }
658        });
659
660      quote! {
661        {
662          let __path = #path_expr;
663          let __segments = &__path.segments;
664          let mut __idx = 0;
665          let mut __matched = true;
666
667          #(let mut #spanless_binding_names: Option<_> = None;)*
668
669          #fq_check
670
671          if __matched {
672            #len_check
673          }
674
675          if __matched {
676            __matched = loop {
677              #(#seg_checks)*
678              break __matched;
679            }
680          }
681
682          if __matched && __idx != __segments.len() {
683            __matched = false;
684          }
685
686          if __matched {
687            #(#binding_extractions)*
688            return #body;
689          }
690        }
691      }
692    },
693  );
694
695  let expanded = quote! {
696    (|| {
697      #(#arms_code)*
698
699      #wildcard_body
700    })()
701  };
702
703  TokenStream::from(expanded)
704}