Skip to main content

axum_controller/
lib.rs

1#![doc = include_str!("../README.md")]
2//!
3//! ## Basic route macro usage
4//! See the docs of [`axum_typed_routing`] for details on the route macro.
5//! For convenience we re-export the route macro & `TypedRouter` for you
6//! so that all you need to use on your side is `use axum_controller::*`
7//!
8//! ## Controller macro usage
9//!
10//! This crate also offers a `controller()` attribute macro.
11//! use it like this:
12//!
13//! ```
14#![doc = include_str!("../examples/controller.rs")]
15//! ```
16
17#![forbid(unsafe_code)]
18
19use proc_macro::TokenStream;
20use proc_macro2::Ident;
21
22#[macro_use]
23extern crate quote;
24
25#[macro_use]
26extern crate syn;
27
28use syn::{
29    parse::{Parse, ParseStream},
30    punctuated::Punctuated,
31    ItemImpl, MetaNameValue,
32};
33
34#[derive(Clone, Default)]
35struct MyAttrs {
36    middlewares: Vec<syn::Expr>,
37    path: Option<syn::Expr>,
38    state: Option<syn::Expr>,
39}
40
41impl Parse for MyAttrs {
42    fn parse(input: ParseStream) -> syn::Result<Self> {
43        let mut path: Option<syn::Expr> = None;
44        let mut state: Option<syn::Expr> = None;
45        let mut middlewares: Vec<syn::Expr> = Vec::new();
46
47        for nv in Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)? {
48            let segs = nv.path.segments.clone().into_pairs();
49            let seg = segs.into_iter().next().unwrap().into_value();
50            let ident = seg.ident;
51            match ident.to_string().as_str() {
52                "path" => {
53                    if path.is_some() {
54                        return Err(syn::Error::new_spanned(
55                            &nv.path,
56                            "duplicate `path` attribute",
57                        ));
58                    }
59                    path = Some(nv.value);
60                }
61                "state" => {
62                    if state.is_some() {
63                        return Err(syn::Error::new_spanned(
64                            &nv.path,
65                            "duplicate `state` attribute",
66                        ));
67                    }
68                    state = Some(nv.value);
69                }
70                "middleware" => middlewares.push(nv.value),
71                _ => {
72                    return Err(syn::Error::new_spanned(
73                        &nv.path,
74                        format_args!(
75                            "unknown attribute `{}`; expected `path`, `state`, or `middleware`",
76                            ident
77                        ),
78                    ));
79                }
80            }
81        }
82
83        Ok(Self {
84            middlewares,
85            path,
86            state,
87        })
88    }
89}
90
91#[derive(Clone)]
92struct MyItem {
93    struct_name: syn::Type,
94    route_fns: Vec<syn::Ident>,
95}
96
97impl Parse for MyItem {
98    fn parse(input: ParseStream) -> syn::Result<Self> {
99        let ast: ItemImpl = input.parse()?;
100        let struct_name = *(ast.clone().self_ty.clone());
101        let mut route_fns: Vec<syn::Ident> = vec![];
102
103        for item in &ast.items {
104            if let syn::ImplItem::Fn(impl_item_fn) = item {
105                for attr in impl_item_fn.attrs.clone() {
106                    if attr.path().is_ident("route") {
107                        let fn_name: Ident = impl_item_fn.sig.ident.clone();
108                        route_fns.push(fn_name);
109                    }
110                }
111            }
112        }
113
114        Ok(Self {
115            struct_name,
116            route_fns,
117        })
118    }
119}
120
121// TODO add better docs
122/// A macro that generates a `into_router`(\_: State<_>) impl which automatically wires up all `route`'s and the given middlewares, path-prefix etc
123///
124/// ## Syntax:
125/// ```
126/// use axum_controller::controller;
127///
128/// struct ExampleController;
129/// #[controller(
130///   path = "/asd",
131/// )]
132/// impl ExampleController { /* ... */ }
133/// ```
134/// - path
135///   - optional, 0-1 allowed, defaults to `"/"`
136///   - A path to prefix `.nest` the `routes` in the controller Struct under
137/// - state
138///   - optional, 0-1 allowed, defaults to `"()"`)
139///   - The type signature of the state given to the routes
140/// - middleware
141///   - optional, 0-n allowed, default to [] (no middlewares)
142///   - Middlewares to `.layer` in the created router
143///
144#[proc_macro_attribute]
145pub fn controller(attr: TokenStream, item: TokenStream) -> TokenStream {
146    let args = match syn::parse::<MyAttrs>(attr) {
147        Ok(args) => args,
148        Err(err) => return err.to_compile_error().into(),
149    };
150    let item2: proc_macro2::TokenStream = item.clone().into();
151    let myimpl = match syn::parse::<MyItem>(item.clone()) {
152        Ok(myimpl) => myimpl,
153        Err(err) => return err.to_compile_error().into(),
154    };
155
156    let state = args.state.unwrap_or_else(|| parse_quote!(()));
157    let route_fns = myimpl.route_fns;
158    let struct_name = &myimpl.struct_name;
159    let route = args.path.unwrap_or_else(|| syn::parse_quote!("/"));
160
161    let route_calls = route_fns
162        .into_iter()
163        .map(move |route| {
164            quote! {
165            .typed_route(#struct_name :: #route)
166                  }
167        })
168        .collect::<Vec<_>>();
169
170    let nesting_call = quote! {
171        .nest(#route, nested_router)
172    };
173
174    let nested_router_qoute = quote! {
175        axum::Router::new()
176        #nesting_call
177    };
178    let unnested_router_quote = quote! {
179        nested_router
180    };
181    let maybe_nesting_call = if let syn::Expr::Lit(lit) = route {
182        if lit.eq(&syn::parse_quote!("/")) {
183            unnested_router_quote
184        } else {
185            nested_router_qoute
186        }
187    } else {
188        nested_router_qoute
189    };
190
191    let middleware_calls = args
192        .middlewares
193        .clone()
194        .into_iter()
195        .map(|middleware| quote! {.layer(#middleware)})
196        .collect::<Vec<_>>();
197
198    let from_controller_into_router_impl = quote! {
199        impl #struct_name {
200            pub fn into_stateless_router(state: #state) -> axum::Router<()> {
201                Self::into_router()
202                    .with_state(state)
203
204            }
205
206            pub fn into_router() -> axum::Router<#state> {
207                let nested_router = axum::Router::new()
208                    #(#route_calls)*
209                    #(#middleware_calls)*
210                    ;
211
212                    #maybe_nesting_call
213            }
214
215        }
216    };
217
218    let res: TokenStream = quote! {
219        #item2
220        #from_controller_into_router_impl
221    }
222    .into();
223
224    res
225}