Skip to main content

better_fetch_macros/
lib.rs

1//! Proc-macro helpers for [`better-fetch`](https://docs.rs/better-fetch).
2//!
3//! Enable the `macros` feature on `better-fetch` to use these derives.
4
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::quote;
8use syn::{
9    parse_macro_input, spanned::Spanned, Attribute, Data, DeriveInput, Fields, LitStr, Meta,
10};
11
12fn endpoint_path_attr(attrs: &[Attribute]) -> syn::Result<LitStr> {
13    for attr in attrs {
14        if !attr.path().is_ident("endpoint") {
15            continue;
16        }
17        let Meta::List(list) = &attr.meta else {
18            return Err(syn::Error::new(attr.span(), "`#[endpoint]` must be a list"));
19        };
20        let mut found = None;
21        list.parse_nested_meta(|meta| {
22            if meta.path.is_ident("path") {
23                let value = meta.value()?;
24                found = Some(value.parse::<LitStr>()?);
25            }
26            Ok(())
27        })?;
28        if let Some(path) = found {
29            return Ok(path);
30        }
31        return Err(syn::Error::new(
32            attr.span(),
33            "`#[endpoint]` requires `path = \"...\"`",
34        ));
35    }
36    Err(syn::Error::new(
37        Span::call_site(),
38        "`#[derive(EndpointParams)]` requires `#[endpoint(path = \"/route/:param\")]`",
39    ))
40}
41
42fn param_key(field: &syn::Field) -> syn::Result<String> {
43    for attr in &field.attrs {
44        if !attr.path().is_ident("param") {
45            continue;
46        }
47        let Meta::List(list) = &attr.meta else {
48            continue;
49        };
50        let mut rename = None;
51        list.parse_nested_meta(|meta| {
52            if meta.path.is_ident("rename") {
53                let value = meta.value()?;
54                rename = Some(value.parse::<LitStr>()?.value());
55            }
56            Ok(())
57        })?;
58        if let Some(name) = rename {
59            return Ok(name);
60        }
61    }
62    let ident = field
63        .ident
64        .as_ref()
65        .ok_or_else(|| syn::Error::new(field.span(), "tuple struct fields are not supported"))?;
66    Ok(ident.to_string())
67}
68
69fn path_param_names(path: &str) -> Vec<String> {
70    path.split('/')
71        .filter_map(|segment| segment.strip_prefix(':').map(str::to_string))
72        .collect()
73}
74
75/// Derives [`EndpointParams`](https://docs.rs/better-fetch/latest/better_fetch/trait.EndpointParams.html)
76/// for a struct with one field per `:param` segment in `#[endpoint(path = "...")]`.
77///
78/// Optional `#[param(rename = "segmentName")]` overrides the path segment for a field.
79#[proc_macro_derive(EndpointParams, attributes(endpoint, param))]
80pub fn derive_endpoint_params(input: TokenStream) -> TokenStream {
81    let input = parse_macro_input!(input as DeriveInput);
82    match derive_endpoint_params_impl(input) {
83        Ok(tokens) => tokens.into(),
84        Err(err) => err.to_compile_error().into(),
85    }
86}
87
88fn derive_endpoint_params_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
89    let name = &input.ident;
90    let path = endpoint_path_attr(&input.attrs)?;
91    let path_value = path.value();
92
93    let Data::Struct(data) = &input.data else {
94        return Err(syn::Error::new(
95            input.span(),
96            "`EndpointParams` can only be derived for structs",
97        ));
98    };
99
100    let Fields::Named(fields) = &data.fields else {
101        return Err(syn::Error::new(
102            data.fields.span(),
103            "`EndpointParams` requires a struct with named fields",
104        ));
105    };
106
107    let mut field_keys = Vec::new();
108    let mut apply_pairs = Vec::new();
109
110    for field in &fields.named {
111        let ident = field.ident.as_ref().expect("named field");
112        let key = param_key(field)?;
113        field_keys.push(key.clone());
114        apply_pairs.push(quote! {
115            builder = builder.param(#key, self.#ident);
116        });
117    }
118
119    let expected = path_param_names(&path_value);
120    if expected.len() != field_keys.len() {
121        return Err(syn::Error::new(
122            path.span(),
123            format!(
124                "path `{path_value}` has {} `:param` segment(s) but the struct has {} field(s)",
125                expected.len(),
126                field_keys.len()
127            ),
128        ));
129    }
130
131    for segment in expected {
132        if !field_keys.iter().any(|key| key == &segment) {
133            return Err(syn::Error::new(
134                path.span(),
135                format!("missing struct field for path parameter `:{segment}`"),
136            ));
137        }
138    }
139
140    Ok(quote! {
141        impl ::better_fetch::EndpointParams for #name {
142            type BuilderState = ::better_fetch::NeedsParams;
143
144            fn apply_params(
145                self,
146                mut builder: ::better_fetch::RequestBuilder<'_>,
147            ) -> ::better_fetch::RequestBuilder<'_> {
148                #(#apply_pairs)*
149                builder
150            }
151        }
152    })
153}
154
155/// Derives [`EndpointQuery`](https://docs.rs/better-fetch/latest/better_fetch/trait.EndpointQuery.html)
156/// for a serde-serializable query struct.
157///
158/// Requires `Serialize` on the type (typically via `#[derive(Serialize)]`).
159#[proc_macro_derive(EndpointQuery, attributes(query))]
160pub fn derive_endpoint_query(input: TokenStream) -> TokenStream {
161    let input = parse_macro_input!(input as DeriveInput);
162    match derive_endpoint_query_impl(input) {
163        Ok(tokens) => tokens.into(),
164        Err(err) => err.to_compile_error().into(),
165    }
166}
167
168fn derive_endpoint_query_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
169    let name = &input.ident;
170
171    let Data::Struct(data) = &input.data else {
172        return Err(syn::Error::new(
173            input.span(),
174            "`EndpointQuery` can only be derived for structs",
175        ));
176    };
177
178    if !matches!(data.fields, Fields::Named(_)) {
179        return Err(syn::Error::new(
180            data.fields.span(),
181            "`EndpointQuery` requires a struct with named fields",
182        ));
183    }
184
185    Ok(quote! {
186        impl ::better_fetch::EndpointQuery for #name {
187            fn apply_query(
188                self,
189                builder: ::better_fetch::RequestBuilder<'_>,
190            ) -> ::better_fetch::RequestBuilder<'_> {
191                ::better_fetch::endpoint::apply_serialized_query(self, builder)
192            }
193        }
194    })
195}