Skip to main content

axum_controller/
lib.rs

1#![doc = include_str!("../README.md")]
2//!
3//! ## Route macro usage
4//! See the docs of [`axum_typed_routing`] for details on the route macro.
5//!
6//! ## Controller macro usage
7//!
8//! This crate offers the `controller()` attribute macro.
9//! use it like this:
10//!
11//! ```
12#![doc = include_str!("../examples/controller.rs")]
13//! ```
14
15#![forbid(unsafe_code)]
16
17use proc_macro::TokenStream;
18use proc_macro2::Ident;
19use quote::quote;
20use syn::parse_quote;
21use syn::Token;
22
23use syn::{
24    parse::{Parse, ParseStream},
25    punctuated::Punctuated,
26    ItemImpl, MetaNameValue,
27};
28
29#[derive(Clone, Default)]
30struct ControllerAttrs {
31    middlewares: Vec<syn::Expr>,
32    path: Option<syn::Expr>,
33    state: Option<syn::Expr>,
34}
35
36impl Parse for ControllerAttrs {
37    fn parse(input: ParseStream) -> syn::Result<Self> {
38        let mut path: Option<syn::Expr> = None;
39        let mut state: Option<syn::Expr> = None;
40        let mut middlewares: Vec<syn::Expr> = Vec::new();
41
42        for nv in Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)? {
43            let segs = nv.path.segments.clone().into_pairs();
44            let seg = segs.into_iter().next().unwrap().into_value();
45            let ident = seg.ident;
46            match ident.to_string().as_str() {
47                "path" => {
48                    if path.is_some() {
49                        return Err(syn::Error::new_spanned(
50                            &nv.path,
51                            "duplicate `path` attribute",
52                        ));
53                    }
54                    path = Some(nv.value);
55                }
56                "state" => {
57                    if state.is_some() {
58                        return Err(syn::Error::new_spanned(
59                            &nv.path,
60                            "duplicate `state` attribute",
61                        ));
62                    }
63                    state = Some(nv.value);
64                }
65                "middleware" => middlewares.push(nv.value),
66                _ => {
67                    return Err(syn::Error::new_spanned(
68                        &nv.path,
69                        format_args!(
70                            "unknown attribute `{}`; expected `path`, `state`, or `middleware`",
71                            ident
72                        ),
73                    ));
74                }
75            }
76        }
77
78        Ok(Self {
79            middlewares,
80            path,
81            state,
82        })
83    }
84}
85
86#[derive(Clone)]
87struct ControllerImpl {
88    struct_name: syn::Type,
89    route_fns: Vec<syn::Ident>,
90}
91
92impl Parse for ControllerImpl {
93    fn parse(input: ParseStream) -> syn::Result<Self> {
94        let ast: ItemImpl = input.parse()?;
95        let struct_name = *(ast.clone().self_ty.clone());
96        let mut route_fns: Vec<syn::Ident> = vec![];
97
98        for item in &ast.items {
99            if let syn::ImplItem::Fn(impl_item_fn) = item {
100                for attr in impl_item_fn.attrs.clone() {
101                    if attr.path().is_ident("route") {
102                        let fn_name: Ident = impl_item_fn.sig.ident.clone();
103                        route_fns.push(fn_name);
104                    }
105                }
106            }
107        }
108
109        Ok(Self {
110            struct_name,
111            route_fns,
112        })
113    }
114}
115
116// TODO add better docs
117/// A macro that generates a `into_router`(\_: State<_>) impl which automatically wires up all `route`'s and the given middlewares, path-prefix etc
118///
119/// ## Syntax:
120/// ```
121/// use axum_controller::controller;
122///
123/// struct ExampleController;
124/// #[controller(
125///   path = "/asd",
126/// )]
127/// impl ExampleController { /* ... */ }
128/// ```
129/// - path
130///   - optional, 0-1 allowed, defaults to `"/"`
131///   - A path to prefix `.nest` the `routes` in the controller Struct under
132/// - state
133///   - optional, 0-1 allowed, defaults to `"()"`)
134///   - The type signature of the state given to the routes
135/// - middleware
136///   - optional, 0-n allowed, default to [] (no middlewares)
137///   - Middlewares to `.layer` in the created router
138///
139#[proc_macro_attribute]
140pub fn controller(attrs: TokenStream, c_impl: TokenStream) -> TokenStream {
141    let parsed_attrs = match syn::parse::<ControllerAttrs>(attrs) {
142        Ok(args) => args,
143        Err(err) => return err.to_compile_error().into(),
144    };
145    let parsed_impl = match syn::parse::<ControllerImpl>(c_impl.clone()) {
146        Ok(myimpl) => myimpl,
147        Err(err) => return err.to_compile_error().into(),
148    };
149
150    let state = parsed_attrs.state.unwrap_or_else(|| parse_quote!(()));
151    let route_fns = parsed_impl.route_fns;
152    let struct_name = &parsed_impl.struct_name;
153    let route = parsed_attrs.path.unwrap_or_else(|| syn::parse_quote!("/"));
154
155    let no_routes_warning = if route_fns.is_empty() {
156        quote! {
157            #[deprecated(note = "#[controller] applied to impl with no #[route] attributes")]
158            const __AXUM_CONTROLLER_NO_ROUTES_WARNING: () = ();
159            const _: () = __AXUM_CONTROLLER_NO_ROUTES_WARNING;
160        }
161    } else {
162        quote! {}
163    };
164
165    let route_calls = route_fns
166        .into_iter()
167        .map(move |route| {
168            quote! {
169            .typed_route(#struct_name :: #route)
170                  }
171        })
172        .collect::<Vec<_>>();
173
174    let nesting_call = quote! {
175        .nest(#route, nested_router)
176    };
177
178    let nested_router_quote = quote! {
179        axum::Router::new()
180        #nesting_call
181    };
182    let unnested_router_quote = quote! {
183        nested_router
184    };
185    let maybe_nesting_call = if let syn::Expr::Lit(lit) = route {
186        if lit.eq(&syn::parse_quote!("/")) {
187            unnested_router_quote
188        } else {
189            nested_router_quote
190        }
191    } else {
192        nested_router_quote
193    };
194
195    let middleware_calls = parsed_attrs
196        .middlewares
197        .clone()
198        .into_iter()
199        .map(|middleware| quote! {.layer(#middleware)})
200        .collect::<Vec<_>>();
201
202    let from_controller_into_router_impl = quote! {
203        impl #struct_name {
204            pub fn into_stateless_router(state: #state) -> axum::Router<()> {
205                Self::into_router()
206                    .with_state(state)
207
208            }
209
210            pub fn into_router() -> axum::Router<#state> {
211                let nested_router = axum::Router::new()
212                    #(#route_calls)*
213                    #(#middleware_calls)*
214                    ;
215
216                    #maybe_nesting_call
217            }
218
219        }
220    };
221
222    let c_impl: proc_macro2::TokenStream = c_impl.clone().into();
223    let res: TokenStream = quote! {
224        #c_impl
225        #from_controller_into_router_impl
226        #no_routes_warning
227    }
228    .into();
229
230    res
231}