use proc_macro::TokenStream as TokenStream1;
use proc_macro2::{Delimiter, Group, Ident, Punct, Span, TokenStream, TokenTree};
use quote::{format_ident, quote, quote_spanned};
use std::{
collections::hash_map::RandomState,
convert::identity,
hash::{BuildHasher, Hasher},
};
#[proc_macro_attribute]
pub fn macro_vis(attr: TokenStream1, item: TokenStream1) -> TokenStream1 {
macro_vis_inner(attr.into(), item.into())
.unwrap_or_else(identity)
.into()
}
fn macro_vis_inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream, TokenStream> {
let vis = parse_vis(attr)?;
let Macro {
attrs,
macro_rules,
bang,
name,
arms,
rules,
semi,
} = parse_macro(item)?;
let real_name = format_ident!("__{}_{}", name, RandomState::new().build_hasher().finish());
Ok(match vis {
Vis::Local { pub_token, scope } => {
quote! {
#attrs
#macro_rules #bang #real_name #arms #semi
#pub_token #scope use #real_name as #name;
}
}
Vis::Public { pub_token } => {
let macro_token = Ident::new("macro", macro_rules.span());
let mut arms_2_0 = Group::new(Delimiter::Brace, macro_2_0_arms(&rules));
arms_2_0.set_span(arms.span());
let display_name = format_ident!("{}ǃ", name);
quote! {
#[cfg(not(doc_nightly))]
#[doc(hidden)]
#[macro_export]
#macro_rules #bang #real_name #arms #semi
#[cfg(not(doc_nightly))]
#[doc(hidden)]
#pub_token use #real_name as #name;
#[cfg(all(doc, not(doc_nightly)))]
#[doc = "<sup>**\\[macro\\]**</sup>"]
#attrs
#pub_token fn #display_name() {}
#[cfg(doc_nightly)]
#[rustc_macro_transparency = "semitransparent"]
#attrs
#pub_token #macro_token #name #arms_2_0
}
}
})
}
#[derive(Debug)]
enum Vis {
Public {
pub_token: Ident,
},
Local {
pub_token: Option<Ident>,
scope: Option<Group>,
},
}
#[derive(Debug)]
struct Macro {
attrs: TokenStream,
macro_rules: Ident,
bang: Punct,
name: Ident,
arms: Group,
rules: Vec<MacroRule>,
semi: Option<Punct>,
}
#[derive(Debug)]
struct MacroRule {
matcher: Group,
equals: Punct,
greater_than: Punct,
transcriber: Group,
semi: Option<Punct>,
}
fn parse_vis(vis: TokenStream) -> Result<Vis, TokenStream> {
let mut vis = vis.into_iter();
let pub_token = match vis.next() {
Some(TokenTree::Ident(pub_token)) if pub_token == "pub" => pub_token,
Some(token) => {
return Err(error(token.span(), "expected visibility"));
}
None => {
return Ok(Vis::Local {
pub_token: None,
scope: None,
})
}
};
let scope = match vis.next() {
Some(TokenTree::Group(scope)) if scope.delimiter() == Delimiter::Parenthesis => scope,
Some(token) => {
return Err(error(token.span(), "expected parenthesis"));
}
None => return Ok(Vis::Public { pub_token }),
};
if let Some(trailing) = vis.next() {
return Err(error(trailing.span(), "trailing tokens"));
}
Ok(Vis::Local {
pub_token: Some(pub_token),
scope: Some(scope),
})
}
fn parse_macro(item: TokenStream) -> Result<Macro, TokenStream> {
let mut item = item.into_iter();
let mut attrs = TokenStream::new();
let macro_rules = loop {
match item.next() {
Some(TokenTree::Punct(punct)) if punct.as_char() == '#' => {
let next = item.next().expect("unexpected EOF in attribute");
if !matches!(&next, TokenTree::Group(group) if group.delimiter() == Delimiter::Bracket)
{
unreachable!("attribute without square brackets");
}
attrs.extend([TokenTree::Punct(punct), next]);
}
Some(TokenTree::Ident(macro_rules)) if macro_rules == "macro_rules" => {
break macro_rules;
}
token => {
return Err(error(opt_span(&token), "expected macro_rules! macro"));
}
}
};
let bang = match item.next() {
Some(TokenTree::Punct(p)) if p.as_char() == '!' => p,
token => {
return Err(error(opt_span(&token), "expected exclamation mark"));
}
};
let name = match item.next() {
Some(TokenTree::Ident(ident)) => ident,
token => {
return Err(error(opt_span(&token), "expected identifier"));
}
};
let arms = match item.next() {
Some(TokenTree::Group(group)) => group,
token => {
return Err(error(opt_span(&token), "expected macro arms"));
}
};
let mut rule_tokens = arms.stream().into_iter();
let mut rules = Vec::new();
loop {
let matcher = match rule_tokens.next() {
Some(TokenTree::Group(group)) => group,
Some(token) => {
return Err(error(token.span(), "expected macro matcher"));
}
None if rules.is_empty() => {
return Err(error(arms.span(), "expected macro rules"));
}
None => break,
};
let equals = match rule_tokens.next() {
Some(TokenTree::Punct(equals)) if equals.as_char() == '=' => equals,
token => return Err(error(opt_span(&token), "expected =>")),
};
let greater_than = match rule_tokens.next() {
Some(TokenTree::Punct(greater_than)) if greater_than.as_char() == '>' => greater_than,
_ => return Err(error(equals.span(), "expected =>")),
};
let transcriber = match rule_tokens.next() {
Some(TokenTree::Group(group)) => group,
token => return Err(error(opt_span(&token), "expected macro transcriber")),
};
let mut rule = MacroRule {
matcher,
equals,
greater_than,
transcriber,
semi: None,
};
match rule_tokens.next() {
Some(TokenTree::Punct(semi)) if semi.as_char() == ';' => {
rule.semi = Some(semi);
rules.push(rule);
}
None => {
rules.push(rule);
break;
}
Some(token) => {
return Err(error(token.span(), "expected semicolon"));
}
}
}
let semi = if arms.delimiter() != Delimiter::Brace {
Some(match item.next() {
Some(TokenTree::Punct(semi)) if semi.as_char() == ';' => semi,
_ => unreachable!("no semicolon after () or []-delimited macro"),
})
} else {
None
};
if item.next().is_some() {
unreachable!("trailing tokens after macro_rules! macro");
}
Ok(Macro {
attrs,
macro_rules,
bang,
name,
arms,
rules,
semi,
})
}
fn opt_span(token: &Option<TokenTree>) -> Span {
token
.as_ref()
.map(|token| token.span())
.unwrap_or_else(Span::call_site)
}
fn macro_2_0_arms(rules: &[MacroRule]) -> TokenStream {
rules
.iter()
.map(
|MacroRule {
matcher,
equals,
greater_than,
transcriber,
semi,
}| {
let comma = semi.as_ref().map(|semi| {
let mut comma = Punct::new(',', semi.spacing());
comma.set_span(semi.span());
comma
});
quote!(#matcher #equals #greater_than #transcriber #comma)
},
)
.collect()
}
fn error(span: Span, msg: &str) -> TokenStream {
quote_spanned!(span=> ::core::compile_error!(#msg))
}
#[cfg(test)]
mod tests {
use crate::{parse_macro, parse_vis, Macro, Vis};
use proc_macro2::TokenStream;
use quote::quote;
#[test]
fn vis_parse() {
assert!(matches!(
parse_vis(TokenStream::new()),
Ok(Vis::Local {
pub_token: None,
scope: None
})
));
assert!(matches!(
parse_vis(quote!(pub)),
Ok(Vis::Public { pub_token }) if pub_token == "pub"
));
assert!(matches!(
parse_vis(quote!(pub(crate))),
Ok(Vis::Local { pub_token: Some(pub_token), scope: Some(scope) })
if pub_token == "pub" && scope.to_string() == quote!((crate)).to_string()
));
assert!(matches!(
parse_vis(quote!(pub(foo bar))),
Ok(Vis::Local { pub_token: Some(pub_token), scope: Some(scope) })
if pub_token == "pub" && scope.to_string() == quote!((foo bar)).to_string()
));
}
#[test]
fn vis_error() {
macro_rules! assert_err {
(($($input:tt)*) -> $e:literal) => {
assert_eq!(
parse_vis(quote!($($input)*)).unwrap_err().to_string(),
quote!(::core::compile_error!($e)).to_string(),
);
};
}
assert_err!((priv) -> "expected visibility");
assert_err!((pub[crate]) -> "expected parenthesis");
assert_err!((pub() trailing) -> "trailing tokens");
}
#[test]
fn macro_parse() {
assert!(matches!(
parse_macro(quote!(macro_rules! foo { (m) => { t } })),
Ok(Macro { attrs, macro_rules, bang, name, arms, rules, semi: None })
if attrs.is_empty()
&& macro_rules == "macro_rules"
&& bang.as_char() == '!'
&& name == "foo"
&& arms.to_string() == quote!({ (m) => { t } }).to_string()
&& rules.len() == 1
&& rules[0].matcher.to_string() == quote!((m)).to_string()
&& rules[0].equals.as_char() == '='
&& rules[0].greater_than.as_char() == '>'
&& rules[0].transcriber.to_string() == quote!({ t }).to_string()
&& rules[0].semi.is_none()
));
assert!(matches!(
parse_macro(quote! {
#[attr1]
#[attr2 = "foo"]
macro_rules! foo [
{} => ();
[$] => [[]];
];
}),
Ok(Macro { attrs, arms, rules, semi: Some(semi), .. })
if attrs.to_string() == quote!(#[attr1] #[attr2 = "foo"]).to_string()
&& arms.to_string() == quote!([{} => (); [$] => [[]];]).to_string()
&& semi.as_char() == ';'
&& rules.len() == 2
&& rules[0].matcher.to_string() == quote!({}).to_string()
&& rules[0].transcriber.to_string() == quote!(()).to_string()
&& rules[0].semi.as_ref().map_or(false, |semi| semi.as_char() == ';')
&& rules[1].matcher.to_string() == quote!([$]).to_string()
&& rules[1].transcriber.to_string() == quote!([[]]).to_string()
&& rules[1].semi.as_ref().map_or(false, |semi| semi.as_char() == ';')
));
}
#[test]
fn macro_error() {
macro_rules! assert_err {
(($($input:tt)*) -> $e:literal) => {
assert_eq!(
parse_macro(quote!($($input)*)).unwrap_err().to_string(),
quote!(::core::compile_error!($e)).to_string(),
);
}
}
assert_err!(() -> "expected macro_rules! macro");
assert_err!((const _: () = {};) -> "expected macro_rules! macro");
assert_err!((macro_rules x {}) -> "expected exclamation mark");
assert_err!((macro_rules! { () => {} }) -> "expected identifier");
assert_err!((macro_rules! foo) -> "expected macro arms");
assert_err!((macro_rules! foo { }) -> "expected macro rules");
assert_err!((macro_rules! foo { # }) -> "expected macro matcher");
assert_err!((macro_rules! foo { () }) -> "expected =>");
assert_err!((macro_rules! foo { () = }) -> "expected =>");
assert_err!((macro_rules! foo { () => }) -> "expected macro transcriber");
assert_err!((macro_rules! foo { () => {} () => {} }) -> "expected semicolon");
}
}