Skip to main content

better_fetch_macros/
lib.rs

1//! Proc-macro helpers for [`better-fetch`](https://docs.rs/better-fetch).
2//!
3//! Enable the `macros` feature on `better-fetch` to use these derives.
4
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::quote;
8use syn::{
9    parse_macro_input, spanned::Spanned, Attribute, Data, DeriveInput, Fields, LitStr, Meta,
10};
11
12fn endpoint_path_attr(attrs: &[Attribute]) -> syn::Result<LitStr> {
13    for attr in attrs {
14        if !attr.path().is_ident("endpoint") {
15            continue;
16        }
17        let Meta::List(list) = &attr.meta else {
18            return Err(syn::Error::new(attr.span(), "`#[endpoint]` must be a list"));
19        };
20        let mut found = None;
21        list.parse_nested_meta(|meta| {
22            if meta.path.is_ident("path") {
23                let value = meta.value()?;
24                found = Some(value.parse::<LitStr>()?);
25            }
26            Ok(())
27        })?;
28        if let Some(path) = found {
29            return Ok(path);
30        }
31        return Err(syn::Error::new(
32            attr.span(),
33            "`#[endpoint]` requires `path = \"...\"`",
34        ));
35    }
36    Err(syn::Error::new(
37        Span::call_site(),
38        "`#[derive(EndpointParams)]` requires `#[endpoint(path = \"/route/:param\")]`",
39    ))
40}
41
42fn param_key(field: &syn::Field) -> syn::Result<String> {
43    for attr in &field.attrs {
44        if !attr.path().is_ident("param") {
45            continue;
46        }
47        let Meta::List(list) = &attr.meta else {
48            continue;
49        };
50        let mut rename = None;
51        list.parse_nested_meta(|meta| {
52            if meta.path.is_ident("rename") {
53                let value = meta.value()?;
54                rename = Some(value.parse::<LitStr>()?.value());
55            }
56            Ok(())
57        })?;
58        if let Some(name) = rename {
59            return Ok(name);
60        }
61    }
62    let ident = field
63        .ident
64        .as_ref()
65        .ok_or_else(|| syn::Error::new(field.span(), "tuple struct fields are not supported"))?;
66    Ok(ident.to_string())
67}
68
69fn path_param_names(path: &str) -> Vec<String> {
70    path.split('/')
71        .filter_map(|segment| segment.strip_prefix(':').map(str::to_string))
72        .collect()
73}
74
75/// Derives [`EndpointParams`](https://docs.rs/better-fetch/latest/better_fetch/trait.EndpointParams.html)
76/// for a struct with one field per `:param` segment in `#[endpoint(path = "...")]`.
77///
78/// Optional `#[param(rename = "segmentName")]` overrides the path segment for a field.
79#[proc_macro_derive(EndpointParams, attributes(endpoint, param))]
80pub fn derive_endpoint_params(input: TokenStream) -> TokenStream {
81    let input = parse_macro_input!(input as DeriveInput);
82    match derive_endpoint_params_impl(input) {
83        Ok(tokens) => tokens.into(),
84        Err(err) => err.to_compile_error().into(),
85    }
86}
87
88fn derive_endpoint_params_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
89    let name = &input.ident;
90    let path = endpoint_path_attr(&input.attrs)?;
91    let path_value = path.value();
92
93    let Data::Struct(data) = &input.data else {
94        return Err(syn::Error::new(
95            input.span(),
96            "`EndpointParams` can only be derived for structs",
97        ));
98    };
99
100    let Fields::Named(fields) = &data.fields else {
101        return Err(syn::Error::new(
102            data.fields.span(),
103            "`EndpointParams` requires a struct with named fields",
104        ));
105    };
106
107    let mut field_keys = Vec::new();
108    let mut apply_pairs = Vec::new();
109
110    let mut seen_keys = std::collections::HashSet::new();
111    for field in &fields.named {
112        let ident = field.ident.as_ref().expect("named field");
113        let key = param_key(field)?;
114        if !seen_keys.insert(key.clone()) {
115            return Err(syn::Error::new(
116                field.span(),
117                format!("duplicate path parameter `{key}`"),
118            ));
119        }
120        field_keys.push(key.clone());
121        apply_pairs.push(quote! {
122            builder = builder.param(#key, self.#ident);
123        });
124    }
125
126    let expected = path_param_names(&path_value);
127    let mut seen_segments = std::collections::HashSet::new();
128    for segment in &expected {
129        if !seen_segments.insert(segment.clone()) {
130            return Err(syn::Error::new(
131                path.span(),
132                format!("duplicate `:param` segment `:{segment}` in path"),
133            ));
134        }
135    }
136    if expected.len() != field_keys.len() {
137        return Err(syn::Error::new(
138            path.span(),
139            format!(
140                "path `{path_value}` has {} `:param` segment(s) but the struct has {} field(s)",
141                expected.len(),
142                field_keys.len()
143            ),
144        ));
145    }
146
147    for segment in expected {
148        if !field_keys.iter().any(|key| key == &segment) {
149            return Err(syn::Error::new(
150                path.span(),
151                format!("missing struct field for path parameter `:{segment}`"),
152            ));
153        }
154    }
155
156    Ok(quote! {
157        impl ::better_fetch::EndpointParams for #name {
158            type BuilderState = ::better_fetch::NeedsParams;
159
160            fn apply_params(
161                self,
162                mut builder: ::better_fetch::RequestBuilder<'_>,
163            ) -> ::better_fetch::RequestBuilder<'_> {
164                #(#apply_pairs)*
165                builder
166            }
167        }
168    })
169}
170
171/// Derives [`EndpointQuery`](https://docs.rs/better-fetch/latest/better_fetch/trait.EndpointQuery.html)
172/// for a serde-serializable query struct.
173///
174/// Requires `Serialize` on the type (typically via `#[derive(Serialize)]`).
175#[proc_macro_derive(EndpointQuery, attributes(query))]
176pub fn derive_endpoint_query(input: TokenStream) -> TokenStream {
177    let input = parse_macro_input!(input as DeriveInput);
178    match derive_endpoint_query_impl(input) {
179        Ok(tokens) => tokens.into(),
180        Err(err) => err.to_compile_error().into(),
181    }
182}
183
184fn derive_endpoint_query_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
185    let name = &input.ident;
186
187    let Data::Struct(data) = &input.data else {
188        return Err(syn::Error::new(
189            input.span(),
190            "`EndpointQuery` can only be derived for structs",
191        ));
192    };
193
194    if !matches!(data.fields, Fields::Named(_)) {
195        return Err(syn::Error::new(
196            data.fields.span(),
197            "`EndpointQuery` requires a struct with named fields",
198        ));
199    }
200
201    Ok(quote! {
202        impl ::better_fetch::EndpointQuery for #name {
203            fn apply_query(
204                self,
205                builder: ::better_fetch::RequestBuilder<'_>,
206            ) -> ::better_fetch::Result<::better_fetch::RequestBuilder<'_>> {
207                ::better_fetch::endpoint::apply_serialized_query(self, builder)
208            }
209        }
210    })
211}
212
213fn endpoint_meta(
214    attrs: &[Attribute],
215) -> syn::Result<(proc_macro2::TokenStream, LitStr, bool, bool)> {
216    for attr in attrs {
217        if !attr.path().is_ident("endpoint") {
218            continue;
219        }
220        let Meta::List(list) = &attr.meta else {
221            return Err(syn::Error::new(attr.span(), "`#[endpoint]` must be a list"));
222        };
223        let mut method = None;
224        let mut path = None;
225        let mut register = false;
226        list.parse_nested_meta(|meta| {
227            if meta.path.is_ident("method") {
228                let value = meta.value()?;
229                method = Some(value.parse::<syn::Path>()?);
230            } else if meta.path.is_ident("path") {
231                let value = meta.value()?;
232                path = Some(value.parse::<LitStr>()?);
233            } else if meta.path.is_ident("register") {
234                register = true;
235            }
236            Ok(())
237        })?;
238        let method_path = method.ok_or_else(|| {
239            syn::Error::new(attr.span(), "`#[endpoint]` requires `method = GET` (etc.)")
240        })?;
241        let path = path.ok_or_else(|| {
242            syn::Error::new(attr.span(), "`#[endpoint]` requires `path = \"...\"`")
243        })?;
244        let is_post = method_path.get_ident().is_some_and(|id| id == "POST")
245            || method_path
246                .segments
247                .last()
248                .is_some_and(|seg| seg.ident == "POST");
249        let method = if let Some(ident) = method_path.get_ident() {
250            quote!(::http::Method::#ident)
251        } else {
252            quote!(#method_path)
253        };
254        return Ok((method, path, is_post, register));
255    }
256    Err(syn::Error::new(
257        Span::call_site(),
258        "`#[derive(Endpoint)]` requires `#[endpoint(method = GET, path = \"...\")]`",
259    ))
260}
261
262fn is_unit_type(ty: &syn::Type) -> bool {
263    matches!(ty, syn::Type::Tuple(t) if t.elems.is_empty())
264}
265
266fn endpoint_field_type(field: &syn::Field, attr: &str) -> Option<syn::Type> {
267    field
268        .attrs
269        .iter()
270        .any(|a| a.path().is_ident(attr))
271        .then(|| field.ty.clone())
272}
273
274/// Derives [`Endpoint`](https://docs.rs/better-fetch/latest/better_fetch/trait.Endpoint.html).
275///
276/// ```ignore
277/// #[derive(Endpoint)]
278/// #[endpoint(method = GET, path = "/items/:id")]
279/// struct GetItem {
280///     #[response]
281///     Item,
282///     #[params]
283///     ItemParams,
284/// }
285/// ```
286#[proc_macro_derive(
287    Endpoint,
288    attributes(endpoint, response, params, query, body, headers, param, query_field)
289)]
290pub fn derive_endpoint(input: TokenStream) -> TokenStream {
291    let input = parse_macro_input!(input as DeriveInput);
292    match derive_endpoint_impl(input) {
293        Ok(tokens) => tokens.into(),
294        Err(err) => err.to_compile_error().into(),
295    }
296}
297
298fn derive_endpoint_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
299    let name = &input.ident;
300    let (method, path, is_post, register) = endpoint_meta(&input.attrs)?;
301    let path_value = path.value();
302
303    let Data::Struct(data) = &input.data else {
304        return Err(syn::Error::new(
305            input.span(),
306            "`Endpoint` can only be derived for structs",
307        ));
308    };
309
310    let Fields::Named(fields) = &data.fields else {
311        return Err(syn::Error::new(
312            data.fields.span(),
313            "`Endpoint` requires a struct with named fields for `#[response]` etc.",
314        ));
315    };
316
317    let mut response = quote!(());
318    let mut params = quote!(());
319    let mut query = quote!(());
320    let mut body = quote!(());
321    let mut headers = quote!(());
322    let mut body_ty: Option<syn::Type> = None;
323    let mut inline_param_fields: Vec<&syn::Field> = Vec::new();
324    let mut inline_query_fields: Vec<&syn::Field> = Vec::new();
325    let mut explicit_params = false;
326    let mut explicit_query = false;
327
328    for field in &fields.named {
329        if field.attrs.iter().any(|a| a.path().is_ident("param")) {
330            inline_param_fields.push(field);
331            continue;
332        }
333        if field.attrs.iter().any(|a| a.path().is_ident("query_field")) {
334            inline_query_fields.push(field);
335            continue;
336        }
337        if let Some(ty) = endpoint_field_type(field, "response") {
338            response = quote!(#ty);
339        } else if let Some(ty) = endpoint_field_type(field, "params") {
340            explicit_params = true;
341            params = quote!(#ty);
342        } else if let Some(ty) = endpoint_field_type(field, "query") {
343            explicit_query = true;
344            query = quote!(#ty);
345        } else if let Some(ty) = endpoint_field_type(field, "body") {
346            body_ty = Some(ty.clone());
347            body = quote!(#ty);
348        } else if let Some(ty) = endpoint_field_type(field, "headers") {
349            headers = quote!(#ty);
350        }
351    }
352
353    if explicit_params && !inline_param_fields.is_empty() {
354        return Err(syn::Error::new(
355            input.span(),
356            "use either `#[params] Type` or `#[param]` fields on the endpoint struct, not both",
357        ));
358    }
359
360    if explicit_query && !inline_query_fields.is_empty() {
361        return Err(syn::Error::new(
362            input.span(),
363            "use either `#[query] Type` or `#[query_field]` fields on the endpoint struct, not both",
364        ));
365    }
366
367    let params_ty_ident = syn::Ident::new(&format!("{name}Params"), name.span());
368    let query_ty_ident = syn::Ident::new(&format!("{name}Query"), name.span());
369    let inline_params_impl = if !inline_param_fields.is_empty() {
370        let mut field_defs = Vec::new();
371        let mut apply_pairs = Vec::new();
372        let mut field_keys = Vec::new();
373        let mut seen_keys = std::collections::HashSet::new();
374
375        for field in &inline_param_fields {
376            let ident = field.ident.as_ref().expect("named field");
377            let key = param_key(field)?;
378            if !seen_keys.insert(key.clone()) {
379                return Err(syn::Error::new(
380                    field.span(),
381                    format!("duplicate path parameter `{key}`"),
382                ));
383            }
384            field_keys.push(key.clone());
385            let ty = &field.ty;
386            field_defs.push(quote! { pub #ident: #ty });
387            apply_pairs.push(quote! {
388                builder = builder.param(#key, self.#ident);
389            });
390        }
391
392        let expected = path_param_names(&path_value);
393        if expected.len() != field_keys.len() {
394            return Err(syn::Error::new(
395                path.span(),
396                format!(
397                    "path `{path_value}` has {} `:param` segment(s) but the endpoint has {} `#[param]` field(s)",
398                    expected.len(),
399                    field_keys.len()
400                ),
401            ));
402        }
403        for segment in expected {
404            if !field_keys.iter().any(|key| key == &segment) {
405                return Err(syn::Error::new(
406                    path.span(),
407                    format!("missing `#[param]` field for path parameter `:{segment}`"),
408                ));
409            }
410        }
411
412        params = quote!(#params_ty_ident);
413        quote! {
414            #[derive(Debug, Clone, Default)]
415            pub struct #params_ty_ident {
416                #(#field_defs),*
417            }
418
419            impl ::better_fetch::EndpointParams for #params_ty_ident {
420                type BuilderState = ::better_fetch::NeedsParams;
421
422                fn apply_params(
423                    self,
424                    mut builder: ::better_fetch::RequestBuilder<'_>,
425                ) -> ::better_fetch::RequestBuilder<'_> {
426                    #(#apply_pairs)*
427                    builder
428                }
429            }
430        }
431    } else {
432        quote! {}
433    };
434
435    let inline_query_impl = if !inline_query_fields.is_empty() {
436        let mut field_defs = Vec::new();
437        for field in &inline_query_fields {
438            let ident = field.ident.as_ref().expect("named field");
439            let ty = &field.ty;
440            field_defs.push(quote! { pub #ident: #ty });
441        }
442        query = quote!(#query_ty_ident);
443        quote! {
444            #[derive(Debug, Clone, Default, ::serde::Serialize)]
445            pub struct #query_ty_ident {
446                #(#field_defs),*
447            }
448
449            impl ::better_fetch::EndpointQuery for #query_ty_ident {
450                fn apply_query(
451                    self,
452                    builder: ::better_fetch::RequestBuilder<'_>,
453                ) -> ::better_fetch::Result<::better_fetch::RequestBuilder<'_>> {
454                    ::better_fetch::endpoint::apply_serialized_query(self, builder)
455                }
456            }
457        }
458    } else {
459        quote! {}
460    };
461
462    let explicit_query_impl = if explicit_query && inline_query_fields.is_empty() {
463        quote! {
464            impl ::better_fetch::EndpointQuery for #query {
465                fn apply_query(
466                    self,
467                    builder: ::better_fetch::RequestBuilder<'_>,
468                ) -> ::better_fetch::Result<::better_fetch::RequestBuilder<'_>> {
469                    ::better_fetch::endpoint::apply_serialized_query(self, builder)
470                }
471            }
472        }
473    } else {
474        quote! {}
475    };
476
477    let body_required = is_post && body_ty.as_ref().is_some_and(|ty| !is_unit_type(ty));
478
479    let body_required_impl = if let Some(body_type) = body_ty.filter(|_| body_required) {
480        quote! {
481            impl ::better_fetch::EndpointBody for #body_type {
482                type ParamsNext = ::better_fetch::NeedsBody;
483                type CallInitial = ::better_fetch::NeedsBody;
484
485                fn apply_body(
486                    self,
487                    builder: ::better_fetch::RequestBuilder<'_>,
488                ) -> ::better_fetch::Result<::better_fetch::RequestBuilder<'_>> {
489                    builder.json(&self)
490                }
491            }
492
493            impl ::better_fetch::DefaultParamsInitial<#name> for () {
494                fn initial(
495                    client: &::better_fetch::Client,
496                ) -> ::better_fetch::EndpointRequestBuilder<'_, #name, ::better_fetch::NeedsBody> {
497                    ::better_fetch::EndpointRequestBuilder::new_needs_body(
498                        client.request(#method, #path_value),
499                    )
500                }
501            }
502        }
503    } else {
504        quote! {}
505    };
506
507    let register_impl = if register {
508        quote! {
509            impl #name {
510                /// Registers this route in a [`SchemaRegistry`](::better_fetch::SchemaRegistry).
511                #[cfg(feature = "schema")]
512                pub fn register(registry: &mut ::better_fetch::SchemaRegistry) {
513                    registry.register_typed::<#name, #body, #response>();
514                }
515            }
516        }
517    } else {
518        quote! {}
519    };
520
521    Ok(quote! {
522        #inline_params_impl
523        #inline_query_impl
524        #explicit_query_impl
525        impl ::better_fetch::Endpoint for #name {
526            const METHOD: ::http::Method = #method;
527            const PATH: &'static str = #path_value;
528            type Response = #response;
529            type Params = #params;
530            type Query = #query;
531            type Body = #body;
532            type Headers = #headers;
533        }
534        #body_required_impl
535        #register_impl
536    })
537}