api_forge_macro/
lib.rs

1use darling::{FromDeriveInput, FromField};
2use proc_macro::TokenStream;
3use proc_macro2::Ident;
4use quote::quote;
5use syn::{parse_macro_input, DeriveInput, LitStr};
6
7#[derive(Debug, FromDeriveInput, Clone)]
8#[darling(attributes(request))]
9struct RequestArgs {
10    data: darling::ast::Data<(), HeaderField>,
11
12    endpoint: String,
13
14    #[darling(default, rename = "response_type")]
15    response_type: Option<Ident>,
16
17    #[darling(default, rename = "method")]
18    method: Option<Ident>,
19
20    #[darling(default, rename = "transmission")]
21    transmission: Option<Ident>,
22
23    #[darling(default, rename = "authentication")]
24    authentication: Option<Ident>,
25
26    #[darling(default, rename = "path_parameters")]
27    path_parameters: Option<Vec<LitStr>>,
28}
29
30#[derive(Debug, FromField, Clone)]
31#[darling(attributes(request))]
32struct HeaderField {
33    ident: Option<Ident>,
34
35    // This will capture the name of the header
36    #[darling(default)]
37    header_name: Option<LitStr>,
38}
39
40#[proc_macro_derive(Request, attributes(request))]
41pub fn derive_request(input: TokenStream) -> TokenStream {
42    // Parse the input into a DeriveInput struct using syn
43    let input = parse_macro_input!(input as DeriveInput);
44
45    // Use `darling` to parse the attributes from the input
46    let args = RequestArgs::from_derive_input(&input).unwrap_or_else(|e| {
47        let error = e.write_errors();
48
49        panic!("{}", error);
50    });
51
52    let name = &input.ident;
53    let data = args.data.clone();
54    let mut header_inserts = vec![];
55
56    data.map_struct_fields(|field| {
57        if field.header_name.is_some() {
58            // Add the #[serde(skip)] attribute to the header fields
59            let header_field_ident = field.ident.as_ref().unwrap();
60            let header_name = field.header_name.as_ref().unwrap().value();
61
62            let header_insert = quote! {
63                if let Some(value) = self.#header_field_ident.as_ref() {
64                    builder = builder.header(#header_name, value);
65                }
66            };
67
68            header_inserts.push(header_insert);
69        }
70    });
71
72    let endpoint = args.endpoint;
73    let response_type = args
74        .response_type
75        .unwrap_or_else(|| Ident::new("EmptyResponse", proc_macro2::Span::call_site()));
76    let method = args
77        .method
78        .unwrap_or_else(|| Ident::new("GET", proc_macro2::Span::call_site()));
79    let transmission_method = args
80        .transmission
81        .unwrap_or_else(|| Ident::new("QueryParams", proc_macro2::Span::call_site()));
82    let authentication_method = args
83        .authentication
84        .unwrap_or_else(|| Ident::new("None", proc_macro2::Span::call_site()));
85    let path_parameters = args.path_parameters.unwrap_or(Vec::new());
86    let path_parameters = path_parameters
87        .iter()
88        .map(|p| p.value())
89        .collect::<Vec<_>>();
90    let path_parameters_idents = path_parameters
91        .iter()
92        .map(|p| Ident::new(p, proc_macro2::Span::call_site()))
93        .collect::<Vec<_>>();
94
95    let res_type = if response_type == Ident::new("EmptyResponse", proc_macro2::Span::call_site()) {
96        quote!(())
97    } else {
98        quote!(#response_type)
99    };
100
101    // Generate the final code for the derive macro
102    let expanded = quote! {
103        impl api_forge::ApiRequest<#res_type> for #name {
104            const ENDPOINT: &'static str = #endpoint;
105            const METHOD: reqwest::Method = reqwest::Method::#method;
106            const DATA_TRANSMISSION_METHOD: api_forge::DataTransmissionMethod = api_forge::DataTransmissionMethod::#transmission_method;
107            const AUTHENTICATION_METHOD: api_forge::AuthenticationMethod = api_forge::AuthenticationMethod::#authentication_method;
108
109            fn generate_request(
110                &self,
111                base_url: &str,
112                headers: Option<reqwest::header::HeaderMap>,
113                token: Option<(String, Option<String>)>,
114            ) -> reqwest::RequestBuilder {
115                let mut url = format!("{}{}", base_url, Self::ENDPOINT);
116
117                #(
118                    url = url.replace(&format!("{{{}}}", #path_parameters), &self.#path_parameters_idents.to_string());
119                )*
120
121                let client = reqwest::Client::new();
122
123                let mut builder = match Self::METHOD {
124                    reqwest::Method::GET => client.get(&url),
125                    reqwest::Method::POST => client.post(&url),
126                    reqwest::Method::PUT => client.put(&url),
127                    reqwest::Method::DELETE => client.delete(&url),
128                    reqwest::Method::PATCH => client.patch(&url),
129                    reqwest::Method::HEAD => client.head(&url),
130                    _ => client.get(&url),
131                };
132
133                builder = match Self::DATA_TRANSMISSION_METHOD {
134                    api_forge::DataTransmissionMethod::QueryParams => builder.query(self),
135                    api_forge::DataTransmissionMethod::Json => builder.json(self),
136                    _ => builder.form(self),
137                };
138
139                if let Some((token, password)) = token {
140                    builder = match Self::AUTHENTICATION_METHOD {
141                        api_forge::AuthenticationMethod::Basic => builder.basic_auth(token, password),
142                        api_forge::AuthenticationMethod::Bearer => builder.bearer_auth(token),
143                        api_forge::AuthenticationMethod::None => builder,
144                    };
145                }
146
147                let mut all_headers = reqwest::header::HeaderMap::new();
148
149                #(#header_inserts)*
150
151                if let Some(headers) = headers {
152                    all_headers.extend(headers);
153                }
154
155                builder.headers(all_headers)
156            }
157        }
158    };
159
160    TokenStream::from(expanded)
161}