1#![warn(clippy::doc_markdown)]
2
3use proc_macro2::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{parse_macro_input, spanned::Spanned, DeriveInput};
6
7#[doc(hidden)]
9mod internal;
10
11trait HasAttributes: Spanned {
12 fn attributes(&self) -> &[syn::Attribute];
13}
14
15impl HasAttributes for syn::DeriveInput {
16 fn attributes(&self) -> &[syn::Attribute] {
17 &self.attrs
18 }
19}
20
21impl HasAttributes for syn::Variant {
22 fn attributes(&self) -> &[syn::Attribute] {
23 &self.attrs
24 }
25}
26
27fn single_field(fields: &syn::Fields) -> Option<TokenStream> {
28 match fields {
29 syn::Fields::Named(fields) => {
30 let mut fields = fields.named.iter();
31 let field = fields.next()?;
32 fields
33 .next()
34 .is_none()
35 .then(|| quote! { { #field: ref field } })
36 }
37 syn::Fields::Unnamed(fields) => {
38 let mut fields = fields.unnamed.iter();
39 let _field = fields.next()?;
40 fields.next().is_none().then(|| quote! { (ref field) })
41 }
42 syn::Fields::Unit => None,
43 }
44}
45
46enum StatusCodeAttribute {
47 StatusCode(syn::Path),
48 Transparent,
49}
50
51impl StatusCodeAttribute {
52 fn parse<T: HasAttributes>(obj: &T) -> syn::Result<Option<Self>> {
53 obj.attributes()
54 .iter()
55 .find(|attribute| attribute.path().is_ident("status_code"))
56 .map(|status_code| {
57 let syn::Meta::List(syn::MetaList { tokens, .. }) = &status_code.meta else {
58 return Err(syn::Error::new(
59 obj.span(),
60 "status_code attr must be in the form #[status_code(...)]",
61 ));
62 };
63
64 let path = syn::parse2::<syn::Path>(tokens.clone())?;
65
66 Ok(if path.is_ident("transparent") {
67 StatusCodeAttribute::Transparent
68 } else {
69 StatusCodeAttribute::StatusCode(
70 syn::parse_quote! { picoserve::response::StatusCode::#path },
71 )
72 })
73 })
74 .transpose()
75 }
76}
77
78fn try_derive_error_with_status_code(input: &DeriveInput) -> Result<TokenStream, syn::Error> {
79 let ident = &input.ident;
80
81 let default_status_code = StatusCodeAttribute::parse(input)?;
82
83 let status_code: syn::Expr = match &input.data {
84 syn::Data::Struct(data_struct) => match default_status_code
85 .ok_or_else(|| syn::Error::new(input.span(), "Missing #[status_code(..)]"))?
86 {
87 StatusCodeAttribute::StatusCode(path) => syn::Expr::Path(syn::ExprPath {
88 attrs: Vec::new(),
89 qself: None,
90 path,
91 }),
92 StatusCodeAttribute::Transparent => {
93 let fields = single_field(&data_struct.fields).ok_or_else(|| {
94 syn::Error::new(input.span(), "Transparent errors must have a single field")
95 })?;
96
97 syn::parse_quote! {
98 let Self #fields = self;
99 picoserve::response::ErrorWithStatusCode::status_code(field)
100 }
101 }
102 },
103 syn::Data::Enum(data_enum) => {
104 let cases = data_enum
105 .variants
106 .iter()
107 .map(|variant| {
108 let variant_status_code = StatusCodeAttribute::parse(variant)?;
109
110 let selected_status_code = variant_status_code
111 .as_ref()
112 .or(default_status_code.as_ref());
113
114 let selected_status_code = selected_status_code.ok_or_else(|| {
115 syn::Error::new(
116 variant.span(),
117 "Either the enum or this variant must have an attribute of status_code",
118 )
119 })?;
120
121 let ident = &variant.ident;
122 let fields;
123 let status_code: syn::Expr;
124
125 match selected_status_code {
126 StatusCodeAttribute::StatusCode(selected_status_code) => {
127 fields = match variant.fields {
128 syn::Fields::Named(..) => quote! { {..} },
129 syn::Fields::Unnamed(..) => quote! { (..) },
130 syn::Fields::Unit => quote! {},
131 };
132
133 status_code = syn::parse_quote! { #selected_status_code };
134 }
135 StatusCodeAttribute::Transparent => {
136 fields = single_field(&variant.fields).ok_or_else(|| {
137 syn::Error::new(
138 variant.span(),
139 "Transparent errors must have a single field",
140 )
141 })?;
142
143 status_code = syn::parse_quote! {
144 picoserve::response::ErrorWithStatusCode::status_code(field)
145 };
146 }
147 }
148
149 Ok(quote! { Self::#ident #fields => #status_code })
150 })
151 .collect::<Result<Vec<_>, syn::Error>>()?;
152
153 syn::parse_quote! {
154 match *self {
155 #(#cases,)*
156 }
157 }
158 }
159 syn::Data::Union(..) => {
160 return Err(syn::Error::new(input.span(), "Must be a struct or an enum"))
161 }
162 };
163
164 let syn::Generics {
165 lt_token,
166 params: generics_params,
167 gt_token,
168 where_clause,
169 } = &input.generics;
170
171 let self_is_display = syn::parse_quote!(Self: core::fmt::Display);
172
173 let where_clause_predicates = where_clause
174 .as_ref()
175 .map(|where_clause| where_clause.predicates.iter())
176 .into_iter()
177 .flatten()
178 .chain(std::iter::once(&self_is_display))
179 .collect::<syn::punctuated::Punctuated<_, syn::token::Comma>>();
180
181 let param_names = generics_params
182 .iter()
183 .map(|param| match param {
184 syn::GenericParam::Lifetime(syn::LifetimeParam { lifetime, .. }) => {
185 lifetime.to_token_stream()
186 }
187 syn::GenericParam::Type(type_param) => type_param.ident.to_token_stream(),
188 syn::GenericParam::Const(const_param) => const_param.ident.to_token_stream(),
189 })
190 .collect::<syn::punctuated::Punctuated<TokenStream, syn::token::Comma>>();
191
192 Ok(quote! {
193 #[allow(unused_qualifications)]
194 #[automatically_derived]
195 impl #lt_token #generics_params #gt_token picoserve::response::ErrorWithStatusCode for #ident #lt_token #param_names #gt_token where #where_clause_predicates {
196 fn status_code(&self) -> picoserve::response::StatusCode {
197 #status_code
198 }
199 }
200
201 #[allow(unused_qualifications)]
202 #[automatically_derived]
203 impl #lt_token #generics_params #gt_token picoserve::response::IntoResponse for #ident #lt_token #param_names #gt_token where #where_clause_predicates {
204 async fn write_to<R: picoserve::io::Read, W: picoserve::response::ResponseWriter<Error = R::Error>>(
205 self,
206 connection: picoserve::response::Connection<'_, R>,
207 response_writer: W,
208 ) -> Result<picoserve::ResponseSent, W::Error> {
209 (picoserve::response::ErrorWithStatusCode::status_code(&self), format_args!("{self}\n"))
210 .write_to(connection, response_writer)
211 .await
212 }
213 }
214 })
215}
216
217#[proc_macro_derive(ErrorWithStatusCode, attributes(status_code))]
236pub fn derive_error_with_status_code(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
237 let input = parse_macro_input!(input as DeriveInput);
238
239 match try_derive_error_with_status_code(&input) {
240 Ok(tokens) => tokens.into(),
241 Err(error) => error.into_compile_error().into(),
242 }
243}
244
245#[doc(hidden)]
247#[proc_macro]
248pub fn generate_method_router(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
249 internal::router::generate_method_router(input)
250}