1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::punctuated::Punctuated;
5use syn::spanned::Spanned;
6use syn::{Error, Expr, ExprPath, Ident, ImplItem, ItemImpl, Lit, Meta, Token};
7
8#[proc_macro_attribute]
9pub fn router(attr: TokenStream, item: TokenStream) -> TokenStream {
10 let router_attr = match parse_router_attribute(attr) {
11 Ok(v) => v,
12 Err(err) => return err.to_compile_error().into(),
13 };
14 match syn::parse::<ItemImpl>(item) {
15 Ok(item_impl) => handle_router_impl(item_impl, router_attr),
16 Err(_) => Error::new(
17 Span::call_site(),
18 "#[router] can only be applied to impl blocks",
19 )
20 .to_compile_error()
21 .into(),
22 }
23}
24
25struct RouterAttribute {
26 middleware: Vec<ExprPath>,
27}
28
29fn parse_router_attribute(attr: TokenStream) -> Result<RouterAttribute, Error> {
30 if attr.is_empty() {
31 return Ok(RouterAttribute {
32 middleware: Vec::new(),
33 });
34 }
35
36 let meta_items: Punctuated<Meta, Token![,]> =
37 syn::parse::Parser::parse2(Punctuated::parse_terminated, attr.into())?;
38
39 let mut middleware = Vec::new();
40
41 for meta in meta_items {
42 match meta {
43 Meta::NameValue(nv) if nv.path.is_ident("middleware") => {
44 if let Expr::Array(expr_array) = &nv.value {
45 for expr in &expr_array.elems {
46 if let Expr::Path(expr_path) = expr {
47 middleware.push(expr_path.clone());
48 } else {
49 return Err(Error::new_spanned(
50 expr,
51 "Middleware must be a path expression",
52 ));
53 }
54 }
55 } else {
56 return Err(Error::new_spanned(
57 &nv.value,
58 "`middleware` attribute requires an array of paths",
59 ));
60 }
61 }
62 _ => {
63 return Err(Error::new_spanned(
64 meta,
65 "Unsupported attribute format in #[router]",
66 ));
67 }
68 }
69 }
70
71 Ok(RouterAttribute { middleware })
72}
73
74struct RouteAttribute {
75 http_method: proc_macro2::TokenStream,
76 path: String,
77 middleware: Vec<ExprPath>,
78}
79
80fn parse_route_attribute(attr: &syn::Attribute) -> Result<RouteAttribute, Error> {
81 let meta_items: Punctuated<Meta, Token![,]> =
82 attr.parse_args_with(Punctuated::parse_terminated)?;
83
84 let mut http_method = None;
85 let mut path = None;
86 let mut middleware = Vec::new();
87
88 for meta in meta_items {
89 match meta {
90 Meta::Path(path_meta) => {
91 let ident = path_meta
92 .get_ident()
93 .ok_or_else(|| Error::new_spanned(&path_meta, "Expected identifier"))?;
94
95 http_method = Some(match ident.to_string().as_str() {
96 "get" => quote! { ::diode_http::routing::get },
97 "post" => quote! { ::diode_http::routing::post },
98 "delete" => quote! { ::diode_http::routing::delete },
99 "patch" => quote! { ::diode_http::routing::patch },
100 "put" => quote! { ::diode_http::routing::put },
101 "options" => quote! { ::diode_http::routing::options },
102 "connect" => quote! { ::diode_http::routing::connect },
103 "head" => quote! { ::diode_http::routing::head },
104 "trace" => quote! { ::diode_http::routing::trace },
105 "any" => quote! { ::diode_http::routing::any },
106 _ => {
107 return Err(Error::new_spanned(
108 ident,
109 format!("Unsupported HTTP method: {ident}"),
110 ));
111 }
112 });
113 }
114 Meta::NameValue(nv) if nv.path.is_ident("path") => {
115 if let Expr::Lit(expr_lit) = &nv.value
116 && let Lit::Str(lit_str) = &expr_lit.lit
117 {
118 path = Some(lit_str.value());
119 continue;
120 }
121 return Err(Error::new_spanned(
122 &nv.value,
123 "`path` attribute requires a string literal",
124 ));
125 }
126 Meta::NameValue(nv) if nv.path.is_ident("middleware") => {
127 if let Expr::Array(expr_array) = &nv.value {
128 for expr in &expr_array.elems {
129 if let Expr::Path(expr_path) = expr {
130 middleware.push(expr_path.clone());
131 } else {
132 return Err(Error::new_spanned(
133 expr,
134 "Middleware must be a path expression",
135 ));
136 }
137 }
138 } else {
139 return Err(Error::new_spanned(
140 &nv.value,
141 "`middleware` attribute requires an array of paths",
142 ));
143 }
144 }
145 _ => {
146 return Err(Error::new_spanned(
147 meta,
148 "Unsupported attribute format in #[route]",
149 ));
150 }
151 }
152 }
153
154 let http_method = http_method
155 .ok_or_else(|| Error::new_spanned(attr, "Missing HTTP method in #[route] attribute"))?;
156
157 let path =
158 path.ok_or_else(|| Error::new_spanned(attr, "Missing path in #[route] attribute"))?;
159
160 Ok(RouteAttribute {
161 http_method,
162 path,
163 middleware,
164 })
165}
166
167fn handle_router_impl(input: ItemImpl, router_attr: RouterAttribute) -> TokenStream {
168 if input.trait_.is_some() {
169 return Error::new(input.span(), "Trait impls are not supported")
170 .to_compile_error()
171 .into();
172 }
173
174 let self_ty = &input.self_ty;
175 let mut routes = Vec::new();
176 let mut errors = Vec::new();
177
178 let router_middleware = router_attr.middleware;
179
180 let mut cleaned_input = input.clone();
182 for item in &mut cleaned_input.items {
183 if let ImplItem::Fn(fn_item) = item {
184 fn_item.attrs.retain(|attr| !attr.path().is_ident("route"));
185 }
186 }
187
188 for item in &input.items {
189 let ImplItem::Fn(fn_item) = item else {
190 continue;
191 };
192
193 for attr in &fn_item.attrs {
194 if !attr.path().is_ident("route") {
195 continue;
196 }
197
198 match parse_route_attribute(attr) {
199 Ok(RouteAttribute {
200 http_method,
201 path,
202 middleware,
203 }) => {
204 let ident = &fn_item.sig.ident;
205 let arg_count = fn_item.sig.inputs.len().saturating_sub(1); let args: Vec<_> = (0..arg_count)
207 .map(|i| Ident::new(&format!("arg{i}"), Span::call_site()))
208 .collect();
209
210 routes.push(quote! {
211 let mut route = #http_method({
212 let this = self.clone();
213 move |#(#args,)*| {
214 async move { Self::#ident(&this, #(#args,)*).await }
215 }
216 });
217 #(
218 let middleware = app
219 .get_component::<<#middleware as ::diode::Service>::Handle>()
220 .ok_or_else(|| {
221 format!(
222 "Missing component: {}",
223 ::std::any::type_name::<<#middleware as ::diode::Service>::Handle>()
224 )
225 })
226 .unwrap();
227 route = route.layer(::diode_http::MiddlewareLayerImpl(middleware));
228 )*
229 router = router.route(#path, route);
230 });
231 }
232 Err(e) => errors.push(e),
233 }
234 }
235 }
236
237 if !errors.is_empty() {
238 let mut combined_error = Error::new(
239 Span::call_site(),
240 "Errors occurred while processing route attributes",
241 );
242 for error in errors {
243 combined_error.combine(error);
244 }
245 return combined_error.to_compile_error().into();
246 }
247
248 quote! {
249 #cleaned_input
250
251 impl ::diode_http::RouterBuilder for #self_ty {
252 fn build_router(self: ::std::sync::Arc<Self>, app: &::diode::App) -> ::diode_http::Router {
253 let mut router = ::diode_http::Router::new();
254 #(#routes)*
255 #(
256 let middleware = app
257 .get_component::<<#router_middleware as ::diode::Service>::Handle>()
258 .ok_or_else(|| {
259 format!(
260 "Missing component: {}",
261 ::std::any::type_name::<<#router_middleware as ::diode::Service>::Handle>()
262 )
263 })
264 .unwrap();
265 router = router.layer(::diode_http::MiddlewareLayerImpl(middleware));
266 )*
267 router
268 }
269 }
270 }
271 .into()
272}