diode_http_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::punctuated::Punctuated;
5use syn::spanned::Spanned;
6use syn::{Error, Expr, ExprPath, Ident, ImplItem, ItemImpl, Lit, Meta, Token};
7
8#[proc_macro_attribute]
9pub fn router(attr: TokenStream, item: TokenStream) -> TokenStream {
10    let router_attr = match parse_router_attribute(attr) {
11        Ok(v) => v,
12        Err(err) => return err.to_compile_error().into(),
13    };
14    match syn::parse::<ItemImpl>(item) {
15        Ok(item_impl) => handle_router_impl(item_impl, router_attr),
16        Err(_) => Error::new(
17            Span::call_site(),
18            "#[router] can only be applied to impl blocks",
19        )
20        .to_compile_error()
21        .into(),
22    }
23}
24
25struct RouterAttribute {
26    middleware: Vec<ExprPath>,
27}
28
29fn parse_router_attribute(attr: TokenStream) -> Result<RouterAttribute, Error> {
30    if attr.is_empty() {
31        return Ok(RouterAttribute {
32            middleware: Vec::new(),
33        });
34    }
35
36    let meta_items: Punctuated<Meta, Token![,]> =
37        syn::parse::Parser::parse2(Punctuated::parse_terminated, attr.into())?;
38
39    let mut middleware = Vec::new();
40
41    for meta in meta_items {
42        match meta {
43            Meta::NameValue(nv) if nv.path.is_ident("middleware") => {
44                if let Expr::Array(expr_array) = &nv.value {
45                    for expr in &expr_array.elems {
46                        if let Expr::Path(expr_path) = expr {
47                            middleware.push(expr_path.clone());
48                        } else {
49                            return Err(Error::new_spanned(
50                                expr,
51                                "Middleware must be a path expression",
52                            ));
53                        }
54                    }
55                } else {
56                    return Err(Error::new_spanned(
57                        &nv.value,
58                        "`middleware` attribute requires an array of paths",
59                    ));
60                }
61            }
62            _ => {
63                return Err(Error::new_spanned(
64                    meta,
65                    "Unsupported attribute format in #[router]",
66                ));
67            }
68        }
69    }
70
71    Ok(RouterAttribute { middleware })
72}
73
74struct RouteAttribute {
75    http_method: proc_macro2::TokenStream,
76    path: String,
77    middleware: Vec<ExprPath>,
78}
79
80fn parse_route_attribute(attr: &syn::Attribute) -> Result<RouteAttribute, Error> {
81    let meta_items: Punctuated<Meta, Token![,]> =
82        attr.parse_args_with(Punctuated::parse_terminated)?;
83
84    let mut http_method = None;
85    let mut path = None;
86    let mut middleware = Vec::new();
87
88    for meta in meta_items {
89        match meta {
90            Meta::Path(path_meta) => {
91                let ident = path_meta
92                    .get_ident()
93                    .ok_or_else(|| Error::new_spanned(&path_meta, "Expected identifier"))?;
94
95                http_method = Some(match ident.to_string().as_str() {
96                    "get" => quote! { ::diode_http::routing::get },
97                    "post" => quote! { ::diode_http::routing::post },
98                    "delete" => quote! { ::diode_http::routing::delete },
99                    "patch" => quote! { ::diode_http::routing::patch },
100                    "put" => quote! { ::diode_http::routing::put },
101                    "options" => quote! { ::diode_http::routing::options },
102                    "connect" => quote! { ::diode_http::routing::connect },
103                    "head" => quote! { ::diode_http::routing::head },
104                    "trace" => quote! { ::diode_http::routing::trace },
105                    "any" => quote! { ::diode_http::routing::any },
106                    _ => {
107                        return Err(Error::new_spanned(
108                            ident,
109                            format!("Unsupported HTTP method: {ident}"),
110                        ));
111                    }
112                });
113            }
114            Meta::NameValue(nv) if nv.path.is_ident("path") => {
115                if let Expr::Lit(expr_lit) = &nv.value
116                    && let Lit::Str(lit_str) = &expr_lit.lit
117                {
118                    path = Some(lit_str.value());
119                    continue;
120                }
121                return Err(Error::new_spanned(
122                    &nv.value,
123                    "`path` attribute requires a string literal",
124                ));
125            }
126            Meta::NameValue(nv) if nv.path.is_ident("middleware") => {
127                if let Expr::Array(expr_array) = &nv.value {
128                    for expr in &expr_array.elems {
129                        if let Expr::Path(expr_path) = expr {
130                            middleware.push(expr_path.clone());
131                        } else {
132                            return Err(Error::new_spanned(
133                                expr,
134                                "Middleware must be a path expression",
135                            ));
136                        }
137                    }
138                } else {
139                    return Err(Error::new_spanned(
140                        &nv.value,
141                        "`middleware` attribute requires an array of paths",
142                    ));
143                }
144            }
145            _ => {
146                return Err(Error::new_spanned(
147                    meta,
148                    "Unsupported attribute format in #[route]",
149                ));
150            }
151        }
152    }
153
154    let http_method = http_method
155        .ok_or_else(|| Error::new_spanned(attr, "Missing HTTP method in #[route] attribute"))?;
156
157    let path =
158        path.ok_or_else(|| Error::new_spanned(attr, "Missing path in #[route] attribute"))?;
159
160    Ok(RouteAttribute {
161        http_method,
162        path,
163        middleware,
164    })
165}
166
167fn handle_router_impl(input: ItemImpl, router_attr: RouterAttribute) -> TokenStream {
168    if input.trait_.is_some() {
169        return Error::new(input.span(), "Trait impls are not supported")
170            .to_compile_error()
171            .into();
172    }
173
174    let self_ty = &input.self_ty;
175    let mut routes = Vec::new();
176    let mut errors = Vec::new();
177
178    let router_middleware = router_attr.middleware;
179
180    // Create cleaned impl with route attributes removed
181    let mut cleaned_input = input.clone();
182    for item in &mut cleaned_input.items {
183        if let ImplItem::Fn(fn_item) = item {
184            fn_item.attrs.retain(|attr| !attr.path().is_ident("route"));
185        }
186    }
187
188    for item in &input.items {
189        let ImplItem::Fn(fn_item) = item else {
190            continue;
191        };
192
193        for attr in &fn_item.attrs {
194            if !attr.path().is_ident("route") {
195                continue;
196            }
197
198            match parse_route_attribute(attr) {
199                Ok(RouteAttribute {
200                    http_method,
201                    path,
202                    middleware,
203                }) => {
204                    let ident = &fn_item.sig.ident;
205                    let arg_count = fn_item.sig.inputs.len().saturating_sub(1); // Exclude self
206                    let args: Vec<_> = (0..arg_count)
207                        .map(|i| Ident::new(&format!("arg{i}"), Span::call_site()))
208                        .collect();
209
210                    routes.push(quote! {
211                        let mut route = #http_method({
212                            let this = self.clone();
213                            move |#(#args,)*| {
214                                async move { Self::#ident(&this, #(#args,)*).await }
215                            }
216                        });
217                        #(
218                            let middleware = app
219                                .get_component::<<#middleware as ::diode::Service>::Handle>()
220                                .ok_or_else(|| {
221                                    format!(
222                                        "Missing component: {}",
223                                        ::std::any::type_name::<<#middleware as ::diode::Service>::Handle>()
224                                    )
225                                })
226                                .unwrap();
227                            route = route.layer(::diode_http::MiddlewareLayerImpl(middleware));
228                        )*
229                        router = router.route(#path, route);
230                    });
231                }
232                Err(e) => errors.push(e),
233            }
234        }
235    }
236
237    if !errors.is_empty() {
238        let mut combined_error = Error::new(
239            Span::call_site(),
240            "Errors occurred while processing route attributes",
241        );
242        for error in errors {
243            combined_error.combine(error);
244        }
245        return combined_error.to_compile_error().into();
246    }
247
248    quote! {
249        #cleaned_input
250
251        impl ::diode_http::RouterBuilder for #self_ty {
252            fn build_router(self: ::std::sync::Arc<Self>, app: &::diode::App) -> ::diode_http::Router {
253                let mut router = ::diode_http::Router::new();
254                #(#routes)*
255                #(
256                    let middleware = app
257                        .get_component::<<#router_middleware as ::diode::Service>::Handle>()
258                        .ok_or_else(|| {
259                            format!(
260                                "Missing component: {}",
261                                ::std::any::type_name::<<#router_middleware as ::diode::Service>::Handle>()
262                            )
263                        })
264                        .unwrap();
265                    router = router.layer(::diode_http::MiddlewareLayerImpl(middleware));
266                )*
267                router
268            }
269        }
270    }
271    .into()
272}