Skip to main content

potato_macro/
lib.rs

1mod utils;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span};
5use quote::{quote, ToTokens};
6use rand::Rng;
7use serde_json::json;
8use std::{collections::HashSet, sync::LazyLock};
9use utils::StringExt as _;
10
11static ARG_TYPES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
12    [
13        "String", "bool", "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
14        "f32", "f64",
15    ]
16    .into_iter()
17    .collect()
18});
19
20fn random_ident() -> Ident {
21    let mut rng = rand::thread_rng();
22    let value = format!("__potato_id_{}", rng.r#gen::<u64>());
23    Ident::new(&value, Span::call_site())
24}
25
26fn http_handler_macro(attr: TokenStream, input: TokenStream, req_name: &str) -> TokenStream {
27    let req_name = Ident::new(req_name, Span::call_site());
28    let (route_path, oauth_arg) = {
29        let mut oroute_path = syn::parse::<syn::LitStr>(attr.clone())
30            .ok()
31            .map(|path| path.value());
32        let mut oauth_arg = None;
33        //
34        if oroute_path.is_none() {
35            let http_parser = syn::meta::parser(|meta| {
36                if meta.path.is_ident("path") {
37                    if let Ok(arg) = meta.value() {
38                        if let Ok(route_path) = arg.parse::<syn::LitStr>() {
39                            let route_path = route_path.value();
40                            oroute_path = Some(route_path);
41                        }
42                    }
43                    Ok(())
44                } else if meta.path.is_ident("auth_arg") {
45                    if let Ok(arg) = meta.value() {
46                        if let Ok(tmp_field) = arg.parse::<Ident>() {
47                            oauth_arg = Some(tmp_field.to_string());
48                        }
49                    }
50                    Ok(())
51                } else {
52                    Err(meta.error("unsupported annotation property"))
53                }
54            });
55            syn::parse_macro_input!(attr with http_parser);
56        }
57        if oroute_path.is_none() {
58            panic!("`path` argument is required");
59        }
60        let route_path = oroute_path.unwrap();
61        if !route_path.starts_with('/') {
62            panic!("route path must start with '/'");
63        }
64        (route_path, oauth_arg)
65    };
66    let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
67    let doc_show = {
68        let mut doc_show = true;
69        for attr in root_fn.attrs.iter() {
70            if attr.meta.path().get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
71                if let Ok(meta_list) = attr.meta.require_list() {
72                    if meta_list.tokens.to_string() == "hidden" {
73                        doc_show = false;
74                        break;
75                    }
76                }
77            }
78        }
79        doc_show
80    };
81    let doc_auth = oauth_arg.is_some();
82    let doc_summary = {
83        let mut docs = vec![];
84        for attr in root_fn.attrs.iter() {
85            if let Ok(attr) = attr.meta.require_name_value() {
86                if attr.path.get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
87                    let mut doc = attr.value.to_token_stream().to_string();
88                    if doc.starts_with('\"') {
89                        doc.remove(0);
90                        doc.pop();
91                    }
92                    docs.push(doc);
93                }
94            }
95        }
96        if docs.iter().all(|d| d.starts_with(' ')) {
97            for doc in docs.iter_mut() {
98                doc.remove(0);
99            }
100        }
101        docs.join("\n")
102    };
103    let doc_desp = "";
104    let fn_name = root_fn.sig.ident.clone();
105    let is_async = root_fn.sig.asyncness.is_some();
106    let wrap_func_name = random_ident();
107    let mut args = vec![];
108    let mut arg_names = vec![];
109    let mut doc_args = vec![];
110    let mut arg_auth_mark = false;
111    for arg in root_fn.sig.inputs.iter() {
112        if let syn::FnArg::Typed(arg) = arg {
113            let arg_type_str = arg
114                .ty
115                .as_ref()
116                .to_token_stream()
117                .to_string()
118                .type_simplify();
119            let arg_name_str = arg.pat.to_token_stream().to_string();
120            args.push(match &arg_type_str[..] {
121                "& mut HttpRequest" => quote! { req },
122                "PostFile" => {
123                    doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
124                    quote! {
125                        match req.body_files.get(&potato::utils::refstr::LocalHipStr<'static>::from_str(#arg_name_str)).cloned() {
126                            Some(file) => file,
127                            None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
128                        }
129                    }
130                },
131                arg_type_str if ARG_TYPES.contains(arg_type_str) => {
132                    let is_auth_arg = match oauth_arg.as_ref() {
133                        Some(auth_arg) => auth_arg == &arg_name_str,
134                        None => false,
135                    };
136                    if is_auth_arg {
137                        if arg_type_str != "String" {
138                            panic!("auth_arg argument is must String type");
139                        }
140                        arg_auth_mark = true;
141                        if is_async {
142                            quote! {
143                                match req.headers
144                                    .get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization"))
145                                    .map(|v| v.to_str()) {
146                                    Some(mut auth) => {
147                                        if auth.starts_with("Bearer ") {
148                                            auth = &auth[7..];
149                                        }
150                                        match potato::ServerAuth::jwt_check(&auth).await {
151                                            Ok(payload) => payload,
152                                            Err(err) => return potato::HttpResponse::error(format!("auth failed: {err:?}")),
153                                        }
154                                    }
155                                    None => return potato::HttpResponse::error("miss header : Authorization"),
156                                }
157                            }
158                        } else {
159                            quote! {
160                                match req.headers
161                                    .get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization"))
162                                    .map(|v| v.to_str()) {
163                                    Some(mut auth) => {
164                                        if auth.starts_with("Bearer ") {
165                                            auth = &auth[7..];
166                                        }
167                                        match tokio::task::block_in_place(|| {
168                                            tokio::runtime::Handle::current().block_on(potato::ServerAuth::jwt_check(&auth))
169                                        }) {
170                                            Ok(payload) => payload,
171                                            Err(err) => return potato::HttpResponse::error(format!("auth failed: {err:?}")),
172                                        }
173                                    }
174                                    None => return potato::HttpResponse::error("miss header : Authorization"),
175                                }
176                            }
177                        }
178                    } else {
179                        doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
180                        let mut arg_value = quote! {
181                            match req.body_pairs
182                                .get(&potato::hipstr::LocalHipStr::from(#arg_name_str))
183                                .map(|p| p.to_string()) {
184                                Some(val) => val,
185                                None => match req.url_query
186                                    .get(&potato::hipstr::LocalHipStr::from(#arg_name_str))
187                                    .map(|p| p.as_str().to_string()) {
188                                    Some(val) => val,
189                                    None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
190                                },
191                            }
192                        };
193                        if arg_type_str != "String" {
194                            arg_value = quote! {
195                                match #arg_value.parse() {
196                                    Ok(val) => val,
197                                    Err(err) => return potato::HttpResponse::error(format!("arg[{}] is not {} type", #arg_name_str, #arg_type_str)),
198                                }
199                            }
200                        }
201                        arg_value
202                    }
203                },
204                _ => panic!("unsupported arg type: [{arg_type_str}]"),
205            });
206            arg_names.push(random_ident());
207        } else {
208            panic!("unsupported: {}", arg.to_token_stream().to_string());
209        }
210    }
211    if !arg_auth_mark && doc_auth {
212        panic!("`auth_arg` attribute is must point to an existing argument");
213    }
214    let wrap_func_name2 = random_ident();
215    let ret_type = root_fn
216        .sig
217        .output
218        .to_token_stream()
219        .to_string()
220        .type_simplify();
221    let call_expr = match args.len() {
222        0 => quote! { #fn_name() },
223        1 => quote! {{
224            let #(#arg_names),* = #(#args),*;
225            #fn_name(#(#arg_names),*)
226        }},
227        _ => quote! {{
228            let (#(#arg_names),*) = (#(#args),*);
229            #fn_name(#(#arg_names),*)
230        }},
231    };
232    let wrap_func_body = if is_async {
233        match &ret_type[..] {
234            "Result<()>" => quote! {
235                match #call_expr.await {
236                    Ok(_) => potato::HttpResponse::text("ok"),
237                    Err(err) => potato::HttpResponse::error(format!("{err:?}")),
238                }
239            },
240            "Result<HttpResponse>" | "anyhow::Result<HttpResponse>" => quote! {
241                match #call_expr.await {
242                    Ok(ret) => ret,
243                    Err(err) => potato::HttpResponse::error(format!("{err:?}")),
244                }
245            },
246            "()" => quote! {
247                #call_expr.await;
248                potato::HttpResponse::text("ok")
249            },
250            "HttpResponse" => quote! {
251                #call_expr.await
252            },
253            _ => panic!("unsupported ret type: {ret_type}"),
254        }
255    } else {
256        match &ret_type[..] {
257            "Result<()>" => quote! {
258                match #call_expr {
259                    Ok(_) => potato::HttpResponse::text("ok"),
260                    Err(err) => potato::HttpResponse::error(format!("{err:?}")),
261                }
262            },
263            "Result<HttpResponse>" | "anyhow::Result<HttpResponse>" => quote! {
264                match #call_expr {
265                    Ok(ret) => ret,
266                    Err(err) => potato::HttpResponse::error(format!("{err:?}")),
267                }
268            },
269            "()" => quote! {
270                #call_expr;
271                potato::HttpResponse::text("ok")
272            },
273            "HttpResponse" => quote! {
274                #call_expr
275            },
276            _ => panic!("unsupported ret type: {ret_type}"),
277        }
278    };
279    let doc_args = serde_json::to_string(&doc_args).unwrap();
280    if is_async {
281        quote! {
282            #root_fn
283
284            #[doc(hidden)]
285            async fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
286                #wrap_func_body
287            }
288
289            #[doc(hidden)]
290            fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
291                Box::pin(#wrap_func_name2(req))
292            }
293
294            potato::inventory::submit!{potato::RequestHandlerFlag::new(
295                potato::HttpMethod::#req_name,
296                #route_path,
297                potato::HttpHandler::Async(#wrap_func_name),
298                potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args)
299            )}
300        }
301        .into()
302    } else {
303        quote! {
304            #root_fn
305
306            #[doc(hidden)]
307            fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
308                #wrap_func_body
309            }
310
311            potato::inventory::submit!{potato::RequestHandlerFlag::new(
312                potato::HttpMethod::#req_name,
313                #route_path,
314                potato::HttpHandler::Sync(#wrap_func_name2),
315                potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args)
316            )}
317        }
318        .into()
319    }
320    //}.to_string();
321    //panic!("{content}");
322    //todo!()
323}
324
325#[proc_macro_attribute]
326pub fn http_get(attr: TokenStream, input: TokenStream) -> TokenStream {
327    http_handler_macro(attr, input, "GET")
328}
329
330#[proc_macro_attribute]
331pub fn http_post(attr: TokenStream, input: TokenStream) -> TokenStream {
332    http_handler_macro(attr, input, "POST")
333}
334
335#[proc_macro_attribute]
336pub fn http_put(attr: TokenStream, input: TokenStream) -> TokenStream {
337    http_handler_macro(attr, input, "PUT")
338}
339
340#[proc_macro_attribute]
341pub fn http_delete(attr: TokenStream, input: TokenStream) -> TokenStream {
342    http_handler_macro(attr, input, "DELETE")
343}
344
345#[proc_macro_attribute]
346pub fn http_options(attr: TokenStream, input: TokenStream) -> TokenStream {
347    http_handler_macro(attr, input, "OPTIONS")
348}
349
350#[proc_macro_attribute]
351pub fn http_head(attr: TokenStream, input: TokenStream) -> TokenStream {
352    http_handler_macro(attr, input, "HEAD")
353}
354
355#[proc_macro]
356pub fn embed_dir(input: TokenStream) -> TokenStream {
357    let path = syn::parse_macro_input!(input as syn::LitStr).value();
358    quote! {{
359        #[derive(potato::rust_embed::Embed)]
360        #[folder = #path]
361        struct Asset;
362
363        potato::load_embed::<Asset>()
364    }}
365    .into()
366}
367
368#[proc_macro_derive(StandardHeader)]
369pub fn standard_header_derive(input: TokenStream) -> TokenStream {
370    let root_enum = syn::parse_macro_input!(input as syn::ItemEnum);
371    let enum_name = root_enum.ident;
372    let mut try_from_str_items = vec![];
373    let mut to_str_items = vec![];
374    let mut headers_items = vec![];
375    let mut headers_apply_items = vec![];
376    for root_field in root_enum.variants.iter() {
377        let name = root_field.ident.clone();
378        if root_field.fields.iter().next().is_some() {
379            panic!("unsupported enum type");
380        }
381        let str_name = name.to_string().replace("_", "-");
382        let len = str_name.len();
383        try_from_str_items
384            .push(quote! { #len if value.eq_ignore_ascii_case(#str_name) => Some(Self::#name), });
385        to_str_items.push(quote! { Self::#name => #str_name, });
386        headers_items.push(quote! { #name(String), });
387        headers_apply_items
388            .push(quote! { Headers::#name(s) => self.set_header(HeaderItem::#name.to_str(), s), });
389    }
390    let r = quote! {
391        impl #enum_name {
392            pub fn try_from_str(value: &str) -> Option<Self> {
393                match value.len() {
394                    #( #try_from_str_items )*
395                    _ => None,
396                }
397            }
398
399            pub fn to_str(&self) -> &'static str {
400                match self {
401                    #( #to_str_items )*
402                }
403            }
404        }
405
406        pub enum Headers {
407            #( #headers_items )*
408            Custom((String, String)),
409        }
410
411        impl HttpRequest {
412            pub fn apply_header(&mut self, header: Headers) {
413                match header {
414                    #( #headers_apply_items )*
415                    Headers::Custom((k, v)) => self.set_header(&k[..], v),
416                }
417            }
418        }
419    };
420    r.into()
421}