1#![doc = include_str!("../README.md")]
2#![doc = include_str!("../examples/controller.rs")]
18#![forbid(unsafe_code)]
21
22use proc_macro::TokenStream;
23use proc_macro2::Ident;
24use quote::quote;
25use syn::parse_quote;
26use syn::Token;
27
28use syn::{
29 parse::{Parse, ParseStream},
30 punctuated::Punctuated,
31 ItemImpl, MetaNameValue,
32};
33
34#[derive(Default)]
35struct ControllerAttrs {
36 middlewares: Vec<syn::Expr>,
37 path: Option<syn::Expr>,
38 state: Option<syn::Expr>,
39}
40
41impl Parse for ControllerAttrs {
42 fn parse(input: ParseStream) -> syn::Result<Self> {
43 let mut path: Option<syn::Expr> = None;
44 let mut state: Option<syn::Expr> = None;
45 let mut middlewares: Vec<syn::Expr> = Vec::new();
46
47 for nv in Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)? {
48 let segs = nv.path.segments.clone().into_pairs();
49 let seg = segs.into_iter().next().unwrap().into_value();
50 let ident = seg.ident;
51 match ident.to_string().as_str() {
52 "path" => {
53 if path.is_some() {
54 return Err(syn::Error::new_spanned(
55 &nv.path,
56 "duplicate `path` attribute",
57 ));
58 }
59 path = Some(nv.value);
60 }
61 "state" => {
62 if state.is_some() {
63 return Err(syn::Error::new_spanned(
64 &nv.path,
65 "duplicate `state` attribute",
66 ));
67 }
68 state = Some(nv.value);
69 }
70 "middleware" => middlewares.push(nv.value),
71 _ => {
72 return Err(syn::Error::new_spanned(
73 &nv.path,
74 format_args!(
75 "unknown attribute `{}`; expected `path`, `state`, or `middleware`",
76 ident
77 ),
78 ));
79 }
80 }
81 }
82
83 Ok(Self {
84 middlewares,
85 path,
86 state,
87 })
88 }
89}
90
91struct ControllerImpl {
92 struct_name: syn::Path,
93 route_fns: Vec<syn::Ident>,
94}
95
96impl Parse for ControllerImpl {
97 fn parse(input: ParseStream) -> syn::Result<Self> {
98 let ast: ItemImpl = input.parse()?;
99 let struct_name = match *ast.self_ty {
100 syn::Type::Path(syn::TypePath { path, .. }) => path,
101 other => {
102 return Err(syn::Error::new_spanned(
103 other,
104 "expected a path (struct name) as the impl target",
105 ))
106 }
107 };
108 let mut route_fns: Vec<syn::Ident> = vec![];
109
110 for item in &ast.items {
111 if let syn::ImplItem::Fn(impl_item_fn) = item {
112 for attr in &impl_item_fn.attrs {
113 if attr.path().is_ident("route") {
114 let fn_name: Ident = impl_item_fn.sig.ident.clone();
115 route_fns.push(fn_name);
116 }
117 }
118 }
119 }
120
121 Ok(Self {
122 struct_name,
123 route_fns,
124 })
125 }
126}
127
128#[proc_macro_attribute]
153pub fn controller(attrs: TokenStream, c_impl: TokenStream) -> TokenStream {
154 let parsed_attrs = match syn::parse::<ControllerAttrs>(attrs) {
155 Ok(args) => args,
156 Err(err) => return err.to_compile_error().into(),
157 };
158 let parsed_impl = match syn::parse::<ControllerImpl>(c_impl.clone()) {
159 Ok(myimpl) => myimpl,
160 Err(err) => return err.to_compile_error().into(),
161 };
162
163 let state = parsed_attrs.state.unwrap_or_else(|| parse_quote!(()));
164 let route_fns = parsed_impl.route_fns;
165 let struct_name = &parsed_impl.struct_name;
166 let route = parsed_attrs.path.unwrap_or_else(|| syn::parse_quote!("/"));
167
168 let no_routes_warning = if route_fns.is_empty() {
169 quote! {
170 #[deprecated(note = "#[controller] applied to impl with no #[route] attributes")]
171 const __AXUM_CONTROLLER_NO_ROUTES_WARNING: () = ();
172 const _: () = __AXUM_CONTROLLER_NO_ROUTES_WARNING;
173 }
174 } else {
175 quote! {}
176 };
177
178 let route_calls = route_fns
179 .into_iter()
180 .map(move |route| {
181 quote! {
182 .typed_route(#struct_name :: #route)
183 }
184 })
185 .collect::<Vec<_>>();
186
187 let nesting_call = quote! {
188 .nest(#route, nested_router)
189 };
190
191 let nested_router_quote = quote! {
192 axum::Router::new()
193 #nesting_call
194 };
195 let unnested_router_quote = quote! {
196 nested_router
197 };
198 let maybe_nesting_call = if let syn::Expr::Lit(lit) = route {
199 if lit.eq(&syn::parse_quote!("/")) {
200 unnested_router_quote
201 } else {
202 nested_router_quote
203 }
204 } else {
205 nested_router_quote
206 };
207
208 let middleware_calls = parsed_attrs
209 .middlewares
210 .into_iter()
211 .map(|middleware| quote! {.layer(#middleware)})
212 .collect::<Vec<_>>();
213
214 let from_controller_into_router_impl = quote! {
215 impl #struct_name {
216 pub fn into_app_router(state: #state) -> axum::Router<()> {
217 Self::into_router()
218 .with_state(state)
219 }
220
221 pub fn into_router() -> axum::Router<#state> {
222 let nested_router = axum::Router::new()
223 #(#route_calls)*
224 #(#middleware_calls)*
225 ;
226
227 #maybe_nesting_call
228 }
229 }
230 };
231
232 let c_impl: proc_macro2::TokenStream = c_impl.into();
233 let res: TokenStream = quote! {
234 #c_impl
235 #from_controller_into_router_impl
236 #no_routes_warning
237 }
238 .into();
239
240 res
241}