1use proc_macro::TokenStream;
2use std::collections::HashMap;
3use proc_macro2::Ident;
4use quote::{quote, ToTokens};
5use syn::{parse_macro_input, LitStr};
6use syn::spanned::Spanned;
7
8#[proc_macro_attribute]
9pub fn post(args: TokenStream, item: TokenStream) -> TokenStream {
10 expand(args, item, Some(Ident::new("POST", proc_macro2::Span::call_site())))
11}
12
13#[proc_macro_attribute]
14pub fn get(args: TokenStream, item: TokenStream) -> TokenStream {
15 expand(args, item, Some(Ident::new("GET", proc_macro2::Span::call_site())))
16}
17
18#[proc_macro_attribute]
19pub fn put(args: TokenStream, item: TokenStream) -> TokenStream {
20 expand(args, item, Some(Ident::new("PUT", proc_macro2::Span::call_site())))
21}
22
23#[proc_macro_attribute]
24pub fn delete(args: TokenStream, item: TokenStream) -> TokenStream {
25 expand(args, item, Some(Ident::new("DELETE", proc_macro2::Span::call_site())))
26}
27
28#[proc_macro_attribute]
29pub fn patch(args: TokenStream, item: TokenStream) -> TokenStream {
30 expand(args, item, Some(Ident::new("PATCH", proc_macro2::Span::call_site())))
31}
32
33#[proc_macro_attribute]
34pub fn head(args: TokenStream, item: TokenStream) -> TokenStream {
35 expand(args, item, Some(Ident::new("HEAD", proc_macro2::Span::call_site())))
36}
37
38#[proc_macro_attribute]
39pub fn options(args: TokenStream, item: TokenStream) -> TokenStream {
40 expand(args, item, Some(Ident::new("OPTIONS", proc_macro2::Span::call_site())))
41}
42
43#[proc_macro_attribute]
44pub fn trace(args: TokenStream, item: TokenStream) -> TokenStream {
45 expand(args, item, Some(Ident::new("TRACE", proc_macro2::Span::call_site())))
46}
47
48#[proc_macro_attribute]
49pub fn connect(args: TokenStream, item: TokenStream) -> TokenStream {
50 expand(args, item, Some(Ident::new("CONNECT", proc_macro2::Span::call_site())))
51}
52
53#[proc_macro_attribute]
54pub fn handler(args: TokenStream, item: TokenStream) -> TokenStream {
55 expand(args, item, None)
56}
57
58fn expand(
59 args: TokenStream,
60 item: TokenStream,
61 method: Option<Ident>,
62) -> TokenStream {
63 let mut function_item = parse_macro_input!(item as syn::ItemFn);
64 let function_ident = function_item.sig.ident.clone();
65
66 let arg = parse_macro_input!(args as LitStr);
67 let const_name = format!("_AltariaEndpoint{}", function_ident.to_string().to_uppercase());
68 let const_ident = Ident::new(&const_name, function_ident.span());
69
70 let path = arg.value();
71 let query_index = path.find('?');
72
73 let url = if let Some(index) = query_index { &path[..index] } else { &path };
74 let query_part = if let Some(index) = query_index { &path[index + 1..] } else { "" };
75
76 let params = url.split('/')
77 .filter(|s| s.starts_with('{') && s.ends_with('}'))
78 .map(|s| &s[1..s.len() - 1])
79 .collect::<Vec<&str>>();
80
81 let query_params: HashMap<String, String> = query_part.split('&')
82 .map(|s| s.split('=').collect::<Vec<&str>>())
83 .filter(|s| s.len() == 2)
84 .filter(|s| !s[0].is_empty() && !s[1].is_empty())
85 .map(|s| (s[0].to_string(), s[1].to_string()))
86 .filter(|(_, value)| value.starts_with('{') && value.ends_with('}'))
87 .map(|(key, value)| (value[1..value.len() - 1].to_string(), key)) .collect();
89
90 let mut inputs: Vec<_> = function_item.sig.inputs.iter().cloned().collect();
91 inputs.sort_by_key(|arg| {
92 if let syn::FnArg::Typed(pat_type) = arg {
93 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
94 let index_of_param = params.iter().position(|param| *param == pat_ident.ident.to_string());
95 if let Some(index) = index_of_param {
96 return index;
97 }
98 }
99 }
100 params.len()
101 });
102
103 let mut accesses = Vec::new();
104 let mut idents = Vec::new();
105 let mut extractors = Vec::new();
106 let mut extractions = Vec::new();
107
108 for (index, arg) in inputs.iter().enumerate() {
109 if let syn::FnArg::Typed(pat_type) = arg {
110 if let syn::Type::Path(type_path) = &*pat_type.ty {
111 let variable_ident = Ident::new(&format!("param_{}", index), index.span());
112 if let syn::Pat::Ident(ident) = &*pat_type.pat {
113 let name = ident.ident.to_string();
114 if params.contains(&name.as_str()) {
115 let extractor = quote! { altaria::extractor::param::Param::<#type_path> };
116 let access = quote! { #variable_ident.0 };
117 accesses.push(access);
118 idents.push(variable_ident.clone());
119 extractors.push(extractor.clone());
120 extractions.push(quote! {
121 let #variable_ident = #extractor::from_request(#index, &mut request).await?;
122 });
123 continue;
124 } else if query_params.contains_key(&name) {
125 let actual_name = query_params.get(&name).unwrap();
126
127 let true_type = extract_option_type_param(type_path);
128 let extractor = if let Some(type_path) = true_type {
129 quote! { altaria::extractor::query::OptionalQuery::<#type_path> }
130 } else {
131 quote! { altaria::extractor::query::Query::<#type_path> }
132 };
133 let access = quote! { #variable_ident.0 };
134 accesses.push(access);
135 idents.push(variable_ident.clone());
136 extractors.push(extractor.clone());
137 extractions.push(quote! {
138 let #variable_ident = #extractor::from_request_by_name(#actual_name, &request)?;
139 });
140 continue;
141 }
142 }
143 let extractor_name = type_path.to_token_stream().to_string().replace("<", "::<").replace(" ", "");
144 let extractor: proc_macro2::TokenStream = syn::parse_str(&extractor_name).expect("");
145 let access = quote! { #variable_ident };
146 accesses.push(access);
147 idents.push(variable_ident.clone());
148 extractors.push(extractor.clone());
149 extractions.push(quote! {
150 let #variable_ident = #extractor::from_request(#index, &mut request).await?;
151 });
152 } else {
153 panic!("Invalid function argument: it's either not a simple identifier or not a type");
154 }
155 } else {
156 panic!("Invalid function argument: it's either not a simple identifier or not a type");
157 }
158 }
159
160 let method = match method {
161 Some(method) => quote! { Some(altaria::request::HttpMethod::#method) },
162 None => quote! { None }
163 };
164
165 function_item.sig.inputs = syn::punctuated::Punctuated::from_iter(inputs);
166 TokenStream::from(quote! {
167 pub(crate) struct #const_ident;
168
169 impl #const_ident {
170 #[inline(always)]
171 pub const fn new() -> Self {
172 Self
173 }
174
175 #[inline(always)]
176 pub const fn get_endpoint() -> &'static str {
177 #path
178 }
179 }
180
181 #[altaria::async_trait::async_trait]
182 impl altaria::router::func::FunctionRouteHandler<(#(#extractors),*)> for #const_ident {
183 fn get_method(&self) -> Option<altaria::request::HttpMethod> {
184 #method
185 }
186
187 async fn handle_request(&self, mut request: altaria::request::HttpRequest) -> altaria::response::HttpResponse {
188 let extract_values = async {
189 use altaria::extractor::FromRequest;
190 use altaria::extractor::query::NamedExtractor;
191 #(#extractions)*
192 Result::<_, altaria::extractor::ExtractorError>::Ok((#(#idents),*))
193 }.await;
194
195 match extract_values {
196 Ok((#(#idents),*)) => {
197 let response = #function_ident(#(#accesses),*).await;
198 response.into_response()
199 },
200 Err(err) => altaria::router::func::handle_function_failure(err)
201 }
202 }
203 }
204
205 #function_item
206 })
207}
208
209fn extract_option_type_param(type_path: &syn::TypePath) -> Option<syn::Type> {
210 if let Some(segment) = type_path.path.segments.last() {
211 if segment.ident == "Option" {
212 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
213 if let Some(arg) = args.args.first() {
214 if let syn::GenericArgument::Type(ty) = arg {
215 return Some(ty.clone())
216 }
217 }
218 }
219 }
220 }
221 None
222}