Skip to main content

altaria_macros/
lib.rs

1use proc_macro::TokenStream;
2use std::collections::HashMap;
3use proc_macro2::Ident;
4use quote::{quote, ToTokens};
5use syn::{parse_macro_input, LitStr};
6use syn::spanned::Spanned;
7
8#[proc_macro_attribute]
9pub fn post(args: TokenStream, item: TokenStream) -> TokenStream {
10    expand(args, item, Some(Ident::new("POST", proc_macro2::Span::call_site())))
11}
12
13#[proc_macro_attribute]
14pub fn get(args: TokenStream, item: TokenStream) -> TokenStream {
15    expand(args, item, Some(Ident::new("GET", proc_macro2::Span::call_site())))
16}
17
18#[proc_macro_attribute]
19pub fn put(args: TokenStream, item: TokenStream) -> TokenStream {
20    expand(args, item, Some(Ident::new("PUT", proc_macro2::Span::call_site())))
21}
22
23#[proc_macro_attribute]
24pub fn delete(args: TokenStream, item: TokenStream) -> TokenStream {
25    expand(args, item, Some(Ident::new("DELETE", proc_macro2::Span::call_site())))
26}
27
28#[proc_macro_attribute]
29pub fn patch(args: TokenStream, item: TokenStream) -> TokenStream {
30    expand(args, item, Some(Ident::new("PATCH", proc_macro2::Span::call_site())))
31}
32
33#[proc_macro_attribute]
34pub fn head(args: TokenStream, item: TokenStream) -> TokenStream {
35    expand(args, item, Some(Ident::new("HEAD", proc_macro2::Span::call_site())))
36}
37
38#[proc_macro_attribute]
39pub fn options(args: TokenStream, item: TokenStream) -> TokenStream {
40    expand(args, item, Some(Ident::new("OPTIONS", proc_macro2::Span::call_site())))
41}
42
43#[proc_macro_attribute]
44pub fn trace(args: TokenStream, item: TokenStream) -> TokenStream {
45    expand(args, item, Some(Ident::new("TRACE", proc_macro2::Span::call_site())))
46}
47
48#[proc_macro_attribute]
49pub fn connect(args: TokenStream, item: TokenStream) -> TokenStream {
50    expand(args, item, Some(Ident::new("CONNECT", proc_macro2::Span::call_site())))
51}
52
53#[proc_macro_attribute]
54pub fn handler(args: TokenStream, item: TokenStream) -> TokenStream {
55    expand(args, item, None)
56}
57
58fn expand(
59    args: TokenStream,
60    item: TokenStream,
61    method: Option<Ident>,
62) -> TokenStream {
63    let mut function_item = parse_macro_input!(item as syn::ItemFn);
64    let function_ident = function_item.sig.ident.clone();
65
66    let arg = parse_macro_input!(args as LitStr);
67    let const_name = format!("_AltariaEndpoint{}", function_ident.to_string().to_uppercase());
68    let const_ident = Ident::new(&const_name, function_ident.span());
69
70    let path = arg.value();
71    let query_index = path.find('?');
72
73    let url = if let Some(index) = query_index { &path[..index] } else { &path };
74    let query_part = if let Some(index) = query_index { &path[index + 1..] } else { "" };
75
76    let params = url.split('/')
77        .filter(|s| s.starts_with('{') && s.ends_with('}'))
78        .map(|s| &s[1..s.len() - 1])
79        .collect::<Vec<&str>>();
80
81    let query_params: HashMap<String, String> = query_part.split('&')
82        .map(|s| s.split('=').collect::<Vec<&str>>())
83        .filter(|s| s.len() == 2)
84        .filter(|s| !s[0].is_empty() && !s[1].is_empty())
85        .map(|s| (s[0].to_string(), s[1].to_string()))
86        .filter(|(_, value)| value.starts_with('{') && value.ends_with('}'))
87        .map(|(key, value)| (value[1..value.len() - 1].to_string(), key)) // <- flip
88        .collect();
89
90    let mut inputs: Vec<_> = function_item.sig.inputs.iter().cloned().collect();
91    inputs.sort_by_key(|arg| {
92        if let syn::FnArg::Typed(pat_type) = arg {
93            if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
94                let index_of_param = params.iter().position(|param| *param == pat_ident.ident.to_string());
95                if let Some(index) = index_of_param {
96                    return index;
97                }
98            }
99        }
100        params.len()
101    });
102
103    let mut accesses = Vec::new();
104    let mut idents = Vec::new();
105    let mut extractors = Vec::new();
106    let mut extractions = Vec::new();
107
108    for (index, arg) in inputs.iter().enumerate() {
109        if let syn::FnArg::Typed(pat_type) = arg {
110            if let syn::Type::Path(type_path) = &*pat_type.ty {
111                let variable_ident = Ident::new(&format!("param_{}", index), index.span());
112                if let syn::Pat::Ident(ident) = &*pat_type.pat {
113                    let name = ident.ident.to_string();
114                    if params.contains(&name.as_str()) {
115                        let extractor = quote! { altaria::extractor::param::Param::<#type_path> };
116                        let access = quote! { #variable_ident.0 };
117                        accesses.push(access);
118                        idents.push(variable_ident.clone());
119                        extractors.push(extractor.clone());
120                        extractions.push(quote! {
121                            let #variable_ident = #extractor::from_request(#index, &mut request).await?;
122                        });
123                        continue;
124                    } else if query_params.contains_key(&name) {
125                        let actual_name = query_params.get(&name).unwrap();
126
127                        let true_type = extract_option_type_param(type_path);
128                        let extractor = if let Some(type_path) = true_type {
129                            quote! { altaria::extractor::query::OptionalQuery::<#type_path> }
130                        } else {
131                            quote! { altaria::extractor::query::Query::<#type_path> }
132                        };
133                        let access = quote! { #variable_ident.0 };
134                        accesses.push(access);
135                        idents.push(variable_ident.clone());
136                        extractors.push(extractor.clone());
137                        extractions.push(quote! {
138                            let #variable_ident = #extractor::from_request_by_name(#actual_name, &request)?;
139                        });
140                        continue;
141                    }
142                }
143                let extractor_name = type_path.to_token_stream().to_string().replace("<", "::<").replace(" ", "");
144                let extractor: proc_macro2::TokenStream = syn::parse_str(&extractor_name).expect("");
145                let access = quote! { #variable_ident };
146                accesses.push(access);
147                idents.push(variable_ident.clone());
148                extractors.push(extractor.clone());
149                extractions.push(quote! {
150                    let #variable_ident = #extractor::from_request(#index, &mut request).await?;
151                });
152            } else {
153                panic!("Invalid function argument: it's either not a simple identifier or not a type");
154            }
155        } else {
156            panic!("Invalid function argument: it's either not a simple identifier or not a type");
157        }
158    }
159
160    let method = match method {
161        Some(method) => quote! { Some(altaria::request::HttpMethod::#method) },
162        None => quote! { None }
163    };
164
165    function_item.sig.inputs = syn::punctuated::Punctuated::from_iter(inputs);
166    TokenStream::from(quote! {
167        pub(crate) struct #const_ident;
168
169        impl #const_ident {
170            #[inline(always)]
171            pub const fn new() -> Self {
172                Self
173            }
174
175            #[inline(always)]
176            pub const fn get_endpoint() -> &'static str {
177                #path
178            }
179        }
180
181        #[altaria::async_trait::async_trait]
182        impl altaria::router::func::FunctionRouteHandler<(#(#extractors),*)> for #const_ident {
183            fn get_method(&self) -> Option<altaria::request::HttpMethod> {
184                #method
185            }
186
187            async fn handle_request(&self, mut request: altaria::request::HttpRequest) -> altaria::response::HttpResponse {
188                let extract_values = async {
189                    use altaria::extractor::FromRequest;
190                    use altaria::extractor::query::NamedExtractor;
191                    #(#extractions)*
192                    Result::<_, altaria::extractor::ExtractorError>::Ok((#(#idents),*))
193                }.await;
194
195                match extract_values {
196                    Ok((#(#idents),*)) => {
197                        let response = #function_ident(#(#accesses),*).await;
198                        response.into_response()
199                    },
200                    Err(err) => altaria::router::func::handle_function_failure(err)
201                }
202            }
203        }
204
205        #function_item
206    })
207}
208
209fn extract_option_type_param(type_path: &syn::TypePath) -> Option<syn::Type> {
210    if let Some(segment) = type_path.path.segments.last() {
211        if segment.ident == "Option" {
212            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
213                if let Some(arg) = args.args.first() {
214                    if let syn::GenericArgument::Type(ty) = arg {
215                        return Some(ty.clone())
216                    }
217                }
218            }
219        }
220    }
221    None
222}