fal_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse::Parse, parse::ParseStream, parse_macro_input, punctuated::Punctuated, FnArg, ItemFn,
5    Meta, PatType, ReturnType, Token,
6};
7
8enum EndpointAttr {
9    Path(String),
10    InFalCrate,
11}
12
13impl Parse for EndpointAttr {
14    fn parse(input: ParseStream) -> syn::Result<Self> {
15        let meta: Meta = input.parse()?;
16        match meta {
17            Meta::Path(path) if path.is_ident("in_fal_crate") => Ok(EndpointAttr::InFalCrate),
18            Meta::NameValue(nv) if nv.path.is_ident("endpoint") => match nv.value {
19                syn::Expr::Lit(lit) => {
20                    if let syn::Lit::Str(s) = lit.lit {
21                        return Ok(EndpointAttr::Path(s.value()));
22                    }
23                    Err(syn::Error::new_spanned(lit, "expected string literal"))
24                }
25                _ => Err(syn::Error::new_spanned(nv.value, "expected string literal")),
26            },
27            _ => Err(syn::Error::new_spanned(
28                meta,
29                "expected endpoint = \"...\" or in_fal_crate",
30            )),
31        }
32    }
33}
34
35#[doc = include_str!("../README.md")]
36#[proc_macro_attribute]
37pub fn endpoint(attr: TokenStream, item: TokenStream) -> TokenStream {
38    let attr =
39        parse_macro_input!(attr with Punctuated::<EndpointAttr, Token![,]>::parse_terminated);
40    let input_fn = parse_macro_input!(item as ItemFn);
41
42    // Extract the endpoint string and in_fal_crate flag from the attributes
43    let mut endpoint_str = None;
44    let mut in_fal_crate = false;
45
46    for attr in attr.iter() {
47        match attr {
48            EndpointAttr::Path(s) => endpoint_str = Some(s),
49            EndpointAttr::InFalCrate => in_fal_crate = true,
50        }
51    }
52
53    let endpoint_str = endpoint_str.expect("endpoint attribute must be provided");
54
55    let fn_name = &input_fn.sig.ident;
56    let fn_name_str = fn_name.to_string();
57    let camel_case_name = snake_to_upper_camel(&fn_name_str);
58    let struct_name = syn::Ident::new(&format!("{}Params", camel_case_name), fn_name.span());
59    let vis = &input_fn.vis;
60
61    // Extract function parameters
62    let mut param_fields = Vec::new();
63    let mut param_names = Vec::new();
64
65    for arg in input_fn.sig.inputs.iter() {
66        if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
67            if let syn::Pat::Ident(pat_ident) = *pat.clone() {
68                let ident = pat_ident.ident.clone();
69                param_fields.push(quote! { pub #ident: #ty });
70                param_names.push(pat_ident.ident);
71            }
72        }
73    }
74
75    // Extract return type
76    let return_type = match &input_fn.sig.output {
77        ReturnType::Type(_, ty) => ty,
78        _ => panic!("Function must have a return type"),
79    };
80
81    // Choose the appropriate crate reference
82    let crate_ref = if in_fal_crate {
83        quote! { crate }
84    } else {
85        quote! { fal }
86    };
87
88    // Generate the expanded code
89    let struct_def = quote! {
90        #[derive(serde::Serialize)]
91        #vis struct #struct_name {
92            #(#param_fields),*
93        }
94    };
95
96    let inputs = input_fn.sig.inputs;
97
98    let fn_def = quote! {
99        #vis fn #fn_name(#inputs) -> #crate_ref::request::FalRequest<#struct_name, #return_type> {
100            #crate_ref::request::FalRequest::new(
101                #endpoint_str,
102                #struct_name {
103                    #(#param_names: #param_names),*
104                }
105            )
106        }
107    };
108
109    let expanded = quote! {
110        #struct_def
111        #fn_def
112    };
113
114    TokenStream::from(expanded)
115}
116
117fn snake_to_upper_camel(input: &str) -> String {
118    let mut result = String::new();
119    let mut capitalize_next = true;
120
121    for c in input.chars() {
122        if c == '_' {
123            capitalize_next = true;
124        } else if capitalize_next {
125            result.push(c.to_ascii_uppercase());
126            capitalize_next = false;
127        } else {
128            result.push(c);
129        }
130    }
131
132    result
133}