1#![doc = include_str!("../README.md")]
2#![doc = include_str!("../examples/controller.rs")]
13#![forbid(unsafe_code)]
16
17use proc_macro::TokenStream;
18use proc_macro2::Ident;
19use quote::quote;
20use syn::parse_quote;
21use syn::Token;
22
23use syn::{
24 parse::{Parse, ParseStream},
25 punctuated::Punctuated,
26 ItemImpl, MetaNameValue,
27};
28
29#[derive(Clone, Default)]
30struct ControllerAttrs {
31 middlewares: Vec<syn::Expr>,
32 path: Option<syn::Expr>,
33 state: Option<syn::Expr>,
34}
35
36impl Parse for ControllerAttrs {
37 fn parse(input: ParseStream) -> syn::Result<Self> {
38 let mut path: Option<syn::Expr> = None;
39 let mut state: Option<syn::Expr> = None;
40 let mut middlewares: Vec<syn::Expr> = Vec::new();
41
42 for nv in Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)? {
43 let segs = nv.path.segments.clone().into_pairs();
44 let seg = segs.into_iter().next().unwrap().into_value();
45 let ident = seg.ident;
46 match ident.to_string().as_str() {
47 "path" => {
48 if path.is_some() {
49 return Err(syn::Error::new_spanned(
50 &nv.path,
51 "duplicate `path` attribute",
52 ));
53 }
54 path = Some(nv.value);
55 }
56 "state" => {
57 if state.is_some() {
58 return Err(syn::Error::new_spanned(
59 &nv.path,
60 "duplicate `state` attribute",
61 ));
62 }
63 state = Some(nv.value);
64 }
65 "middleware" => middlewares.push(nv.value),
66 _ => {
67 return Err(syn::Error::new_spanned(
68 &nv.path,
69 format_args!(
70 "unknown attribute `{}`; expected `path`, `state`, or `middleware`",
71 ident
72 ),
73 ));
74 }
75 }
76 }
77
78 Ok(Self {
79 middlewares,
80 path,
81 state,
82 })
83 }
84}
85
86#[derive(Clone)]
87struct ControllerImpl {
88 struct_name: syn::Type,
89 route_fns: Vec<syn::Ident>,
90}
91
92impl Parse for ControllerImpl {
93 fn parse(input: ParseStream) -> syn::Result<Self> {
94 let ast: ItemImpl = input.parse()?;
95 let struct_name = *(ast.clone().self_ty.clone());
96 let mut route_fns: Vec<syn::Ident> = vec![];
97
98 for item in &ast.items {
99 if let syn::ImplItem::Fn(impl_item_fn) = item {
100 for attr in impl_item_fn.attrs.clone() {
101 if attr.path().is_ident("route") {
102 let fn_name: Ident = impl_item_fn.sig.ident.clone();
103 route_fns.push(fn_name);
104 }
105 }
106 }
107 }
108
109 Ok(Self {
110 struct_name,
111 route_fns,
112 })
113 }
114}
115
116#[proc_macro_attribute]
140pub fn controller(attrs: TokenStream, c_impl: TokenStream) -> TokenStream {
141 let parsed_attrs = match syn::parse::<ControllerAttrs>(attrs) {
142 Ok(args) => args,
143 Err(err) => return err.to_compile_error().into(),
144 };
145 let parsed_impl = match syn::parse::<ControllerImpl>(c_impl.clone()) {
146 Ok(myimpl) => myimpl,
147 Err(err) => return err.to_compile_error().into(),
148 };
149
150 let state = parsed_attrs.state.unwrap_or_else(|| parse_quote!(()));
151 let route_fns = parsed_impl.route_fns;
152 let struct_name = &parsed_impl.struct_name;
153 let route = parsed_attrs.path.unwrap_or_else(|| syn::parse_quote!("/"));
154
155 let no_routes_warning = if route_fns.is_empty() {
156 quote! {
157 #[deprecated(note = "#[controller] applied to impl with no #[route] attributes")]
158 const __AXUM_CONTROLLER_NO_ROUTES_WARNING: () = ();
159 const _: () = __AXUM_CONTROLLER_NO_ROUTES_WARNING;
160 }
161 } else {
162 quote! {}
163 };
164
165 let route_calls = route_fns
166 .into_iter()
167 .map(move |route| {
168 quote! {
169 .typed_route(#struct_name :: #route)
170 }
171 })
172 .collect::<Vec<_>>();
173
174 let nesting_call = quote! {
175 .nest(#route, nested_router)
176 };
177
178 let nested_router_quote = quote! {
179 axum::Router::new()
180 #nesting_call
181 };
182 let unnested_router_quote = quote! {
183 nested_router
184 };
185 let maybe_nesting_call = if let syn::Expr::Lit(lit) = route {
186 if lit.eq(&syn::parse_quote!("/")) {
187 unnested_router_quote
188 } else {
189 nested_router_quote
190 }
191 } else {
192 nested_router_quote
193 };
194
195 let middleware_calls = parsed_attrs
196 .middlewares
197 .clone()
198 .into_iter()
199 .map(|middleware| quote! {.layer(#middleware)})
200 .collect::<Vec<_>>();
201
202 let from_controller_into_router_impl = quote! {
203 impl #struct_name {
204 pub fn into_stateless_router(state: #state) -> axum::Router<()> {
205 Self::into_router()
206 .with_state(state)
207
208 }
209
210 pub fn into_router() -> axum::Router<#state> {
211 let nested_router = axum::Router::new()
212 #(#route_calls)*
213 #(#middleware_calls)*
214 ;
215
216 #maybe_nesting_call
217 }
218
219 }
220 };
221
222 let c_impl: proc_macro2::TokenStream = c_impl.clone().into();
223 let res: TokenStream = quote! {
224 #c_impl
225 #from_controller_into_router_impl
226 #no_routes_warning
227 }
228 .into();
229
230 res
231}