1mod utils;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span};
5use quote::{format_ident, quote, ToTokens};
6use rand::Rng;
7use serde_json::json;
8use std::{collections::HashSet, sync::LazyLock};
9use syn::Token;
10use utils::StringExt as _;
11
12struct CorsAttrConfig {
14 origin: Option<String>,
15 methods: Option<String>,
16 headers: Option<String>,
17 max_age: Option<String>,
18 credentials: bool,
19 expose_headers: Option<String>,
20}
21
22fn parse_cors_attr(tokens: &proc_macro2::TokenStream) -> CorsAttrConfig {
24 let config = CorsAttrConfig {
25 origin: None,
26 methods: None,
27 headers: None,
28 max_age: None,
29 credentials: false,
30 expose_headers: None,
31 };
32
33 if tokens.is_empty() {
34 return config; }
36
37 use syn::parse::Parser;
39
40 fn parse_inner(input: syn::parse::ParseStream) -> syn::Result<CorsAttrConfig> {
41 let mut config = CorsAttrConfig {
42 origin: None,
43 methods: None,
44 headers: None,
45 max_age: None,
46 credentials: false,
47 expose_headers: None,
48 };
49
50 let vars =
51 syn::punctuated::Punctuated::<syn::MetaNameValue, Token![,]>::parse_terminated(input)?;
52 for meta in vars {
53 let key = meta
54 .path
55 .get_ident()
56 .map(|i| i.to_string())
57 .unwrap_or_default();
58 if let syn::Expr::Lit(expr_lit) = &meta.value {
59 match &expr_lit.lit {
60 syn::Lit::Str(s) => {
61 let val = s.value();
62 match key.as_str() {
63 "origin" => config.origin = Some(val),
64 "methods" => config.methods = Some(val),
65 "headers" => config.headers = Some(val),
66 "max_age" => config.max_age = Some(val),
67 "expose_headers" => config.expose_headers = Some(val),
68 _ => {}
69 }
70 }
71 syn::Lit::Bool(b) => {
72 if key == "credentials" {
73 config.credentials = b.value();
74 }
75 }
76 _ => {}
77 }
78 }
79 }
80 Ok(config)
81 }
82
83 match parse_inner.parse2(tokens.clone()) {
84 Ok(cfg) => cfg,
85 Err(e) => panic!("Failed to parse cors attributes: {e}"),
86 }
87}
88
89static ARG_TYPES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
90 [
91 "String", "bool", "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
92 "f32", "f64",
93 ]
94 .into_iter()
95 .collect()
96});
97
98fn validate_controller_struct(item_struct: &syn::ItemStruct) -> (bool, bool) {
101 let mut has_once_cache = false;
102 let mut has_session_cache = false;
103
104 if let syn::Fields::Named(fields_named) = &item_struct.fields {
105 for field in &fields_named.named {
106 let field_type_str = field.ty.to_token_stream().to_string().type_simplify();
107
108 if field_type_str.contains("OnceCache") {
110 has_once_cache = true;
111 } else if field_type_str.contains("SessionCache") {
112 has_session_cache = true;
113 } else {
114 panic!(
115 "Controller field must be &OnceCache or &SessionCache, got: {}",
116 field_type_str
117 );
118 }
119 }
120 }
121
122 (has_once_cache, has_session_cache)
123}
124
125fn parse_header_attr(tokens: &proc_macro2::TokenStream) -> Result<(String, String), syn::Error> {
127 use syn::parse::Parser;
128
129 let parser = |input: syn::parse::ParseStream| {
130 let key_ident: Ident = input.parse()?;
134 let key_name = key_ident.to_string();
135
136 if key_name == "Custom" {
137 let content;
139 syn::parenthesized!(content in input);
140 let key_lit: syn::LitStr = content.parse()?;
141 let key = key_lit.value();
142 let _: Token![=] = input.parse()?;
143 let value: syn::LitStr = input.parse()?;
144 Ok((key, value.value()))
145 } else {
146 let _: Token![=] = input.parse()?;
148 let value: syn::LitStr = input.parse()?;
149 Ok((key_name, value.value()))
150 }
151 };
152
153 parser.parse2(tokens.clone())
154}
155
156fn random_ident() -> Ident {
157 let mut rng = rand::thread_rng();
158 let value = format!("__potato_id_{}", rng.r#gen::<u64>());
159 Ident::new(&value, Span::call_site())
160}
161
162fn attr_last_ident(attr: &syn::Attribute) -> Option<String> {
163 attr.meta
164 .path()
165 .segments
166 .iter()
167 .last()
168 .map(|segment| segment.ident.to_string())
169}
170
171fn parse_hook_attr_items(attr: &syn::Attribute, attr_name: &str) -> Vec<Ident> {
172 let parser = syn::punctuated::Punctuated::<Ident, syn::Token![,]>::parse_terminated;
173 let idents = attr.parse_args_with(parser).unwrap_or_else(|err| {
174 panic!("invalid `{attr_name}` annotation: {err}");
175 });
176 if idents.is_empty() {
177 panic!("`{attr_name}` annotation requires at least one function name");
178 }
179 idents.into_iter().collect()
180}
181
182fn collect_handler_hooks(root_fn: &mut syn::ItemFn) -> (Vec<Ident>, Vec<Ident>) {
183 enum HookKind {
184 Pre,
185 Post,
186 }
187 let mut hooks = vec![];
188 let mut new_attrs = Vec::with_capacity(root_fn.attrs.len());
189 for attr in root_fn.attrs.iter() {
190 match attr_last_ident(attr).as_deref() {
191 Some("preprocess") => {
192 hooks.extend(
193 parse_hook_attr_items(attr, "preprocess")
194 .into_iter()
195 .map(|item| (HookKind::Pre, item)),
196 );
197 }
198 Some("postprocess") => {
199 hooks.extend(
200 parse_hook_attr_items(attr, "postprocess")
201 .into_iter()
202 .map(|item| (HookKind::Post, item)),
203 );
204 }
205 _ => new_attrs.push(attr.clone()),
206 }
207 }
208 root_fn.attrs = new_attrs;
209 let mut preprocess_fns = vec![];
210 let mut postprocess_fns = vec![];
211 for (kind, hook) in hooks.into_iter() {
212 match kind {
213 HookKind::Pre => preprocess_fns.push(hook),
214 HookKind::Post => postprocess_fns.push(hook),
215 }
216 }
217 (preprocess_fns, postprocess_fns)
218}
219
220fn validate_preprocess_signature(root_fn: &syn::ItemFn) -> (String, bool, bool) {
221 if root_fn.sig.inputs.is_empty() || root_fn.sig.inputs.len() > 3 {
222 panic!("`preprocess` function must accept one to three arguments");
223 }
224 let mut arg_types = vec![];
225 for arg in root_fn.sig.inputs.iter() {
226 match arg {
227 syn::FnArg::Typed(arg) => {
228 arg_types.push(arg.ty.to_token_stream().to_string().type_simplify())
229 }
230 _ => panic!("`preprocess` function does not support receiver argument"),
231 }
232 }
233 if arg_types[0] != "& mut HttpRequest" {
234 panic!(
235 "`preprocess` first argument type must be `&mut potato::HttpRequest`, got `{}`",
236 arg_types[0]
237 );
238 }
239
240 let has_once_cache = arg_types.iter().any(|t| t == "& mut OnceCache");
241 let has_session_cache = arg_types.iter().any(|t| t == "& mut SessionCache");
242
243 if arg_types.len() == 2 && !has_once_cache && !has_session_cache {
244 panic!(
245 "`preprocess` second argument type must be `&mut potato::OnceCache` or `&mut potato::SessionCache`, got `{}`",
246 arg_types[1]
247 );
248 }
249 if arg_types.len() == 3 {
250 if !has_once_cache {
251 panic!("`preprocess` must have `&mut potato::OnceCache` as one of the arguments");
252 }
253 if !has_session_cache {
254 panic!("`preprocess` must have `&mut potato::SessionCache` as one of the arguments");
255 }
256 }
257
258 let ret_type = root_fn
259 .sig
260 .output
261 .to_token_stream()
262 .to_string()
263 .type_simplify();
264 match &ret_type[..] {
265 "Result<Option<HttpResponse>>" | "Option<HttpResponse>" | "Result<()>" | "()" => {}
266 _ => panic!(
267 "unsupported `preprocess` return type: `{ret_type}`, expected `anyhow::Result<Option<potato::HttpResponse>>`, `Option<potato::HttpResponse>`, `anyhow::Result<()>`, or `()`"
268 ),
269 }
270 (ret_type, has_once_cache, has_session_cache)
271}
272
273fn validate_postprocess_signature(root_fn: &syn::ItemFn) -> (String, bool, bool) {
274 if root_fn.sig.inputs.len() < 2 && root_fn.sig.inputs.len() > 4 {
275 panic!("`postprocess` function must accept two to four arguments");
276 }
277 let mut arg_types = vec![];
278 for arg in root_fn.sig.inputs.iter() {
279 match arg {
280 syn::FnArg::Typed(arg) => {
281 arg_types.push(arg.ty.to_token_stream().to_string().type_simplify())
282 }
283 _ => panic!("`postprocess` function does not support receiver argument"),
284 }
285 }
286 if arg_types[0] != "& mut HttpRequest" {
287 panic!(
288 "`postprocess` first argument must be `&mut potato::HttpRequest`, got `{}`",
289 arg_types[0]
290 );
291 }
292 if arg_types[1] != "& mut HttpResponse" {
293 panic!(
294 "`postprocess` second argument must be `&mut potato::HttpResponse`, got `{}`",
295 arg_types[1]
296 );
297 }
298
299 let remaining_args = &arg_types[2..];
300 let has_once_cache = remaining_args.iter().any(|t| t == "& mut OnceCache");
301 let has_session_cache = remaining_args.iter().any(|t| t == "& mut SessionCache");
302
303 if arg_types.len() == 3 && !has_once_cache && !has_session_cache {
304 panic!(
305 "`postprocess` third argument must be `&mut potato::OnceCache` or `&mut potato::SessionCache`, got `{}`",
306 arg_types[2]
307 );
308 }
309 if arg_types.len() == 4 && (!has_once_cache || !has_session_cache) {
310 panic!(
311 "`postprocess` with 4 arguments must have both `&mut potato::OnceCache` and `&mut potato::SessionCache`"
312 );
313 }
314
315 let ret_type = root_fn
316 .sig
317 .output
318 .to_token_stream()
319 .to_string()
320 .type_simplify();
321 match &ret_type[..] {
322 "Result<()>" | "()" => {}
323 _ => panic!(
324 "unsupported `postprocess` return type: `{ret_type}`, expected `anyhow::Result<()>` or `()`"
325 ),
326 }
327 (ret_type, has_once_cache, has_session_cache)
328}
329
330fn preprocess_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
331 if !attr.is_empty() {
332 return input;
333 }
334 let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
335 let fn_name = root_fn.sig.ident.clone();
336 let wrap_name = format_ident!("__potato_preprocess_adapter_{}", fn_name);
337 let wrap_name_inner = format_ident!("__potato_preprocess_adapter_inner_{}", fn_name);
338 let is_async = root_fn.sig.asyncness.is_some();
339 let (ret_type, has_once_cache, has_session_cache) = validate_preprocess_signature(&root_fn);
340
341 let wrap_signature = match (has_once_cache, has_session_cache) {
343 (true, true) => quote! {
344 async fn #wrap_name_inner(
345 req: &mut potato::HttpRequest,
346 once_cache: &mut potato::OnceCache,
347 session_cache: &mut potato::SessionCache,
348 ) -> anyhow::Result<Option<potato::HttpResponse>>
349 },
350 (true, false) => quote! {
351 async fn #wrap_name_inner(
352 req: &mut potato::HttpRequest,
353 once_cache: &mut potato::OnceCache,
354 ) -> anyhow::Result<Option<potato::HttpResponse>>
355 },
356 (false, true) => quote! {
357 async fn #wrap_name_inner(
358 req: &mut potato::HttpRequest,
359 session_cache: &mut potato::SessionCache,
360 ) -> anyhow::Result<Option<potato::HttpResponse>>
361 },
362 (false, false) => quote! {
363 async fn #wrap_name_inner(
364 req: &mut potato::HttpRequest,
365 ) -> anyhow::Result<Option<potato::HttpResponse>>
366 },
367 };
368
369 let call_body = if is_async {
371 match &ret_type[..] {
372 "Result<Option<HttpResponse>>" => match (has_once_cache, has_session_cache) {
373 (true, true) => {
374 quote! { #fn_name(req, once_cache, session_cache).await }
375 }
376 (true, false) => quote! { #fn_name(req, once_cache).await },
377 (false, true) => quote! { #fn_name(req, session_cache).await },
378 (false, false) => quote! { #fn_name(req).await },
379 },
380 "Option<HttpResponse>" => match (has_once_cache, has_session_cache) {
381 (true, true) => quote! { Ok(#fn_name(req, once_cache, session_cache).await) },
382 (true, false) => quote! { Ok(#fn_name(req, once_cache).await) },
383 (false, true) => quote! { Ok(#fn_name(req, session_cache).await) },
384 (false, false) => quote! { Ok(#fn_name(req).await) },
385 },
386 "Result<()>" => match (has_once_cache, has_session_cache) {
387 (true, true) => {
388 quote! { #fn_name(req, once_cache, session_cache).await.map(|_| None) }
389 }
390 (true, false) => quote! { #fn_name(req, once_cache).await.map(|_| None) },
391 (false, true) => quote! { #fn_name(req, session_cache).await.map(|_| None) },
392 (false, false) => quote! { #fn_name(req).await.map(|_| None) },
393 },
394 "()" => match (has_once_cache, has_session_cache) {
395 (true, true) => quote! { #fn_name(req, once_cache, session_cache).await; Ok(None) },
396 (true, false) => quote! { #fn_name(req, once_cache).await; Ok(None) },
397 (false, true) => quote! { #fn_name(req, session_cache).await; Ok(None) },
398 (false, false) => quote! { #fn_name(req).await; Ok(None) },
399 },
400 _ => unreachable!(),
401 }
402 } else {
403 match &ret_type[..] {
404 "Result<Option<HttpResponse>>" => match (has_once_cache, has_session_cache) {
405 (true, true) => quote! { #fn_name(req, once_cache, session_cache) },
406 (true, false) => quote! { #fn_name(req, once_cache) },
407 (false, true) => quote! { #fn_name(req, session_cache) },
408 (false, false) => quote! { #fn_name(req) },
409 },
410 "Option<HttpResponse>" => match (has_once_cache, has_session_cache) {
411 (true, true) => quote! { Ok(#fn_name(req, once_cache, session_cache)) },
412 (true, false) => quote! { Ok(#fn_name(req, once_cache)) },
413 (false, true) => quote! { Ok(#fn_name(req, session_cache)) },
414 (false, false) => quote! { Ok(#fn_name(req)) },
415 },
416 "Result<()>" => match (has_once_cache, has_session_cache) {
417 (true, true) => quote! { #fn_name(req, once_cache, session_cache).map(|_| None) },
418 (true, false) => quote! { #fn_name(req, once_cache).map(|_| None) },
419 (false, true) => quote! { #fn_name(req, session_cache).map(|_| None) },
420 (false, false) => quote! { #fn_name(req).map(|_| None) },
421 },
422 "()" => match (has_once_cache, has_session_cache) {
423 (true, true) => quote! { #fn_name(req, once_cache, session_cache); Ok(None) },
424 (true, false) => quote! { #fn_name(req, once_cache); Ok(None) },
425 (false, true) => quote! { #fn_name(req, session_cache); Ok(None) },
426 (false, false) => quote! { #fn_name(req); Ok(None) },
427 },
428 _ => unreachable!(),
429 }
430 };
431
432 let wrapper_body = match (has_once_cache, has_session_cache) {
434 (true, true) => quote! {
435 #wrap_name_inner(
436 req,
437 once_cache.expect("OnceCache required but not provided"),
438 session_cache.expect("SessionCache required but not provided"),
439 ).await
440 },
441 (true, false) => quote! {
442 #wrap_name_inner(
443 req,
444 once_cache.expect("OnceCache required but not provided"),
445 ).await
446 },
447 (false, true) => quote! {
448 #wrap_name_inner(
449 req,
450 session_cache.expect("SessionCache required but not provided"),
451 ).await
452 },
453 (false, false) => quote! {
454 #wrap_name_inner(req).await
455 },
456 };
457
458 quote! {
459 #root_fn
460
461 #[doc(hidden)]
462 #wrap_signature {
463 #call_body
464 }
465
466 #[doc(hidden)]
467 pub fn #wrap_name<'a>(
468 req: &'a mut potato::HttpRequest,
469 once_cache: Option<&'a mut potato::OnceCache>,
470 session_cache: Option<&'a mut potato::SessionCache>,
471 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<Option<potato::HttpResponse>>> + Send + 'a>> {
472 Box::pin(async move {
473 #wrapper_body
474 })
475 }
476 }
477 .into()
478}
479
480fn postprocess_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
481 if !attr.is_empty() {
482 return input;
483 }
484 let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
485 let fn_name = root_fn.sig.ident.clone();
486 let wrap_name = format_ident!("__potato_postprocess_adapter_{}", fn_name);
487 let wrap_name_inner = format_ident!("__potato_postprocess_adapter_inner_{}", fn_name);
488 let is_async = root_fn.sig.asyncness.is_some();
489 let (ret_type, has_once_cache, has_session_cache) = validate_postprocess_signature(&root_fn);
490
491 let wrap_signature = match (has_once_cache, has_session_cache) {
493 (true, true) => quote! {
494 async fn #wrap_name_inner(
495 req: &mut potato::HttpRequest,
496 res: &mut potato::HttpResponse,
497 once_cache: &mut potato::OnceCache,
498 session_cache: &mut potato::SessionCache,
499 ) -> anyhow::Result<()>
500 },
501 (true, false) => quote! {
502 async fn #wrap_name_inner(
503 req: &mut potato::HttpRequest,
504 res: &mut potato::HttpResponse,
505 once_cache: &mut potato::OnceCache,
506 ) -> anyhow::Result<()>
507 },
508 (false, true) => quote! {
509 async fn #wrap_name_inner(
510 req: &mut potato::HttpRequest,
511 res: &mut potato::HttpResponse,
512 session_cache: &mut potato::SessionCache,
513 ) -> anyhow::Result<()>
514 },
515 (false, false) => quote! {
516 async fn #wrap_name_inner(
517 req: &mut potato::HttpRequest,
518 res: &mut potato::HttpResponse,
519 ) -> anyhow::Result<()>
520 },
521 };
522
523 let call_body = if is_async {
525 match &ret_type[..] {
526 "Result<()>" => {
527 if has_once_cache && has_session_cache {
528 quote! {
529 #fn_name(req, res, once_cache, session_cache).await
530 }
531 } else if has_once_cache {
532 quote! {
533 #fn_name(req, res, once_cache).await
534 }
535 } else if has_session_cache {
536 quote! {
537 #fn_name(req, res, session_cache).await
538 }
539 } else {
540 quote! {
541 #fn_name(req, res).await
542 }
543 }
544 }
545 "()" => {
546 if has_once_cache && has_session_cache {
547 quote! {
548 #fn_name(req, res, once_cache, session_cache).await;
549 Ok(())
550 }
551 } else if has_once_cache {
552 quote! {
553 #fn_name(req, res, once_cache).await;
554 Ok(())
555 }
556 } else if has_session_cache {
557 quote! {
558 #fn_name(req, res, session_cache).await;
559 Ok(())
560 }
561 } else {
562 quote! {
563 #fn_name(req, res).await;
564 Ok(())
565 }
566 }
567 }
568 _ => unreachable!(),
569 }
570 } else {
571 match &ret_type[..] {
572 "Result<()>" => {
573 if has_once_cache && has_session_cache {
574 quote! {
575 #fn_name(req, res, once_cache, session_cache)
576 }
577 } else if has_once_cache {
578 quote! {
579 #fn_name(req, res, once_cache)
580 }
581 } else if has_session_cache {
582 quote! {
583 #fn_name(req, res, session_cache)
584 }
585 } else {
586 quote! {
587 #fn_name(req, res)
588 }
589 }
590 }
591 "()" => {
592 if has_once_cache && has_session_cache {
593 quote! {
594 #fn_name(req, res, once_cache, session_cache);
595 Ok(())
596 }
597 } else if has_once_cache {
598 quote! {
599 #fn_name(req, res, once_cache);
600 Ok(())
601 }
602 } else if has_session_cache {
603 quote! {
604 #fn_name(req, res, session_cache);
605 Ok(())
606 }
607 } else {
608 quote! {
609 #fn_name(req, res);
610 Ok(())
611 }
612 }
613 }
614 _ => unreachable!(),
615 }
616 };
617
618 let wrapper_body = match (has_once_cache, has_session_cache) {
620 (true, true) => quote! {
621 #wrap_name_inner(
622 req,
623 res,
624 once_cache.expect("OnceCache required but not provided"),
625 session_cache.expect("SessionCache required but not provided"),
626 ).await
627 },
628 (true, false) => quote! {
629 #wrap_name_inner(
630 req,
631 res,
632 once_cache.expect("OnceCache required but not provided"),
633 ).await
634 },
635 (false, true) => quote! {
636 #wrap_name_inner(
637 req,
638 res,
639 session_cache.expect("SessionCache required but not provided"),
640 ).await
641 },
642 (false, false) => quote! {
643 #wrap_name_inner(req, res).await
644 },
645 };
646
647 quote! {
648 #root_fn
649
650 #[doc(hidden)]
651 #wrap_signature {
652 #call_body
653 }
654
655 #[doc(hidden)]
656 pub fn #wrap_name<'a>(
657 req: &'a mut potato::HttpRequest,
658 res: &'a mut potato::HttpResponse,
659 once_cache: Option<&'a mut potato::OnceCache>,
660 session_cache: Option<&'a mut potato::SessionCache>,
661 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>> {
662 Box::pin(async move {
663 #wrapper_body
664 })
665 }
666 }
667 .into()
668}
669
670fn http_handler_macro(attr: TokenStream, input: TokenStream, req_name: &str) -> TokenStream {
671 let req_name = Ident::new(req_name, Span::call_site());
672
673 let root_fn_for_check = syn::parse::<syn::ItemFn>(input.clone());
675 let has_receiver = if let Ok(ref func) = root_fn_for_check {
676 func.sig
677 .inputs
678 .iter()
679 .any(|arg| matches!(arg, syn::FnArg::Receiver(_)))
680 } else {
681 false
682 };
683
684 let (route_path, is_send) = {
685 let mut oroute_path: Option<String> = None;
686 let mut is_send = true; let attr_stream: proc_macro2::TokenStream = attr.into();
690 let tokens: Vec<_> = attr_stream.into_iter().collect();
691 let mut i = 0;
692 while i < tokens.len() {
693 let token = &tokens[i];
694 if let proc_macro2::TokenTree::Literal(lit) = token {
696 let lit_str = lit.to_string();
697 if lit_str.starts_with('"') && lit_str.ends_with('"') {
699 let path = &lit_str[1..lit_str.len() - 1];
700 oroute_path = Some(path.to_string());
701 }
702 i += 1;
703 } else if let proc_macro2::TokenTree::Ident(ident) = token {
704 if ident.to_string() == "Send" {
705 is_send = true;
706 i += 1;
707 } else if ident.to_string() == "path" {
708 if i + 2 < tokens.len() {
710 if let proc_macro2::TokenTree::Punct(punct) = &tokens[i + 1] {
711 if punct.as_char() == '=' {
712 if let proc_macro2::TokenTree::Literal(lit) = &tokens[i + 2] {
713 let lit_str = lit.to_string();
714 if lit_str.starts_with('"') && lit_str.ends_with('"') {
715 let path = &lit_str[1..lit_str.len() - 1];
716 oroute_path = Some(path.to_string());
717 }
718 i += 3;
719 continue;
720 }
721 }
722 }
723 }
724 i += 1;
725 } else {
726 i += 1;
727 }
728 } else {
729 i += 1;
730 }
731 }
732
733 if oroute_path.is_none() && has_receiver {
735 } else if oroute_path.is_none() {
737 panic!("`path` argument is required for non-controller methods");
738 }
739
740 let route_path = oroute_path.unwrap_or_default();
741
742 let route_path = if has_receiver {
744 if route_path.is_empty() {
745 String::new()
747 } else {
748 route_path
751 }
752 } else {
753 if route_path.is_empty() {
754 panic!("`path` argument is required for non-controller methods");
755 }
756 route_path
757 };
758
759 if !route_path.is_empty() && !route_path.starts_with('/') {
760 panic!("route path must start with '/'");
761 }
762 (route_path, is_send)
763 };
764
765 let mut root_fn = syn::parse_macro_input!(input as syn::ItemFn);
767 let mut fn_headers: Vec<(String, String)> = Vec::new();
768 let mut cors_config: Option<CorsAttrConfig> = None;
769 let mut max_concurrency: Option<usize> = None;
770 let mut remaining_attrs = Vec::new();
771
772 for attr in root_fn.attrs.iter() {
773 let is_header_attr = attr.path().is_ident("header")
775 || (attr.path().segments.len() == 2
776 && attr
777 .path()
778 .segments
779 .iter()
780 .next()
781 .map(|s| s.ident.to_string())
782 == Some("potato".to_string())
783 && attr
784 .path()
785 .segments
786 .iter()
787 .last()
788 .map(|s| s.ident.to_string())
789 == Some("header".to_string()));
790
791 if is_header_attr {
792 if let syn::Meta::List(meta_list) = &attr.meta {
793 if let Ok((key, value)) = parse_header_attr(&meta_list.tokens) {
795 fn_headers.push((key, value));
796 }
797 }
798 continue;
799 }
800
801 let is_cors_attr = attr.path().is_ident("cors")
803 || (attr.path().segments.len() == 2
804 && attr
805 .path()
806 .segments
807 .iter()
808 .next()
809 .map(|s| s.ident.to_string())
810 == Some("potato".to_string())
811 && attr
812 .path()
813 .segments
814 .iter()
815 .last()
816 .map(|s| s.ident.to_string())
817 == Some("cors".to_string()));
818
819 if is_cors_attr {
820 if let syn::Meta::List(meta_list) = &attr.meta {
821 cors_config = Some(parse_cors_attr(&meta_list.tokens));
822 } else {
823 cors_config = Some(CorsAttrConfig {
825 origin: None,
826 methods: None,
827 headers: None,
828 max_age: None,
829 credentials: false,
830 expose_headers: None,
831 });
832 }
833 continue;
834 }
835
836 let is_max_concurrency_attr = attr.path().is_ident("max_concurrency")
838 || (attr.path().segments.len() == 2
839 && attr
840 .path()
841 .segments
842 .iter()
843 .next()
844 .map(|s| s.ident.to_string())
845 == Some("potato".to_string())
846 && attr
847 .path()
848 .segments
849 .iter()
850 .last()
851 .map(|s| s.ident.to_string())
852 == Some("max_concurrency".to_string()));
853
854 if is_max_concurrency_attr {
855 if let syn::Meta::List(meta_list) = &attr.meta {
856 let tokens = &meta_list.tokens;
857 if let Ok(lit_int) = syn::parse2::<syn::LitInt>(tokens.clone()) {
859 if let Ok(val) = lit_int.base10_parse::<usize>() {
860 if val == 0 {
861 panic!("max_concurrency must be greater than 0");
862 }
863 max_concurrency = Some(val);
864 } else {
865 panic!("invalid max_concurrency value");
866 }
867 } else {
868 panic!(
869 "max_concurrency requires a numeric value, e.g., #[max_concurrency(10)]"
870 );
871 }
872 } else if let syn::Meta::NameValue(name_value) = &attr.meta {
873 if let syn::Expr::Lit(expr_lit) = &name_value.value {
874 if let syn::Lit::Int(lit_int) = &expr_lit.lit {
875 if let Ok(val) = lit_int.base10_parse::<usize>() {
876 if val == 0 {
877 panic!("max_concurrency must be greater than 0");
878 }
879 max_concurrency = Some(val);
880 } else {
881 panic!("invalid max_concurrency value");
882 }
883 } else {
884 panic!("max_concurrency requires a numeric value");
885 }
886 } else {
887 panic!("max_concurrency requires a numeric value");
888 }
889 } else {
890 panic!("max_concurrency requires a numeric value, e.g., #[max_concurrency(10)]");
891 }
892 continue;
893 }
894
895 remaining_attrs.push(attr.clone());
896 }
897
898 let all_headers = fn_headers;
899
900 root_fn.attrs = remaining_attrs;
901 let (preprocess_fns, postprocess_fns) = collect_handler_hooks(&mut root_fn);
902
903 let handler_has_once_cache = root_fn.sig.inputs.iter().any(|arg| {
905 if let syn::FnArg::Typed(arg) = arg {
906 arg.ty.to_token_stream().to_string().type_simplify() == "& mut OnceCache"
907 } else {
908 false
909 }
910 });
911 let handler_has_session_cache = root_fn.sig.inputs.iter().any(|arg| {
912 if let syn::FnArg::Typed(arg) = arg {
913 arg.ty.to_token_stream().to_string().type_simplify() == "& mut SessionCache"
914 } else {
915 false
916 }
917 });
918
919 let need_once_cache = handler_has_once_cache;
924 let need_session_cache = handler_has_session_cache;
925
926 let preprocess_adapters: Vec<Ident> = preprocess_fns
927 .iter()
928 .map(|name| format_ident!("__potato_preprocess_adapter_{}", name))
929 .collect();
930 let postprocess_adapters: Vec<Ident> = postprocess_fns
931 .iter()
932 .map(|name| format_ident!("__potato_postprocess_adapter_{}", name))
933 .collect();
934 let doc_show = {
935 let mut doc_show = true;
936 for attr in root_fn.attrs.iter() {
937 if attr.meta.path().get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
938 if let Ok(meta_list) = attr.meta.require_list() {
939 if meta_list.tokens.to_string() == "hidden" {
940 doc_show = false;
941 break;
942 }
943 }
944 }
945 }
946 doc_show
947 };
948 let doc_auth = need_session_cache;
949 let doc_summary = {
950 let mut docs = vec![];
951 for attr in root_fn.attrs.iter() {
952 if let Ok(attr) = attr.meta.require_name_value() {
953 if attr.path.get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
954 let mut doc = attr.value.to_token_stream().to_string();
955 if doc.starts_with('\"') {
956 doc.remove(0);
957 doc.pop();
958 }
959 docs.push(doc);
960 }
961 }
962 }
963 if docs.iter().all(|d| d.starts_with(' ')) {
964 for doc in docs.iter_mut() {
965 doc.remove(0);
966 }
967 }
968 docs.join("\n")
969 };
970 let doc_desp = "";
971 let fn_name = root_fn.sig.ident.clone();
972 let is_async = root_fn.sig.asyncness.is_some();
973
974 let has_receiver = root_fn
976 .sig
977 .inputs
978 .iter()
979 .any(|arg| matches!(arg, syn::FnArg::Receiver(_)));
980
981 let final_path = if has_receiver {
986 if route_path.is_empty() {
989 panic!("Controller methods must specify a path (e.g., #[potato::http_get(\"/\")])");
992 } else {
993 route_path
994 }
995 } else {
996 if route_path.is_empty() {
997 panic!("`path` argument is required for non-controller methods");
998 }
999 route_path
1000 };
1001
1002 let final_path_expr = quote! { #final_path };
1003
1004 let tag_expr = if has_receiver {
1006 quote! { "" }
1009 } else {
1010 quote! { "" }
1011 };
1012
1013 let wrap_func_name = random_ident();
1014 let mut args = vec![];
1015 let mut arg_names = vec![];
1016 let mut arg_types = vec![];
1017 let mut doc_args = vec![];
1018 for arg in root_fn.sig.inputs.iter() {
1019 if let syn::FnArg::Receiver(_receiver) = arg {
1021 continue;
1024 }
1025
1026 if let syn::FnArg::Typed(arg) = arg {
1027 let arg_type_str = arg
1028 .ty
1029 .as_ref()
1030 .to_token_stream()
1031 .to_string()
1032 .type_simplify();
1033 let arg_name_str = arg.pat.to_token_stream().to_string();
1034 let arg_value = match &arg_type_str[..] {
1035 "& mut HttpRequest" => quote! { req },
1036 "& mut OnceCache" => {
1037 quote! { __potato_once_cache.as_mut().expect("OnceCache not available") }
1038 }
1039 "& mut SessionCache" => {
1040 quote! { __potato_session_cache.as_mut().expect("SessionCache not available") }
1041 }
1042 "PostFile" => {
1043 doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
1044 quote! {
1045 match req.body_files.get(&potato::utils::refstr::LocalHipStr<'static>::from_str(#arg_name_str)).cloned() {
1046 Some(file) => file,
1047 None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
1048 }
1049 }
1050 }
1051 arg_type_str if ARG_TYPES.contains(arg_type_str) => {
1052 doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
1053 let mut arg_value = quote! {
1054 match req.body_pairs
1055 .get(&potato::hipstr::LocalHipStr::from(#arg_name_str))
1056 .map(|p| p.to_string()) {
1057 Some(val) => val,
1058 None => match req.url_query
1059 .get(&potato::hipstr::LocalHipStr::from(#arg_name_str))
1060 .map(|p| p.as_str().to_string()) {
1061 Some(val) => val,
1062 None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
1063 },
1064 }
1065 };
1066 if arg_type_str != "String" {
1067 arg_value = quote! {
1068 match #arg_value.parse() {
1069 Ok(val) => val,
1070 Err(err) => return potato::HttpResponse::error(format!("arg[{}] is not {} type", #arg_name_str, #arg_type_str)),
1071 }
1072 }
1073 }
1074 arg_value
1075 }
1076 _ => panic!("unsupported arg type: [{arg_type_str}]"),
1077 };
1078 args.push(arg_value);
1079 arg_names.push(random_ident());
1080 arg_types.push(arg_type_str);
1082 }
1083 }
1084 let wrap_func_name2 = random_ident();
1085 let ret_type = root_fn
1086 .sig
1087 .output
1088 .to_token_stream()
1089 .to_string()
1090 .type_simplify();
1091
1092 let _controller_create_fn = if has_receiver {
1095 quote! {
1096 }
1099 } else {
1100 quote! {}
1101 };
1102
1103 let call_args: Vec<_> = args
1105 .iter()
1106 .enumerate()
1107 .map(|(i, _arg)| {
1108 let arg_name = &arg_names[i];
1109 let arg_type = &arg_types[i];
1110 if arg_type == "& mut HttpRequest" {
1112 quote! { req }
1113 } else {
1114 quote! { #arg_name }
1115 }
1116 })
1117 .collect();
1118
1119 let call_expr = if has_receiver {
1120 match args.len() {
1124 0 => quote! { #fn_name() },
1125 1 => {
1126 let arg_name = &arg_names[0];
1127 let arg = &args[0];
1128 let arg_type = &arg_types[0];
1129 if arg_type == "& mut HttpRequest" {
1130 quote! { #fn_name(req) }
1131 } else {
1132 quote! {{
1133 let #arg_name = #arg;
1134 #fn_name(#arg_name)
1135 }}
1136 }
1137 }
1138 _ => {
1139 let let_bindings: Vec<_> = arg_types
1140 .iter()
1141 .zip(arg_names.iter())
1142 .zip(args.iter())
1143 .filter(|((arg_type, _), _)| *arg_type != "& mut HttpRequest")
1144 .map(|((_, arg_name), arg)| quote! { let #arg_name = #arg; })
1145 .collect();
1146
1147 quote! {{
1148 #(#let_bindings)*
1149 #fn_name(#(#call_args),*)
1150 }}
1151 }
1152 }
1153 } else {
1154 match args.len() {
1156 0 => quote! { #fn_name() },
1157 1 => {
1158 let arg_name = &arg_names[0];
1159 let arg = &args[0];
1160 let arg_type = &arg_types[0];
1161 if arg_type == "& mut HttpRequest" {
1163 quote! { #fn_name(req) }
1164 } else {
1165 quote! {{
1166 let #arg_name = #arg;
1167 #fn_name(#arg_name)
1168 }}
1169 }
1170 }
1171 _ => {
1172 let let_bindings: Vec<_> = arg_types
1174 .iter()
1175 .zip(arg_names.iter())
1176 .zip(args.iter())
1177 .filter(|((arg_type, _), _)| *arg_type != "& mut HttpRequest")
1178 .map(|((_, arg_name), arg)| quote! { let #arg_name = #arg; })
1179 .collect();
1180
1181 quote! {{
1182 #(#let_bindings)*
1183 #fn_name(#(#call_args),*)
1184 }}
1185 }
1186 }
1187 };
1188 let handler_wrap_func_body = generate_response_handler(call_expr, &ret_type, is_async);
1189 let doc_args = serde_json::to_string(&doc_args).unwrap();
1190
1191 let add_headers_code = if all_headers.is_empty() {
1193 quote! {}
1194 } else {
1195 let header_statements = all_headers.iter().map(|(key, value)| {
1196 let http_key = key.replace("_", "-");
1198 quote! {
1199 __potato_response.add_header(
1200 std::borrow::Cow::Borrowed(#http_key),
1201 std::borrow::Cow::Borrowed(#value)
1202 );
1203 }
1204 });
1205 quote! {
1206 #(#header_statements)*
1207 }
1208 };
1209
1210 let cors_headers_code = if let Some(cors) = &cors_config {
1212 let mut statements = vec![];
1213
1214 let origin_val = cors.origin.as_deref().unwrap_or("*");
1216 statements.push(quote! {
1217 __potato_response.add_header(
1218 "Access-Control-Allow-Origin".into(),
1219 #origin_val.into()
1220 );
1221 });
1222
1223 if let Some(ref methods) = cors.methods {
1225 let mut methods_list: Vec<&str> = methods.split(',').map(|s| s.trim()).collect();
1226 if !methods_list.contains(&"HEAD") {
1227 methods_list.push("HEAD");
1228 }
1229 if !methods_list.contains(&"OPTIONS") {
1230 methods_list.push("OPTIONS");
1231 }
1232 let methods_str = methods_list.join(",");
1233 statements.push(quote! {
1234 __potato_response.add_header(
1235 "Access-Control-Allow-Methods".into(),
1236 #methods_str.into()
1237 );
1238 });
1239 }
1240
1241 let headers_val = cors.headers.as_deref().unwrap_or("*");
1243 statements.push(quote! {
1244 __potato_response.add_header(
1245 "Access-Control-Allow-Headers".into(),
1246 #headers_val.into()
1247 );
1248 });
1249
1250 if let Some(ref max_age) = cors.max_age {
1252 statements.push(quote! {
1253 __potato_response.add_header(
1254 "Access-Control-Max-Age".into(),
1255 #max_age.into()
1256 );
1257 });
1258 } else {
1259 statements.push(quote! {
1260 __potato_response.add_header(
1261 "Access-Control-Max-Age".into(),
1262 "86400".into()
1263 );
1264 });
1265 }
1266
1267 if cors.credentials {
1268 statements.push(quote! {
1269 __potato_response.add_header(
1270 "Access-Control-Allow-Credentials".into(),
1271 "true".into()
1272 );
1273 });
1274 }
1275
1276 if let Some(ref expose_headers) = cors.expose_headers {
1277 statements.push(quote! {
1278 __potato_response.add_header(
1279 "Access-Control-Expose-Headers".into(),
1280 #expose_headers.into()
1281 );
1282 });
1283 }
1284
1285 quote! { #(#statements)* }
1286 } else {
1287 quote! {}
1288 };
1289
1290 let auto_head_handler = if cors_config.is_some()
1292 && (req_name == "POST" || req_name == "PUT" || req_name == "DELETE")
1293 {
1294 let head_wrap_name = format_ident!("__potato_cors_head_{}", fn_name);
1295 Some(quote! {
1296 #[doc(hidden)]
1297 fn #head_wrap_name(req: &mut potato::HttpRequest) -> potato::HttpResponse {
1298 potato::HttpResponse::html("")
1301 }
1302 })
1303 } else {
1304 None
1305 };
1306
1307 let semaphore_static = if let Some(max_conn) = max_concurrency {
1309 let semaphore_name =
1310 format_ident!("__POTATO_SEMAPHORE_{}", fn_name.to_string().to_uppercase());
1311 Some(quote! {
1312 #[doc(hidden)]
1313 #[allow(non_upper_case_globals)]
1314 static #semaphore_name: std::sync::LazyLock<tokio::sync::Semaphore> =
1315 std::sync::LazyLock::new(|| tokio::sync::Semaphore::new(#max_conn));
1316 })
1317 } else {
1318 None
1319 };
1320
1321 let wrap_func_body = if is_async {
1322 if max_concurrency.is_some() {
1323 let semaphore_name =
1324 format_ident!("__POTATO_SEMAPHORE_{}", fn_name.to_string().to_uppercase());
1325 quote! {
1326 let __potato_permit = #semaphore_name.acquire().await;
1327
1328 let __potato_error_handler: Option<potato::ErrorHandler> = {
1330 let mut handler = None;
1331 for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1332 handler = Some(flag.handler.clone());
1333 break;
1334 }
1335 handler
1336 };
1337
1338 let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1340 Some(potato::OnceCache::new())
1341 } else {
1342 None
1343 };
1344 let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1345 if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1347 let header_value = h.as_str();
1348 if header_value.starts_with("Bearer ") {
1349 potato::SessionCache::from_token(&header_value[7..]).await.ok()
1350 } else {
1351 None
1352 }
1353 } else {
1354 None
1355 }
1356 } else {
1357 None
1358 };
1359
1360 if #need_session_cache && __potato_session_cache.is_none() {
1362 let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1363 __potato_resp.http_code = 401;
1364 return __potato_resp;
1365 }
1366
1367 if let Some(ref mut session_cache) = __potato_session_cache {
1369 if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1370 session_cache.parse_request_cookies(cookie_header.as_str());
1371 }
1372 }
1373
1374 let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1375 #(
1376 if __potato_pre_response.is_none() {
1377 __potato_pre_response = match #preprocess_adapters(
1378 req,
1379 __potato_once_cache.as_mut(),
1380 __potato_session_cache.as_mut(),
1381 ).await {
1382 Ok(Some(ret)) => Some(ret),
1383 Ok(None) => None,
1384 Err(err) => {
1385 let handler = &__potato_error_handler;
1386 Some(match handler {
1387 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1388 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1389 None => potato::HttpResponse::error(format!("{err:?}")),
1390 })
1391 }
1392 };
1393 }
1394 )*
1395
1396 let mut __potato_response = match __potato_pre_response {
1397 Some(ret) => ret,
1398 None => match #handler_wrap_func_body {
1399 Ok(resp) => resp,
1400 Err(err) => {
1401 let handler = &__potato_error_handler;
1402 match handler {
1403 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1404 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1405 None => potato::HttpResponse::error(format!("{err:?}")),
1406 }
1407 }
1408 },
1409 };
1410
1411 #(
1412 if let Err(err) = #postprocess_adapters(
1413 req,
1414 &mut __potato_response,
1415 __potato_once_cache.as_mut(),
1416 __potato_session_cache.as_mut(),
1417 ).await {
1418 drop(__potato_permit);
1419 let handler = &__potato_error_handler;
1420 return match handler {
1421 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1422 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1423 None => potato::HttpResponse::error(format!("{err:?}")),
1424 };
1425 }
1426 )*
1427
1428 #add_headers_code
1429 #cors_headers_code
1430
1431 if let Some(ref session_cache) = __potato_session_cache {
1433 session_cache.apply_cookies(&mut __potato_response);
1434 }
1435
1436 drop(__potato_permit);
1437 __potato_response
1438 }
1439 } else {
1440 quote! {
1441 let __potato_error_handler: Option<potato::ErrorHandler> = {
1443 let mut handler = None;
1444 for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1445 handler = Some(flag.handler.clone());
1446 break;
1447 }
1448 handler
1449 };
1450
1451 let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1453 Some(potato::OnceCache::new())
1454 } else {
1455 None
1456 };
1457 let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1458 if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1460 let header_value = h.as_str();
1461 if header_value.starts_with("Bearer ") {
1462 potato::SessionCache::from_token(&header_value[7..]).await.ok()
1463 } else {
1464 None
1465 }
1466 } else {
1467 None
1468 }
1469 } else {
1470 None
1471 };
1472
1473 if #need_session_cache && __potato_session_cache.is_none() {
1475 let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1476 __potato_resp.http_code = 401;
1477 return __potato_resp;
1478 }
1479
1480 if let Some(ref mut session_cache) = __potato_session_cache {
1482 if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1483 session_cache.parse_request_cookies(cookie_header.as_str());
1484 }
1485 }
1486
1487 let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1488 #(
1489 if __potato_pre_response.is_none() {
1490 __potato_pre_response = match #preprocess_adapters(
1491 req,
1492 __potato_once_cache.as_mut(),
1493 __potato_session_cache.as_mut(),
1494 ).await {
1495 Ok(Some(ret)) => Some(ret),
1496 Ok(None) => None,
1497 Err(err) => {
1498 let handler = &__potato_error_handler;
1499 Some(match handler {
1500 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1501 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1502 None => potato::HttpResponse::error(format!("{err:?}")),
1503 })
1504 }
1505 };
1506 }
1507 )*
1508
1509 let mut __potato_response = match __potato_pre_response {
1510 Some(ret) => ret,
1511 None => match #handler_wrap_func_body {
1512 Ok(resp) => resp,
1513 Err(err) => {
1514 let handler = &__potato_error_handler;
1515 match handler {
1516 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1517 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1518 None => potato::HttpResponse::error(format!("{err:?}")),
1519 }
1520 }
1521 },
1522 };
1523
1524 #(
1525 if let Err(err) = #postprocess_adapters(
1526 req,
1527 &mut __potato_response,
1528 __potato_once_cache.as_mut(),
1529 __potato_session_cache.as_mut(),
1530 ).await {
1531 let handler = &__potato_error_handler;
1532 return match handler {
1533 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1534 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1535 None => potato::HttpResponse::error(format!("{err:?}")),
1536 };
1537 }
1538 )*
1539
1540 #add_headers_code
1541 #cors_headers_code
1542
1543 __potato_response
1544 }
1545 }
1546 } else {
1547 if max_concurrency.is_some() {
1548 let semaphore_name =
1549 format_ident!("__POTATO_SEMAPHORE_{}", fn_name.to_string().to_uppercase());
1550 quote! {
1551 let __potato_permit = #semaphore_name.acquire().await;
1552
1553 let __potato_error_handler: Option<potato::ErrorHandler> = {
1555 let mut handler = None;
1556 for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1557 handler = Some(flag.handler.clone());
1558 break;
1559 }
1560 handler
1561 };
1562
1563 let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1565 Some(potato::OnceCache::new())
1566 } else {
1567 None
1568 };
1569 let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1570 if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1572 let header_value = h.as_str();
1573 if header_value.starts_with("Bearer ") {
1574 potato::SessionCache::from_token(&header_value[7..]).await.ok()
1575 } else {
1576 None
1577 }
1578 } else {
1579 None
1580 }
1581 } else {
1582 None
1583 };
1584
1585 if #need_session_cache && __potato_session_cache.is_none() {
1587 let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1588 __potato_resp.http_code = 401;
1589 return __potato_resp;
1590 }
1591
1592 if let Some(ref mut session_cache) = __potato_session_cache {
1594 if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1595 session_cache.parse_request_cookies(cookie_header.as_str());
1596 }
1597 }
1598
1599 let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1600 #(
1601 if __potato_pre_response.is_none() {
1602 __potato_pre_response = match #preprocess_adapters(
1603 req,
1604 __potato_once_cache.as_mut(),
1605 __potato_session_cache.as_mut(),
1606 ).await {
1607 Ok(Some(ret)) => Some(ret),
1608 Ok(None) => None,
1609 Err(err) => {
1610 let handler = &__potato_error_handler;
1611 Some(match handler {
1612 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1613 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1614 None => potato::HttpResponse::error(format!("{err:?}")),
1615 })
1616 }
1617 };
1618 }
1619 )*
1620
1621 let mut __potato_response = match __potato_pre_response {
1622 Some(ret) => ret,
1623 None => match #handler_wrap_func_body {
1624 Ok(resp) => resp,
1625 Err(err) => {
1626 let handler = &__potato_error_handler;
1627 match handler {
1628 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1629 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1630 None => potato::HttpResponse::error(format!("{err:?}")),
1631 }
1632 }
1633 },
1634 };
1635
1636 #(
1637 if let Err(err) = #postprocess_adapters(
1638 req,
1639 &mut __potato_response,
1640 __potato_once_cache.as_mut(),
1641 __potato_session_cache.as_mut(),
1642 ).await {
1643 drop(__potato_permit);
1644 let handler = &__potato_error_handler;
1645 return match handler {
1646 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1647 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1648 None => potato::HttpResponse::error(format!("{err:?}")),
1649 };
1650 }
1651 )*
1652
1653 #add_headers_code
1654 #cors_headers_code
1655
1656 if let Some(ref session_cache) = __potato_session_cache {
1658 session_cache.apply_cookies(&mut __potato_response);
1659 }
1660
1661 drop(__potato_permit);
1662 __potato_response
1663 }
1664 } else {
1665 quote! {
1666 let __potato_error_handler: Option<potato::ErrorHandler> = {
1668 let mut handler = None;
1669 for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1670 handler = Some(flag.handler.clone());
1671 break;
1672 }
1673 handler
1674 };
1675
1676 let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1678 Some(potato::OnceCache::new())
1679 } else {
1680 None
1681 };
1682 let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1683 if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1685 let header_value = h.as_str();
1686 if header_value.starts_with("Bearer ") {
1687 potato::SessionCache::from_token(&header_value[7..]).await.ok()
1688 } else {
1689 None
1690 }
1691 } else {
1692 None
1693 }
1694 } else {
1695 None
1696 };
1697
1698 if #need_session_cache && __potato_session_cache.is_none() {
1700 let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1701 __potato_resp.http_code = 401;
1702 return __potato_resp;
1703 }
1704
1705 if let Some(ref mut session_cache) = __potato_session_cache {
1707 if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1708 session_cache.parse_request_cookies(cookie_header.as_str());
1709 }
1710 }
1711
1712 let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1713 #(
1714 if __potato_pre_response.is_none() {
1715 __potato_pre_response = match #preprocess_adapters(
1716 req,
1717 __potato_once_cache.as_mut(),
1718 __potato_session_cache.as_mut(),
1719 ).await {
1720 Ok(Some(ret)) => Some(ret),
1721 Ok(None) => None,
1722 Err(err) => {
1723 let handler = &__potato_error_handler;
1724 Some(match handler {
1725 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1726 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1727 None => potato::HttpResponse::error(format!("{err:?}")),
1728 })
1729 }
1730 };
1731 }
1732 )*
1733
1734 let mut __potato_response = match __potato_pre_response {
1735 Some(ret) => ret,
1736 None => match #handler_wrap_func_body {
1737 Ok(resp) => resp,
1738 Err(err) => {
1739 let handler = &__potato_error_handler;
1740 match handler {
1741 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1742 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1743 None => potato::HttpResponse::error(format!("{err:?}")),
1744 }
1745 }
1746 },
1747 };
1748
1749 #(
1750 if let Err(err) = #postprocess_adapters(
1751 req,
1752 &mut __potato_response,
1753 __potato_once_cache.as_mut(),
1754 __potato_session_cache.as_mut(),
1755 ).await {
1756 let handler = &__potato_error_handler;
1757 return match handler {
1758 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1759 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1760 None => potato::HttpResponse::error(format!("{err:?}")),
1761 };
1762 }
1763 )*
1764
1765 #add_headers_code
1766 #cors_headers_code
1767
1768 if let Some(ref session_cache) = __potato_session_cache {
1770 session_cache.apply_cookies(&mut __potato_response);
1771 }
1772
1773 __potato_response
1774 }
1775 }
1776 };
1777
1778 if is_async {
1779 let (wrapper_sig, handler_variant) = if is_send {
1780 (
1781 quote! { fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> },
1782 quote! { potato::HttpHandler::Async },
1783 )
1784 } else {
1785 (
1786 quote! { fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + '_>> },
1787 quote! { potato::HttpHandler::AsyncNoSend },
1788 )
1789 };
1790 quote! {
1791 #root_fn
1792
1793 #auto_head_handler
1794
1795 #semaphore_static
1796
1797 #[doc(hidden)]
1798 async fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
1799 #wrap_func_body
1800 }
1801
1802 #[doc(hidden)]
1803 #wrapper_sig {
1804 Box::pin(#wrap_func_name2(req))
1805 }
1806
1807 potato::inventory::submit!{potato::RequestHandlerFlag::new(
1808 potato::HttpMethod::#req_name,
1809 #final_path_expr,
1810 #handler_variant(#wrap_func_name),
1811 potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args, #tag_expr)
1812 )}
1813 }
1814 .into()
1815 } else {
1816 let (wrapper_sig, handler_variant) = if is_send {
1817 (
1818 quote! { fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> },
1819 quote! { potato::HttpHandler::Async },
1820 )
1821 } else {
1822 (
1823 quote! { fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + '_>> },
1824 quote! { potato::HttpHandler::AsyncNoSend },
1825 )
1826 };
1827 quote! {
1828 #root_fn
1829
1830 #auto_head_handler
1831
1832 #semaphore_static
1833
1834 #[doc(hidden)]
1835 async fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
1836 #wrap_func_body
1837 }
1838
1839 #[doc(hidden)]
1840 #wrapper_sig {
1841 Box::pin(#wrap_func_name2(req))
1842 }
1843
1844 potato::inventory::submit!{potato::RequestHandlerFlag::new(
1845 potato::HttpMethod::#req_name,
1846 #final_path_expr,
1847 #handler_variant(#wrap_func_name),
1848 potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args, #tag_expr)
1849 )}
1850 }
1851 .into()
1852 }
1853 }
1857
1858fn generate_response_handler(
1865 call_expr: proc_macro2::TokenStream,
1866 ret_type: &str,
1867 is_async: bool,
1868) -> proc_macro2::TokenStream {
1869 let call_with_await = if is_async {
1870 quote! { #call_expr.await }
1871 } else {
1872 quote! { #call_expr }
1873 };
1874
1875 match ret_type {
1876 "Result<()>" | "anyhow::Result<()>" => quote! {
1878 match #call_with_await {
1879 Ok(_) => Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::text("ok")),
1880 Err(err) => Err(err),
1881 }
1882 },
1883 "Result<HttpResponse>" | "anyhow::Result<HttpResponse>" => quote! {
1884 match #call_with_await {
1885 Ok(ret) => Ok(ret),
1886 Err(err) => Err(err),
1887 }
1888 },
1889 "Result<String>" | "anyhow::Result<String>" => quote! {
1890 match #call_with_await {
1891 Ok(ret) => Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::html(ret)),
1892 Err(err) => Err(err),
1893 }
1894 },
1895 "Result<& 'static str>" | "anyhow::Result<& 'static str>" => quote! {
1896 match #call_with_await {
1897 Ok(ret) => Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::html(ret)),
1898 Err(err) => Err(err),
1899 }
1900 },
1901 "Result<serde_json::Value>" | "anyhow::Result<serde_json::Value>" => quote! {
1902 match #call_with_await {
1903 Ok(ret) => Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::json(serde_json::to_string(&ret).unwrap_or_else(|_| "{}".to_string()))),
1904 Err(err) => Err(err),
1905 }
1906 },
1907 "()" => quote! {
1909 {
1910 #call_with_await;
1911 Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::text("ok"))
1912 }
1913 },
1914 "HttpResponse" => quote! {
1915 Ok::<potato::HttpResponse, anyhow::Error>(#call_with_await)
1916 },
1917 "String" => quote! {
1918 Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::html(#call_with_await))
1919 },
1920 "& 'static str" => quote! {
1921 Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::html(#call_with_await))
1922 },
1923 "serde_json::Value" => quote! {
1924 Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::json(serde_json::to_string(&#call_with_await).unwrap_or_else(|_| "{}".to_string())))
1925 },
1926 _ => panic!("unsupported ret type: {}", ret_type),
1928 }
1929}
1930
1931#[proc_macro_attribute]
1932pub fn http_get(attr: TokenStream, input: TokenStream) -> TokenStream {
1933 http_handler_macro(attr, input, "GET")
1934}
1935
1936#[proc_macro_attribute]
1937pub fn http_post(attr: TokenStream, input: TokenStream) -> TokenStream {
1938 http_handler_macro(attr, input, "POST")
1939}
1940
1941#[proc_macro_attribute]
1942pub fn http_put(attr: TokenStream, input: TokenStream) -> TokenStream {
1943 http_handler_macro(attr, input, "PUT")
1944}
1945
1946#[proc_macro_attribute]
1947pub fn http_delete(attr: TokenStream, input: TokenStream) -> TokenStream {
1948 http_handler_macro(attr, input, "DELETE")
1949}
1950
1951#[proc_macro_attribute]
1952pub fn http_options(attr: TokenStream, input: TokenStream) -> TokenStream {
1953 http_handler_macro(attr, input, "OPTIONS")
1954}
1955
1956#[proc_macro_attribute]
1957pub fn http_head(attr: TokenStream, input: TokenStream) -> TokenStream {
1958 http_handler_macro(attr, input, "HEAD")
1959}
1960
1961#[proc_macro_attribute]
1995pub fn controller(attr: TokenStream, input: TokenStream) -> TokenStream {
1996 controller_macro(attr, input)
1997}
1998
1999fn controller_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
2000 let input_clone = input.clone();
2002 if let Ok(item_impl) = syn::parse::<syn::ItemImpl>(input_clone) {
2003 return controller_impl_macro(attr, item_impl);
2005 }
2006
2007 let item_struct = syn::parse_macro_input!(input as syn::ItemStruct);
2009
2010 let base_path = if attr.is_empty() {
2012 quote! {}
2014 } else {
2015 let attr_str = attr.to_string();
2016 let base_path = attr_str.trim_matches('"').to_string();
2017 quote! {
2018 #[doc(hidden)]
2019 const __POTATO_CONTROLLER_BASE_PATH: &str = #base_path;
2020 }
2021 };
2022
2023 let (has_once_cache, has_session_cache) = validate_controller_struct(&item_struct);
2025 let struct_name = &item_struct.ident;
2026 let struct_name_str = struct_name.to_string();
2027
2028 let controller_creation_fn = if has_session_cache {
2031 quote! {
2034 #[doc(hidden)]
2035 #[allow(dead_code)]
2036 async fn __potato_create_controller(req: &potato::HttpRequest) -> Result<Box<Self>, potato::HttpResponse> {
2037 let once_cache = Box::leak(Box::new(potato::OnceCache::new()));
2039
2040 let session_cache = {
2042 if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
2043 let header_value = h.as_str();
2044 if header_value.starts_with("Bearer ") {
2045 potato::SessionCache::from_token(&header_value[7..]).await.ok()
2046 } else {
2047 None
2048 }
2049 } else {
2050 None
2051 }
2052 };
2053
2054 let session_cache = match session_cache {
2055 Some(cache) => cache,
2056 None => {
2057 let mut resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
2058 resp.http_code = 401;
2059 return Err(resp);
2060 }
2061 };
2062 let session_cache = Box::leak(Box::new(session_cache));
2063
2064 let controller = Self {
2066 once_cache,
2067 sess_cache: session_cache,
2068 };
2069
2070 Ok(Box::new(controller))
2071 }
2072 }
2073 } else {
2074 quote! {
2077 #[doc(hidden)]
2078 #[allow(dead_code)]
2079 async fn __potato_create_controller(_req: &potato::HttpRequest) -> Result<Box<Self>, potato::HttpResponse> {
2080 let once_cache = Box::leak(Box::new(potato::OnceCache::new()));
2082
2083 let _temp_session_cache = Box::leak(Box::new(potato::SessionCache::new()));
2085
2086 let controller = Self {
2088 once_cache,
2089 };
2090
2091 Ok(Box::new(controller))
2092 }
2093 }
2094 };
2095
2096 let struct_generics = &item_struct.generics;
2098 let (impl_generics, type_generics, where_clause) = struct_generics.split_for_impl();
2099
2100 let controller_name_const =
2102 quote::format_ident!("__POTATO_CONTROLLER_NAME_{}", struct_name_str);
2103
2104 let output = quote! {
2105 #item_struct
2106
2107 #base_path
2108
2109 #[doc(hidden)]
2110 const #controller_name_const: &str = #struct_name_str;
2111
2112 potato::inventory::submit! {
2114 potato::ControllerStructFlag::new(
2115 #struct_name_str,
2116 potato::ControllerStructFieldInfo {
2117 has_once_cache: #has_once_cache,
2118 has_session_cache: #has_session_cache,
2119 }
2120 )
2121 }
2122
2123 impl #impl_generics #struct_name #type_generics #where_clause {
2125 #controller_creation_fn
2126 }
2127 };
2128
2129 output.into()
2130}
2131
2132fn controller_impl_macro(attr: TokenStream, item_impl: syn::ItemImpl) -> TokenStream {
2134 let base_path_str = if attr.is_empty() {
2136 None
2139 } else {
2140 let attr_str = attr.to_string();
2141 Some(attr_str.trim_matches('"').to_string())
2142 };
2143
2144 let self_type = &item_impl.self_ty;
2146
2147 let self_type_name = match &*item_impl.self_ty {
2149 syn::Type::Path(type_path) => {
2150 if let Some(segment) = type_path.path.segments.last() {
2152 let ident = &segment.ident;
2153 quote! { #ident }
2154 } else {
2155 quote! { #self_type }
2156 }
2157 }
2158 _ => quote! { #self_type },
2159 };
2160
2161 let self_type_tag = match &*item_impl.self_ty {
2163 syn::Type::Path(type_path) => {
2164 if let Some(segment) = type_path.path.segments.last() {
2166 segment.ident.to_string()
2167 } else {
2168 self_type.to_token_stream().to_string()
2169 }
2170 }
2171 _ => self_type.to_token_stream().to_string(),
2172 };
2173
2174 let mut cleaned_items = Vec::new();
2176 let mut generated_code = Vec::new();
2177
2178 for item in &item_impl.items {
2179 if let syn::ImplItem::Fn(method) = item {
2180 let has_http_attr = method.attrs.iter().any(|attr| {
2182 let attr_name = attr.path().to_token_stream().to_string();
2183 attr_name.contains("http_get")
2184 || attr_name.contains("http_post")
2185 || attr_name.contains("http_put")
2186 || attr_name.contains("http_delete")
2187 || attr_name.contains("http_head")
2188 || attr_name.contains("http_patch")
2189 || attr_name.contains("http_options")
2190 });
2191
2192 if has_http_attr {
2193 let mut cleaned_method = method.clone();
2195 cleaned_method.attrs = method
2196 .attrs
2197 .iter()
2198 .filter(|attr| {
2199 let attr_name = attr.path().to_token_stream().to_string();
2200 !attr_name.contains("http_get")
2201 && !attr_name.contains("http_post")
2202 && !attr_name.contains("http_put")
2203 && !attr_name.contains("http_delete")
2204 && !attr_name.contains("http_head")
2205 && !attr_name.contains("http_patch")
2206 && !attr_name.contains("http_options")
2207 })
2208 .cloned()
2209 .collect();
2210
2211 cleaned_items.push(syn::ImplItem::Fn(cleaned_method));
2212
2213 for attr in &method.attrs {
2215 let attr_name = attr.path().to_token_stream().to_string();
2216 if attr_name.contains("http_get")
2217 || attr_name.contains("http_post")
2218 || attr_name.contains("http_put")
2219 || attr_name.contains("http_delete")
2220 || attr_name.contains("http_head")
2221 || attr_name.contains("http_patch")
2222 || attr_name.contains("http_options")
2223 {
2224 let http_method = if attr_name.contains("http_get") {
2226 "GET"
2227 } else if attr_name.contains("http_post") {
2228 "POST"
2229 } else if attr_name.contains("http_put") {
2230 "PUT"
2231 } else if attr_name.contains("http_delete") {
2232 "DELETE"
2233 } else if attr_name.contains("http_head") {
2234 "HEAD"
2235 } else if attr_name.contains("http_patch") {
2236 "PATCH"
2237 } else {
2238 "OPTIONS"
2239 };
2240
2241 let (method_path, is_send) = match &attr.meta {
2243 syn::Meta::List(list) => {
2244 let tokens: Vec<_> = list.tokens.clone().into_iter().collect();
2245 let mut path = String::new();
2246 let mut send = true; let mut i = 0;
2249 while i < tokens.len() {
2250 if let proc_macro2::TokenTree::Literal(lit) = &tokens[i] {
2251 let lit_str = lit.to_string();
2252 if lit_str.starts_with('"') && lit_str.ends_with('"') {
2253 path = lit_str[1..lit_str.len() - 1].to_string();
2254 }
2255 i += 1;
2256 } else if let proc_macro2::TokenTree::Ident(ident) = &tokens[i]
2257 {
2258 if ident.to_string() == "Send" {
2259 send = true;
2260 }
2261 i += 1;
2262 } else {
2263 i += 1;
2264 }
2265 }
2266
2267 (path, send)
2268 }
2269 _ => (String::new(), true),
2270 };
2271
2272 let final_path = if let Some(ref base_path) = base_path_str {
2274 if method_path.is_empty() {
2276 base_path.clone()
2277 } else {
2278 format!("{}{}", base_path, method_path)
2279 }
2280 } else {
2281 panic!("impl block controller must specify a base path, e.g., #[potato::controller(\"/api/users\")]");
2283 };
2284
2285 let final_path_lit =
2287 syn::LitStr::new(&final_path, proc_macro2::Span::call_site());
2288
2289 let fn_name = &method.sig.ident;
2291 let wrapper_fn_name = quote::format_ident!("__potato_ctrl_{}", fn_name);
2292 let is_async = method.sig.asyncness.is_some();
2293
2294 let (has_receiver, _is_mut_receiver) = method
2296 .sig
2297 .inputs
2298 .iter()
2299 .filter_map(|arg| {
2300 if let syn::FnArg::Receiver(recv) = arg {
2301 Some((true, recv.mutability.is_some()))
2302 } else {
2303 None
2304 }
2305 })
2306 .next()
2307 .unwrap_or((false, false));
2308
2309 let other_params: Vec<_> = method
2311 .sig
2312 .inputs
2313 .iter()
2314 .filter_map(|arg| {
2315 if let syn::FnArg::Typed(pat_type) = arg {
2316 Some(pat_type.clone())
2317 } else {
2318 None
2319 }
2320 })
2321 .collect();
2322
2323 let method_has_session_cache = other_params.iter().any(|pat_type| {
2325 pat_type.ty.to_token_stream().to_string().type_simplify()
2326 == "& mut SessionCache"
2327 });
2328
2329 let doc_auth = has_receiver || method_has_session_cache;
2331
2332 let param_names: Vec<_> = other_params
2334 .iter()
2335 .filter_map(|pat_type| {
2336 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
2337 Some(pat_ident.ident.clone())
2338 } else {
2339 None
2340 }
2341 })
2342 .collect();
2343
2344 let method_call = if has_receiver {
2346 if param_names.is_empty() {
2351 if is_async {
2352 quote! {
2353 {
2354 let mut controller = match #self_type_name::__potato_create_controller(req).await {
2355 Ok(boxed) => boxed,
2356 Err(resp) => return resp,
2357 };
2358 controller.#fn_name().await
2359 }
2360 }
2361 } else {
2362 quote! {
2363 {
2364 let mut controller = match #self_type_name::__potato_create_controller(req).await {
2365 Ok(boxed) => boxed,
2366 Err(resp) => return resp,
2367 };
2368 controller.#fn_name()
2369 }
2370 }
2371 }
2372 } else {
2373 if is_async {
2374 quote! {
2375 {
2376 let mut controller = match #self_type_name::__potato_create_controller(req).await {
2377 Ok(boxed) => boxed,
2378 Err(resp) => return resp,
2379 };
2380 controller.#fn_name(#(#param_names),*).await
2381 }
2382 }
2383 } else {
2384 quote! {
2385 {
2386 let mut controller = match #self_type_name::__potato_create_controller(req).await {
2387 Ok(boxed) => boxed,
2388 Err(resp) => return resp,
2389 };
2390 controller.#fn_name(#(#param_names),*)
2391 }
2392 }
2393 }
2394 }
2395 } else {
2396 if param_names.is_empty() {
2400 if is_async {
2401 quote! { #self_type_name::#fn_name().await }
2402 } else {
2403 quote! { #self_type_name::#fn_name() }
2404 }
2405 } else {
2406 let mut param_bindings = Vec::new();
2409 for (i, param) in other_params.iter().enumerate() {
2410 let param_type_str =
2411 param.ty.to_token_stream().to_string().type_simplify();
2412 let param_name = ¶m_names[i];
2413
2414 match ¶m_type_str[..] {
2415 "& mut OnceCache" => {
2416 param_bindings.push(quote! {
2417 let #param_name = &mut __potato_once_cache;
2418 });
2419 }
2420 "& mut SessionCache" => {
2421 param_bindings.push(quote! {
2422 let #param_name = &mut __potato_session_cache;
2423 });
2424 }
2425 _ => {
2426 }
2428 }
2429 }
2430
2431 let needs_session_cache = other_params.iter().any(|p| {
2433 p.ty.to_token_stream().to_string().type_simplify()
2434 == "& mut SessionCache"
2435 });
2436
2437 let session_cache_init = if needs_session_cache {
2438 quote! {
2439 {
2440 if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
2442 let header_value = h.as_str();
2443 if header_value.starts_with("Bearer ") {
2444 potato::SessionCache::from_token(&header_value[7..]).await.ok()
2445 } else {
2446 None
2447 }
2448 } else {
2449 None
2450 }
2451 }
2452 }
2453 } else {
2454 quote! { None }
2455 };
2456
2457 if is_async {
2458 quote! {
2459 {
2460 let mut __potato_once_cache = potato::OnceCache::new();
2461 let mut __potato_session_cache = #session_cache_init.unwrap_or_else(|| potato::SessionCache::new());
2462 #(#param_bindings)*
2463 #self_type_name::#fn_name(#(#param_names),*).await
2464 }
2465 }
2466 } else {
2467 quote! {
2468 {
2469 let mut __potato_once_cache = potato::OnceCache::new();
2470 let mut __potato_session_cache = #session_cache_init.unwrap_or_else(|| potato::SessionCache::new());
2471 #(#param_bindings)*
2472 #self_type_name::#fn_name(#(#param_names),*)
2473 }
2474 }
2475 }
2476 }
2477 };
2478
2479 let ret_type_str = method
2481 .sig
2482 .output
2483 .to_token_stream()
2484 .to_string()
2485 .type_simplify();
2486
2487 let resp_handler =
2489 generate_response_handler(method_call.clone(), &ret_type_str, false);
2490
2491 let wrapper_fn = if is_async {
2493 quote! {
2494 #[doc(hidden)]
2495 fn #wrapper_fn_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
2496 Box::pin(async move {
2497 match #resp_handler {
2498 Ok(resp) => resp,
2499 Err(err) => potato::HttpResponse::error(err.to_string()),
2500 }
2501 })
2502 }
2503 }
2504 } else {
2505 quote! {
2506 #[doc(hidden)]
2507 fn #wrapper_fn_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
2508 Box::pin(async move {
2509 match #resp_handler {
2510 Ok(resp) => resp,
2511 Err(err) => potato::HttpResponse::error(err.to_string()),
2512 }
2513 })
2514 }
2515 }
2516 };
2517
2518 generated_code.push(wrapper_fn);
2519
2520 let http_method_ident = quote::format_ident!("{}", http_method);
2522
2523 let route_register = if is_send {
2524 quote! {
2525 potato::inventory::submit! {
2526 potato::RequestHandlerFlag::new(
2527 potato::HttpMethod::#http_method_ident,
2528 #final_path_lit,
2529 potato::HttpHandler::Async(#wrapper_fn_name),
2530 potato::RequestHandlerFlagDoc::new(true, #doc_auth, "", "", "", #self_type_tag)
2531 )
2532 }
2533 }
2534 } else {
2535 quote! {
2536 potato::inventory::submit! {
2537 potato::RequestHandlerFlag::new(
2538 potato::HttpMethod::#http_method_ident,
2539 #final_path_lit,
2540 potato::HttpHandler::AsyncNoSend(#wrapper_fn_name),
2541 potato::RequestHandlerFlagDoc::new(true, #doc_auth, "", "", "", #self_type_tag)
2542 )
2543 }
2544 }
2545 };
2546
2547 generated_code.push(route_register);
2548 }
2549 }
2550 } else {
2551 cleaned_items.push(item.clone());
2553 }
2554 } else {
2555 cleaned_items.push(item.clone());
2557 }
2558 }
2559
2560 let mut cleaned_impl = item_impl.clone();
2562 cleaned_impl.items = cleaned_items;
2563
2564 let output = quote! {
2566 #cleaned_impl
2567
2568 #(#generated_code)*
2569 };
2570
2571 output.into()
2572}
2573
2574#[proc_macro_attribute]
2575pub fn preprocess(attr: TokenStream, input: TokenStream) -> TokenStream {
2576 preprocess_macro(attr, input)
2577}
2578
2579#[proc_macro_attribute]
2580pub fn postprocess(attr: TokenStream, input: TokenStream) -> TokenStream {
2581 postprocess_macro(attr, input)
2582}
2583
2584fn handle_error_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
2602 if !attr.is_empty() {
2603 return input;
2604 }
2605
2606 let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
2607 let fn_name = root_fn.sig.ident.clone();
2608 let is_async = root_fn.sig.asyncness.is_some();
2609
2610 if root_fn.sig.inputs.len() != 2 {
2612 panic!("`handle_error` function must accept exactly two arguments");
2613 }
2614
2615 let mut arg_types = vec![];
2616 for arg in root_fn.sig.inputs.iter() {
2617 match arg {
2618 syn::FnArg::Typed(arg) => {
2619 arg_types.push(arg.ty.to_token_stream().to_string().type_simplify())
2620 }
2621 _ => panic!("`handle_error` function does not support receiver argument"),
2622 }
2623 }
2624
2625 if arg_types[0] != "& mut HttpRequest" {
2626 panic!(
2627 "`handle_error` first argument must be `&mut potato::HttpRequest`, got `{}`",
2628 arg_types[0]
2629 );
2630 }
2631 if arg_types[1] != "anyhow::Error" {
2632 panic!(
2633 "`handle_error` second argument must be `anyhow::Error`, got `{}`",
2634 arg_types[1]
2635 );
2636 }
2637
2638 let ret_type = root_fn
2639 .sig
2640 .output
2641 .to_token_stream()
2642 .to_string()
2643 .type_simplify();
2644 if ret_type != "HttpResponse" {
2645 panic!(
2646 "`handle_error` return type must be `potato::HttpResponse`, got `{}`",
2647 ret_type
2648 );
2649 }
2650
2651 let wrap_name = format_ident!("__potato_error_handler_adapter_{}", fn_name);
2653 let wrap_name_inner = format_ident!("__potato_error_handler_adapter_inner_{}", fn_name);
2654
2655 let call_body = if is_async {
2657 quote! { #fn_name(req, err).await }
2658 } else {
2659 quote! { #fn_name(req, err) }
2660 };
2661
2662 quote! {
2663 #root_fn
2664
2665 #[doc(hidden)]
2666 async fn #wrap_name_inner(
2667 req: &mut potato::HttpRequest,
2668 err: anyhow::Error,
2669 ) -> potato::HttpResponse {
2670 #call_body
2671 }
2672
2673 #[doc(hidden)]
2674 pub fn #wrap_name(
2675 req: &mut potato::HttpRequest,
2676 err: anyhow::Error,
2677 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
2678 Box::pin(#wrap_name_inner(req, err))
2679 }
2680
2681 potato::inventory::submit! {
2682 potato::ErrorHandlerFlag::new(
2683 potato::ErrorHandler::Async(#wrap_name)
2684 )
2685 }
2686 }
2687 .into()
2688}
2689
2690#[proc_macro_attribute]
2691pub fn handle_error(attr: TokenStream, input: TokenStream) -> TokenStream {
2692 handle_error_macro(attr, input)
2693}
2694
2695#[proc_macro_attribute]
2718pub fn limit_size(attr: TokenStream, input: TokenStream) -> TokenStream {
2719 limit_size_macro(attr, input)
2720}
2721
2722fn limit_size_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
2723 let (_max_header, max_body) = {
2725 let attr_tokens: proc_macro2::TokenStream = attr.clone().into();
2726 if attr_tokens.is_empty() {
2727 (None, None)
2729 } else {
2730 let result = syn::parse::Parser::parse2(
2732 |input: syn::parse::ParseStream| -> syn::Result<(Option<syn::Expr>, Option<syn::Expr>)> {
2733 let mut header_expr = None;
2734 let mut body_expr = None;
2735
2736 while !input.is_empty() {
2738 let ident: Ident = input.parse()?;
2739 input.parse::<Token![=]>()?;
2740 let value: syn::Expr = input.parse()?;
2741
2742 match ident.to_string().as_str() {
2743 "header" => header_expr = Some(value),
2744 "body" => body_expr = Some(value),
2745 _ => return Err(syn::Error::new(ident.span(), "expected 'header' or 'body'")),
2746 }
2747
2748 if input.peek(Token![,]) {
2750 input.parse::<Token![,]>()?;
2751 }
2752 }
2753
2754 Ok((header_expr, body_expr))
2755 },
2756 attr_tokens.clone(),
2757 );
2758
2759 match result {
2760 Ok((h, b)) => (h, b),
2761 Err(_) => {
2762 if let Ok(expr) = syn::parse2::<syn::Expr>(attr_tokens) {
2764 (None, Some(expr))
2765 } else {
2766 (None, None)
2767 }
2768 }
2769 }
2770 }
2771 };
2772
2773 let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
2774
2775 let body_check = if let Some(body_expr) = max_body {
2777 quote! {
2778 let body_len = req.body.len();
2780 if body_len > #body_expr {
2781 let mut res = potato::HttpResponse::text(format!(
2782 "Payload Too Large: body size {} bytes exceeds limit {} bytes",
2783 body_len, #body_expr
2784 ));
2785 res.http_code = 413;
2786 return res;
2787 }
2788 }
2789 } else {
2790 quote! {}
2791 };
2792
2793 let mut wrapped_fn = root_fn.clone();
2795 let original_block = root_fn.block.as_ref();
2796 let new_block: syn::Block = syn::parse_quote!({
2797 #body_check
2798 #original_block
2799 });
2800 wrapped_fn.block = Box::new(new_block);
2801
2802 quote! {
2803 #wrapped_fn
2804 }
2805 .into()
2806}
2807
2808#[proc_macro_attribute]
2811pub fn header(_attr: TokenStream, input: TokenStream) -> TokenStream {
2812 input
2815}
2816
2817#[proc_macro_attribute]
2820pub fn cors(_attr: TokenStream, input: TokenStream) -> TokenStream {
2821 input
2824}
2825
2826#[proc_macro]
2827pub fn embed_dir(input: TokenStream) -> TokenStream {
2828 let path = syn::parse_macro_input!(input as syn::LitStr).value();
2829 quote! {{
2830 #[derive(potato::rust_embed::Embed)]
2831 #[folder = #path]
2832 struct Asset;
2833
2834 potato::load_embed::<Asset>()
2835 }}
2836 .into()
2837}
2838
2839#[proc_macro_derive(StandardHeader)]
2840pub fn standard_header_derive(input: TokenStream) -> TokenStream {
2841 let root_enum = syn::parse_macro_input!(input as syn::ItemEnum);
2842 let enum_name = root_enum.ident;
2843 let mut try_from_str_items = vec![];
2844 let mut to_str_items = vec![];
2845 let mut headers_items = vec![];
2846 let mut headers_apply_items = vec![];
2847 for root_field in root_enum.variants.iter() {
2848 let name = root_field.ident.clone();
2849 if root_field.fields.iter().next().is_some() {
2850 panic!("unsupported enum type");
2851 }
2852 let str_name = name.to_string().replace("_", "-");
2853 let len = str_name.len();
2854 try_from_str_items
2855 .push(quote! { #len if value.eq_ignore_ascii_case(#str_name) => Some(Self::#name), });
2856 to_str_items.push(quote! { Self::#name => #str_name, });
2857 headers_items.push(quote! { #name(String), });
2858 headers_apply_items
2859 .push(quote! { Headers::#name(s) => self.set_header(HeaderItem::#name.to_str(), s), });
2860 }
2861 let r = quote! {
2862 impl #enum_name {
2863 pub fn try_from_str(value: &str) -> Option<Self> {
2864 match value.len() {
2865 #( #try_from_str_items )*
2866 _ => None,
2867 }
2868 }
2869
2870 pub fn to_str(&self) -> &'static str {
2871 match self {
2872 #( #to_str_items )*
2873 }
2874 }
2875 }
2876
2877 pub enum Headers {
2878 #( #headers_items )*
2879 Custom((String, String)),
2880 }
2881
2882 impl HttpRequest {
2883 pub fn apply_header(&mut self, header: Headers) {
2884 match header {
2885 #( #headers_apply_items )*
2886 Headers::Custom((k, v)) => self.set_header(&k[..], v),
2887 }
2888 }
2889 }
2890 };
2891 r.into()
2892}