#![recursion_limit = "128"]
extern crate proc_macro;
mod url_parse;
use proc_macro2::Span;
use quote::quote;
use std::collections::HashSet;
use syn::spanned::Spanned;
macro_rules! error {
($span:expr, $msg:expr) => {
return syn::Error::new($span, $msg).to_compile_error().into();
};
}
#[proc_macro_derive(Routes, attributes(route))]
pub fn routes(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = syn::parse_macro_input!(input as syn::DeriveInput);
let syn::DeriveInput {
ident,
generics,
data,
..
} = input;
let route_enum = match data {
syn::Data::Enum(e) => e,
syn::Data::Struct(syn::DataStruct { struct_token, .. }) => {
error!(struct_token.span, "`Routes` macro only works on enums")
}
syn::Data::Union(syn::DataUnion { union_token, .. }) => {
error!(union_token.span, "`Routes` macro only works on enums")
}
};
if generics.type_params().next().is_some()
|| generics.lifetimes().next().is_some()
|| generics.const_params().next().is_some()
{
error!(
ident.span(),
"cannot derive Routes for an enum with generic parameters"
)
}
let mut variant_data: Vec<(syn::Variant, Vec<url_parse::Token>)> = Vec::new();
for variant in route_enum.variants.iter() {
let route = match variant
.attrs
.iter()
.find(|attr| attr.path.is_ident("route"))
{
Some(route) => route,
None => error!(
variant.ident.span(),
format!("missing a `#[route \"..\"]` attribute")
),
};
let url_str = match route.parse_meta().unwrap() {
syn::Meta::NameValue(syn::MetaNameValue {
lit: syn::Lit::Str(url),
..
}) => url,
_ => error!(route.span(), "not in the form `#[route = \"..\"]`"),
};
let url_str_lit = url_str.value();
let url = match url_parse::parse_url(&url_str_lit).collect::<Result<Vec<_>, _>>() {
Ok(url) => url,
Err(e) => error!(url_str.span(), format!("cannot parse url: {}", e)),
};
for (idx, _) in url
.iter()
.enumerate()
.filter(|(_, val)| val.is_placeholder())
{
let followed_by_slash = match url.get(idx + 1) {
Some(tok) if tok.begins_with_forward_slash() => true,
Some(_) => false,
None => true,
};
if !followed_by_slash {
error!(route.span(), "placeholders must be followed by a '/'");
}
}
let mut url_placeholders = url
.iter()
.filter_map(|tok| tok.clone().to_placeholder().map(|v| v.to_string()))
.collect::<HashSet<_>>();
match &variant.fields {
syn::Fields::Named(fields) => {
for field in fields.named.iter() {
let field_name = field.ident.clone().unwrap();
if !url_placeholders.remove(&field_name.to_string()) {
error!(field.ident.span(), "field missing in url");
}
}
if !url_placeholders.is_empty() {
error!(
url_str.span(),
format!("url parameters {:?} missing from variant", url_placeholders)
)
}
}
syn::Fields::Unnamed(_) => error!(
variant.fields.span(),
"unnamed fields are not yet supported"
),
syn::Fields::Unit => {
if !url_placeholders.is_empty() {
error!(
variant.ident.span(),
"there are fields in the url but not in the enum variant"
)
}
}
}
variant_data.push((variant.clone(), url));
}
let format_route = variant_data
.iter()
.map(|(variant, url)| {
let variant_ident = &variant.ident;
let variant_fields = url
.iter()
.filter_map(|tok| tok.clone().to_placeholder())
.collect::<Vec<_>>();
let format_parts = url
.iter()
.map(|tok| match tok {
url_parse::Token::Literal(lit) => quote! {
f.write_str(#lit)?;
},
url_parse::Token::Placeholder(ident) => quote! {
#[cfg(debug_assertions)]
{
let mut test_val = format!("{}", #ident);
if test_val.contains('/') {
panic!("A parameter of a route contained a '/' \
(this makes parsing impossible and can be a security risk: see\
the documentation for `rooty_derive::Routes`)");
}
}
std::fmt::Display::fmt(#ident, f)?;
},
})
.collect::<Vec<_>>();
quote! {
#ident :: #variant_ident { #(ref #variant_fields),* } => {
#(#format_parts)*
Ok(())
}
}
})
.collect::<Vec<_>>();
let parse_route = variant_data
.iter()
.map(|(variant, url)| {
let variant_ident = &variant.ident;
let variant_fields = url
.iter()
.filter_map(|tok| tok.clone().to_placeholder())
.collect::<Vec<_>>();
let format_parts = url
.iter()
.map(|tok| match tok {
url_parse::Token::Literal(lit) => quote! {
input = rooty::consume_literal(input, #lit)?;
},
url_parse::Token::Placeholder(ident) => quote! {
let (next_input, #ident) = rooty::consume_placeholder(input)?;
input = next_input;
},
})
.collect::<Vec<_>>();
let variant_parser_name =
syn::Ident::new(&format!("__parser_{}", variant_ident), variant_ident.span());
quote! {
#[allow(non_snake_case)]
fn #variant_parser_name(mut input: &str) -> Result<#ident, rooty::NotFound> {
#(#format_parts)*
if input.is_empty() {
Ok(#ident :: #variant_ident { #(#variant_fields),* })
} else {
Err(rooty::NotFound)
}
}
if let Ok(t) = #variant_parser_name(input) {
return Ok(t);
}
}
})
.collect::<Vec<_>>();
let display_name = syn::Ident::new(&format!("_Display_{}", ident), Span::call_site());
let expanded = quote! {
#[repr(transparent)]
pub struct #display_name #generics (#ident #generics);
impl #generics std::fmt::Display for #display_name #generics {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self.0 {
#(#format_route),*
}
}
}
impl #generics rooty::Routes for #ident #generics {
type UrlDisplay = #display_name;
fn url(&self) -> &Self::UrlDisplay {
unsafe { std::mem::transmute(self) }
}
fn parse_url(input: &str) -> Result<Self, rooty::NotFound> {
#(#parse_route)*
Err(rooty::NotFound)
}
}
};
proc_macro::TokenStream::from(expanded)
}