rooty_derive 0.1.4

see the `rooty` crate
Documentation
#![recursion_limit = "128"]
extern crate proc_macro;

mod url_parse;

use proc_macro2::Span;
use quote::quote;
use std::collections::HashSet;
use syn::spanned::Spanned;

macro_rules! error {
    ($span:expr, $msg:expr) => {
        return syn::Error::new($span, $msg).to_compile_error().into();
    };
}

/// Use this proc-macro to auto-derive rooty::Routes for your route enum.
///
/// All the fields of your enum variants must implement `Display` and `FromStr`.
///
/// **NOTE** when `Display`ed, no fields should print any forward slashes '/'. This cannot be
/// checked at compile-time. The route will panic in debug mode when you try to display it, in
/// release mode it will just print but you will not be able to parse it again.
///
/// # Example
///
/// ```
/// use chrono::NaiveDate;
/// # use rooty_derive::Routes;
///
/// #[derive(Debug, Routes)]
/// pub enum MyRoutes {
///     #[route = "/"]
///     Home,
///     #[route = "/about"]
///     About,
///     #[route = "/users/{id}"]
///     User { id: i32 },
///     #[route = "/posts/{date}"]
///     Posts { date: NaiveDate },
///     #[route = "/post/{title}"]
///     Post { title: String },
/// }
/// ```

#[proc_macro_derive(Routes, attributes(route))]
pub fn routes(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = syn::parse_macro_input!(input as syn::DeriveInput);
    let syn::DeriveInput {
        ident,
        generics,
        data,
        ..
    } = input;
    let route_enum = match data {
        syn::Data::Enum(e) => e,
        syn::Data::Struct(syn::DataStruct { struct_token, .. }) => {
            error!(struct_token.span, "`Routes` macro only works on enums")
        }
        syn::Data::Union(syn::DataUnion { union_token, .. }) => {
            error!(union_token.span, "`Routes` macro only works on enums")
        }
    };
    // Check there is no generic parameters for the route enum (TODO think about whether some
    // generics should be supoorted)
    if generics.type_params().next().is_some()
        || generics.lifetimes().next().is_some()
        || generics.const_params().next().is_some()
    {
        error!(
            ident.span(),
            "cannot derive Routes for an enum with generic parameters"
        )
    }
    // use this vec to collect the information we need to impl the to/from string methods
    let mut variant_data: Vec<(syn::Variant, Vec<url_parse::Token>)> = Vec::new();
    for variant in route_enum.variants.iter() {
        let route = match variant
            .attrs
            .iter()
            .find(|attr| attr.path.is_ident("route"))
        {
            Some(route) => route,
            None => error!(
                variant.ident.span(),
                format!("missing a `#[route \"..\"]` attribute")
            ),
        };

        // Parse the path and get the placeholders that need to be filled in
        let url_str = match route.parse_meta().unwrap() {
            syn::Meta::NameValue(syn::MetaNameValue {
                lit: syn::Lit::Str(url),
                ..
            }) => url,
            _ => error!(route.span(), "not in the form `#[route = \"..\"]`"),
        };
        let url_str_lit = url_str.value();
        let url = match url_parse::parse_url(&url_str_lit).collect::<Result<Vec<_>, _>>() {
            Ok(url) => url,
            Err(e) => error!(url_str.span(), format!("cannot parse url: {}", e)),
        };
        // Check that all placeholders are followed by forwards slashes. Without this, parsing
        // isn't possible, as we wouldn't know the length to pass to the parser.
        for (idx, _) in url
            .iter()
            .enumerate()
            .filter(|(_, val)| val.is_placeholder())
        {
            let followed_by_slash = match url.get(idx + 1) {
                Some(tok) if tok.begins_with_forward_slash() => true,
                Some(_) => false,
                None => true,
            };
            if !followed_by_slash {
                error!(route.span(), "placeholders must be followed by a '/'");
            }
        }
        let mut url_placeholders = url
            .iter()
            .filter_map(|tok| tok.clone().to_placeholder().map(|v| v.to_string()))
            .collect::<HashSet<_>>();

        // check that fields in the variant match fields in the url
        match &variant.fields {
            syn::Fields::Named(fields) => {
                for field in fields.named.iter() {
                    // Cannot fail - we already know we have named fields here
                    let field_name = field.ident.clone().unwrap();
                    if !url_placeholders.remove(&field_name.to_string()) {
                        error!(field.ident.span(), "field missing in url");
                    }
                }
                if !url_placeholders.is_empty() {
                    error!(
                        url_str.span(),
                        format!("url parameters {:?} missing from variant", url_placeholders)
                    )
                }
            }
            syn::Fields::Unnamed(_) => error!(
                variant.fields.span(),
                "unnamed fields are not yet supported"
            ),
            syn::Fields::Unit => {
                if !url_placeholders.is_empty() {
                    error!(
                        variant.ident.span(),
                        "there are fields in the url but not in the enum variant"
                    )
                }
            }
        }
        variant_data.push((variant.clone(), url));
    }
    // build string formatting tokens
    let format_route = variant_data
        .iter()
        .map(|(variant, url)| {
            let variant_ident = &variant.ident;
            let variant_fields = url
                .iter()
                .filter_map(|tok| tok.clone().to_placeholder())
                .collect::<Vec<_>>();
            let format_parts = url
                .iter()
                .map(|tok| match tok {
                    url_parse::Token::Literal(lit) => quote! {
                        f.write_str(#lit)?;
                    },
                    url_parse::Token::Placeholder(ident) => quote! {
                        #[cfg(debug_assertions)]
                        {
                            let mut test_val = format!("{}", #ident);
                            if test_val.contains('/') {
                                panic!("A parameter of a route contained a '/' \
                                    (this makes parsing impossible and can be a security risk: see\
                                    the documentation for `rooty_derive::Routes`)");
                            }
                        }
                        std::fmt::Display::fmt(#ident, f)?;
                    },
                })
                .collect::<Vec<_>>();
            quote! {
                #ident :: #variant_ident { #(ref #variant_fields),* } => {
                    #(#format_parts)*
                    Ok(())
                }
            }
        })
        .collect::<Vec<_>>();
    // build string formatting tokens
    let parse_route = variant_data
        .iter()
        .map(|(variant, url)| {
            let variant_ident = &variant.ident;
            let variant_fields = url
                .iter()
                .filter_map(|tok| tok.clone().to_placeholder())
                .collect::<Vec<_>>();
            let format_parts = url
                .iter()
                .map(|tok| match tok {
                    url_parse::Token::Literal(lit) => quote! {
                        input = rooty::consume_literal(input, #lit)?;
                    },
                    url_parse::Token::Placeholder(ident) => quote! {
                        let (next_input, #ident) = rooty::consume_placeholder(input)?;
                        input = next_input;
                    },
                })
                .collect::<Vec<_>>();
            // We implement parsing by creating a parsing function for each variant and then trying
            // them in order.
            let variant_parser_name =
                syn::Ident::new(&format!("__parser_{}", variant_ident), variant_ident.span());
            quote! {
                #[allow(non_snake_case)]
                fn #variant_parser_name(mut input: &str) -> Result<#ident, rooty::NotFound> {
                    #(#format_parts)*
                    if input.is_empty() {
                        Ok(#ident :: #variant_ident { #(#variant_fields),* })
                    } else {
                        Err(rooty::NotFound)
                    }
                }
                if let Ok(t) = #variant_parser_name(input) {
                    return Ok(t);
                }
            }
        })
        .collect::<Vec<_>>();

    let display_name = syn::Ident::new(&format!("_Display_{}", ident), Span::call_site());

    let expanded = quote! {
        #[repr(transparent)]
        pub struct #display_name #generics (#ident #generics);

        impl #generics std::fmt::Display for #display_name #generics {
            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
                match self.0 {
                    #(#format_route),*
                }
            }
        }
        impl #generics rooty::Routes for #ident #generics {
            type UrlDisplay = #display_name;
            fn url(&self) -> &Self::UrlDisplay {
                // safety: transmute between two identical structs (`repr(transparent)`)
                unsafe { std::mem::transmute(self) }
            }
            fn parse_url(input: &str) -> Result<Self, rooty::NotFound> {
                #(#parse_route)*
                Err(rooty::NotFound)
            }
        }
    };

    proc_macro::TokenStream::from(expanded)
}