1#![doc = include_str!("../README.md")]
2#![doc = include_str!("../examples/controller.rs")]
15#![forbid(unsafe_code)]
18
19use proc_macro::TokenStream;
20use proc_macro2::Ident;
21
22#[macro_use]
23extern crate quote;
24
25#[macro_use]
26extern crate syn;
27
28use syn::{
29 parse::{Parse, ParseStream},
30 punctuated::Punctuated,
31 ItemImpl, MetaNameValue,
32};
33
34#[derive(Clone, Default)]
35struct MyAttrs {
36 middlewares: Vec<syn::Expr>,
37 path: Option<syn::Expr>,
38 state: Option<syn::Expr>,
39}
40
41impl Parse for MyAttrs {
42 fn parse(input: ParseStream) -> syn::Result<Self> {
43 let mut path: Option<syn::Expr> = None;
44 let mut state: Option<syn::Expr> = None;
45 let mut middlewares: Vec<syn::Expr> = Vec::new();
46
47 for nv in Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)? {
48 let segs = nv.path.segments.clone().into_pairs();
49 let seg = segs.into_iter().next().unwrap().into_value();
50 let ident = seg.ident;
51 match ident.to_string().as_str() {
52 "path" => {
53 if path.is_some() {
54 return Err(syn::Error::new_spanned(
55 &nv.path,
56 "duplicate `path` attribute",
57 ));
58 }
59 path = Some(nv.value);
60 }
61 "state" => {
62 if state.is_some() {
63 return Err(syn::Error::new_spanned(
64 &nv.path,
65 "duplicate `state` attribute",
66 ));
67 }
68 state = Some(nv.value);
69 }
70 "middleware" => middlewares.push(nv.value),
71 _ => {
72 return Err(syn::Error::new_spanned(
73 &nv.path,
74 format_args!(
75 "unknown attribute `{}`; expected `path`, `state`, or `middleware`",
76 ident
77 ),
78 ));
79 }
80 }
81 }
82
83 Ok(Self {
84 middlewares,
85 path,
86 state,
87 })
88 }
89}
90
91#[derive(Clone)]
92struct MyItem {
93 struct_name: syn::Type,
94 route_fns: Vec<syn::Ident>,
95}
96
97impl Parse for MyItem {
98 fn parse(input: ParseStream) -> syn::Result<Self> {
99 let ast: ItemImpl = input.parse()?;
100 let struct_name = *(ast.clone().self_ty.clone());
101 let mut route_fns: Vec<syn::Ident> = vec![];
102
103 for item in &ast.items {
104 if let syn::ImplItem::Fn(impl_item_fn) = item {
105 for attr in impl_item_fn.attrs.clone() {
106 if attr.path().is_ident("route") {
107 let fn_name: Ident = impl_item_fn.sig.ident.clone();
108 route_fns.push(fn_name);
109 }
110 }
111 }
112 }
113
114 Ok(Self {
115 struct_name,
116 route_fns,
117 })
118 }
119}
120
121#[proc_macro_attribute]
145pub fn controller(attr: TokenStream, item: TokenStream) -> TokenStream {
146 let args = match syn::parse::<MyAttrs>(attr) {
147 Ok(args) => args,
148 Err(err) => return err.to_compile_error().into(),
149 };
150 let item2: proc_macro2::TokenStream = item.clone().into();
151 let myimpl = match syn::parse::<MyItem>(item.clone()) {
152 Ok(myimpl) => myimpl,
153 Err(err) => return err.to_compile_error().into(),
154 };
155
156 let state = args.state.unwrap_or_else(|| parse_quote!(()));
157 let route_fns = myimpl.route_fns;
158 let struct_name = &myimpl.struct_name;
159 let route = args.path.unwrap_or_else(|| syn::parse_quote!("/"));
160
161 let no_routes_warning = if route_fns.is_empty() {
162 quote! {
163 #[deprecated(note = "#[controller] applied to impl with no #[route] attributes")]
164 const __AXUM_CONTROLLER_NO_ROUTES_WARNING: () = ();
165 const _: () = __AXUM_CONTROLLER_NO_ROUTES_WARNING;
166 }
167 } else {
168 quote! {}
169 };
170
171 let route_calls = route_fns
172 .into_iter()
173 .map(move |route| {
174 quote! {
175 .typed_route(#struct_name :: #route)
176 }
177 })
178 .collect::<Vec<_>>();
179
180 let nesting_call = quote! {
181 .nest(#route, nested_router)
182 };
183
184 let nested_router_qoute = quote! {
185 axum::Router::new()
186 #nesting_call
187 };
188 let unnested_router_quote = quote! {
189 nested_router
190 };
191 let maybe_nesting_call = if let syn::Expr::Lit(lit) = route {
192 if lit.eq(&syn::parse_quote!("/")) {
193 unnested_router_quote
194 } else {
195 nested_router_qoute
196 }
197 } else {
198 nested_router_qoute
199 };
200
201 let middleware_calls = args
202 .middlewares
203 .clone()
204 .into_iter()
205 .map(|middleware| quote! {.layer(#middleware)})
206 .collect::<Vec<_>>();
207
208 let from_controller_into_router_impl = quote! {
209 impl #struct_name {
210 pub fn into_stateless_router(state: #state) -> axum::Router<()> {
211 Self::into_router()
212 .with_state(state)
213
214 }
215
216 pub fn into_router() -> axum::Router<#state> {
217 let nested_router = axum::Router::new()
218 #(#route_calls)*
219 #(#middleware_calls)*
220 ;
221
222 #maybe_nesting_call
223 }
224
225 }
226 };
227
228 let res: TokenStream = quote! {
229 #item2
230 #from_controller_into_router_impl
231 #no_routes_warning
232 }
233 .into();
234
235 res
236}