Skip to main content

picoserve_derive/
lib.rs

1#![warn(clippy::doc_markdown)]
2
3use proc_macro2::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{parse_macro_input, spanned::Spanned, DeriveInput};
6
7/// Helpers used interally by picoserve
8#[doc(hidden)]
9mod internal;
10
11trait HasAttributes: Spanned {
12    fn attributes(&self) -> &[syn::Attribute];
13}
14
15impl HasAttributes for syn::DeriveInput {
16    fn attributes(&self) -> &[syn::Attribute] {
17        &self.attrs
18    }
19}
20
21impl HasAttributes for syn::Variant {
22    fn attributes(&self) -> &[syn::Attribute] {
23        &self.attrs
24    }
25}
26
27fn single_field(fields: &syn::Fields) -> Option<TokenStream> {
28    match fields {
29        syn::Fields::Named(fields) => {
30            let mut fields = fields.named.iter();
31            let field = fields.next()?;
32            fields
33                .next()
34                .is_none()
35                .then(|| quote! { { #field: ref field } })
36        }
37        syn::Fields::Unnamed(fields) => {
38            let mut fields = fields.unnamed.iter();
39            let _field = fields.next()?;
40            fields.next().is_none().then(|| quote! { (ref field) })
41        }
42        syn::Fields::Unit => None,
43    }
44}
45
46enum StatusCodeAttribute {
47    StatusCode(syn::Path),
48    Transparent,
49}
50
51impl StatusCodeAttribute {
52    fn parse<T: HasAttributes>(obj: &T) -> syn::Result<Option<Self>> {
53        obj.attributes()
54            .iter()
55            .find(|attribute| attribute.path().is_ident("status_code"))
56            .map(|status_code| {
57                let syn::Meta::List(syn::MetaList { tokens, .. }) = &status_code.meta else {
58                    return Err(syn::Error::new(
59                        obj.span(),
60                        "status_code attr must be in the form #[status_code(...)]",
61                    ));
62                };
63
64                let path = syn::parse2::<syn::Path>(tokens.clone())?;
65
66                Ok(if path.is_ident("transparent") {
67                    StatusCodeAttribute::Transparent
68                } else {
69                    StatusCodeAttribute::StatusCode(
70                        syn::parse_quote! { picoserve::response::StatusCode::#path },
71                    )
72                })
73            })
74            .transpose()
75    }
76}
77
78fn try_derive_error_with_status_code(input: &DeriveInput) -> Result<TokenStream, syn::Error> {
79    let ident = &input.ident;
80
81    let default_status_code = StatusCodeAttribute::parse(input)?;
82
83    let status_code: syn::Expr = match &input.data {
84        syn::Data::Struct(data_struct) => match default_status_code
85            .ok_or_else(|| syn::Error::new(input.span(), "Missing #[status_code(..)]"))?
86        {
87            StatusCodeAttribute::StatusCode(path) => syn::Expr::Path(syn::ExprPath {
88                attrs: Vec::new(),
89                qself: None,
90                path,
91            }),
92            StatusCodeAttribute::Transparent => {
93                let fields = single_field(&data_struct.fields).ok_or_else(|| {
94                    syn::Error::new(input.span(), "Transparent errors must have a single field")
95                })?;
96
97                syn::parse_quote! {
98                    let Self #fields = self;
99                    picoserve::response::ErrorWithStatusCode::status_code(field)
100                }
101            }
102        },
103        syn::Data::Enum(data_enum) => {
104            let cases = data_enum
105                .variants
106                .iter()
107                .map(|variant| {
108                    let variant_status_code = StatusCodeAttribute::parse(variant)?;
109
110                    let selected_status_code = variant_status_code
111                        .as_ref()
112                        .or(default_status_code.as_ref());
113
114                    let selected_status_code = selected_status_code.ok_or_else(|| {
115                        syn::Error::new(
116                            variant.span(),
117                            "Either the enum or this variant must have an attribute of status_code",
118                        )
119                    })?;
120
121                    let ident = &variant.ident;
122                    let fields;
123                    let status_code: syn::Expr;
124
125                    match selected_status_code {
126                        StatusCodeAttribute::StatusCode(selected_status_code) => {
127                            fields = match variant.fields {
128                                syn::Fields::Named(..) => quote! { {..} },
129                                syn::Fields::Unnamed(..) => quote! { (..) },
130                                syn::Fields::Unit => quote! {},
131                            };
132
133                            status_code = syn::parse_quote! { #selected_status_code };
134                        }
135                        StatusCodeAttribute::Transparent => {
136                            fields = single_field(&variant.fields).ok_or_else(|| {
137                                syn::Error::new(
138                                    variant.span(),
139                                    "Transparent errors must have a single field",
140                                )
141                            })?;
142
143                            status_code = syn::parse_quote! {
144                                picoserve::response::ErrorWithStatusCode::status_code(field)
145                            };
146                        }
147                    }
148
149                    Ok(quote! { Self::#ident #fields => #status_code })
150                })
151                .collect::<Result<Vec<_>, syn::Error>>()?;
152
153            syn::parse_quote! {
154                match *self {
155                    #(#cases,)*
156                }
157            }
158        }
159        syn::Data::Union(..) => {
160            return Err(syn::Error::new(input.span(), "Must be a struct or an enum"))
161        }
162    };
163
164    let syn::Generics {
165        lt_token,
166        params: generics_params,
167        gt_token,
168        where_clause,
169    } = &input.generics;
170
171    let self_is_display = syn::parse_quote!(Self: core::fmt::Display);
172
173    let where_clause_predicates = where_clause
174        .as_ref()
175        .map(|where_clause| where_clause.predicates.iter())
176        .into_iter()
177        .flatten()
178        .chain(std::iter::once(&self_is_display))
179        .collect::<syn::punctuated::Punctuated<_, syn::token::Comma>>();
180
181    let param_names = generics_params
182        .iter()
183        .map(|param| match param {
184            syn::GenericParam::Lifetime(syn::LifetimeParam { lifetime, .. }) => {
185                lifetime.to_token_stream()
186            }
187            syn::GenericParam::Type(type_param) => type_param.ident.to_token_stream(),
188            syn::GenericParam::Const(const_param) => const_param.ident.to_token_stream(),
189        })
190        .collect::<syn::punctuated::Punctuated<TokenStream, syn::token::Comma>>();
191
192    Ok(quote! {
193        #[allow(unused_qualifications)]
194        #[automatically_derived]
195        impl #lt_token #generics_params #gt_token picoserve::response::ErrorWithStatusCode for #ident #lt_token #param_names #gt_token where #where_clause_predicates {
196            fn status_code(&self) -> picoserve::response::StatusCode {
197                #status_code
198            }
199        }
200
201        #[allow(unused_qualifications)]
202        #[automatically_derived]
203        impl #lt_token #generics_params #gt_token picoserve::response::IntoResponse for #ident #lt_token #param_names #gt_token where #where_clause_predicates {
204            async fn write_to<R: picoserve::io::Read, W: picoserve::response::ResponseWriter<Error = R::Error>>(
205                self,
206                connection: picoserve::response::Connection<'_, R>,
207                response_writer: W,
208            ) -> Result<picoserve::ResponseSent, W::Error> {
209                (picoserve::response::ErrorWithStatusCode::status_code(&self), format_args!("{self}\n"))
210                    .write_to(connection, response_writer)
211                    .await
212            }
213        }
214    })
215}
216
217/// Derive `ErrorWithStatusCode` for a struct or an enum.
218///
219/// This will also derive `IntoResponse`, returning a `Response` with the given status code and a `text/plain` body of the `Display` implementation.
220///
221/// # Structs
222///
223/// There must be an attribute `status_code` containing the [`StatusCode`](https://docs.rs/picoserve/latest/picoserve/response/status/struct.StatusCode.html) of the error, e.g. `#[status_code(INTERNAL_SERVER_ERROR)]`.
224///
225/// If the `status_code` is `transparent`, the struct must contain a single field which implements `ErrorWithStatusCode`.
226///
227/// # Enums
228///
229/// There may be an attribute `status_code` on the enum itself containing the default [`StatusCode`](https://docs.rs/picoserve/latest/picoserve/response/status/struct.StatusCode.html) of the error.
230///
231/// There may also be an attribute `status_code` on a variant, which overrides the default [`StatusCode`](https://docs.rs/picoserve/latest/picoserve/response/status/struct.StatusCode.html).
232/// If all variants have their own attribute `status_code`, the default may be omitted.
233///
234/// Variants with a `status_code` of `transparent` must contain a single field which implements `ErrorWithStatusCode`.
235#[proc_macro_derive(ErrorWithStatusCode, attributes(status_code))]
236pub fn derive_error_with_status_code(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
237    let input = parse_macro_input!(input as DeriveInput);
238
239    match try_derive_error_with_status_code(&input) {
240        Ok(tokens) => tokens.into(),
241        Err(error) => error.into_compile_error().into(),
242    }
243}
244
245/// Used internally by `picoserve`.
246#[doc(hidden)]
247#[proc_macro]
248pub fn generate_method_router(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
249    internal::router::generate_method_router(input)
250}