1mod utils;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span};
5use quote::{format_ident, quote, ToTokens};
6use rand::Rng;
7use serde_json::json;
8use std::{collections::HashSet, sync::LazyLock};
9use syn::Token;
10use utils::StringExt as _;
11
12struct CorsAttrConfig {
14 origin: Option<String>,
15 methods: Option<String>,
16 headers: Option<String>,
17 max_age: Option<String>,
18 credentials: bool,
19 expose_headers: Option<String>,
20}
21
22fn parse_cors_attr(tokens: &proc_macro2::TokenStream) -> CorsAttrConfig {
24 let config = CorsAttrConfig {
25 origin: None,
26 methods: None,
27 headers: None,
28 max_age: None,
29 credentials: false,
30 expose_headers: None,
31 };
32
33 if tokens.is_empty() {
34 return config; }
36
37 use syn::parse::Parser;
39
40 fn parse_inner(input: syn::parse::ParseStream) -> syn::Result<CorsAttrConfig> {
41 let mut config = CorsAttrConfig {
42 origin: None,
43 methods: None,
44 headers: None,
45 max_age: None,
46 credentials: false,
47 expose_headers: None,
48 };
49
50 let vars =
51 syn::punctuated::Punctuated::<syn::MetaNameValue, Token![,]>::parse_terminated(input)?;
52 for meta in vars {
53 let key = meta
54 .path
55 .get_ident()
56 .map(|i| i.to_string())
57 .unwrap_or_default();
58 if let syn::Expr::Lit(expr_lit) = &meta.value {
59 match &expr_lit.lit {
60 syn::Lit::Str(s) => {
61 let val = s.value();
62 match key.as_str() {
63 "origin" => config.origin = Some(val),
64 "methods" => config.methods = Some(val),
65 "headers" => config.headers = Some(val),
66 "max_age" => config.max_age = Some(val),
67 "expose_headers" => config.expose_headers = Some(val),
68 _ => {}
69 }
70 }
71 syn::Lit::Bool(b) => {
72 if key == "credentials" {
73 config.credentials = b.value();
74 }
75 }
76 _ => {}
77 }
78 }
79 }
80 Ok(config)
81 }
82
83 match parse_inner.parse2(tokens.clone()) {
84 Ok(cfg) => cfg,
85 Err(e) => panic!("Failed to parse cors attributes: {e}"),
86 }
87}
88
89static ARG_TYPES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
90 [
91 "String", "bool", "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
92 "f32", "f64",
93 ]
94 .into_iter()
95 .collect()
96});
97
98fn validate_controller_struct(item_struct: &syn::ItemStruct) -> (bool, bool) {
101 let mut has_once_cache = false;
102 let mut has_session_cache = false;
103
104 if let syn::Fields::Named(fields_named) = &item_struct.fields {
105 for field in &fields_named.named {
106 let field_type_str = field.ty.to_token_stream().to_string().type_simplify();
107
108 if field_type_str.contains("OnceCache") {
110 has_once_cache = true;
111 } else if field_type_str.contains("SessionCache") {
112 has_session_cache = true;
113 } else {
114 panic!(
115 "Controller field must be &OnceCache or &SessionCache, got: {}",
116 field_type_str
117 );
118 }
119 }
120 }
121
122 (has_once_cache, has_session_cache)
123}
124
125fn parse_header_attr(tokens: &proc_macro2::TokenStream) -> Result<(String, String), syn::Error> {
127 use syn::parse::Parser;
128
129 let parser = |input: syn::parse::ParseStream| {
130 let key_ident: Ident = input.parse()?;
134 let key_name = key_ident.to_string();
135
136 if key_name == "Custom" {
137 let content;
139 syn::parenthesized!(content in input);
140 let key_lit: syn::LitStr = content.parse()?;
141 let key = key_lit.value();
142 let _: Token![=] = input.parse()?;
143 let value: syn::LitStr = input.parse()?;
144 Ok((key, value.value()))
145 } else {
146 let _: Token![=] = input.parse()?;
148 let value: syn::LitStr = input.parse()?;
149 Ok((key_name, value.value()))
150 }
151 };
152
153 parser.parse2(tokens.clone())
154}
155
156fn random_ident() -> Ident {
157 let mut rng = rand::thread_rng();
158 let value = format!("__potato_id_{}", rng.r#gen::<u64>());
159 Ident::new(&value, Span::call_site())
160}
161
162fn attr_last_ident(attr: &syn::Attribute) -> Option<String> {
163 attr.meta
164 .path()
165 .segments
166 .iter()
167 .last()
168 .map(|segment| segment.ident.to_string())
169}
170
171fn parse_hook_attr_items(attr: &syn::Attribute, attr_name: &str) -> Vec<Ident> {
172 let parser = syn::punctuated::Punctuated::<Ident, syn::Token![,]>::parse_terminated;
173 let idents = attr.parse_args_with(parser).unwrap_or_else(|err| {
174 panic!("invalid `{attr_name}` annotation: {err}");
175 });
176 if idents.is_empty() {
177 panic!("`{attr_name}` annotation requires at least one function name");
178 }
179 idents.into_iter().collect()
180}
181
182fn collect_handler_hooks(root_fn: &mut syn::ItemFn) -> (Vec<Ident>, Vec<Ident>) {
183 enum HookKind {
184 Pre,
185 Post,
186 }
187 let mut hooks = vec![];
188 let mut new_attrs = Vec::with_capacity(root_fn.attrs.len());
189 for attr in root_fn.attrs.iter() {
190 match attr_last_ident(attr).as_deref() {
191 Some("preprocess") => {
192 hooks.extend(
193 parse_hook_attr_items(attr, "preprocess")
194 .into_iter()
195 .map(|item| (HookKind::Pre, item)),
196 );
197 }
198 Some("postprocess") => {
199 hooks.extend(
200 parse_hook_attr_items(attr, "postprocess")
201 .into_iter()
202 .map(|item| (HookKind::Post, item)),
203 );
204 }
205 _ => new_attrs.push(attr.clone()),
206 }
207 }
208 root_fn.attrs = new_attrs;
209 let mut preprocess_fns = vec![];
210 let mut postprocess_fns = vec![];
211 for (kind, hook) in hooks.into_iter() {
212 match kind {
213 HookKind::Pre => preprocess_fns.push(hook),
214 HookKind::Post => postprocess_fns.push(hook),
215 }
216 }
217 (preprocess_fns, postprocess_fns)
218}
219
220fn validate_preprocess_signature(root_fn: &syn::ItemFn) -> (String, bool, bool) {
221 if root_fn.sig.inputs.is_empty() || root_fn.sig.inputs.len() > 3 {
222 panic!("`preprocess` function must accept one to three arguments");
223 }
224 let mut arg_types = vec![];
225 for arg in root_fn.sig.inputs.iter() {
226 match arg {
227 syn::FnArg::Typed(arg) => {
228 arg_types.push(arg.ty.to_token_stream().to_string().type_simplify())
229 }
230 _ => panic!("`preprocess` function does not support receiver argument"),
231 }
232 }
233 if arg_types[0] != "& mut HttpRequest" {
234 panic!(
235 "`preprocess` first argument type must be `&mut potato::HttpRequest`, got `{}`",
236 arg_types[0]
237 );
238 }
239
240 let has_once_cache = arg_types.iter().any(|t| t == "& mut OnceCache");
241 let has_session_cache = arg_types.iter().any(|t| t == "& mut SessionCache");
242
243 if arg_types.len() == 2 && !has_once_cache && !has_session_cache {
244 panic!(
245 "`preprocess` second argument type must be `&mut potato::OnceCache` or `&mut potato::SessionCache`, got `{}`",
246 arg_types[1]
247 );
248 }
249 if arg_types.len() == 3 {
250 if !has_once_cache {
251 panic!("`preprocess` must have `&mut potato::OnceCache` as one of the arguments");
252 }
253 if !has_session_cache {
254 panic!("`preprocess` must have `&mut potato::SessionCache` as one of the arguments");
255 }
256 }
257
258 let ret_type = root_fn
259 .sig
260 .output
261 .to_token_stream()
262 .to_string()
263 .type_simplify();
264 match &ret_type[..] {
265 "Result<Option<HttpResponse>>" | "Option<HttpResponse>" | "Result<()>" | "()" => {}
266 _ => panic!(
267 "unsupported `preprocess` return type: `{ret_type}`, expected `anyhow::Result<Option<potato::HttpResponse>>`, `Option<potato::HttpResponse>`, `anyhow::Result<()>`, or `()`"
268 ),
269 }
270 (ret_type, has_once_cache, has_session_cache)
271}
272
273fn validate_postprocess_signature(root_fn: &syn::ItemFn) -> (String, bool, bool) {
274 if root_fn.sig.inputs.len() < 2 && root_fn.sig.inputs.len() > 4 {
275 panic!("`postprocess` function must accept two to four arguments");
276 }
277 let mut arg_types = vec![];
278 for arg in root_fn.sig.inputs.iter() {
279 match arg {
280 syn::FnArg::Typed(arg) => {
281 arg_types.push(arg.ty.to_token_stream().to_string().type_simplify())
282 }
283 _ => panic!("`postprocess` function does not support receiver argument"),
284 }
285 }
286 if arg_types[0] != "& mut HttpRequest" {
287 panic!(
288 "`postprocess` first argument must be `&mut potato::HttpRequest`, got `{}`",
289 arg_types[0]
290 );
291 }
292 if arg_types[1] != "& mut HttpResponse" {
293 panic!(
294 "`postprocess` second argument must be `&mut potato::HttpResponse`, got `{}`",
295 arg_types[1]
296 );
297 }
298
299 let remaining_args = &arg_types[2..];
300 let has_once_cache = remaining_args.iter().any(|t| t == "& mut OnceCache");
301 let has_session_cache = remaining_args.iter().any(|t| t == "& mut SessionCache");
302
303 if arg_types.len() == 3 && !has_once_cache && !has_session_cache {
304 panic!(
305 "`postprocess` third argument must be `&mut potato::OnceCache` or `&mut potato::SessionCache`, got `{}`",
306 arg_types[2]
307 );
308 }
309 if arg_types.len() == 4 && (!has_once_cache || !has_session_cache) {
310 panic!(
311 "`postprocess` with 4 arguments must have both `&mut potato::OnceCache` and `&mut potato::SessionCache`"
312 );
313 }
314
315 let ret_type = root_fn
316 .sig
317 .output
318 .to_token_stream()
319 .to_string()
320 .type_simplify();
321 match &ret_type[..] {
322 "Result<()>" | "()" => {}
323 _ => panic!(
324 "unsupported `postprocess` return type: `{ret_type}`, expected `anyhow::Result<()>` or `()`"
325 ),
326 }
327 (ret_type, has_once_cache, has_session_cache)
328}
329
330fn preprocess_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
331 if !attr.is_empty() {
332 return input;
333 }
334 let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
335 let fn_name = root_fn.sig.ident.clone();
336 let wrap_name = format_ident!("__potato_preprocess_adapter_{}", fn_name);
337 let wrap_name_inner = format_ident!("__potato_preprocess_adapter_inner_{}", fn_name);
338 let is_async = root_fn.sig.asyncness.is_some();
339 let (ret_type, has_once_cache, has_session_cache) = validate_preprocess_signature(&root_fn);
340
341 let wrap_signature = match (has_once_cache, has_session_cache) {
343 (true, true) => quote! {
344 async fn #wrap_name_inner(
345 req: &mut potato::HttpRequest,
346 once_cache: &mut potato::OnceCache,
347 session_cache: &mut potato::SessionCache,
348 ) -> anyhow::Result<Option<potato::HttpResponse>>
349 },
350 (true, false) => quote! {
351 async fn #wrap_name_inner(
352 req: &mut potato::HttpRequest,
353 once_cache: &mut potato::OnceCache,
354 ) -> anyhow::Result<Option<potato::HttpResponse>>
355 },
356 (false, true) => quote! {
357 async fn #wrap_name_inner(
358 req: &mut potato::HttpRequest,
359 session_cache: &mut potato::SessionCache,
360 ) -> anyhow::Result<Option<potato::HttpResponse>>
361 },
362 (false, false) => quote! {
363 async fn #wrap_name_inner(
364 req: &mut potato::HttpRequest,
365 ) -> anyhow::Result<Option<potato::HttpResponse>>
366 },
367 };
368
369 let call_body = if is_async {
371 match &ret_type[..] {
372 "Result<Option<HttpResponse>>" => match (has_once_cache, has_session_cache) {
373 (true, true) => {
374 quote! { #fn_name(req, once_cache, session_cache).await }
375 }
376 (true, false) => quote! { #fn_name(req, once_cache).await },
377 (false, true) => quote! { #fn_name(req, session_cache).await },
378 (false, false) => quote! { #fn_name(req).await },
379 },
380 "Option<HttpResponse>" => match (has_once_cache, has_session_cache) {
381 (true, true) => quote! { Ok(#fn_name(req, once_cache, session_cache).await) },
382 (true, false) => quote! { Ok(#fn_name(req, once_cache).await) },
383 (false, true) => quote! { Ok(#fn_name(req, session_cache).await) },
384 (false, false) => quote! { Ok(#fn_name(req).await) },
385 },
386 "Result<()>" => match (has_once_cache, has_session_cache) {
387 (true, true) => {
388 quote! { #fn_name(req, once_cache, session_cache).await.map(|_| None) }
389 }
390 (true, false) => quote! { #fn_name(req, once_cache).await.map(|_| None) },
391 (false, true) => quote! { #fn_name(req, session_cache).await.map(|_| None) },
392 (false, false) => quote! { #fn_name(req).await.map(|_| None) },
393 },
394 "()" => match (has_once_cache, has_session_cache) {
395 (true, true) => quote! { #fn_name(req, once_cache, session_cache).await; Ok(None) },
396 (true, false) => quote! { #fn_name(req, once_cache).await; Ok(None) },
397 (false, true) => quote! { #fn_name(req, session_cache).await; Ok(None) },
398 (false, false) => quote! { #fn_name(req).await; Ok(None) },
399 },
400 _ => unreachable!(),
401 }
402 } else {
403 match &ret_type[..] {
404 "Result<Option<HttpResponse>>" => match (has_once_cache, has_session_cache) {
405 (true, true) => quote! { #fn_name(req, once_cache, session_cache) },
406 (true, false) => quote! { #fn_name(req, once_cache) },
407 (false, true) => quote! { #fn_name(req, session_cache) },
408 (false, false) => quote! { #fn_name(req) },
409 },
410 "Option<HttpResponse>" => match (has_once_cache, has_session_cache) {
411 (true, true) => quote! { Ok(#fn_name(req, once_cache, session_cache)) },
412 (true, false) => quote! { Ok(#fn_name(req, once_cache)) },
413 (false, true) => quote! { Ok(#fn_name(req, session_cache)) },
414 (false, false) => quote! { Ok(#fn_name(req)) },
415 },
416 "Result<()>" => match (has_once_cache, has_session_cache) {
417 (true, true) => quote! { #fn_name(req, once_cache, session_cache).map(|_| None) },
418 (true, false) => quote! { #fn_name(req, once_cache).map(|_| None) },
419 (false, true) => quote! { #fn_name(req, session_cache).map(|_| None) },
420 (false, false) => quote! { #fn_name(req).map(|_| None) },
421 },
422 "()" => match (has_once_cache, has_session_cache) {
423 (true, true) => quote! { #fn_name(req, once_cache, session_cache); Ok(None) },
424 (true, false) => quote! { #fn_name(req, once_cache); Ok(None) },
425 (false, true) => quote! { #fn_name(req, session_cache); Ok(None) },
426 (false, false) => quote! { #fn_name(req); Ok(None) },
427 },
428 _ => unreachable!(),
429 }
430 };
431
432 let wrapper_body = match (has_once_cache, has_session_cache) {
434 (true, true) => quote! {
435 #wrap_name_inner(
436 req,
437 once_cache.expect("OnceCache required but not provided"),
438 session_cache.expect("SessionCache required but not provided"),
439 ).await
440 },
441 (true, false) => quote! {
442 #wrap_name_inner(
443 req,
444 once_cache.expect("OnceCache required but not provided"),
445 ).await
446 },
447 (false, true) => quote! {
448 #wrap_name_inner(
449 req,
450 session_cache.expect("SessionCache required but not provided"),
451 ).await
452 },
453 (false, false) => quote! {
454 #wrap_name_inner(req).await
455 },
456 };
457
458 quote! {
459 #root_fn
460
461 #[doc(hidden)]
462 #wrap_signature {
463 #call_body
464 }
465
466 #[doc(hidden)]
467 pub async fn #wrap_name(
468 req: &mut potato::HttpRequest,
469 once_cache: Option<&mut potato::OnceCache>,
470 session_cache: Option<&mut potato::SessionCache>,
471 ) -> anyhow::Result<Option<potato::HttpResponse>> {
472 #wrapper_body
473 }
474 }
475 .into()
476}
477
478fn postprocess_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
479 if !attr.is_empty() {
480 return input;
481 }
482 let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
483 let fn_name = root_fn.sig.ident.clone();
484 let wrap_name = format_ident!("__potato_postprocess_adapter_{}", fn_name);
485 let wrap_name_inner = format_ident!("__potato_postprocess_adapter_inner_{}", fn_name);
486 let is_async = root_fn.sig.asyncness.is_some();
487 let (ret_type, has_once_cache, has_session_cache) = validate_postprocess_signature(&root_fn);
488
489 let wrap_signature = match (has_once_cache, has_session_cache) {
491 (true, true) => quote! {
492 async fn #wrap_name_inner(
493 req: &mut potato::HttpRequest,
494 res: &mut potato::HttpResponse,
495 once_cache: &mut potato::OnceCache,
496 session_cache: &mut potato::SessionCache,
497 ) -> anyhow::Result<()>
498 },
499 (true, false) => quote! {
500 async fn #wrap_name_inner(
501 req: &mut potato::HttpRequest,
502 res: &mut potato::HttpResponse,
503 once_cache: &mut potato::OnceCache,
504 ) -> anyhow::Result<()>
505 },
506 (false, true) => quote! {
507 async fn #wrap_name_inner(
508 req: &mut potato::HttpRequest,
509 res: &mut potato::HttpResponse,
510 session_cache: &mut potato::SessionCache,
511 ) -> anyhow::Result<()>
512 },
513 (false, false) => quote! {
514 async fn #wrap_name_inner(
515 req: &mut potato::HttpRequest,
516 res: &mut potato::HttpResponse,
517 ) -> anyhow::Result<()>
518 },
519 };
520
521 let call_body = if is_async {
523 match &ret_type[..] {
524 "Result<()>" => {
525 if has_once_cache && has_session_cache {
526 quote! {
527 #fn_name(req, res, once_cache, session_cache).await
528 }
529 } else if has_once_cache {
530 quote! {
531 #fn_name(req, res, once_cache).await
532 }
533 } else if has_session_cache {
534 quote! {
535 #fn_name(req, res, session_cache).await
536 }
537 } else {
538 quote! {
539 #fn_name(req, res).await
540 }
541 }
542 }
543 "()" => {
544 if has_once_cache && has_session_cache {
545 quote! {
546 #fn_name(req, res, once_cache, session_cache).await;
547 Ok(())
548 }
549 } else if has_once_cache {
550 quote! {
551 #fn_name(req, res, once_cache).await;
552 Ok(())
553 }
554 } else if has_session_cache {
555 quote! {
556 #fn_name(req, res, session_cache).await;
557 Ok(())
558 }
559 } else {
560 quote! {
561 #fn_name(req, res).await;
562 Ok(())
563 }
564 }
565 }
566 _ => unreachable!(),
567 }
568 } else {
569 match &ret_type[..] {
570 "Result<()>" => {
571 if has_once_cache && has_session_cache {
572 quote! {
573 #fn_name(req, res, once_cache, session_cache)
574 }
575 } else if has_once_cache {
576 quote! {
577 #fn_name(req, res, once_cache)
578 }
579 } else if has_session_cache {
580 quote! {
581 #fn_name(req, res, session_cache)
582 }
583 } else {
584 quote! {
585 #fn_name(req, res)
586 }
587 }
588 }
589 "()" => {
590 if has_once_cache && has_session_cache {
591 quote! {
592 #fn_name(req, res, once_cache, session_cache);
593 Ok(())
594 }
595 } else if has_once_cache {
596 quote! {
597 #fn_name(req, res, once_cache);
598 Ok(())
599 }
600 } else if has_session_cache {
601 quote! {
602 #fn_name(req, res, session_cache);
603 Ok(())
604 }
605 } else {
606 quote! {
607 #fn_name(req, res);
608 Ok(())
609 }
610 }
611 }
612 _ => unreachable!(),
613 }
614 };
615
616 let wrapper_body = match (has_once_cache, has_session_cache) {
618 (true, true) => quote! {
619 #wrap_name_inner(
620 req,
621 res,
622 once_cache.expect("OnceCache required but not provided"),
623 session_cache.expect("SessionCache required but not provided"),
624 ).await
625 },
626 (true, false) => quote! {
627 #wrap_name_inner(
628 req,
629 res,
630 once_cache.expect("OnceCache required but not provided"),
631 ).await
632 },
633 (false, true) => quote! {
634 #wrap_name_inner(
635 req,
636 res,
637 session_cache.expect("SessionCache required but not provided"),
638 ).await
639 },
640 (false, false) => quote! {
641 #wrap_name_inner(req, res).await
642 },
643 };
644
645 quote! {
646 #root_fn
647
648 #[doc(hidden)]
649 #wrap_signature {
650 #call_body
651 }
652
653 #[doc(hidden)]
654 pub async fn #wrap_name(
655 req: &mut potato::HttpRequest,
656 res: &mut potato::HttpResponse,
657 once_cache: Option<&mut potato::OnceCache>,
658 session_cache: Option<&mut potato::SessionCache>,
659 ) -> anyhow::Result<()> {
660 #wrapper_body
661 }
662 }
663 .into()
664}
665
666fn http_handler_macro(attr: TokenStream, input: TokenStream, req_name: &str) -> TokenStream {
667 let req_name = Ident::new(req_name, Span::call_site());
668
669 let root_fn_for_check = syn::parse::<syn::ItemFn>(input.clone());
671 let has_receiver = if let Ok(ref func) = root_fn_for_check {
672 func.sig
673 .inputs
674 .iter()
675 .any(|arg| matches!(arg, syn::FnArg::Receiver(_)))
676 } else {
677 false
678 };
679
680 let (route_path, default_headers) = {
681 let mut oroute_path = syn::parse::<syn::LitStr>(attr.clone())
682 .ok()
683 .map(|path| path.value());
684 let mut default_headers: Vec<(String, String)> = Vec::new();
685 if oroute_path.is_none() {
687 let http_parser = syn::meta::parser(|meta| {
688 if meta.path.is_ident("path") {
689 if let Ok(arg) = meta.value() {
690 if let Ok(route_path) = arg.parse::<syn::LitStr>() {
691 let route_path = route_path.value();
692 oroute_path = Some(route_path);
693 }
694 }
695 Ok(())
696 } else if meta.path.is_ident("header") {
697 let content;
699 syn::parenthesized!(content in meta.input);
700 let key: Ident = content.parse()?;
701 let _: syn::Token![=] = content.parse()?;
702 let value: syn::LitStr = content.parse()?;
703 default_headers.push((key.to_string(), value.value()));
704 Ok(())
705 } else {
706 Err(meta.error("unsupported annotation property"))
707 }
708 });
709 syn::parse_macro_input!(attr with http_parser);
710 }
711
712 if oroute_path.is_none() && has_receiver {
714 } else if oroute_path.is_none() {
716 panic!("`path` argument is required for non-controller methods");
717 }
718
719 let route_path = oroute_path.unwrap_or_default();
720
721 let route_path = if has_receiver {
723 if route_path.is_empty() {
724 String::new()
726 } else {
727 route_path
730 }
731 } else {
732 if route_path.is_empty() {
733 panic!("`path` argument is required for non-controller methods");
734 }
735 route_path
736 };
737
738 if !route_path.is_empty() && !route_path.starts_with('/') {
739 panic!("route path must start with '/'");
740 }
741 (route_path, default_headers)
742 };
743
744 let mut root_fn = syn::parse_macro_input!(input as syn::ItemFn);
746 let mut fn_headers: Vec<(String, String)> = Vec::new();
747 let mut cors_config: Option<CorsAttrConfig> = None;
748 let mut max_concurrency: Option<usize> = None;
749 let mut remaining_attrs = Vec::new();
750
751 for attr in root_fn.attrs.iter() {
752 let is_header_attr = attr.path().is_ident("header")
754 || (attr.path().segments.len() == 2
755 && attr
756 .path()
757 .segments
758 .iter()
759 .next()
760 .map(|s| s.ident.to_string())
761 == Some("potato".to_string())
762 && attr
763 .path()
764 .segments
765 .iter()
766 .last()
767 .map(|s| s.ident.to_string())
768 == Some("header".to_string()));
769
770 if is_header_attr {
771 if let syn::Meta::List(meta_list) = &attr.meta {
772 if let Ok((key, value)) = parse_header_attr(&meta_list.tokens) {
774 fn_headers.push((key, value));
775 }
776 }
777 continue;
778 }
779
780 let is_cors_attr = attr.path().is_ident("cors")
782 || (attr.path().segments.len() == 2
783 && attr
784 .path()
785 .segments
786 .iter()
787 .next()
788 .map(|s| s.ident.to_string())
789 == Some("potato".to_string())
790 && attr
791 .path()
792 .segments
793 .iter()
794 .last()
795 .map(|s| s.ident.to_string())
796 == Some("cors".to_string()));
797
798 if is_cors_attr {
799 if let syn::Meta::List(meta_list) = &attr.meta {
800 cors_config = Some(parse_cors_attr(&meta_list.tokens));
801 } else {
802 cors_config = Some(CorsAttrConfig {
804 origin: None,
805 methods: None,
806 headers: None,
807 max_age: None,
808 credentials: false,
809 expose_headers: None,
810 });
811 }
812 continue;
813 }
814
815 let is_max_concurrency_attr = attr.path().is_ident("max_concurrency")
817 || (attr.path().segments.len() == 2
818 && attr
819 .path()
820 .segments
821 .iter()
822 .next()
823 .map(|s| s.ident.to_string())
824 == Some("potato".to_string())
825 && attr
826 .path()
827 .segments
828 .iter()
829 .last()
830 .map(|s| s.ident.to_string())
831 == Some("max_concurrency".to_string()));
832
833 if is_max_concurrency_attr {
834 if let syn::Meta::List(meta_list) = &attr.meta {
835 let tokens = &meta_list.tokens;
836 if let Ok(lit_int) = syn::parse2::<syn::LitInt>(tokens.clone()) {
838 if let Ok(val) = lit_int.base10_parse::<usize>() {
839 if val == 0 {
840 panic!("max_concurrency must be greater than 0");
841 }
842 max_concurrency = Some(val);
843 } else {
844 panic!("invalid max_concurrency value");
845 }
846 } else {
847 panic!(
848 "max_concurrency requires a numeric value, e.g., #[max_concurrency(10)]"
849 );
850 }
851 } else if let syn::Meta::NameValue(name_value) = &attr.meta {
852 if let syn::Expr::Lit(expr_lit) = &name_value.value {
853 if let syn::Lit::Int(lit_int) = &expr_lit.lit {
854 if let Ok(val) = lit_int.base10_parse::<usize>() {
855 if val == 0 {
856 panic!("max_concurrency must be greater than 0");
857 }
858 max_concurrency = Some(val);
859 } else {
860 panic!("invalid max_concurrency value");
861 }
862 } else {
863 panic!("max_concurrency requires a numeric value");
864 }
865 } else {
866 panic!("max_concurrency requires a numeric value");
867 }
868 } else {
869 panic!("max_concurrency requires a numeric value, e.g., #[max_concurrency(10)]");
870 }
871 continue;
872 }
873
874 remaining_attrs.push(attr.clone());
875 }
876
877 let mut all_headers = default_headers;
879 all_headers.extend(fn_headers);
880
881 root_fn.attrs = remaining_attrs;
882 let (preprocess_fns, postprocess_fns) = collect_handler_hooks(&mut root_fn);
883
884 let handler_has_once_cache = root_fn.sig.inputs.iter().any(|arg| {
886 if let syn::FnArg::Typed(arg) = arg {
887 arg.ty.to_token_stream().to_string().type_simplify() == "& mut OnceCache"
888 } else {
889 false
890 }
891 });
892 let handler_has_session_cache = root_fn.sig.inputs.iter().any(|arg| {
893 if let syn::FnArg::Typed(arg) = arg {
894 arg.ty.to_token_stream().to_string().type_simplify() == "& mut SessionCache"
895 } else {
896 false
897 }
898 });
899
900 let need_once_cache = handler_has_once_cache;
905 let need_session_cache = handler_has_session_cache;
906
907 let preprocess_adapters: Vec<Ident> = preprocess_fns
908 .iter()
909 .map(|name| format_ident!("__potato_preprocess_adapter_{}", name))
910 .collect();
911 let postprocess_adapters: Vec<Ident> = postprocess_fns
912 .iter()
913 .map(|name| format_ident!("__potato_postprocess_adapter_{}", name))
914 .collect();
915 let doc_show = {
916 let mut doc_show = true;
917 for attr in root_fn.attrs.iter() {
918 if attr.meta.path().get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
919 if let Ok(meta_list) = attr.meta.require_list() {
920 if meta_list.tokens.to_string() == "hidden" {
921 doc_show = false;
922 break;
923 }
924 }
925 }
926 }
927 doc_show
928 };
929 let doc_auth = need_session_cache;
930 let doc_summary = {
931 let mut docs = vec![];
932 for attr in root_fn.attrs.iter() {
933 if let Ok(attr) = attr.meta.require_name_value() {
934 if attr.path.get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
935 let mut doc = attr.value.to_token_stream().to_string();
936 if doc.starts_with('\"') {
937 doc.remove(0);
938 doc.pop();
939 }
940 docs.push(doc);
941 }
942 }
943 }
944 if docs.iter().all(|d| d.starts_with(' ')) {
945 for doc in docs.iter_mut() {
946 doc.remove(0);
947 }
948 }
949 docs.join("\n")
950 };
951 let doc_desp = "";
952 let fn_name = root_fn.sig.ident.clone();
953 let is_async = root_fn.sig.asyncness.is_some();
954
955 let has_receiver = root_fn
957 .sig
958 .inputs
959 .iter()
960 .any(|arg| matches!(arg, syn::FnArg::Receiver(_)));
961
962 let final_path = if has_receiver {
967 if route_path.is_empty() {
970 panic!("Controller methods must specify a path (e.g., #[potato::http_get(\"/\")])");
973 } else {
974 route_path
975 }
976 } else {
977 if route_path.is_empty() {
978 panic!("`path` argument is required for non-controller methods");
979 }
980 route_path
981 };
982
983 let final_path_expr = quote! { #final_path };
984
985 let tag_expr = if has_receiver {
987 quote! { __POTATO_CONTROLLER_NAME }
988 } else {
989 quote! { "" }
990 };
991
992 let wrap_func_name = random_ident();
993 let mut args = vec![];
994 let mut arg_names = vec![];
995 let mut arg_types = vec![];
996 let mut doc_args = vec![];
997 for arg in root_fn.sig.inputs.iter() {
998 if let syn::FnArg::Receiver(_receiver) = arg {
1000 continue;
1003 }
1004
1005 if let syn::FnArg::Typed(arg) = arg {
1006 let arg_type_str = arg
1007 .ty
1008 .as_ref()
1009 .to_token_stream()
1010 .to_string()
1011 .type_simplify();
1012 let arg_name_str = arg.pat.to_token_stream().to_string();
1013 let arg_value = match &arg_type_str[..] {
1014 "& mut HttpRequest" => quote! { req },
1015 "& mut OnceCache" => {
1016 quote! { __potato_once_cache.as_mut().expect("OnceCache not available") }
1017 }
1018 "& mut SessionCache" => {
1019 quote! { __potato_session_cache.as_mut().expect("SessionCache not available") }
1020 }
1021 "PostFile" => {
1022 doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
1023 quote! {
1024 match req.body_files.get(&potato::utils::refstr::LocalHipStr<'static>::from_str(#arg_name_str)).cloned() {
1025 Some(file) => file,
1026 None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
1027 }
1028 }
1029 }
1030 arg_type_str if ARG_TYPES.contains(arg_type_str) => {
1031 doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
1032 let mut arg_value = quote! {
1033 match req.body_pairs
1034 .get(&potato::hipstr::LocalHipStr::from(#arg_name_str))
1035 .map(|p| p.to_string()) {
1036 Some(val) => val,
1037 None => match req.url_query
1038 .get(&potato::hipstr::LocalHipStr::from(#arg_name_str))
1039 .map(|p| p.as_str().to_string()) {
1040 Some(val) => val,
1041 None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
1042 },
1043 }
1044 };
1045 if arg_type_str != "String" {
1046 arg_value = quote! {
1047 match #arg_value.parse() {
1048 Ok(val) => val,
1049 Err(err) => return potato::HttpResponse::error(format!("arg[{}] is not {} type", #arg_name_str, #arg_type_str)),
1050 }
1051 }
1052 }
1053 arg_value
1054 }
1055 _ => panic!("unsupported arg type: [{arg_type_str}]"),
1056 };
1057 args.push(arg_value);
1058 arg_names.push(random_ident());
1059 arg_types.push(arg_type_str);
1061 }
1062 }
1063 let wrap_func_name2 = random_ident();
1064 let ret_type = root_fn
1065 .sig
1066 .output
1067 .to_token_stream()
1068 .to_string()
1069 .type_simplify();
1070
1071 let _controller_create_fn = if has_receiver {
1074 quote! {
1075 }
1078 } else {
1079 quote! {}
1080 };
1081
1082 let call_args: Vec<_> = args
1084 .iter()
1085 .enumerate()
1086 .map(|(i, _arg)| {
1087 let arg_name = &arg_names[i];
1088 let arg_type = &arg_types[i];
1089 if arg_type == "& mut HttpRequest" {
1091 quote! { req }
1092 } else {
1093 quote! { #arg_name }
1094 }
1095 })
1096 .collect();
1097
1098 let call_expr = if has_receiver {
1099 match args.len() {
1103 0 => quote! { #fn_name() },
1104 1 => {
1105 let arg_name = &arg_names[0];
1106 let arg = &args[0];
1107 let arg_type = &arg_types[0];
1108 if arg_type == "& mut HttpRequest" {
1109 quote! { #fn_name(req) }
1110 } else {
1111 quote! {{
1112 let #arg_name = #arg;
1113 #fn_name(#arg_name)
1114 }}
1115 }
1116 }
1117 _ => {
1118 let let_bindings: Vec<_> = arg_types
1119 .iter()
1120 .zip(arg_names.iter())
1121 .zip(args.iter())
1122 .filter(|((arg_type, _), _)| *arg_type != "& mut HttpRequest")
1123 .map(|((_, arg_name), arg)| quote! { let #arg_name = #arg; })
1124 .collect();
1125
1126 quote! {{
1127 #(#let_bindings)*
1128 #fn_name(#(#call_args),*)
1129 }}
1130 }
1131 }
1132 } else {
1133 match args.len() {
1135 0 => quote! { #fn_name() },
1136 1 => {
1137 let arg_name = &arg_names[0];
1138 let arg = &args[0];
1139 let arg_type = &arg_types[0];
1140 if arg_type == "& mut HttpRequest" {
1142 quote! { #fn_name(req) }
1143 } else {
1144 quote! {{
1145 let #arg_name = #arg;
1146 #fn_name(#arg_name)
1147 }}
1148 }
1149 }
1150 _ => {
1151 let let_bindings: Vec<_> = arg_types
1153 .iter()
1154 .zip(arg_names.iter())
1155 .zip(args.iter())
1156 .filter(|((arg_type, _), _)| *arg_type != "& mut HttpRequest")
1157 .map(|((_, arg_name), arg)| quote! { let #arg_name = #arg; })
1158 .collect();
1159
1160 quote! {{
1161 #(#let_bindings)*
1162 #fn_name(#(#call_args),*)
1163 }}
1164 }
1165 }
1166 };
1167 let handler_wrap_func_body = if is_async {
1168 match &ret_type[..] {
1169 "Result<()>" => quote! {
1170 match #call_expr.await {
1171 Ok(_) => Ok(potato::HttpResponse::text("ok")),
1172 Err(err) => Err(err),
1173 }
1174 },
1175 "Result<HttpResponse>" | "anyhow::Result<HttpResponse>" => quote! {
1176 match #call_expr.await {
1177 Ok(ret) => Ok(ret),
1178 Err(err) => Err(err),
1179 }
1180 },
1181 "Result<String>" | "anyhow::Result<String>" => quote! {
1182 match #call_expr.await {
1183 Ok(ret) => Ok(potato::HttpResponse::html(ret)),
1184 Err(err) => Err(err),
1185 }
1186 },
1187 "Result<& 'static str>" | "anyhow::Result<& 'static str>" => quote! {
1188 match #call_expr.await {
1189 Ok(ret) => Ok(potato::HttpResponse::html(ret)),
1190 Err(err) => Err(err),
1191 }
1192 },
1193 "()" => quote! {
1194 #call_expr.await;
1195 Ok(potato::HttpResponse::text("ok"))
1196 },
1197 "HttpResponse" => quote! {
1198 Ok(#call_expr.await)
1199 },
1200 "String" => quote! {
1201 Ok(potato::HttpResponse::html(#call_expr.await))
1202 },
1203 "& 'static str" => quote! {
1204 Ok(potato::HttpResponse::html(#call_expr.await))
1205 },
1206 _ => panic!("unsupported ret type: {ret_type}"),
1207 }
1208 } else {
1209 match &ret_type[..] {
1210 "Result<()>" => quote! {
1211 match #call_expr {
1212 Ok(_) => Ok(potato::HttpResponse::text("ok")),
1213 Err(err) => Err(err),
1214 }
1215 },
1216 "Result<HttpResponse>" | "anyhow::Result<HttpResponse>" => quote! {
1217 match #call_expr {
1218 Ok(ret) => Ok(ret),
1219 Err(err) => Err(err),
1220 }
1221 },
1222 "Result<String>" | "anyhow::Result<String>" => quote! {
1223 match #call_expr {
1224 Ok(ret) => Ok(potato::HttpResponse::html(ret)),
1225 Err(err) => Err(err),
1226 }
1227 },
1228 "Result<& 'static str>" | "anyhow::Result<& 'static str>" => quote! {
1229 match #call_expr {
1230 Ok(ret) => Ok(potato::HttpResponse::html(ret)),
1231 Err(err) => Err(err),
1232 }
1233 },
1234 "()" => quote! {
1235 #call_expr;
1236 Ok(potato::HttpResponse::text("ok"))
1237 },
1238 "HttpResponse" => quote! {
1239 Ok(#call_expr)
1240 },
1241 "String" => quote! {
1242 Ok(potato::HttpResponse::html(#call_expr))
1243 },
1244 "& 'static str" => quote! {
1245 Ok(potato::HttpResponse::html(#call_expr))
1246 },
1247 _ => panic!("unsupported ret type: {ret_type}"),
1248 }
1249 };
1250 let doc_args = serde_json::to_string(&doc_args).unwrap();
1251
1252 let add_headers_code = if all_headers.is_empty() {
1254 quote! {}
1255 } else {
1256 let header_statements = all_headers.iter().map(|(key, value)| {
1257 let http_key = key.replace("_", "-");
1259 quote! {
1260 __potato_response.add_header(
1261 std::borrow::Cow::Borrowed(#http_key),
1262 std::borrow::Cow::Borrowed(#value)
1263 );
1264 }
1265 });
1266 quote! {
1267 #(#header_statements)*
1268 }
1269 };
1270
1271 let cors_headers_code = if let Some(cors) = &cors_config {
1273 let mut statements = vec![];
1274
1275 let origin_val = cors.origin.as_deref().unwrap_or("*");
1277 statements.push(quote! {
1278 __potato_response.add_header(
1279 "Access-Control-Allow-Origin".into(),
1280 #origin_val.into()
1281 );
1282 });
1283
1284 if let Some(ref methods) = cors.methods {
1286 let mut methods_list: Vec<&str> = methods.split(',').map(|s| s.trim()).collect();
1287 if !methods_list.contains(&"HEAD") {
1288 methods_list.push("HEAD");
1289 }
1290 if !methods_list.contains(&"OPTIONS") {
1291 methods_list.push("OPTIONS");
1292 }
1293 let methods_str = methods_list.join(",");
1294 statements.push(quote! {
1295 __potato_response.add_header(
1296 "Access-Control-Allow-Methods".into(),
1297 #methods_str.into()
1298 );
1299 });
1300 }
1301
1302 let headers_val = cors.headers.as_deref().unwrap_or("*");
1304 statements.push(quote! {
1305 __potato_response.add_header(
1306 "Access-Control-Allow-Headers".into(),
1307 #headers_val.into()
1308 );
1309 });
1310
1311 if let Some(ref max_age) = cors.max_age {
1313 statements.push(quote! {
1314 __potato_response.add_header(
1315 "Access-Control-Max-Age".into(),
1316 #max_age.into()
1317 );
1318 });
1319 } else {
1320 statements.push(quote! {
1321 __potato_response.add_header(
1322 "Access-Control-Max-Age".into(),
1323 "86400".into()
1324 );
1325 });
1326 }
1327
1328 if cors.credentials {
1329 statements.push(quote! {
1330 __potato_response.add_header(
1331 "Access-Control-Allow-Credentials".into(),
1332 "true".into()
1333 );
1334 });
1335 }
1336
1337 if let Some(ref expose_headers) = cors.expose_headers {
1338 statements.push(quote! {
1339 __potato_response.add_header(
1340 "Access-Control-Expose-Headers".into(),
1341 #expose_headers.into()
1342 );
1343 });
1344 }
1345
1346 quote! { #(#statements)* }
1347 } else {
1348 quote! {}
1349 };
1350
1351 let auto_head_handler = if cors_config.is_some()
1353 && (req_name == "POST" || req_name == "PUT" || req_name == "DELETE")
1354 {
1355 let head_wrap_name = format_ident!("__potato_cors_head_{}", fn_name);
1356 Some(quote! {
1357 #[doc(hidden)]
1358 fn #head_wrap_name(req: &mut potato::HttpRequest) -> potato::HttpResponse {
1359 potato::HttpResponse::html("")
1362 }
1363 })
1364 } else {
1365 None
1366 };
1367
1368 let semaphore_static = if let Some(max_conn) = max_concurrency {
1370 let semaphore_name =
1371 format_ident!("__POTATO_SEMAPHORE_{}", fn_name.to_string().to_uppercase());
1372 Some(quote! {
1373 #[doc(hidden)]
1374 #[allow(non_upper_case_globals)]
1375 static #semaphore_name: std::sync::LazyLock<tokio::sync::Semaphore> =
1376 std::sync::LazyLock::new(|| tokio::sync::Semaphore::new(#max_conn));
1377 })
1378 } else {
1379 None
1380 };
1381
1382 let wrap_func_body = if is_async {
1383 if max_concurrency.is_some() {
1384 let semaphore_name =
1385 format_ident!("__POTATO_SEMAPHORE_{}", fn_name.to_string().to_uppercase());
1386 quote! {
1387 let __potato_permit = #semaphore_name.acquire().await;
1388
1389 let __potato_error_handler: Option<potato::ErrorHandler> = {
1391 let mut handler = None;
1392 for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1393 handler = Some(flag.handler.clone());
1394 break;
1395 }
1396 handler
1397 };
1398
1399 let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1401 Some(potato::OnceCache::new())
1402 } else {
1403 None
1404 };
1405 let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1406 if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1408 let header_value = h.as_str();
1409 if header_value.starts_with("Bearer ") {
1410 potato::SessionCache::from_token(&header_value[7..]).await.ok()
1411 } else {
1412 None
1413 }
1414 } else {
1415 None
1416 }
1417 } else {
1418 None
1419 };
1420
1421 if #need_session_cache && __potato_session_cache.is_none() {
1423 let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1424 __potato_resp.http_code = 401;
1425 return __potato_resp;
1426 }
1427
1428 if let Some(ref mut session_cache) = __potato_session_cache {
1430 if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1431 session_cache.parse_request_cookies(cookie_header.as_str());
1432 }
1433 }
1434
1435 let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1436 #(
1437 if __potato_pre_response.is_none() {
1438 __potato_pre_response = match #preprocess_adapters(
1439 req,
1440 __potato_once_cache.as_mut(),
1441 __potato_session_cache.as_mut(),
1442 ).await {
1443 Ok(Some(ret)) => Some(ret),
1444 Ok(None) => None,
1445 Err(err) => {
1446 let handler = &__potato_error_handler;
1447 Some(match handler {
1448 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1449 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1450 None => potato::HttpResponse::error(format!("{err:?}")),
1451 })
1452 }
1453 };
1454 }
1455 )*
1456
1457 let mut __potato_response = match __potato_pre_response {
1458 Some(ret) => ret,
1459 None => match #handler_wrap_func_body {
1460 Ok(resp) => resp,
1461 Err(err) => {
1462 let handler = &__potato_error_handler;
1463 match handler {
1464 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1465 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1466 None => potato::HttpResponse::error(format!("{err:?}")),
1467 }
1468 }
1469 },
1470 };
1471
1472 #(
1473 if let Err(err) = #postprocess_adapters(
1474 req,
1475 &mut __potato_response,
1476 __potato_once_cache.as_mut(),
1477 __potato_session_cache.as_mut(),
1478 ).await {
1479 drop(__potato_permit);
1480 let handler = &__potato_error_handler;
1481 return match handler {
1482 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1483 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1484 None => potato::HttpResponse::error(format!("{err:?}")),
1485 };
1486 }
1487 )*
1488
1489 #add_headers_code
1490 #cors_headers_code
1491
1492 if let Some(ref session_cache) = __potato_session_cache {
1494 session_cache.apply_cookies(&mut __potato_response);
1495 }
1496
1497 drop(__potato_permit);
1498 __potato_response
1499 }
1500 } else {
1501 quote! {
1502 let __potato_error_handler: Option<potato::ErrorHandler> = {
1504 let mut handler = None;
1505 for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1506 handler = Some(flag.handler.clone());
1507 break;
1508 }
1509 handler
1510 };
1511
1512 let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1514 Some(potato::OnceCache::new())
1515 } else {
1516 None
1517 };
1518 let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1519 if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1521 let header_value = h.as_str();
1522 if header_value.starts_with("Bearer ") {
1523 potato::SessionCache::from_token(&header_value[7..]).await.ok()
1524 } else {
1525 None
1526 }
1527 } else {
1528 None
1529 }
1530 } else {
1531 None
1532 };
1533
1534 if #need_session_cache && __potato_session_cache.is_none() {
1536 let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1537 __potato_resp.http_code = 401;
1538 return __potato_resp;
1539 }
1540
1541 if let Some(ref mut session_cache) = __potato_session_cache {
1543 if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1544 session_cache.parse_request_cookies(cookie_header.as_str());
1545 }
1546 }
1547
1548 let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1549 #(
1550 if __potato_pre_response.is_none() {
1551 __potato_pre_response = match #preprocess_adapters(
1552 req,
1553 __potato_once_cache.as_mut(),
1554 __potato_session_cache.as_mut(),
1555 ).await {
1556 Ok(Some(ret)) => Some(ret),
1557 Ok(None) => None,
1558 Err(err) => {
1559 let handler = &__potato_error_handler;
1560 Some(match handler {
1561 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1562 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1563 None => potato::HttpResponse::error(format!("{err:?}")),
1564 })
1565 }
1566 };
1567 }
1568 )*
1569
1570 let mut __potato_response = match __potato_pre_response {
1571 Some(ret) => ret,
1572 None => match #handler_wrap_func_body {
1573 Ok(resp) => resp,
1574 Err(err) => {
1575 let handler = &__potato_error_handler;
1576 match handler {
1577 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1578 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1579 None => potato::HttpResponse::error(format!("{err:?}")),
1580 }
1581 }
1582 },
1583 };
1584
1585 #(
1586 if let Err(err) = #postprocess_adapters(
1587 req,
1588 &mut __potato_response,
1589 __potato_once_cache.as_mut(),
1590 __potato_session_cache.as_mut(),
1591 ).await {
1592 let handler = &__potato_error_handler;
1593 return match handler {
1594 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1595 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1596 None => potato::HttpResponse::error(format!("{err:?}")),
1597 };
1598 }
1599 )*
1600
1601 #add_headers_code
1602 #cors_headers_code
1603
1604 __potato_response
1605 }
1606 }
1607 } else {
1608 if max_concurrency.is_some() {
1609 let semaphore_name =
1610 format_ident!("__POTATO_SEMAPHORE_{}", fn_name.to_string().to_uppercase());
1611 quote! {
1612 let __potato_permit = #semaphore_name.acquire().await;
1613
1614 let __potato_error_handler: Option<potato::ErrorHandler> = {
1616 let mut handler = None;
1617 for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1618 handler = Some(flag.handler.clone());
1619 break;
1620 }
1621 handler
1622 };
1623
1624 let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1626 Some(potato::OnceCache::new())
1627 } else {
1628 None
1629 };
1630 let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1631 if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1633 let header_value = h.as_str();
1634 if header_value.starts_with("Bearer ") {
1635 potato::SessionCache::from_token(&header_value[7..]).await.ok()
1636 } else {
1637 None
1638 }
1639 } else {
1640 None
1641 }
1642 } else {
1643 None
1644 };
1645
1646 if #need_session_cache && __potato_session_cache.is_none() {
1648 let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1649 __potato_resp.http_code = 401;
1650 return __potato_resp;
1651 }
1652
1653 if let Some(ref mut session_cache) = __potato_session_cache {
1655 if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1656 session_cache.parse_request_cookies(cookie_header.as_str());
1657 }
1658 }
1659
1660 let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1661 #(
1662 if __potato_pre_response.is_none() {
1663 __potato_pre_response = match #preprocess_adapters(
1664 req,
1665 __potato_once_cache.as_mut(),
1666 __potato_session_cache.as_mut(),
1667 ).await {
1668 Ok(Some(ret)) => Some(ret),
1669 Ok(None) => None,
1670 Err(err) => {
1671 let handler = &__potato_error_handler;
1672 Some(match handler {
1673 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1674 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1675 None => potato::HttpResponse::error(format!("{err:?}")),
1676 })
1677 }
1678 };
1679 }
1680 )*
1681
1682 let mut __potato_response = match __potato_pre_response {
1683 Some(ret) => ret,
1684 None => match #handler_wrap_func_body {
1685 Ok(resp) => resp,
1686 Err(err) => {
1687 let handler = &__potato_error_handler;
1688 match handler {
1689 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1690 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1691 None => potato::HttpResponse::error(format!("{err:?}")),
1692 }
1693 }
1694 },
1695 };
1696
1697 #(
1698 if let Err(err) = #postprocess_adapters(
1699 req,
1700 &mut __potato_response,
1701 __potato_once_cache.as_mut(),
1702 __potato_session_cache.as_mut(),
1703 ).await {
1704 drop(__potato_permit);
1705 let handler = &__potato_error_handler;
1706 return match handler {
1707 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1708 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1709 None => potato::HttpResponse::error(format!("{err:?}")),
1710 };
1711 }
1712 )*
1713
1714 #add_headers_code
1715 #cors_headers_code
1716
1717 if let Some(ref session_cache) = __potato_session_cache {
1719 session_cache.apply_cookies(&mut __potato_response);
1720 }
1721
1722 drop(__potato_permit);
1723 __potato_response
1724 }
1725 } else {
1726 quote! {
1727 let __potato_error_handler: Option<potato::ErrorHandler> = {
1729 let mut handler = None;
1730 for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1731 handler = Some(flag.handler.clone());
1732 break;
1733 }
1734 handler
1735 };
1736
1737 let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1739 Some(potato::OnceCache::new())
1740 } else {
1741 None
1742 };
1743 let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1744 if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1746 let header_value = h.as_str();
1747 if header_value.starts_with("Bearer ") {
1748 potato::SessionCache::from_token(&header_value[7..]).await.ok()
1749 } else {
1750 None
1751 }
1752 } else {
1753 None
1754 }
1755 } else {
1756 None
1757 };
1758
1759 if #need_session_cache && __potato_session_cache.is_none() {
1761 let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1762 __potato_resp.http_code = 401;
1763 return __potato_resp;
1764 }
1765
1766 if let Some(ref mut session_cache) = __potato_session_cache {
1768 if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1769 session_cache.parse_request_cookies(cookie_header.as_str());
1770 }
1771 }
1772
1773 let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1774 #(
1775 if __potato_pre_response.is_none() {
1776 __potato_pre_response = match #preprocess_adapters(
1777 req,
1778 __potato_once_cache.as_mut(),
1779 __potato_session_cache.as_mut(),
1780 ).await {
1781 Ok(Some(ret)) => Some(ret),
1782 Ok(None) => None,
1783 Err(err) => {
1784 let handler = &__potato_error_handler;
1785 Some(match handler {
1786 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1787 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1788 None => potato::HttpResponse::error(format!("{err:?}")),
1789 })
1790 }
1791 };
1792 }
1793 )*
1794
1795 let mut __potato_response = match __potato_pre_response {
1796 Some(ret) => ret,
1797 None => match #handler_wrap_func_body {
1798 Ok(resp) => resp,
1799 Err(err) => {
1800 let handler = &__potato_error_handler;
1801 match handler {
1802 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1803 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1804 None => potato::HttpResponse::error(format!("{err:?}")),
1805 }
1806 }
1807 },
1808 };
1809
1810 #(
1811 if let Err(err) = #postprocess_adapters(
1812 req,
1813 &mut __potato_response,
1814 __potato_once_cache.as_mut(),
1815 __potato_session_cache.as_mut(),
1816 ).await {
1817 let handler = &__potato_error_handler;
1818 return match handler {
1819 Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1820 Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1821 None => potato::HttpResponse::error(format!("{err:?}")),
1822 };
1823 }
1824 )*
1825
1826 #add_headers_code
1827 #cors_headers_code
1828
1829 if let Some(ref session_cache) = __potato_session_cache {
1831 session_cache.apply_cookies(&mut __potato_response);
1832 }
1833
1834 __potato_response
1835 }
1836 }
1837 };
1838
1839 if is_async {
1840 quote! {
1841 #root_fn
1842
1843 #auto_head_handler
1844
1845 #semaphore_static
1846
1847 #[doc(hidden)]
1848 async fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
1849 #wrap_func_body
1850 }
1851
1852 #[doc(hidden)]
1853 fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
1854 Box::pin(#wrap_func_name2(req))
1855 }
1856
1857 potato::inventory::submit!{potato::RequestHandlerFlag::new(
1858 potato::HttpMethod::#req_name,
1859 #final_path_expr,
1860 potato::HttpHandler::Async(#wrap_func_name),
1861 potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args, #tag_expr)
1862 )}
1863 }
1864 .into()
1865 } else {
1866 quote! {
1867 #root_fn
1868
1869 #auto_head_handler
1870
1871 #semaphore_static
1872
1873 #[doc(hidden)]
1874 async fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
1875 #wrap_func_body
1876 }
1877
1878 #[doc(hidden)]
1879 fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
1880 Box::pin(#wrap_func_name2(req))
1881 }
1882
1883 potato::inventory::submit!{potato::RequestHandlerFlag::new(
1884 potato::HttpMethod::#req_name,
1885 #final_path_expr,
1886 potato::HttpHandler::Async(#wrap_func_name),
1887 potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args, #tag_expr)
1888 )}
1889 }
1890 .into()
1891 }
1892 }
1896
1897#[proc_macro_attribute]
1898pub fn http_get(attr: TokenStream, input: TokenStream) -> TokenStream {
1899 http_handler_macro(attr, input, "GET")
1900}
1901
1902#[proc_macro_attribute]
1903pub fn http_post(attr: TokenStream, input: TokenStream) -> TokenStream {
1904 http_handler_macro(attr, input, "POST")
1905}
1906
1907#[proc_macro_attribute]
1908pub fn http_put(attr: TokenStream, input: TokenStream) -> TokenStream {
1909 http_handler_macro(attr, input, "PUT")
1910}
1911
1912#[proc_macro_attribute]
1913pub fn http_delete(attr: TokenStream, input: TokenStream) -> TokenStream {
1914 http_handler_macro(attr, input, "DELETE")
1915}
1916
1917#[proc_macro_attribute]
1918pub fn http_options(attr: TokenStream, input: TokenStream) -> TokenStream {
1919 http_handler_macro(attr, input, "OPTIONS")
1920}
1921
1922#[proc_macro_attribute]
1923pub fn http_head(attr: TokenStream, input: TokenStream) -> TokenStream {
1924 http_handler_macro(attr, input, "HEAD")
1925}
1926
1927#[proc_macro_attribute]
1961pub fn controller(attr: TokenStream, input: TokenStream) -> TokenStream {
1962 controller_macro(attr, input)
1963}
1964
1965fn controller_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
1966 let input_clone = input.clone();
1968 if let Ok(item_impl) = syn::parse::<syn::ItemImpl>(input_clone) {
1969 return controller_impl_macro(attr, item_impl);
1971 }
1972
1973 let item_struct = syn::parse_macro_input!(input as syn::ItemStruct);
1975
1976 let base_path = if attr.is_empty() {
1978 quote! {}
1980 } else {
1981 let attr_str = attr.to_string();
1982 let base_path = attr_str.trim_matches('"').to_string();
1983 quote! {
1984 #[doc(hidden)]
1985 const __POTATO_CONTROLLER_BASE_PATH: &str = #base_path;
1986 }
1987 };
1988
1989 let (has_once_cache, has_session_cache) = validate_controller_struct(&item_struct);
1991 let struct_name = &item_struct.ident;
1992 let struct_name_str = struct_name.to_string();
1993
1994 let controller_creation_fn = if has_session_cache {
1997 quote! {
2000 #[doc(hidden)]
2001 #[allow(dead_code)]
2002 async fn __potato_create_controller(req: &potato::HttpRequest) -> Result<Box<Self>, potato::HttpResponse> {
2003 let once_cache = Box::leak(Box::new(potato::OnceCache::new()));
2005
2006 let session_cache = {
2008 if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
2009 let header_value = h.as_str();
2010 if header_value.starts_with("Bearer ") {
2011 potato::SessionCache::from_token(&header_value[7..]).await.ok()
2012 } else {
2013 None
2014 }
2015 } else {
2016 None
2017 }
2018 };
2019
2020 let session_cache = match session_cache {
2021 Some(cache) => cache,
2022 None => {
2023 let mut resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
2024 resp.http_code = 401;
2025 return Err(resp);
2026 }
2027 };
2028 let session_cache = Box::leak(Box::new(session_cache));
2029
2030 let controller = Self {
2032 once_cache,
2033 sess_cache: session_cache,
2034 };
2035
2036 Ok(Box::new(controller))
2037 }
2038 }
2039 } else {
2040 quote! {
2043 #[doc(hidden)]
2044 #[allow(dead_code)]
2045 async fn __potato_create_controller(_req: &potato::HttpRequest) -> Result<Box<Self>, potato::HttpResponse> {
2046 let once_cache = Box::leak(Box::new(potato::OnceCache::new()));
2048
2049 let _temp_session_cache = Box::leak(Box::new(potato::SessionCache::new()));
2051
2052 let controller = Self {
2054 once_cache,
2055 };
2056
2057 Ok(Box::new(controller))
2058 }
2059 }
2060 };
2061
2062 let struct_generics = &item_struct.generics;
2064 let (impl_generics, type_generics, where_clause) = struct_generics.split_for_impl();
2065
2066 let output = quote! {
2067 #item_struct
2068
2069 #base_path
2070
2071 #[doc(hidden)]
2072 const __POTATO_CONTROLLER_NAME: &str = #struct_name_str;
2073
2074 potato::inventory::submit! {
2076 potato::ControllerStructFlag::new(
2077 #struct_name_str,
2078 potato::ControllerStructFieldInfo {
2079 has_once_cache: #has_once_cache,
2080 has_session_cache: #has_session_cache,
2081 }
2082 )
2083 }
2084
2085 impl #impl_generics #struct_name #type_generics #where_clause {
2087 #controller_creation_fn
2088 }
2089 };
2090
2091 output.into()
2092}
2093
2094fn controller_impl_macro(attr: TokenStream, item_impl: syn::ItemImpl) -> TokenStream {
2096 let base_path_str = if attr.is_empty() {
2098 None
2101 } else {
2102 let attr_str = attr.to_string();
2103 Some(attr_str.trim_matches('"').to_string())
2104 };
2105
2106 let self_type = &item_impl.self_ty;
2108
2109 let self_type_name = match &*item_impl.self_ty {
2111 syn::Type::Path(type_path) => {
2112 if let Some(segment) = type_path.path.segments.last() {
2114 let ident = &segment.ident;
2115 quote! { #ident }
2116 } else {
2117 quote! { #self_type }
2118 }
2119 }
2120 _ => quote! { #self_type },
2121 };
2122
2123 let self_type_tag = match &*item_impl.self_ty {
2125 syn::Type::Path(type_path) => {
2126 if let Some(segment) = type_path.path.segments.last() {
2128 segment.ident.to_string()
2129 } else {
2130 self_type.to_token_stream().to_string()
2131 }
2132 }
2133 _ => self_type.to_token_stream().to_string(),
2134 };
2135
2136 let mut cleaned_items = Vec::new();
2138 let mut generated_code = Vec::new();
2139
2140 for item in &item_impl.items {
2141 if let syn::ImplItem::Fn(method) = item {
2142 let has_http_attr = method.attrs.iter().any(|attr| {
2144 let attr_name = attr.path().to_token_stream().to_string();
2145 attr_name.contains("http_get")
2146 || attr_name.contains("http_post")
2147 || attr_name.contains("http_put")
2148 || attr_name.contains("http_delete")
2149 || attr_name.contains("http_head")
2150 || attr_name.contains("http_patch")
2151 || attr_name.contains("http_options")
2152 });
2153
2154 if has_http_attr {
2155 let mut cleaned_method = method.clone();
2157 cleaned_method.attrs = method
2158 .attrs
2159 .iter()
2160 .filter(|attr| {
2161 let attr_name = attr.path().to_token_stream().to_string();
2162 !attr_name.contains("http_get")
2163 && !attr_name.contains("http_post")
2164 && !attr_name.contains("http_put")
2165 && !attr_name.contains("http_delete")
2166 && !attr_name.contains("http_head")
2167 && !attr_name.contains("http_patch")
2168 && !attr_name.contains("http_options")
2169 })
2170 .cloned()
2171 .collect();
2172
2173 cleaned_items.push(syn::ImplItem::Fn(cleaned_method));
2174
2175 for attr in &method.attrs {
2177 let attr_name = attr.path().to_token_stream().to_string();
2178 if attr_name.contains("http_get")
2179 || attr_name.contains("http_post")
2180 || attr_name.contains("http_put")
2181 || attr_name.contains("http_delete")
2182 || attr_name.contains("http_head")
2183 || attr_name.contains("http_patch")
2184 || attr_name.contains("http_options")
2185 {
2186 let http_method = if attr_name.contains("http_get") {
2188 "GET"
2189 } else if attr_name.contains("http_post") {
2190 "POST"
2191 } else if attr_name.contains("http_put") {
2192 "PUT"
2193 } else if attr_name.contains("http_delete") {
2194 "DELETE"
2195 } else if attr_name.contains("http_head") {
2196 "HEAD"
2197 } else if attr_name.contains("http_patch") {
2198 "PATCH"
2199 } else {
2200 "OPTIONS"
2201 };
2202
2203 let method_path = match &attr.meta {
2205 syn::Meta::List(list) => {
2206 if let Ok(lit_str) =
2207 syn::parse::<syn::LitStr>(list.tokens.clone().into())
2208 {
2209 lit_str.value()
2210 } else {
2211 String::new()
2212 }
2213 }
2214 _ => String::new(),
2215 };
2216
2217 let final_path = if let Some(ref base_path) = base_path_str {
2219 if method_path.is_empty() {
2221 base_path.clone()
2222 } else {
2223 format!("{}{}", base_path, method_path)
2224 }
2225 } else {
2226 panic!("impl block controller must specify a base path, e.g., #[potato::controller(\"/api/users\")]");
2228 };
2229
2230 let final_path_lit =
2232 syn::LitStr::new(&final_path, proc_macro2::Span::call_site());
2233
2234 let fn_name = &method.sig.ident;
2236 let wrapper_fn_name = quote::format_ident!("__potato_ctrl_{}", fn_name);
2237 let is_async = method.sig.asyncness.is_some();
2238
2239 let (has_receiver, _is_mut_receiver) = method
2241 .sig
2242 .inputs
2243 .iter()
2244 .filter_map(|arg| {
2245 if let syn::FnArg::Receiver(recv) = arg {
2246 Some((true, recv.mutability.is_some()))
2247 } else {
2248 None
2249 }
2250 })
2251 .next()
2252 .unwrap_or((false, false));
2253
2254 let other_params: Vec<_> = method
2256 .sig
2257 .inputs
2258 .iter()
2259 .filter_map(|arg| {
2260 if let syn::FnArg::Typed(pat_type) = arg {
2261 Some(pat_type.clone())
2262 } else {
2263 None
2264 }
2265 })
2266 .collect();
2267
2268 let method_has_session_cache = other_params.iter().any(|pat_type| {
2270 pat_type.ty.to_token_stream().to_string().type_simplify()
2271 == "& mut SessionCache"
2272 });
2273
2274 let doc_auth = has_receiver || method_has_session_cache;
2276
2277 let param_names: Vec<_> = other_params
2279 .iter()
2280 .filter_map(|pat_type| {
2281 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
2282 Some(pat_ident.ident.clone())
2283 } else {
2284 None
2285 }
2286 })
2287 .collect();
2288
2289 let method_call = if has_receiver {
2291 if param_names.is_empty() {
2296 if is_async {
2297 quote! {
2298 {
2299 let mut controller = match #self_type_name::__potato_create_controller(req).await {
2300 Ok(boxed) => boxed,
2301 Err(resp) => return resp,
2302 };
2303 controller.#fn_name().await
2304 }
2305 }
2306 } else {
2307 quote! {
2308 {
2309 let mut controller = match #self_type_name::__potato_create_controller(req).await {
2310 Ok(boxed) => boxed,
2311 Err(resp) => return resp,
2312 };
2313 controller.#fn_name()
2314 }
2315 }
2316 }
2317 } else {
2318 if is_async {
2319 quote! {
2320 {
2321 let mut controller = match #self_type_name::__potato_create_controller(req).await {
2322 Ok(boxed) => boxed,
2323 Err(resp) => return resp,
2324 };
2325 controller.#fn_name(#(#param_names),*).await
2326 }
2327 }
2328 } else {
2329 quote! {
2330 {
2331 let mut controller = match #self_type_name::__potato_create_controller(req).await {
2332 Ok(boxed) => boxed,
2333 Err(resp) => return resp,
2334 };
2335 controller.#fn_name(#(#param_names),*)
2336 }
2337 }
2338 }
2339 }
2340 } else {
2341 if param_names.is_empty() {
2345 if is_async {
2346 quote! { #self_type_name::#fn_name().await }
2347 } else {
2348 quote! { #self_type_name::#fn_name() }
2349 }
2350 } else {
2351 let mut param_bindings = Vec::new();
2354 for (i, param) in other_params.iter().enumerate() {
2355 let param_type_str =
2356 param.ty.to_token_stream().to_string().type_simplify();
2357 let param_name = ¶m_names[i];
2358
2359 match ¶m_type_str[..] {
2360 "& mut OnceCache" => {
2361 param_bindings.push(quote! {
2362 let #param_name = &mut __potato_once_cache;
2363 });
2364 }
2365 "& mut SessionCache" => {
2366 param_bindings.push(quote! {
2367 let #param_name = &mut __potato_session_cache;
2368 });
2369 }
2370 _ => {
2371 }
2373 }
2374 }
2375
2376 let needs_session_cache = other_params.iter().any(|p| {
2378 p.ty.to_token_stream().to_string().type_simplify()
2379 == "& mut SessionCache"
2380 });
2381
2382 let session_cache_init = if needs_session_cache {
2383 quote! {
2384 {
2385 if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
2387 let header_value = h.as_str();
2388 if header_value.starts_with("Bearer ") {
2389 potato::SessionCache::from_token(&header_value[7..]).await.ok()
2390 } else {
2391 None
2392 }
2393 } else {
2394 None
2395 }
2396 }
2397 }
2398 } else {
2399 quote! { None }
2400 };
2401
2402 if is_async {
2403 quote! {
2404 {
2405 let mut __potato_once_cache = potato::OnceCache::new();
2406 let mut __potato_session_cache = #session_cache_init.unwrap_or_else(|| potato::SessionCache::new());
2407 #(#param_bindings)*
2408 #self_type_name::#fn_name(#(#param_names),*).await
2409 }
2410 }
2411 } else {
2412 quote! {
2413 {
2414 let mut __potato_once_cache = potato::OnceCache::new();
2415 let mut __potato_session_cache = #session_cache_init.unwrap_or_else(|| potato::SessionCache::new());
2416 #(#param_bindings)*
2417 #self_type_name::#fn_name(#(#param_names),*)
2418 }
2419 }
2420 }
2421 }
2422 };
2423
2424 let wrapper_fn = if is_async {
2426 quote! {
2427 #[doc(hidden)]
2428 fn #wrapper_fn_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
2429 Box::pin(async move {
2430 match #method_call {
2431 Ok(resp) => potato::HttpResponse::text(resp.to_string()),
2432 Err(err) => potato::HttpResponse::error(err.to_string()),
2433 }
2434 })
2435 }
2436 }
2437 } else {
2438 quote! {
2439 #[doc(hidden)]
2440 fn #wrapper_fn_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
2441 Box::pin(async move {
2442 match #method_call {
2443 Ok(resp) => potato::HttpResponse::text(resp.to_string()),
2444 Err(err) => potato::HttpResponse::error(err.to_string()),
2445 }
2446 })
2447 }
2448 }
2449 };
2450
2451 generated_code.push(wrapper_fn);
2452
2453 let http_method_ident = quote::format_ident!("{}", http_method);
2455
2456 let route_register = quote! {
2457 potato::inventory::submit! {
2458 potato::RequestHandlerFlag::new(
2459 potato::HttpMethod::#http_method_ident,
2460 #final_path_lit,
2461 potato::HttpHandler::Async(#wrapper_fn_name),
2462 potato::RequestHandlerFlagDoc::new(true, #doc_auth, "", "", "", #self_type_tag)
2463 )
2464 }
2465 };
2466
2467 generated_code.push(route_register);
2468 }
2469 }
2470 } else {
2471 cleaned_items.push(item.clone());
2473 }
2474 } else {
2475 cleaned_items.push(item.clone());
2477 }
2478 }
2479
2480 let mut cleaned_impl = item_impl.clone();
2482 cleaned_impl.items = cleaned_items;
2483
2484 let output = quote! {
2486 #cleaned_impl
2487
2488 #(#generated_code)*
2489 };
2490
2491 output.into()
2492}
2493
2494#[proc_macro_attribute]
2495pub fn preprocess(attr: TokenStream, input: TokenStream) -> TokenStream {
2496 preprocess_macro(attr, input)
2497}
2498
2499#[proc_macro_attribute]
2500pub fn postprocess(attr: TokenStream, input: TokenStream) -> TokenStream {
2501 postprocess_macro(attr, input)
2502}
2503
2504fn handle_error_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
2522 if !attr.is_empty() {
2523 return input;
2524 }
2525
2526 let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
2527 let fn_name = root_fn.sig.ident.clone();
2528 let is_async = root_fn.sig.asyncness.is_some();
2529
2530 if root_fn.sig.inputs.len() != 2 {
2532 panic!("`handle_error` function must accept exactly two arguments");
2533 }
2534
2535 let mut arg_types = vec![];
2536 for arg in root_fn.sig.inputs.iter() {
2537 match arg {
2538 syn::FnArg::Typed(arg) => {
2539 arg_types.push(arg.ty.to_token_stream().to_string().type_simplify())
2540 }
2541 _ => panic!("`handle_error` function does not support receiver argument"),
2542 }
2543 }
2544
2545 if arg_types[0] != "& mut HttpRequest" {
2546 panic!(
2547 "`handle_error` first argument must be `&mut potato::HttpRequest`, got `{}`",
2548 arg_types[0]
2549 );
2550 }
2551 if arg_types[1] != "anyhow::Error" {
2552 panic!(
2553 "`handle_error` second argument must be `anyhow::Error`, got `{}`",
2554 arg_types[1]
2555 );
2556 }
2557
2558 let ret_type = root_fn
2559 .sig
2560 .output
2561 .to_token_stream()
2562 .to_string()
2563 .type_simplify();
2564 if ret_type != "HttpResponse" {
2565 panic!(
2566 "`handle_error` return type must be `potato::HttpResponse`, got `{}`",
2567 ret_type
2568 );
2569 }
2570
2571 let wrap_name = format_ident!("__potato_error_handler_adapter_{}", fn_name);
2573 let wrap_name_inner = format_ident!("__potato_error_handler_adapter_inner_{}", fn_name);
2574
2575 let call_body = if is_async {
2577 quote! { #fn_name(req, err).await }
2578 } else {
2579 quote! { #fn_name(req, err) }
2580 };
2581
2582 quote! {
2583 #root_fn
2584
2585 #[doc(hidden)]
2586 async fn #wrap_name_inner(
2587 req: &mut potato::HttpRequest,
2588 err: anyhow::Error,
2589 ) -> potato::HttpResponse {
2590 #call_body
2591 }
2592
2593 #[doc(hidden)]
2594 pub fn #wrap_name(
2595 req: &mut potato::HttpRequest,
2596 err: anyhow::Error,
2597 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
2598 Box::pin(#wrap_name_inner(req, err))
2599 }
2600
2601 potato::inventory::submit! {
2602 potato::ErrorHandlerFlag::new(
2603 potato::ErrorHandler::Async(#wrap_name)
2604 )
2605 }
2606 }
2607 .into()
2608}
2609
2610#[proc_macro_attribute]
2611pub fn handle_error(attr: TokenStream, input: TokenStream) -> TokenStream {
2612 handle_error_macro(attr, input)
2613}
2614
2615#[proc_macro_attribute]
2638pub fn limit_size(attr: TokenStream, input: TokenStream) -> TokenStream {
2639 limit_size_macro(attr, input)
2640}
2641
2642fn limit_size_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
2643 let (_max_header, max_body) = {
2645 let attr_tokens: proc_macro2::TokenStream = attr.clone().into();
2646 if attr_tokens.is_empty() {
2647 (None, None)
2649 } else {
2650 let result = syn::parse::Parser::parse2(
2652 |input: syn::parse::ParseStream| -> syn::Result<(Option<syn::Expr>, Option<syn::Expr>)> {
2653 let mut header_expr = None;
2654 let mut body_expr = None;
2655
2656 while !input.is_empty() {
2658 let ident: Ident = input.parse()?;
2659 input.parse::<Token![=]>()?;
2660 let value: syn::Expr = input.parse()?;
2661
2662 match ident.to_string().as_str() {
2663 "header" => header_expr = Some(value),
2664 "body" => body_expr = Some(value),
2665 _ => return Err(syn::Error::new(ident.span(), "expected 'header' or 'body'")),
2666 }
2667
2668 if input.peek(Token![,]) {
2670 input.parse::<Token![,]>()?;
2671 }
2672 }
2673
2674 Ok((header_expr, body_expr))
2675 },
2676 attr_tokens.clone(),
2677 );
2678
2679 match result {
2680 Ok((h, b)) => (h, b),
2681 Err(_) => {
2682 if let Ok(expr) = syn::parse2::<syn::Expr>(attr_tokens) {
2684 (None, Some(expr))
2685 } else {
2686 (None, None)
2687 }
2688 }
2689 }
2690 }
2691 };
2692
2693 let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
2694
2695 let body_check = if let Some(body_expr) = max_body {
2697 quote! {
2698 let body_len = req.body.len();
2700 if body_len > #body_expr {
2701 let mut res = potato::HttpResponse::text(format!(
2702 "Payload Too Large: body size {} bytes exceeds limit {} bytes",
2703 body_len, #body_expr
2704 ));
2705 res.http_code = 413;
2706 return res;
2707 }
2708 }
2709 } else {
2710 quote! {}
2711 };
2712
2713 let mut wrapped_fn = root_fn.clone();
2715 let original_block = root_fn.block.as_ref();
2716 let new_block: syn::Block = syn::parse_quote!({
2717 #body_check
2718 #original_block
2719 });
2720 wrapped_fn.block = Box::new(new_block);
2721
2722 quote! {
2723 #wrapped_fn
2724 }
2725 .into()
2726}
2727
2728#[proc_macro_attribute]
2731pub fn header(_attr: TokenStream, input: TokenStream) -> TokenStream {
2732 input
2735}
2736
2737#[proc_macro_attribute]
2740pub fn cors(_attr: TokenStream, input: TokenStream) -> TokenStream {
2741 input
2744}
2745
2746#[proc_macro]
2747pub fn embed_dir(input: TokenStream) -> TokenStream {
2748 let path = syn::parse_macro_input!(input as syn::LitStr).value();
2749 quote! {{
2750 #[derive(potato::rust_embed::Embed)]
2751 #[folder = #path]
2752 struct Asset;
2753
2754 potato::load_embed::<Asset>()
2755 }}
2756 .into()
2757}
2758
2759#[proc_macro_derive(StandardHeader)]
2760pub fn standard_header_derive(input: TokenStream) -> TokenStream {
2761 let root_enum = syn::parse_macro_input!(input as syn::ItemEnum);
2762 let enum_name = root_enum.ident;
2763 let mut try_from_str_items = vec![];
2764 let mut to_str_items = vec![];
2765 let mut headers_items = vec![];
2766 let mut headers_apply_items = vec![];
2767 for root_field in root_enum.variants.iter() {
2768 let name = root_field.ident.clone();
2769 if root_field.fields.iter().next().is_some() {
2770 panic!("unsupported enum type");
2771 }
2772 let str_name = name.to_string().replace("_", "-");
2773 let len = str_name.len();
2774 try_from_str_items
2775 .push(quote! { #len if value.eq_ignore_ascii_case(#str_name) => Some(Self::#name), });
2776 to_str_items.push(quote! { Self::#name => #str_name, });
2777 headers_items.push(quote! { #name(String), });
2778 headers_apply_items
2779 .push(quote! { Headers::#name(s) => self.set_header(HeaderItem::#name.to_str(), s), });
2780 }
2781 let r = quote! {
2782 impl #enum_name {
2783 pub fn try_from_str(value: &str) -> Option<Self> {
2784 match value.len() {
2785 #( #try_from_str_items )*
2786 _ => None,
2787 }
2788 }
2789
2790 pub fn to_str(&self) -> &'static str {
2791 match self {
2792 #( #to_str_items )*
2793 }
2794 }
2795 }
2796
2797 pub enum Headers {
2798 #( #headers_items )*
2799 Custom((String, String)),
2800 }
2801
2802 impl HttpRequest {
2803 pub fn apply_header(&mut self, header: Headers) {
2804 match header {
2805 #( #headers_apply_items )*
2806 Headers::Custom((k, v)) => self.set_header(&k[..], v),
2807 }
2808 }
2809 }
2810 };
2811 r.into()
2812}