#![doc = include_str!("../README.md")]
#![doc = include_str!("../examples/controller.rs")]
#![forbid(unsafe_code)]
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use syn::parse_quote;
use syn::Token;
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
ItemImpl, MetaNameValue,
};
#[derive(Clone, Default)]
struct ControllerAttrs {
middlewares: Vec<syn::Expr>,
path: Option<syn::Expr>,
state: Option<syn::Expr>,
}
impl Parse for ControllerAttrs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut path: Option<syn::Expr> = None;
let mut state: Option<syn::Expr> = None;
let mut middlewares: Vec<syn::Expr> = Vec::new();
for nv in Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)? {
let segs = nv.path.segments.clone().into_pairs();
let seg = segs.into_iter().next().unwrap().into_value();
let ident = seg.ident;
match ident.to_string().as_str() {
"path" => {
if path.is_some() {
return Err(syn::Error::new_spanned(
&nv.path,
"duplicate `path` attribute",
));
}
path = Some(nv.value);
}
"state" => {
if state.is_some() {
return Err(syn::Error::new_spanned(
&nv.path,
"duplicate `state` attribute",
));
}
state = Some(nv.value);
}
"middleware" => middlewares.push(nv.value),
_ => {
return Err(syn::Error::new_spanned(
&nv.path,
format_args!(
"unknown attribute `{}`; expected `path`, `state`, or `middleware`",
ident
),
));
}
}
}
Ok(Self {
middlewares,
path,
state,
})
}
}
#[derive(Clone)]
struct ControllerImpl {
struct_name: syn::Type,
route_fns: Vec<syn::Ident>,
}
impl Parse for ControllerImpl {
fn parse(input: ParseStream) -> syn::Result<Self> {
let ast: ItemImpl = input.parse()?;
let struct_name = *(ast.clone().self_ty.clone());
let mut route_fns: Vec<syn::Ident> = vec![];
for item in &ast.items {
if let syn::ImplItem::Fn(impl_item_fn) = item {
for attr in impl_item_fn.attrs.clone() {
if attr.path().is_ident("route") {
let fn_name: Ident = impl_item_fn.sig.ident.clone();
route_fns.push(fn_name);
}
}
}
}
Ok(Self {
struct_name,
route_fns,
})
}
}
#[proc_macro_attribute]
pub fn controller(attrs: TokenStream, c_impl: TokenStream) -> TokenStream {
let parsed_attrs = match syn::parse::<ControllerAttrs>(attrs) {
Ok(args) => args,
Err(err) => return err.to_compile_error().into(),
};
let parsed_impl = match syn::parse::<ControllerImpl>(c_impl.clone()) {
Ok(myimpl) => myimpl,
Err(err) => return err.to_compile_error().into(),
};
let state = parsed_attrs.state.unwrap_or_else(|| parse_quote!(()));
let route_fns = parsed_impl.route_fns;
let struct_name = &parsed_impl.struct_name;
let route = parsed_attrs.path.unwrap_or_else(|| syn::parse_quote!("/"));
let no_routes_warning = if route_fns.is_empty() {
quote! {
#[deprecated(note = "#[controller] applied to impl with no #[route] attributes")]
const __AXUM_CONTROLLER_NO_ROUTES_WARNING: () = ();
const _: () = __AXUM_CONTROLLER_NO_ROUTES_WARNING;
}
} else {
quote! {}
};
let route_calls = route_fns
.into_iter()
.map(move |route| {
quote! {
.typed_route(#struct_name :: #route)
}
})
.collect::<Vec<_>>();
let nesting_call = quote! {
.nest(#route, nested_router)
};
let nested_router_quote = quote! {
axum::Router::new()
#nesting_call
};
let unnested_router_quote = quote! {
nested_router
};
let maybe_nesting_call = if let syn::Expr::Lit(lit) = route {
if lit.eq(&syn::parse_quote!("/")) {
unnested_router_quote
} else {
nested_router_quote
}
} else {
nested_router_quote
};
let middleware_calls = parsed_attrs
.middlewares
.clone()
.into_iter()
.map(|middleware| quote! {.layer(#middleware)})
.collect::<Vec<_>>();
let from_controller_into_router_impl = quote! {
impl #struct_name {
pub fn into_stateless_router(state: #state) -> axum::Router<()> {
Self::into_router()
.with_state(state)
}
pub fn into_router() -> axum::Router<#state> {
let nested_router = axum::Router::new()
#(#route_calls)*
#(#middleware_calls)*
;
#maybe_nesting_call
}
}
};
let c_impl: proc_macro2::TokenStream = c_impl.clone().into();
let res: TokenStream = quote! {
#c_impl
#from_controller_into_router_impl
#no_routes_warning
}
.into();
res
}