Skip to main content

axum_controller/
lib.rs

1#![doc = include_str!("../README.md")]
2//!
3//! This crate is an extension of [`axum_typed_routing`]. It uses the `#[route]`
4//! attribute and the `TypedRouter` trait from that crate to discover and wire up
5//! handler methods. You **must** add `axum-typed-routing` to your own
6//! `Cargo.toml` dependencies for this crate to work.
7//!
8//! ## Route macro usage
9//! See the docs of [`axum_typed_routing`] for details on the `#[route]` macro.
10//!
11//! ## Controller macro usage
12//!
13//! This crate offers the `controller()` attribute macro.
14//! use it like this:
15//!
16//! ```
17#![doc = include_str!("../examples/controller.rs")]
18//! ```
19
20#![forbid(unsafe_code)]
21
22use proc_macro::TokenStream;
23use proc_macro2::Ident;
24use quote::quote;
25use syn::parse_quote;
26use syn::Token;
27
28use syn::{
29    parse::{Parse, ParseStream},
30    punctuated::Punctuated,
31    ItemImpl, MetaNameValue,
32};
33
34#[derive(Default)]
35struct ControllerAttrs {
36    middlewares: Vec<syn::Expr>,
37    path: Option<syn::Expr>,
38    state: Option<syn::Expr>,
39}
40
41impl Parse for ControllerAttrs {
42    fn parse(input: ParseStream) -> syn::Result<Self> {
43        let mut path: Option<syn::Expr> = None;
44        let mut state: Option<syn::Expr> = None;
45        let mut middlewares: Vec<syn::Expr> = Vec::new();
46
47        for nv in Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)? {
48            let segs = nv.path.segments.clone().into_pairs();
49            let seg = segs.into_iter().next().unwrap().into_value();
50            let ident = seg.ident;
51            match ident.to_string().as_str() {
52                "path" => {
53                    if path.is_some() {
54                        return Err(syn::Error::new_spanned(
55                            &nv.path,
56                            "duplicate `path` attribute",
57                        ));
58                    }
59                    path = Some(nv.value);
60                }
61                "state" => {
62                    if state.is_some() {
63                        return Err(syn::Error::new_spanned(
64                            &nv.path,
65                            "duplicate `state` attribute",
66                        ));
67                    }
68                    state = Some(nv.value);
69                }
70                "middleware" => middlewares.push(nv.value),
71                _ => {
72                    return Err(syn::Error::new_spanned(
73                        &nv.path,
74                        format_args!(
75                            "unknown attribute `{}`; expected `path`, `state`, or `middleware`",
76                            ident
77                        ),
78                    ));
79                }
80            }
81        }
82
83        Ok(Self {
84            middlewares,
85            path,
86            state,
87        })
88    }
89}
90
91struct ControllerImpl {
92    struct_name: syn::Path,
93    route_fns: Vec<syn::Ident>,
94}
95
96impl Parse for ControllerImpl {
97    fn parse(input: ParseStream) -> syn::Result<Self> {
98        let ast: ItemImpl = input.parse()?;
99        let struct_name = match *ast.self_ty {
100            syn::Type::Path(syn::TypePath { path, .. }) => path,
101            other => {
102                return Err(syn::Error::new_spanned(
103                    other,
104                    "expected a path (struct name) as the impl target",
105                ))
106            }
107        };
108        let mut route_fns: Vec<syn::Ident> = vec![];
109
110        for item in &ast.items {
111            if let syn::ImplItem::Fn(impl_item_fn) = item {
112                for attr in &impl_item_fn.attrs {
113                    if attr.path().is_ident("route") {
114                        let fn_name: Ident = impl_item_fn.sig.ident.clone();
115                        route_fns.push(fn_name);
116                    }
117                }
118            }
119        }
120
121        Ok(Self {
122            struct_name,
123            route_fns,
124        })
125    }
126}
127
128/// A macro that generates `into_router` and `into_app_router` methods which
129/// automatically wire up all `#[route]` handlers and the given middlewares,
130/// path-prefix etc.
131///
132/// ## Syntax:
133/// ```
134/// use axum_controller::controller;
135///
136/// struct ExampleController;
137/// #[controller(
138///   path = "/asd",
139/// )]
140/// impl ExampleController { /* ... */ }
141/// ```
142/// - path
143///   - optional, 0-1 allowed, defaults to `"/"`
144///   - A path to prefix `.nest` the `routes` in the controller Struct under
145/// - state
146///   - optional, 0-1 allowed, defaults to `"()"`)
147///   - The type signature of the state given to the routes
148/// - middleware
149///   - optional, 0-n allowed, default to [] (no middlewares)
150///   - Middlewares to `.layer` in the created router
151///
152#[proc_macro_attribute]
153pub fn controller(attrs: TokenStream, c_impl: TokenStream) -> TokenStream {
154    let parsed_attrs = match syn::parse::<ControllerAttrs>(attrs) {
155        Ok(args) => args,
156        Err(err) => return err.to_compile_error().into(),
157    };
158    let parsed_impl = match syn::parse::<ControllerImpl>(c_impl.clone()) {
159        Ok(myimpl) => myimpl,
160        Err(err) => return err.to_compile_error().into(),
161    };
162
163    let state = parsed_attrs.state.unwrap_or_else(|| parse_quote!(()));
164    let route_fns = parsed_impl.route_fns;
165    let struct_name = &parsed_impl.struct_name;
166    let route = parsed_attrs.path.unwrap_or_else(|| syn::parse_quote!("/"));
167
168    let no_routes_warning = if route_fns.is_empty() {
169        quote! {
170            #[deprecated(note = "#[controller] applied to impl with no #[route] attributes")]
171            const __AXUM_CONTROLLER_NO_ROUTES_WARNING: () = ();
172            const _: () = __AXUM_CONTROLLER_NO_ROUTES_WARNING;
173        }
174    } else {
175        quote! {}
176    };
177
178    let route_calls = route_fns
179        .into_iter()
180        .map(move |route| {
181            quote! {
182            .typed_route(#struct_name :: #route)
183                  }
184        })
185        .collect::<Vec<_>>();
186
187    let nesting_call = quote! {
188        .nest(#route, nested_router)
189    };
190
191    let nested_router_quote = quote! {
192        axum::Router::new()
193        #nesting_call
194    };
195    let unnested_router_quote = quote! {
196        nested_router
197    };
198    let maybe_nesting_call = if let syn::Expr::Lit(lit) = route {
199        if lit.eq(&syn::parse_quote!("/")) {
200            unnested_router_quote
201        } else {
202            nested_router_quote
203        }
204    } else {
205        nested_router_quote
206    };
207
208    let middleware_calls = parsed_attrs
209        .middlewares
210        .into_iter()
211        .map(|middleware| quote! {.layer(#middleware)})
212        .collect::<Vec<_>>();
213
214    let from_controller_into_router_impl = quote! {
215        impl #struct_name {
216            pub fn into_app_router(state: #state) -> axum::Router<()> {
217                Self::into_router()
218                    .with_state(state)
219            }
220
221            pub fn into_router() -> axum::Router<#state> {
222                let nested_router = axum::Router::new()
223                    #(#route_calls)*
224                    #(#middleware_calls)*
225                    ;
226
227                    #maybe_nesting_call
228            }
229        }
230    };
231
232    let c_impl: proc_macro2::TokenStream = c_impl.into();
233    let res: TokenStream = quote! {
234        #c_impl
235        #from_controller_into_router_impl
236        #no_routes_warning
237    }
238    .into();
239
240    res
241}