1mod utils;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span};
5use quote::{quote, ToTokens};
6use rand::Rng;
7use serde_json::json;
8use std::{collections::HashSet, sync::LazyLock};
9use utils::StringExt as _;
10
11static ARG_TYPES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
12 [
13 "String", "bool", "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
14 "f32", "f64",
15 ]
16 .into_iter()
17 .collect()
18});
19
20fn random_ident() -> Ident {
21 let mut rng = rand::thread_rng();
22 let value = format!("__potato_id_{}", rng.r#gen::<u64>());
23 Ident::new(&value, Span::call_site())
24}
25
26fn http_handler_macro(attr: TokenStream, input: TokenStream, req_name: &str) -> TokenStream {
27 let req_name = Ident::new(req_name, Span::call_site());
28 let (route_path, oauth_arg) = {
29 let mut oroute_path = syn::parse::<syn::LitStr>(attr.clone())
30 .ok()
31 .map(|path| path.value());
32 let mut oauth_arg = None;
33 if oroute_path.is_none() {
35 let http_parser = syn::meta::parser(|meta| {
36 if meta.path.is_ident("path") {
37 if let Ok(arg) = meta.value() {
38 if let Ok(route_path) = arg.parse::<syn::LitStr>() {
39 let route_path = route_path.value();
40 oroute_path = Some(route_path);
41 }
42 }
43 Ok(())
44 } else if meta.path.is_ident("auth_arg") {
45 if let Ok(arg) = meta.value() {
46 if let Ok(tmp_field) = arg.parse::<Ident>() {
47 oauth_arg = Some(tmp_field.to_string());
48 }
49 }
50 Ok(())
51 } else {
52 Err(meta.error("unsupported annotation property"))
53 }
54 });
55 syn::parse_macro_input!(attr with http_parser);
56 }
57 if oroute_path.is_none() {
58 panic!("`path` argument is required");
59 }
60 let route_path = oroute_path.unwrap();
61 if !route_path.starts_with('/') {
62 panic!("route path must start with '/'");
63 }
64 (route_path, oauth_arg)
65 };
66 let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
67 let doc_show = {
68 let mut doc_show = true;
69 for attr in root_fn.attrs.iter() {
70 if attr.meta.path().get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
71 if let Ok(meta_list) = attr.meta.require_list() {
72 if meta_list.tokens.to_string() == "hidden" {
73 doc_show = false;
74 break;
75 }
76 }
77 }
78 }
79 doc_show
80 };
81 let doc_auth = oauth_arg.is_some();
82 let doc_summary = {
83 let mut docs = vec![];
84 for attr in root_fn.attrs.iter() {
85 if let Ok(attr) = attr.meta.require_name_value() {
86 if attr.path.get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
87 let mut doc = attr.value.to_token_stream().to_string();
88 if doc.starts_with('\"') {
89 doc.remove(0);
90 doc.pop();
91 }
92 docs.push(doc);
93 }
94 }
95 }
96 if docs.iter().all(|d| d.starts_with(' ')) {
97 for doc in docs.iter_mut() {
98 doc.remove(0);
99 }
100 }
101 docs.join("\n")
102 };
103 let doc_desp = "";
104 let fn_name = root_fn.sig.ident.clone();
105 let is_async = root_fn.sig.asyncness.is_some();
106 let wrap_func_name = random_ident();
107 let mut args = vec![];
108 let mut arg_names = vec![];
109 let mut doc_args = vec![];
110 let mut arg_auth_mark = false;
111 for arg in root_fn.sig.inputs.iter() {
112 if let syn::FnArg::Typed(arg) = arg {
113 let arg_type_str = arg
114 .ty
115 .as_ref()
116 .to_token_stream()
117 .to_string()
118 .type_simplify();
119 let arg_name_str = arg.pat.to_token_stream().to_string();
120 args.push(match &arg_type_str[..] {
121 "& mut HttpRequest" => quote! { req },
122 "PostFile" => {
123 doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
124 quote! {
125 match req.body_files.get(&potato::utils::refstr::LocalHipStr<'static>::from_str(#arg_name_str)).cloned() {
126 Some(file) => file,
127 None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
128 }
129 }
130 },
131 arg_type_str if ARG_TYPES.contains(arg_type_str) => {
132 let is_auth_arg = match oauth_arg.as_ref() {
133 Some(auth_arg) => auth_arg == &arg_name_str,
134 None => false,
135 };
136 if is_auth_arg {
137 if arg_type_str != "String" {
138 panic!("auth_arg argument is must String type");
139 }
140 arg_auth_mark = true;
141 if is_async {
142 quote! {
143 match req.headers
144 .get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization"))
145 .map(|v| v.to_str()) {
146 Some(mut auth) => {
147 if auth.starts_with("Bearer ") {
148 auth = &auth[7..];
149 }
150 match potato::ServerAuth::jwt_check(&auth).await {
151 Ok(payload) => payload,
152 Err(err) => return potato::HttpResponse::error(format!("auth failed: {err:?}")),
153 }
154 }
155 None => return potato::HttpResponse::error("miss header : Authorization"),
156 }
157 }
158 } else {
159 quote! {
160 match req.headers
161 .get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization"))
162 .map(|v| v.to_str()) {
163 Some(mut auth) => {
164 if auth.starts_with("Bearer ") {
165 auth = &auth[7..];
166 }
167 match tokio::task::block_in_place(|| {
168 tokio::runtime::Handle::current().block_on(potato::ServerAuth::jwt_check(&auth))
169 }) {
170 Ok(payload) => payload,
171 Err(err) => return potato::HttpResponse::error(format!("auth failed: {err:?}")),
172 }
173 }
174 None => return potato::HttpResponse::error("miss header : Authorization"),
175 }
176 }
177 }
178 } else {
179 doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
180 let mut arg_value = quote! {
181 match req.body_pairs
182 .get(&potato::hipstr::LocalHipStr::from(#arg_name_str))
183 .map(|p| p.to_string()) {
184 Some(val) => val,
185 None => match req.url_query
186 .get(&potato::hipstr::LocalHipStr::from(#arg_name_str))
187 .map(|p| p.as_str().to_string()) {
188 Some(val) => val,
189 None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
190 },
191 }
192 };
193 if arg_type_str != "String" {
194 arg_value = quote! {
195 match #arg_value.parse() {
196 Ok(val) => val,
197 Err(err) => return potato::HttpResponse::error(format!("arg[{}] is not {} type", #arg_name_str, #arg_type_str)),
198 }
199 }
200 }
201 arg_value
202 }
203 },
204 _ => panic!("unsupported arg type: [{arg_type_str}]"),
205 });
206 arg_names.push(random_ident());
207 } else {
208 panic!("unsupported: {}", arg.to_token_stream().to_string());
209 }
210 }
211 if !arg_auth_mark && doc_auth {
212 panic!("`auth_arg` attribute is must point to an existing argument");
213 }
214 let wrap_func_name2 = random_ident();
215 let ret_type = root_fn
216 .sig
217 .output
218 .to_token_stream()
219 .to_string()
220 .type_simplify();
221 let call_expr = match args.len() {
222 0 => quote! { #fn_name() },
223 1 => quote! {{
224 let #(#arg_names),* = #(#args),*;
225 #fn_name(#(#arg_names),*)
226 }},
227 _ => quote! {{
228 let (#(#arg_names),*) = (#(#args),*);
229 #fn_name(#(#arg_names),*)
230 }},
231 };
232 let wrap_func_body = if is_async {
233 match &ret_type[..] {
234 "Result<()>" => quote! {
235 match #call_expr.await {
236 Ok(_) => potato::HttpResponse::text("ok"),
237 Err(err) => potato::HttpResponse::error(format!("{err:?}")),
238 }
239 },
240 "Result<HttpResponse>" | "anyhow::Result<HttpResponse>" => quote! {
241 match #call_expr.await {
242 Ok(ret) => ret,
243 Err(err) => potato::HttpResponse::error(format!("{err:?}")),
244 }
245 },
246 "()" => quote! {
247 #call_expr.await;
248 potato::HttpResponse::text("ok")
249 },
250 "HttpResponse" => quote! {
251 #call_expr.await
252 },
253 _ => panic!("unsupported ret type: {ret_type}"),
254 }
255 } else {
256 match &ret_type[..] {
257 "Result<()>" => quote! {
258 match #call_expr {
259 Ok(_) => potato::HttpResponse::text("ok"),
260 Err(err) => potato::HttpResponse::error(format!("{err:?}")),
261 }
262 },
263 "Result<HttpResponse>" | "anyhow::Result<HttpResponse>" => quote! {
264 match #call_expr {
265 Ok(ret) => ret,
266 Err(err) => potato::HttpResponse::error(format!("{err:?}")),
267 }
268 },
269 "()" => quote! {
270 #call_expr;
271 potato::HttpResponse::text("ok")
272 },
273 "HttpResponse" => quote! {
274 #call_expr
275 },
276 _ => panic!("unsupported ret type: {ret_type}"),
277 }
278 };
279 let doc_args = serde_json::to_string(&doc_args).unwrap();
280 if is_async {
281 quote! {
282 #root_fn
283
284 #[doc(hidden)]
285 async fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
286 #wrap_func_body
287 }
288
289 #[doc(hidden)]
290 fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
291 Box::pin(#wrap_func_name2(req))
292 }
293
294 potato::inventory::submit!{potato::RequestHandlerFlag::new(
295 potato::HttpMethod::#req_name,
296 #route_path,
297 potato::HttpHandler::Async(#wrap_func_name),
298 potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args)
299 )}
300 }
301 .into()
302 } else {
303 quote! {
304 #root_fn
305
306 #[doc(hidden)]
307 fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
308 #wrap_func_body
309 }
310
311 potato::inventory::submit!{potato::RequestHandlerFlag::new(
312 potato::HttpMethod::#req_name,
313 #route_path,
314 potato::HttpHandler::Sync(#wrap_func_name2),
315 potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args)
316 )}
317 }
318 .into()
319 }
320 }
324
325#[proc_macro_attribute]
326pub fn http_get(attr: TokenStream, input: TokenStream) -> TokenStream {
327 http_handler_macro(attr, input, "GET")
328}
329
330#[proc_macro_attribute]
331pub fn http_post(attr: TokenStream, input: TokenStream) -> TokenStream {
332 http_handler_macro(attr, input, "POST")
333}
334
335#[proc_macro_attribute]
336pub fn http_put(attr: TokenStream, input: TokenStream) -> TokenStream {
337 http_handler_macro(attr, input, "PUT")
338}
339
340#[proc_macro_attribute]
341pub fn http_delete(attr: TokenStream, input: TokenStream) -> TokenStream {
342 http_handler_macro(attr, input, "DELETE")
343}
344
345#[proc_macro_attribute]
346pub fn http_options(attr: TokenStream, input: TokenStream) -> TokenStream {
347 http_handler_macro(attr, input, "OPTIONS")
348}
349
350#[proc_macro_attribute]
351pub fn http_head(attr: TokenStream, input: TokenStream) -> TokenStream {
352 http_handler_macro(attr, input, "HEAD")
353}
354
355#[proc_macro]
356pub fn embed_dir(input: TokenStream) -> TokenStream {
357 let path = syn::parse_macro_input!(input as syn::LitStr).value();
358 quote! {{
359 #[derive(potato::rust_embed::Embed)]
360 #[folder = #path]
361 struct Asset;
362
363 potato::load_embed::<Asset>()
364 }}
365 .into()
366}
367
368#[proc_macro_derive(StandardHeader)]
369pub fn standard_header_derive(input: TokenStream) -> TokenStream {
370 let root_enum = syn::parse_macro_input!(input as syn::ItemEnum);
371 let enum_name = root_enum.ident;
372 let mut try_from_str_items = vec![];
373 let mut to_str_items = vec![];
374 let mut headers_items = vec![];
375 let mut headers_apply_items = vec![];
376 for root_field in root_enum.variants.iter() {
377 let name = root_field.ident.clone();
378 if root_field.fields.iter().next().is_some() {
379 panic!("unsupported enum type");
380 }
381 let str_name = name.to_string().replace("_", "-");
382 let len = str_name.len();
383 try_from_str_items
384 .push(quote! { #len if value.eq_ignore_ascii_case(#str_name) => Some(Self::#name), });
385 to_str_items.push(quote! { Self::#name => #str_name, });
386 headers_items.push(quote! { #name(String), });
387 headers_apply_items
388 .push(quote! { Headers::#name(s) => self.set_header(HeaderItem::#name.to_str(), s), });
389 }
390 let r = quote! {
391 impl #enum_name {
392 pub fn try_from_str(value: &str) -> Option<Self> {
393 match value.len() {
394 #( #try_from_str_items )*
395 _ => None,
396 }
397 }
398
399 pub fn to_str(&self) -> &'static str {
400 match self {
401 #( #to_str_items )*
402 }
403 }
404 }
405
406 pub enum Headers {
407 #( #headers_items )*
408 Custom((String, String)),
409 }
410
411 impl HttpRequest {
412 pub fn apply_header(&mut self, header: Headers) {
413 match header {
414 #( #headers_apply_items )*
415 Headers::Custom((k, v)) => self.set_header(&k[..], v),
416 }
417 }
418 }
419 };
420 r.into()
421}