Skip to main content

axum_controller/
lib.rs

1#![doc = include_str!("../README.md")]
2//!
3//! ## Basic route macro usage
4//! See the docs of [`axum_typed_routing`] for details on the route macro.
5//! For convenience we re-export the route macro & `TypedRouter` for you
6//! so that all you need to use on your side is `use axum_controller::*`
7//!
8//! ## Controller macro usage
9//!
10//! This crate also offers a `controller()` attribute macro.
11//! use it like this:
12//!
13//! ```
14#![doc = include_str!("../examples/controller.rs")]
15//! ```
16
17#![forbid(unsafe_code)]
18
19use proc_macro::TokenStream;
20use proc_macro2::Ident;
21
22#[macro_use]
23extern crate quote;
24
25#[macro_use]
26extern crate syn;
27
28use syn::{
29    parse::{Parse, ParseStream},
30    punctuated::Punctuated,
31    ItemImpl, MetaNameValue,
32};
33
34#[derive(Clone, Default)]
35struct MyAttrs {
36    middlewares: Vec<syn::Expr>,
37    path: Option<syn::Expr>,
38    state: Option<syn::Expr>,
39}
40
41impl Parse for MyAttrs {
42    fn parse(input: ParseStream) -> syn::Result<Self> {
43        let mut path: Option<syn::Expr> = None;
44        let mut state: Option<syn::Expr> = None;
45        let mut middlewares: Vec<syn::Expr> = Vec::new();
46
47        for nv in Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)? {
48            let segs = nv.path.segments.clone().into_pairs();
49            let seg = segs.into_iter().next().unwrap().into_value();
50            let ident = seg.ident;
51            match ident.to_string().as_str() {
52                "path" => {
53                    if path.is_some() {
54                        return Err(syn::Error::new_spanned(
55                            &nv.path,
56                            "duplicate `path` attribute",
57                        ));
58                    }
59                    path = Some(nv.value);
60                }
61                "state" => {
62                    if state.is_some() {
63                        return Err(syn::Error::new_spanned(
64                            &nv.path,
65                            "duplicate `state` attribute",
66                        ));
67                    }
68                    state = Some(nv.value);
69                }
70                "middleware" => middlewares.push(nv.value),
71                _ => {
72                    return Err(syn::Error::new_spanned(
73                        &nv.path,
74                        format_args!(
75                            "unknown attribute `{}`; expected `path`, `state`, or `middleware`",
76                            ident
77                        ),
78                    ));
79                }
80            }
81        }
82
83        Ok(Self {
84            middlewares,
85            path,
86            state,
87        })
88    }
89}
90
91#[derive(Clone)]
92struct MyItem {
93    struct_name: syn::Type,
94    route_fns: Vec<syn::Ident>,
95}
96
97impl Parse for MyItem {
98    fn parse(input: ParseStream) -> syn::Result<Self> {
99        let ast: ItemImpl = input.parse()?;
100        let struct_name = *(ast.clone().self_ty.clone());
101        let mut route_fns: Vec<syn::Ident> = vec![];
102
103        for item in &ast.items {
104            if let syn::ImplItem::Fn(impl_item_fn) = item {
105                for attr in impl_item_fn.attrs.clone() {
106                    if attr.path().is_ident("route") {
107                        let fn_name: Ident = impl_item_fn.sig.ident.clone();
108                        route_fns.push(fn_name);
109                    }
110                }
111            }
112        }
113
114        Ok(Self {
115            struct_name,
116            route_fns,
117        })
118    }
119}
120
121// TODO add better docs
122/// A macro that generates a `into_router`(\_: State<_>) impl which automatically wires up all `route`'s and the given middlewares, path-prefix etc
123///
124/// ## Syntax:
125/// ```
126/// use axum_controller::controller;
127///
128/// struct ExampleController;
129/// #[controller(
130///   path = "/asd",
131/// )]
132/// impl ExampleController { /* ... */ }
133/// ```
134/// - path
135///   - optional, 0-1 allowed, defaults to `"/"`
136///   - A path to prefix `.nest` the `routes` in the controller Struct under
137/// - state
138///   - optional, 0-1 allowed, defaults to `"()"`)
139///   - The type signature of the state given to the routes
140/// - middleware
141///   - optional, 0-n allowed, default to [] (no middlewares)
142///   - Middlewares to `.layer` in the created router
143///
144#[proc_macro_attribute]
145pub fn controller(attr: TokenStream, item: TokenStream) -> TokenStream {
146    let args = match syn::parse::<MyAttrs>(attr) {
147        Ok(args) => args,
148        Err(err) => return err.to_compile_error().into(),
149    };
150    let item2: proc_macro2::TokenStream = item.clone().into();
151    let myimpl = match syn::parse::<MyItem>(item.clone()) {
152        Ok(myimpl) => myimpl,
153        Err(err) => return err.to_compile_error().into(),
154    };
155
156    let state = args.state.unwrap_or_else(|| parse_quote!(()));
157    let route_fns = myimpl.route_fns;
158    let struct_name = &myimpl.struct_name;
159    let route = args.path.unwrap_or_else(|| syn::parse_quote!("/"));
160
161    let no_routes_warning = if route_fns.is_empty() {
162        quote! {
163            #[deprecated(note = "#[controller] applied to impl with no #[route] attributes")]
164            const __AXUM_CONTROLLER_NO_ROUTES_WARNING: () = ();
165            const _: () = __AXUM_CONTROLLER_NO_ROUTES_WARNING;
166        }
167    } else {
168        quote! {}
169    };
170
171    let route_calls = route_fns
172        .into_iter()
173        .map(move |route| {
174            quote! {
175            .typed_route(#struct_name :: #route)
176                  }
177        })
178        .collect::<Vec<_>>();
179
180    let nesting_call = quote! {
181        .nest(#route, nested_router)
182    };
183
184    let nested_router_qoute = quote! {
185        axum::Router::new()
186        #nesting_call
187    };
188    let unnested_router_quote = quote! {
189        nested_router
190    };
191    let maybe_nesting_call = if let syn::Expr::Lit(lit) = route {
192        if lit.eq(&syn::parse_quote!("/")) {
193            unnested_router_quote
194        } else {
195            nested_router_qoute
196        }
197    } else {
198        nested_router_qoute
199    };
200
201    let middleware_calls = args
202        .middlewares
203        .clone()
204        .into_iter()
205        .map(|middleware| quote! {.layer(#middleware)})
206        .collect::<Vec<_>>();
207
208    let from_controller_into_router_impl = quote! {
209        impl #struct_name {
210            pub fn into_stateless_router(state: #state) -> axum::Router<()> {
211                Self::into_router()
212                    .with_state(state)
213
214            }
215
216            pub fn into_router() -> axum::Router<#state> {
217                let nested_router = axum::Router::new()
218                    #(#route_calls)*
219                    #(#middleware_calls)*
220                    ;
221
222                    #maybe_nesting_call
223            }
224
225        }
226    };
227
228    let res: TokenStream = quote! {
229        #item2
230        #from_controller_into_router_impl
231        #no_routes_warning
232    }
233    .into();
234
235    res
236}