#![warn(rust_2018_idioms)]
#![forbid(unsafe_code)]
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use std::mem::replace;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
spanned::Spanned,
Expr, Ident, LitStr, Pat, PatIdent, Result, Token,
};
macro_rules! abort {
($span:expr, $($format:tt)+) => {{
return Err(syn::Error::new($span, format!($($format)+)).into());
}};
}
macro_rules! bail_if_err {
($expr:expr) => {
match $expr {
Ok(x) => x,
Err(e) => return e.into_compile_error().into(),
}
};
}
struct MacroInput {
in_value: Expr,
pat: Pat,
guard: Option<(Token![if], Box<Expr>)>,
wrap_output_fn: proc_macro2::TokenStream,
fallback_arm: proc_macro2::TokenStream,
}
impl Parse for MacroInput {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let in_value = input.parse()?;
input.parse::<Token![,]>()?;
let pat = Pat::parse_multi_with_leading_vert(input)?;
let guard = if input.peek(Token![if]) {
let if_token: Token![if] = input.parse()?;
let guard: Expr = input.parse()?;
Some((if_token, Box::new(guard)))
} else {
None
};
input.parse::<Token![,]>()?;
let wrap_output_fn = input.parse::<proc_macro2::Group>()?.stream();
let fallback_arm = input.parse::<proc_macro2::Group>()?.stream();
Ok(Self {
pat,
in_value,
guard,
wrap_output_fn,
fallback_arm,
})
}
}
#[proc_macro]
pub fn implicit_try_match_inner(input: TokenStream) -> TokenStream {
let MacroInput {
mut pat,
in_value,
guard,
wrap_output_fn,
fallback_arm,
} = parse_macro_input!(input);
let mut visitor = ForceVariableBindings {
introduced_subpat: false,
};
bail_if_err!(pat_visit_mut::Visit::visit_pat(&mut visitor, &mut pat));
let ForceVariableBindings { introduced_subpat } = visitor;
let mut visitor = CollectVariableBindings {
bindings: Vec::new(),
};
pat_visit::Visit::visit_pat(&mut visitor, &pat)
.expect("errors should have already been reported");
let CollectVariableBindings {
bindings: mut idents,
} = visitor;
idents.sort_unstable_by_key(|i| &i.ident);
idents.dedup_by_key(|i| &i.ident);
let success_output =
if let Some(tokens) = bail_if_err!(check_tuple_captures(&idents)) {
pat_visit_mut::Visit::visit_pat(&mut SuppressClippyIdentWarnings, &mut pat)
.expect("errors should have already been reported");
tokens
} else {
match &idents[..] {
[] => {
quote! {()}
}
[single] => {
let ident = &single.ident;
quote! {#ident}
}
multi => {
let idents: Vec<_> = multi.iter().map(|p| p.ident.clone()).collect();
let ty_params: Vec<_> = (0..idents.len())
.map(|i| Ident::new(&format!("_M_{}", i), Span::call_site()))
.collect();
let ident_strs: Vec<_> = idents
.iter()
.map(|i| LitStr::new(&format!("{}", i), i.span()))
.collect();
let ty_name = Ident::new("__Match", Span::call_site());
let debug_impl = {
let fmt = quote! { ::core::fmt };
quote! {
impl<#(#ty_params),*> #fmt::Debug for #ty_name<#(#ty_params),*>
where
#(#ty_params: #fmt::Debug),*
{
fn fmt(&self, f: &mut #fmt::Formatter<'_>) -> #fmt::Result {
f.debug_struct("<anonymous>")
#(.field(#ident_strs, &self.#idents))*
.finish()
}
}
}
};
quote! {{
#[derive(Clone, Copy)]
struct #ty_name<#(#ty_params),*> {
#(
#idents: #ty_params
),*
}
#debug_impl
#ty_name { #(#idents),* }
}}
}
}
};
let guard = guard.map(|(t0, t1)| quote! { #t0 #t1 });
let allow_redundant_pattern = if introduced_subpat {
quote! { #[allow(clippy::redundant_pattern)] }
} else {
quote! {}
};
let output = quote! {
match #in_value {
#allow_redundant_pattern
#pat #guard => #wrap_output_fn(#success_output),
#fallback_arm
}
};
TokenStream::from(output)
}
struct ForceVariableBindings {
introduced_subpat: bool,
}
impl pat_visit_mut::Visit for ForceVariableBindings {
type Error = syn::Error;
fn visit_pat_ident(&mut self, pat: &mut PatIdent) -> Result<()> {
if pat.subpat.is_some() {
Ok(())
} else {
match is_likely_variable_binding(&pat.ident) {
Some(true) => {
pat.subpat = Some((
syn::token::At {
spans: [Span::call_site()],
},
Box::new(Pat::Wild(syn::PatWild {
attrs: Vec::new(),
underscore_token: syn::token::Underscore {
spans: [Span::call_site()],
},
})),
));
self.introduced_subpat = true;
Ok(())
}
Some(false) => Ok(()),
None => abort!(pat.span(), "ambiguous identifier pattern"),
}
}
}
fn visit_field_pat(&mut self, field: &mut syn::FieldPat) -> Result<()> {
pat_visit_mut::visit_field_pat(self, field)?;
if field.colon_token.is_none() {
field.colon_token = Some(syn::token::Colon {
spans: [Span::call_site()],
});
}
Ok(())
}
}
struct CollectVariableBindings<'a> {
bindings: Vec<&'a PatIdent>,
}
impl<'a> pat_visit::Visit<'a> for CollectVariableBindings<'a> {
type Error = syn::Error;
fn visit_pat_ident(&mut self, pat: &'a PatIdent) -> Result<()> {
if pat.subpat.is_some() {
self.bindings.push(pat);
}
pat_visit::visit_pat_ident(self, pat)
}
}
struct SuppressClippyIdentWarnings;
impl pat_visit_mut::Visit for SuppressClippyIdentWarnings {
type Error = syn::Error;
fn visit_pat_ident(&mut self, pat: &mut PatIdent) -> Result<()> {
if let Some((_, subpat)) = &mut pat.subpat {
let old_subpat = replace(
&mut **subpat,
Pat::Verbatim(proc_macro2::TokenStream::default()),
);
**subpat = Pat::Paren(syn::PatParen {
attrs: Vec::new(),
paren_token: syn::token::Paren::default(),
pat: Box::new(old_subpat),
});
}
Ok(())
}
}
#[allow(clippy::manual_strip)] fn check_tuple_captures(idents: &[&PatIdent]) -> Result<Option<proc_macro2::TokenStream>> {
let mut some_tuple_cap = None;
let mut some_non_tuple_cap = None;
let mut indices: Vec<(u128, &Ident)> = idents
.iter()
.map(|i| {
let index = {
let text = i.ident.to_string();
if text.starts_with('_') {
text[1..].parse().ok()
} else {
None
}
};
if let Some(index) = index {
some_tuple_cap = Some(*i);
(index, &i.ident)
} else {
some_non_tuple_cap = Some(*i);
(0 , &i.ident)
}
})
.collect();
if let (Some(i1), Some(i2)) = (some_tuple_cap, some_non_tuple_cap) {
abort!(
i1.span(),
"can't have both of a tuple binding `{}` and a non-tuple binding \
`{}` at the same time",
i1.ident,
i2.ident
);
}
if some_tuple_cap.is_some() {
indices.sort_unstable_by_key(|e| e.0);
for (&(ind, ref ident), i) in indices.iter().zip(0u128..) {
use std::cmp::Ordering::*;
match ind.cmp(&i) {
Equal => {}
Less => {
assert_eq!(ind, i - 1);
abort!(
ident.span(),
"duplicate tuple binding: `_{}` is defined for multiple times",
ind
);
}
Greater if ind - 1 == i => {
abort!(
ident.span(),
"non-contiguous tuple binding: `_{}` is missing",
ind - 1
);
}
Greater => {
abort!(
ident.span(),
"non-contiguous tuple binding: `_{}` .. `_{}` are missing",
i,
ind - 1
);
}
}
}
let idents: Vec<_> = indices.into_iter().map(|p| p.1).collect();
Ok(Some(quote! { ( #(#idents),* ) }))
} else {
Ok(None)
}
}
macro_rules! define_visit_pat_trait {
(
#[mut($($mut:tt)*)]
#[iter($iter:ident)]
#[out_lifetime($($out_lifetime:tt)*)]
fn visit_pat(..);
..
) => {
use syn::{FieldPat, Pat, PatIdent, spanned::Spanned};
pub trait Visit<$($out_lifetime)*> {
type Error: From<syn::Error>;
fn visit_pat(
&mut self,
pat: & $($out_lifetime)* $($mut)* Pat,
) -> Result<(), Self::Error> {
visit_pat(self, pat)
}
fn visit_pat_ident(
&mut self,
pat: & $($out_lifetime)* $($mut)* PatIdent,
) -> Result<(), Self::Error> {
visit_pat_ident(self, pat)
}
fn visit_field_pat(
&mut self,
field: & $($out_lifetime)* $($mut)* FieldPat,
) -> Result<(), Self::Error> {
visit_field_pat(self, field)
}
}
pub fn visit_pat<'a, V: ?Sized + Visit< $($out_lifetime)* >>(
visit: &mut V,
pat: &'a $($mut)* Pat,
) -> Result<(), V::Error> {
match pat {
Pat::Const(_) => Ok(()),
Pat::Ident(pat) => visit.visit_pat_ident(pat),
Pat::Lit(_) => Ok(()),
Pat::Macro(_) => Ok(()),
Pat::Or(pat) => {
for case in pat.cases.$iter() {
visit.visit_pat(case)?;
}
Ok(())
}
Pat::Paren(pat) => visit.visit_pat(& $($mut)* pat.pat),
Pat::Path(_) => Ok(()),
Pat::Range(_) => Ok(()),
Pat::Reference(pat) => visit.visit_pat(& $($mut)* pat.pat),
Pat::Rest(_) => Ok(()),
Pat::Slice(pat) => {
for elem in pat.elems.$iter() {
visit.visit_pat(elem)?;
}
Ok(())
}
Pat::Struct(pat) => {
for field in pat.fields.$iter() {
visit.visit_field_pat(field)?;
}
Ok(())
}
Pat::Tuple(pat) => {
for elem in pat.elems.$iter() {
visit.visit_pat(elem)?;
}
Ok(())
}
Pat::TupleStruct(pat) => {
for elem in pat.elems.$iter() {
visit.visit_pat(elem)?;
}
Ok(())
}
Pat::Type(pat) => visit.visit_pat(& $($mut)* pat.pat),
Pat::Wild(_) => Ok(()),
_ => abort!(pat.span(), "unsupported pattern"),
}
}
pub fn visit_pat_ident<'a, V: ?Sized + Visit< $($out_lifetime)* >>(
visit: &mut V,
pat: &'a $($mut)* PatIdent,
) -> Result<(), V::Error> {
if let Some((_, subpat)) = & $($mut)* pat.subpat {
visit.visit_pat(subpat)?;
}
Ok(())
}
pub fn visit_field_pat<'a, V: ?Sized + Visit< $($out_lifetime)* >>(
visit: &mut V,
field: &'a $($mut)* FieldPat,
) -> Result<(), V::Error> {
visit.visit_pat(& $($mut)* field.pat)
}
}
}
mod pat_visit {
define_visit_pat_trait! {
#[mut()]
#[iter(iter)]
#[out_lifetime('a)]
fn visit_pat(..);
..
}
}
#[allow(clippy::needless_lifetimes)]
mod pat_visit_mut {
define_visit_pat_trait! {
#[mut(mut)]
#[iter(iter_mut)]
#[out_lifetime()]
fn visit_pat(..);
..
}
}
fn is_likely_variable_binding(ident: &Ident) -> Option<bool> {
struct Scanner {
raw_ident: RawIdent,
ident: Ident,
}
enum RawIdent {
ExpectingR,
ExpectingHash,
Done,
}
enum Ident {
ExpectingAlphanumOrUnderscore,
ExpectingAlphanum,
Done(Option<bool>),
}
use std::fmt::Write;
impl Write for Scanner {
fn write_str(&mut self, s: &str) -> std::fmt::Result {
for b in s.bytes() {
match (&self.raw_ident, b) {
(RawIdent::ExpectingR, b'r') => {
self.raw_ident = RawIdent::ExpectingHash;
}
(RawIdent::ExpectingHash, b'#') => {
self.raw_ident = RawIdent::Done;
self.ident = Ident::ExpectingAlphanumOrUnderscore;
continue;
}
_ => {
self.raw_ident = RawIdent::Done;
}
}
match (&self.ident, b) {
(Ident::Done(_), _) => {}
(Ident::ExpectingAlphanumOrUnderscore | Ident::ExpectingAlphanum, b)
if b.is_ascii_alphanumeric() =>
{
self.ident = Ident::Done(Some(!b.is_ascii_uppercase()));
}
(Ident::ExpectingAlphanumOrUnderscore, b'_') => {
self.ident = Ident::ExpectingAlphanum;
}
_ => {
self.ident = Ident::Done(None);
}
}
}
Ok(())
}
}
let mut scanner = Scanner {
raw_ident: RawIdent::ExpectingR,
ident: Ident::ExpectingAlphanumOrUnderscore,
};
write!(scanner, "{}", ident).unwrap();
match scanner.ident {
Ident::ExpectingAlphanumOrUnderscore => None, Ident::ExpectingAlphanum => None, Ident::Done(output) => output,
}
}