1use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::quote;
8use syn::{
9 parse_macro_input, spanned::Spanned, Attribute, Data, DeriveInput, Fields, LitStr, Meta,
10};
11
12fn endpoint_path_attr(attrs: &[Attribute]) -> syn::Result<LitStr> {
13 for attr in attrs {
14 if !attr.path().is_ident("endpoint") {
15 continue;
16 }
17 let Meta::List(list) = &attr.meta else {
18 return Err(syn::Error::new(attr.span(), "`#[endpoint]` must be a list"));
19 };
20 let mut found = None;
21 list.parse_nested_meta(|meta| {
22 if meta.path.is_ident("path") {
23 let value = meta.value()?;
24 found = Some(value.parse::<LitStr>()?);
25 }
26 Ok(())
27 })?;
28 if let Some(path) = found {
29 return Ok(path);
30 }
31 return Err(syn::Error::new(
32 attr.span(),
33 "`#[endpoint]` requires `path = \"...\"`",
34 ));
35 }
36 Err(syn::Error::new(
37 Span::call_site(),
38 "`#[derive(EndpointParams)]` requires `#[endpoint(path = \"/route/:param\")]`",
39 ))
40}
41
42fn param_key(field: &syn::Field) -> syn::Result<String> {
43 for attr in &field.attrs {
44 if !attr.path().is_ident("param") {
45 continue;
46 }
47 let Meta::List(list) = &attr.meta else {
48 continue;
49 };
50 let mut rename = None;
51 list.parse_nested_meta(|meta| {
52 if meta.path.is_ident("rename") {
53 let value = meta.value()?;
54 rename = Some(value.parse::<LitStr>()?.value());
55 }
56 Ok(())
57 })?;
58 if let Some(name) = rename {
59 return Ok(name);
60 }
61 }
62 let ident = field
63 .ident
64 .as_ref()
65 .ok_or_else(|| syn::Error::new(field.span(), "tuple struct fields are not supported"))?;
66 Ok(ident.to_string())
67}
68
69fn path_param_names(path: &str) -> Vec<String> {
70 path.split('/')
71 .filter_map(|segment| segment.strip_prefix(':').map(str::to_string))
72 .collect()
73}
74
75#[proc_macro_derive(EndpointParams, attributes(endpoint, param))]
80pub fn derive_endpoint_params(input: TokenStream) -> TokenStream {
81 let input = parse_macro_input!(input as DeriveInput);
82 match derive_endpoint_params_impl(input) {
83 Ok(tokens) => tokens.into(),
84 Err(err) => err.to_compile_error().into(),
85 }
86}
87
88fn derive_endpoint_params_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
89 let name = &input.ident;
90 let path = endpoint_path_attr(&input.attrs)?;
91 let path_value = path.value();
92
93 let Data::Struct(data) = &input.data else {
94 return Err(syn::Error::new(
95 input.span(),
96 "`EndpointParams` can only be derived for structs",
97 ));
98 };
99
100 let Fields::Named(fields) = &data.fields else {
101 return Err(syn::Error::new(
102 data.fields.span(),
103 "`EndpointParams` requires a struct with named fields",
104 ));
105 };
106
107 let mut field_keys = Vec::new();
108 let mut apply_pairs = Vec::new();
109
110 let mut seen_keys = std::collections::HashSet::new();
111 for field in &fields.named {
112 let ident = field.ident.as_ref().expect("named field");
113 let key = param_key(field)?;
114 if !seen_keys.insert(key.clone()) {
115 return Err(syn::Error::new(
116 field.span(),
117 format!("duplicate path parameter `{key}`"),
118 ));
119 }
120 field_keys.push(key.clone());
121 apply_pairs.push(quote! {
122 builder = builder.param(#key, self.#ident);
123 });
124 }
125
126 let expected = path_param_names(&path_value);
127 let mut seen_segments = std::collections::HashSet::new();
128 for segment in &expected {
129 if !seen_segments.insert(segment.clone()) {
130 return Err(syn::Error::new(
131 path.span(),
132 format!("duplicate `:param` segment `:{segment}` in path"),
133 ));
134 }
135 }
136 if expected.len() != field_keys.len() {
137 return Err(syn::Error::new(
138 path.span(),
139 format!(
140 "path `{path_value}` has {} `:param` segment(s) but the struct has {} field(s)",
141 expected.len(),
142 field_keys.len()
143 ),
144 ));
145 }
146
147 for segment in expected {
148 if !field_keys.iter().any(|key| key == &segment) {
149 return Err(syn::Error::new(
150 path.span(),
151 format!("missing struct field for path parameter `:{segment}`"),
152 ));
153 }
154 }
155
156 Ok(quote! {
157 impl ::better_fetch::EndpointParams for #name {
158 type BuilderState = ::better_fetch::NeedsParams;
159
160 fn apply_params(
161 self,
162 mut builder: ::better_fetch::RequestBuilder<'_>,
163 ) -> ::better_fetch::RequestBuilder<'_> {
164 #(#apply_pairs)*
165 builder
166 }
167 }
168 })
169}
170
171#[proc_macro_derive(EndpointQuery, attributes(query))]
176pub fn derive_endpoint_query(input: TokenStream) -> TokenStream {
177 let input = parse_macro_input!(input as DeriveInput);
178 match derive_endpoint_query_impl(input) {
179 Ok(tokens) => tokens.into(),
180 Err(err) => err.to_compile_error().into(),
181 }
182}
183
184fn derive_endpoint_query_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
185 let name = &input.ident;
186
187 let Data::Struct(data) = &input.data else {
188 return Err(syn::Error::new(
189 input.span(),
190 "`EndpointQuery` can only be derived for structs",
191 ));
192 };
193
194 if !matches!(data.fields, Fields::Named(_)) {
195 return Err(syn::Error::new(
196 data.fields.span(),
197 "`EndpointQuery` requires a struct with named fields",
198 ));
199 }
200
201 Ok(quote! {
202 impl ::better_fetch::EndpointQuery for #name {
203 fn apply_query(
204 self,
205 builder: ::better_fetch::RequestBuilder<'_>,
206 ) -> ::better_fetch::Result<::better_fetch::RequestBuilder<'_>> {
207 ::better_fetch::endpoint::apply_serialized_query(self, builder)
208 }
209 }
210 })
211}
212
213fn endpoint_meta(
214 attrs: &[Attribute],
215) -> syn::Result<(proc_macro2::TokenStream, LitStr, bool, bool)> {
216 for attr in attrs {
217 if !attr.path().is_ident("endpoint") {
218 continue;
219 }
220 let Meta::List(list) = &attr.meta else {
221 return Err(syn::Error::new(attr.span(), "`#[endpoint]` must be a list"));
222 };
223 let mut method = None;
224 let mut path = None;
225 let mut register = false;
226 list.parse_nested_meta(|meta| {
227 if meta.path.is_ident("method") {
228 let value = meta.value()?;
229 method = Some(value.parse::<syn::Path>()?);
230 } else if meta.path.is_ident("path") {
231 let value = meta.value()?;
232 path = Some(value.parse::<LitStr>()?);
233 } else if meta.path.is_ident("register") {
234 register = true;
235 }
236 Ok(())
237 })?;
238 let method_path = method.ok_or_else(|| {
239 syn::Error::new(attr.span(), "`#[endpoint]` requires `method = GET` (etc.)")
240 })?;
241 let path = path.ok_or_else(|| {
242 syn::Error::new(attr.span(), "`#[endpoint]` requires `path = \"...\"`")
243 })?;
244 let is_post = method_path.get_ident().is_some_and(|id| id == "POST")
245 || method_path
246 .segments
247 .last()
248 .is_some_and(|seg| seg.ident == "POST");
249 let method = if let Some(ident) = method_path.get_ident() {
250 quote!(::http::Method::#ident)
251 } else {
252 quote!(#method_path)
253 };
254 return Ok((method, path, is_post, register));
255 }
256 Err(syn::Error::new(
257 Span::call_site(),
258 "`#[derive(Endpoint)]` requires `#[endpoint(method = GET, path = \"...\")]`",
259 ))
260}
261
262fn is_unit_type(ty: &syn::Type) -> bool {
263 matches!(ty, syn::Type::Tuple(t) if t.elems.is_empty())
264}
265
266fn endpoint_field_type(field: &syn::Field, attr: &str) -> Option<syn::Type> {
267 field
268 .attrs
269 .iter()
270 .any(|a| a.path().is_ident(attr))
271 .then(|| field.ty.clone())
272}
273
274#[proc_macro_derive(
287 Endpoint,
288 attributes(endpoint, response, params, query, body, headers, param, query_field)
289)]
290pub fn derive_endpoint(input: TokenStream) -> TokenStream {
291 let input = parse_macro_input!(input as DeriveInput);
292 match derive_endpoint_impl(input) {
293 Ok(tokens) => tokens.into(),
294 Err(err) => err.to_compile_error().into(),
295 }
296}
297
298fn derive_endpoint_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
299 let name = &input.ident;
300 let (method, path, is_post, register) = endpoint_meta(&input.attrs)?;
301 let path_value = path.value();
302
303 let Data::Struct(data) = &input.data else {
304 return Err(syn::Error::new(
305 input.span(),
306 "`Endpoint` can only be derived for structs",
307 ));
308 };
309
310 let Fields::Named(fields) = &data.fields else {
311 return Err(syn::Error::new(
312 data.fields.span(),
313 "`Endpoint` requires a struct with named fields for `#[response]` etc.",
314 ));
315 };
316
317 let mut response = quote!(());
318 let mut params = quote!(());
319 let mut query = quote!(());
320 let mut body = quote!(());
321 let mut headers = quote!(());
322 let mut body_ty: Option<syn::Type> = None;
323 let mut inline_param_fields: Vec<&syn::Field> = Vec::new();
324 let mut inline_query_fields: Vec<&syn::Field> = Vec::new();
325 let mut explicit_params = false;
326 let mut explicit_query = false;
327
328 for field in &fields.named {
329 if field.attrs.iter().any(|a| a.path().is_ident("param")) {
330 inline_param_fields.push(field);
331 continue;
332 }
333 if field.attrs.iter().any(|a| a.path().is_ident("query_field")) {
334 inline_query_fields.push(field);
335 continue;
336 }
337 if let Some(ty) = endpoint_field_type(field, "response") {
338 response = quote!(#ty);
339 } else if let Some(ty) = endpoint_field_type(field, "params") {
340 explicit_params = true;
341 params = quote!(#ty);
342 } else if let Some(ty) = endpoint_field_type(field, "query") {
343 explicit_query = true;
344 query = quote!(#ty);
345 } else if let Some(ty) = endpoint_field_type(field, "body") {
346 body_ty = Some(ty.clone());
347 body = quote!(#ty);
348 } else if let Some(ty) = endpoint_field_type(field, "headers") {
349 headers = quote!(#ty);
350 }
351 }
352
353 if explicit_params && !inline_param_fields.is_empty() {
354 return Err(syn::Error::new(
355 input.span(),
356 "use either `#[params] Type` or `#[param]` fields on the endpoint struct, not both",
357 ));
358 }
359
360 if explicit_query && !inline_query_fields.is_empty() {
361 return Err(syn::Error::new(
362 input.span(),
363 "use either `#[query] Type` or `#[query_field]` fields on the endpoint struct, not both",
364 ));
365 }
366
367 let params_ty_ident = syn::Ident::new(&format!("{name}Params"), name.span());
368 let query_ty_ident = syn::Ident::new(&format!("{name}Query"), name.span());
369 let inline_params_impl = if !inline_param_fields.is_empty() {
370 let mut field_defs = Vec::new();
371 let mut apply_pairs = Vec::new();
372 let mut field_keys = Vec::new();
373 let mut seen_keys = std::collections::HashSet::new();
374
375 for field in &inline_param_fields {
376 let ident = field.ident.as_ref().expect("named field");
377 let key = param_key(field)?;
378 if !seen_keys.insert(key.clone()) {
379 return Err(syn::Error::new(
380 field.span(),
381 format!("duplicate path parameter `{key}`"),
382 ));
383 }
384 field_keys.push(key.clone());
385 let ty = &field.ty;
386 field_defs.push(quote! { pub #ident: #ty });
387 apply_pairs.push(quote! {
388 builder = builder.param(#key, self.#ident);
389 });
390 }
391
392 let expected = path_param_names(&path_value);
393 if expected.len() != field_keys.len() {
394 return Err(syn::Error::new(
395 path.span(),
396 format!(
397 "path `{path_value}` has {} `:param` segment(s) but the endpoint has {} `#[param]` field(s)",
398 expected.len(),
399 field_keys.len()
400 ),
401 ));
402 }
403 for segment in expected {
404 if !field_keys.iter().any(|key| key == &segment) {
405 return Err(syn::Error::new(
406 path.span(),
407 format!("missing `#[param]` field for path parameter `:{segment}`"),
408 ));
409 }
410 }
411
412 params = quote!(#params_ty_ident);
413 quote! {
414 #[derive(Debug, Clone, Default)]
415 pub struct #params_ty_ident {
416 #(#field_defs),*
417 }
418
419 impl ::better_fetch::EndpointParams for #params_ty_ident {
420 type BuilderState = ::better_fetch::NeedsParams;
421
422 fn apply_params(
423 self,
424 mut builder: ::better_fetch::RequestBuilder<'_>,
425 ) -> ::better_fetch::RequestBuilder<'_> {
426 #(#apply_pairs)*
427 builder
428 }
429 }
430 }
431 } else {
432 quote! {}
433 };
434
435 let inline_query_impl = if !inline_query_fields.is_empty() {
436 let mut field_defs = Vec::new();
437 for field in &inline_query_fields {
438 let ident = field.ident.as_ref().expect("named field");
439 let ty = &field.ty;
440 field_defs.push(quote! { pub #ident: #ty });
441 }
442 query = quote!(#query_ty_ident);
443 quote! {
444 #[derive(Debug, Clone, Default, ::serde::Serialize)]
445 pub struct #query_ty_ident {
446 #(#field_defs),*
447 }
448
449 impl ::better_fetch::EndpointQuery for #query_ty_ident {
450 fn apply_query(
451 self,
452 builder: ::better_fetch::RequestBuilder<'_>,
453 ) -> ::better_fetch::Result<::better_fetch::RequestBuilder<'_>> {
454 ::better_fetch::endpoint::apply_serialized_query(self, builder)
455 }
456 }
457 }
458 } else {
459 quote! {}
460 };
461
462 let explicit_query_impl = if explicit_query && inline_query_fields.is_empty() {
463 quote! {
464 impl ::better_fetch::EndpointQuery for #query {
465 fn apply_query(
466 self,
467 builder: ::better_fetch::RequestBuilder<'_>,
468 ) -> ::better_fetch::Result<::better_fetch::RequestBuilder<'_>> {
469 ::better_fetch::endpoint::apply_serialized_query(self, builder)
470 }
471 }
472 }
473 } else {
474 quote! {}
475 };
476
477 let body_required = is_post && body_ty.as_ref().is_some_and(|ty| !is_unit_type(ty));
478
479 let body_required_impl = if let Some(body_type) = body_ty.filter(|_| body_required) {
480 quote! {
481 impl ::better_fetch::EndpointBody for #body_type {
482 type ParamsNext = ::better_fetch::NeedsBody;
483 type CallInitial = ::better_fetch::NeedsBody;
484
485 fn apply_body(
486 self,
487 builder: ::better_fetch::RequestBuilder<'_>,
488 ) -> ::better_fetch::Result<::better_fetch::RequestBuilder<'_>> {
489 builder.json(&self)
490 }
491 }
492
493 impl ::better_fetch::DefaultParamsInitial<#name> for () {
494 fn initial(
495 client: &::better_fetch::Client,
496 ) -> ::better_fetch::EndpointRequestBuilder<'_, #name, ::better_fetch::NeedsBody> {
497 ::better_fetch::EndpointRequestBuilder::new_needs_body(
498 client.request(#method, #path_value),
499 )
500 }
501 }
502 }
503 } else {
504 quote! {}
505 };
506
507 let register_impl = if register {
508 quote! {
509 impl #name {
510 #[cfg(feature = "schema")]
512 pub fn register(registry: &mut ::better_fetch::SchemaRegistry) {
513 registry.register_typed::<#name, #body, #response>();
514 }
515 }
516 }
517 } else {
518 quote! {}
519 };
520
521 Ok(quote! {
522 #inline_params_impl
523 #inline_query_impl
524 #explicit_query_impl
525 impl ::better_fetch::Endpoint for #name {
526 const METHOD: ::http::Method = #method;
527 const PATH: &'static str = #path_value;
528 type Response = #response;
529 type Params = #params;
530 type Query = #query;
531 type Body = #body;
532 type Headers = #headers;
533 }
534 #body_required_impl
535 #register_impl
536 })
537}