Skip to main content

potato_macro/
lib.rs

1mod utils;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span};
5use quote::{quote, ToTokens};
6use rand::Rng;
7use serde_json::json;
8use std::{collections::HashSet, sync::LazyLock};
9use utils::StringExt as _;
10
11static ARG_TYPES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
12    [
13        "String", "bool", "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
14        "f32", "f64",
15    ]
16    .into_iter()
17    .collect()
18});
19
20fn random_ident() -> Ident {
21    let mut rng = rand::thread_rng();
22    let value = format!("__potato_id_{}", rng.r#gen::<u64>());
23    Ident::new(&value, Span::call_site())
24}
25
26fn http_handler_macro(attr: TokenStream, input: TokenStream, req_name: &str) -> TokenStream {
27    let req_name = Ident::new(req_name, Span::call_site());
28    let (route_path, oauth_arg) = {
29        let mut oroute_path = syn::parse::<syn::LitStr>(attr.clone())
30            .ok()
31            .map(|path| path.value());
32        let mut oauth_arg = None;
33        //
34        if oroute_path.is_none() {
35            let http_parser = syn::meta::parser(|meta| {
36                if meta.path.is_ident("path") {
37                    if let Ok(arg) = meta.value() {
38                        if let Ok(route_path) = arg.parse::<syn::LitStr>() {
39                            let route_path = route_path.value();
40                            oroute_path = Some(route_path);
41                        }
42                    }
43                    Ok(())
44                } else if meta.path.is_ident("auth_arg") {
45                    if let Ok(arg) = meta.value() {
46                        if let Ok(tmp_field) = arg.parse::<Ident>() {
47                            oauth_arg = Some(tmp_field.to_string());
48                        }
49                    }
50                    Ok(())
51                } else {
52                    Err(meta.error("unsupported annotation property"))
53                }
54            });
55            syn::parse_macro_input!(attr with http_parser);
56        }
57        if oroute_path.is_none() {
58            panic!("`path` argument is required");
59        }
60        let route_path = oroute_path.unwrap();
61        if !route_path.starts_with('/') {
62            panic!("route path must start with '/'");
63        }
64        (route_path, oauth_arg)
65    };
66    let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
67    let doc_show = {
68        let mut doc_show = true;
69        for attr in root_fn.attrs.iter() {
70            if attr.meta.path().get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
71                if let Ok(meta_list) = attr.meta.require_list() {
72                    if meta_list.tokens.to_string() == "hidden" {
73                        doc_show = false;
74                        break;
75                    }
76                }
77            }
78        }
79        doc_show
80    };
81    let doc_auth = oauth_arg.is_some();
82    let doc_summary = {
83        let mut docs = vec![];
84        for attr in root_fn.attrs.iter() {
85            if let Ok(attr) = attr.meta.require_name_value() {
86                if attr.path.get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
87                    let mut doc = attr.value.to_token_stream().to_string();
88                    if doc.starts_with('\"') {
89                        doc.remove(0);
90                        doc.pop();
91                    }
92                    docs.push(doc);
93                }
94            }
95        }
96        if docs.iter().all(|d| d.starts_with(' ')) {
97            for doc in docs.iter_mut() {
98                doc.remove(0);
99            }
100        }
101        docs.join("\n")
102    };
103    let doc_desp = "";
104    let fn_name = root_fn.sig.ident.clone();
105    let wrap_func_name = random_ident();
106    let mut args = vec![];
107    let mut arg_names = vec![];
108    let mut doc_args = vec![];
109    let mut arg_auth_mark = false;
110    for arg in root_fn.sig.inputs.iter() {
111        if let syn::FnArg::Typed(arg) = arg {
112            let arg_type_str = arg
113                .ty
114                .as_ref()
115                .to_token_stream()
116                .to_string()
117                .type_simplify();
118            let arg_name_str = arg.pat.to_token_stream().to_string();
119            args.push(match &arg_type_str[..] {
120                "& mut HttpRequest" => quote! { req },
121                "PostFile" => {
122                    doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
123                    quote! {
124                        match req.body_files.get(&potato::utils::refstr::LocalHipStr<'static>::from_str(#arg_name_str)).cloned() {
125                            Some(file) => file,
126                            None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
127                        }
128                    }
129                },
130                arg_type_str if ARG_TYPES.contains(arg_type_str) => {
131                    let is_auth_arg = match oauth_arg.as_ref() {
132                        Some(auth_arg) => auth_arg == &arg_name_str,
133                        None => false,
134                    };
135                    if is_auth_arg {
136                        if arg_type_str != "String" {
137                            panic!("auth_arg argument is must String type");
138                        }
139                        arg_auth_mark = true;
140                        quote! {
141                            match req.headers
142                                .get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization"))
143                                .map(|v| v.to_str()) {
144                                Some(mut auth) => {
145                                    if auth.starts_with("Bearer ") {
146                                        auth = &auth[7..];
147                                    }
148                                    match potato::ServerAuth::jwt_check(&auth).await {
149                                        Ok(payload) => payload,
150                                        Err(err) => return potato::HttpResponse::error(format!("auth failed: {err:?}")),
151                                    }
152                                }
153                                None => return potato::HttpResponse::error("miss header : Authorization"),
154                            }
155                        }
156                    } else {
157                        doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
158                        let mut arg_value = quote! {
159                            match req.body_pairs
160                                .get(&potato::utils::refstr::LocalHipStr<'static>::from_str(#arg_name_str))
161                                .map(|p| p.to_string()) {
162                                Some(val) => val,
163                                None => match req.url_query
164                                    .get(&potato::utils::refstr::LocalHipStr<'static>::from_str(#arg_name_str))
165                                    .map(|p| p.to_str().to_string()) {
166                                    Some(val) => val,
167                                    None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
168                                },
169                            }
170                        };
171                        if arg_type_str != "String" {
172                            arg_value = quote! {
173                                match #arg_value.parse() {
174                                    Ok(val) => val,
175                                    Err(err) => return potato::HttpResponse::error(format!("arg[{}] is not {} type", #arg_name_str, #arg_type_str)),
176                                }
177                            }
178                        }
179                        arg_value
180                    }
181                },
182                _ => panic!("unsupported arg type: [{arg_type_str}]"),
183            });
184            arg_names.push(random_ident());
185        } else {
186            panic!("unsupported: {}", arg.to_token_stream().to_string());
187        }
188    }
189    if !arg_auth_mark && doc_auth {
190        panic!("`auth_arg` attribute is must point to an existing argument");
191    }
192    let wrap_func_name2 = random_ident();
193    let ret_type = root_fn
194        .sig
195        .output
196        .to_token_stream()
197        .to_string()
198        .type_simplify();
199    let wrap_func_body = match args.len() {
200        0 => match &ret_type[..] {
201            "Result < () >" => quote! {
202                match #fn_name().await {
203                    Ok(ret) => potato::HttpResponse::text("ok"),
204                    Err(err) => potato::HttpResponse::error(format!("{err:?}")),
205                }
206            },
207            "Result < HttpResponse >" | "Result < potato :: HttpResponse >" => quote! {
208                match #fn_name().await {
209                    Ok(ret) => ret,
210                    Err(err) => potato::HttpResponse::error(format!("{err:?}")),
211                }
212            },
213            "()" => quote! {
214                #fn_name().await;
215                potato::HttpResponse::text("ok")
216            },
217            "HttpResponse" | "potato :: HttpResponse" => quote! {
218                #fn_name().await
219            },
220            _ => panic!("unsupported ret type: {ret_type}"),
221        },
222        1 => match &ret_type[..] {
223            "Result < () >" => quote! {
224                let #(#arg_names),* = #(#args),*;
225                match #fn_name(#(#arg_names),*).await {
226                    Ok(ret) => potato::HttpResponse::text("ok"),
227                    Err(err) => potato::HttpResponse::error(format!("{err:?}")),
228                }
229            },
230            "Result < HttpResponse >" | "Result < potato :: HttpResponse >" => quote! {
231                let #(#arg_names),* = #(#args),*;
232                match #fn_name(#(#arg_names),*).await {
233                    Ok(ret) => ret,
234                    Err(err) => potato::HttpResponse::error(format!("{err:?}")),
235                }
236            },
237            "()" => quote! {
238                let #(#arg_names),* = #(#args),*;
239                #fn_name(#(#arg_names),*).await;
240                potato::HttpResponse::text("ok")
241            },
242            "HttpResponse" | "potato :: HttpResponse" => quote! {
243                let #(#arg_names),* = #(#args),*;
244                #fn_name(#(#arg_names),*).await
245            },
246            _ => panic!("unsupported ret type: {ret_type}"),
247        },
248        _ => match &ret_type[..] {
249            "Result < () >" => quote! {
250                let (#(#arg_names),*) = (#(#args),*);
251                match #fn_name(#(#arg_names),*).await {
252                    Ok(ret) => potato::HttpResponse::text("ok"),
253                    Err(err) => potato::HttpResponse::error(format!("{err:?}")),
254                }
255            },
256            "Result < HttpResponse >" | "Result < potato :: HttpResponse >" => quote! {
257                let (#(#arg_names),*) = (#(#args),*);
258                match #fn_name(#(#arg_names),*).await {
259                    Ok(ret) => ret,
260                    Err(err) => potato::HttpResponse::error(format!("{err:?}")),
261                }
262            },
263            "()" => quote! {
264                let (#(#arg_names),*) = (#(#args),*);
265                #fn_name(#(#arg_names),*).await;
266                potato::HttpResponse::text("ok")
267            },
268            "HttpResponse" | "potato :: HttpResponse" => quote! {
269                let (#(#arg_names),*) = (#(#args),*);
270                #fn_name(#(#arg_names),*).await
271            },
272            _ => panic!("unsupported ret type: {ret_type}"),
273        },
274    };
275    let doc_args = serde_json::to_string(&doc_args).unwrap();
276    //let mut content =
277    quote! {
278        #root_fn
279
280        #[doc(hidden)]
281        async fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
282            #wrap_func_body
283        }
284
285        #[doc(hidden)]
286        fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
287            Box::pin(#wrap_func_name2(req))
288        }
289
290        potato::inventory::submit!{potato::RequestHandlerFlag::new(
291            potato::HttpMethod::#req_name,
292            #route_path,
293            #wrap_func_name,
294            potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args)
295        )}
296    }.into()
297    //}.to_string();
298    //panic!("{content}");
299    //todo!()
300}
301
302#[proc_macro_attribute]
303pub fn http_get(attr: TokenStream, input: TokenStream) -> TokenStream {
304    http_handler_macro(attr, input, "GET")
305}
306
307#[proc_macro_attribute]
308pub fn http_post(attr: TokenStream, input: TokenStream) -> TokenStream {
309    http_handler_macro(attr, input, "POST")
310}
311
312#[proc_macro_attribute]
313pub fn http_put(attr: TokenStream, input: TokenStream) -> TokenStream {
314    http_handler_macro(attr, input, "PUT")
315}
316
317#[proc_macro_attribute]
318pub fn http_delete(attr: TokenStream, input: TokenStream) -> TokenStream {
319    http_handler_macro(attr, input, "DELETE")
320}
321
322#[proc_macro_attribute]
323pub fn http_options(attr: TokenStream, input: TokenStream) -> TokenStream {
324    http_handler_macro(attr, input, "OPTIONS")
325}
326
327#[proc_macro_attribute]
328pub fn http_head(attr: TokenStream, input: TokenStream) -> TokenStream {
329    http_handler_macro(attr, input, "HEAD")
330}
331
332#[proc_macro]
333pub fn embed_dir(input: TokenStream) -> TokenStream {
334    let path = syn::parse_macro_input!(input as syn::LitStr).value();
335    quote! {{
336        #[derive(potato::rust_embed::Embed)]
337        #[folder = #path]
338        struct Asset;
339
340        potato::load_embed::<Asset>()
341    }}
342    .into()
343}
344
345#[proc_macro_derive(StandardHeader)]
346pub fn standard_header_derive(input: TokenStream) -> TokenStream {
347    let root_enum = syn::parse_macro_input!(input as syn::ItemEnum);
348    let enum_name = root_enum.ident;
349    let mut try_from_str_items = vec![];
350    let mut to_str_items = vec![];
351    let mut headers_items = vec![];
352    let mut headers_apply_items = vec![];
353    for root_field in root_enum.variants.iter() {
354        let name = root_field.ident.clone();
355        if root_field.fields.iter().next().is_some() {
356            panic!("unsupported enum type");
357        }
358        let str_name = name.to_string().replace("_", "-");
359        let len = str_name.len();
360        try_from_str_items
361            .push(quote! { #len if value.eq_ignore_ascii_case(#str_name) => Some(Self::#name), });
362        to_str_items.push(quote! { Self::#name => #str_name, });
363        headers_items.push(quote! { #name(String), });
364        headers_apply_items
365            .push(quote! { Headers::#name(s) => self.set_header(HeaderItem::#name.to_str(), s), });
366    }
367    let r = quote! {
368        impl #enum_name {
369            pub fn try_from_str(value: &str) -> Option<Self> {
370                match value.len() {
371                    #( #try_from_str_items )*
372                    _ => None,
373                }
374            }
375
376            pub fn to_str(&self) -> &'static str {
377                match self {
378                    #( #to_str_items )*
379                }
380            }
381        }
382
383        pub enum Headers {
384            #( #headers_items )*
385            Custom((String, String)),
386        }
387
388        impl HttpRequest {
389            pub fn apply_header(&mut self, header: Headers) {
390                match header {
391                    #( #headers_apply_items )*
392                    Headers::Custom((k, v)) => self.set_header(&k[..], v),
393                }
394            }
395        }
396    };
397    r.into()
398}