Skip to main content

potato_macro/
lib.rs

1mod utils;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span};
5use quote::{format_ident, quote, ToTokens};
6use rand::Rng;
7use serde_json::json;
8use std::{collections::HashSet, sync::LazyLock};
9use syn::Token;
10use utils::StringExt as _;
11
12/// CORS配置结构体(宏内部使用)
13struct CorsAttrConfig {
14    origin: Option<String>,
15    methods: Option<String>,
16    headers: Option<String>,
17    max_age: Option<String>,
18    credentials: bool,
19    expose_headers: Option<String>,
20}
21
22/// 解析CORS属性
23fn parse_cors_attr(tokens: &proc_macro2::TokenStream) -> CorsAttrConfig {
24    let config = CorsAttrConfig {
25        origin: None,
26        methods: None,
27        headers: None,
28        max_age: None,
29        credentials: false,
30        expose_headers: None,
31    };
32
33    if tokens.is_empty() {
34        return config; // 返回最小限制配置(origin="*", headers="*", methods自动计算)
35    }
36
37    // 解析 key = value 格式
38    use syn::parse::Parser;
39
40    fn parse_inner(input: syn::parse::ParseStream) -> syn::Result<CorsAttrConfig> {
41        let mut config = CorsAttrConfig {
42            origin: None,
43            methods: None,
44            headers: None,
45            max_age: None,
46            credentials: false,
47            expose_headers: None,
48        };
49
50        let vars =
51            syn::punctuated::Punctuated::<syn::MetaNameValue, Token![,]>::parse_terminated(input)?;
52        for meta in vars {
53            let key = meta
54                .path
55                .get_ident()
56                .map(|i| i.to_string())
57                .unwrap_or_default();
58            if let syn::Expr::Lit(expr_lit) = &meta.value {
59                match &expr_lit.lit {
60                    syn::Lit::Str(s) => {
61                        let val = s.value();
62                        match key.as_str() {
63                            "origin" => config.origin = Some(val),
64                            "methods" => config.methods = Some(val),
65                            "headers" => config.headers = Some(val),
66                            "max_age" => config.max_age = Some(val),
67                            "expose_headers" => config.expose_headers = Some(val),
68                            _ => {}
69                        }
70                    }
71                    syn::Lit::Bool(b) => {
72                        if key == "credentials" {
73                            config.credentials = b.value();
74                        }
75                    }
76                    _ => {}
77                }
78            }
79        }
80        Ok(config)
81    }
82
83    match parse_inner.parse2(tokens.clone()) {
84        Ok(cfg) => cfg,
85        Err(e) => panic!("Failed to parse cors attributes: {e}"),
86    }
87}
88
89static ARG_TYPES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
90    [
91        "String", "bool", "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
92        "f32", "f64",
93    ]
94    .into_iter()
95    .collect()
96});
97
98/// Controller 字段类型
99/// 验证 Controller 结构体字段
100fn validate_controller_struct(item_struct: &syn::ItemStruct) -> (bool, bool) {
101    let mut has_once_cache = false;
102    let mut has_session_cache = false;
103
104    if let syn::Fields::Named(fields_named) = &item_struct.fields {
105        for field in &fields_named.named {
106            let field_type_str = field.ty.to_token_stream().to_string().type_simplify();
107
108            // 验证类型必须是 &OnceCache 或 &SessionCache(支持生命周期参数)
109            if field_type_str.contains("OnceCache") {
110                has_once_cache = true;
111            } else if field_type_str.contains("SessionCache") {
112                has_session_cache = true;
113            } else {
114                panic!(
115                    "Controller field must be &OnceCache or &SessionCache, got: {}",
116                    field_type_str
117                );
118            }
119        }
120    }
121
122    (has_once_cache, has_session_cache)
123}
124
125/// 解析 header 标注的 tokens,返回 (key, value)
126fn parse_header_attr(tokens: &proc_macro2::TokenStream) -> Result<(String, String), syn::Error> {
127    use syn::parse::Parser;
128
129    let parser = |input: syn::parse::ParseStream| {
130        // 支持两种格式:
131        // 1. Key = "value" (标准 header)
132        // 2. Custom("key") = "value" (自定义 header)
133        let key_ident: Ident = input.parse()?;
134        let key_name = key_ident.to_string();
135
136        if key_name == "Custom" {
137            // Custom("key") = "value" 格式
138            let content;
139            syn::parenthesized!(content in input);
140            let key_lit: syn::LitStr = content.parse()?;
141            let key = key_lit.value();
142            let _: Token![=] = input.parse()?;
143            let value: syn::LitStr = input.parse()?;
144            Ok((key, value.value()))
145        } else {
146            // Key = "value" 格式
147            let _: Token![=] = input.parse()?;
148            let value: syn::LitStr = input.parse()?;
149            Ok((key_name, value.value()))
150        }
151    };
152
153    parser.parse2(tokens.clone())
154}
155
156fn random_ident() -> Ident {
157    let mut rng = rand::thread_rng();
158    let value = format!("__potato_id_{}", rng.r#gen::<u64>());
159    Ident::new(&value, Span::call_site())
160}
161
162fn attr_last_ident(attr: &syn::Attribute) -> Option<String> {
163    attr.meta
164        .path()
165        .segments
166        .iter()
167        .last()
168        .map(|segment| segment.ident.to_string())
169}
170
171fn parse_hook_attr_items(attr: &syn::Attribute, attr_name: &str) -> Vec<Ident> {
172    let parser = syn::punctuated::Punctuated::<Ident, syn::Token![,]>::parse_terminated;
173    let idents = attr.parse_args_with(parser).unwrap_or_else(|err| {
174        panic!("invalid `{attr_name}` annotation: {err}");
175    });
176    if idents.is_empty() {
177        panic!("`{attr_name}` annotation requires at least one function name");
178    }
179    idents.into_iter().collect()
180}
181
182fn collect_handler_hooks(root_fn: &mut syn::ItemFn) -> (Vec<Ident>, Vec<Ident>) {
183    enum HookKind {
184        Pre,
185        Post,
186    }
187    let mut hooks = vec![];
188    let mut new_attrs = Vec::with_capacity(root_fn.attrs.len());
189    for attr in root_fn.attrs.iter() {
190        match attr_last_ident(attr).as_deref() {
191            Some("preprocess") => {
192                hooks.extend(
193                    parse_hook_attr_items(attr, "preprocess")
194                        .into_iter()
195                        .map(|item| (HookKind::Pre, item)),
196                );
197            }
198            Some("postprocess") => {
199                hooks.extend(
200                    parse_hook_attr_items(attr, "postprocess")
201                        .into_iter()
202                        .map(|item| (HookKind::Post, item)),
203                );
204            }
205            _ => new_attrs.push(attr.clone()),
206        }
207    }
208    root_fn.attrs = new_attrs;
209    let mut preprocess_fns = vec![];
210    let mut postprocess_fns = vec![];
211    for (kind, hook) in hooks.into_iter() {
212        match kind {
213            HookKind::Pre => preprocess_fns.push(hook),
214            HookKind::Post => postprocess_fns.push(hook),
215        }
216    }
217    (preprocess_fns, postprocess_fns)
218}
219
220fn validate_preprocess_signature(root_fn: &syn::ItemFn) -> (String, bool, bool) {
221    if root_fn.sig.inputs.is_empty() || root_fn.sig.inputs.len() > 3 {
222        panic!("`preprocess` function must accept one to three arguments");
223    }
224    let mut arg_types = vec![];
225    for arg in root_fn.sig.inputs.iter() {
226        match arg {
227            syn::FnArg::Typed(arg) => {
228                arg_types.push(arg.ty.to_token_stream().to_string().type_simplify())
229            }
230            _ => panic!("`preprocess` function does not support receiver argument"),
231        }
232    }
233    if arg_types[0] != "& mut HttpRequest" {
234        panic!(
235            "`preprocess` first argument type must be `&mut potato::HttpRequest`, got `{}`",
236            arg_types[0]
237        );
238    }
239
240    let has_once_cache = arg_types.iter().any(|t| t == "& mut OnceCache");
241    let has_session_cache = arg_types.iter().any(|t| t == "& mut SessionCache");
242
243    if arg_types.len() == 2 && !has_once_cache && !has_session_cache {
244        panic!(
245            "`preprocess` second argument type must be `&mut potato::OnceCache` or `&mut potato::SessionCache`, got `{}`",
246            arg_types[1]
247        );
248    }
249    if arg_types.len() == 3 {
250        if !has_once_cache {
251            panic!("`preprocess` must have `&mut potato::OnceCache` as one of the arguments");
252        }
253        if !has_session_cache {
254            panic!("`preprocess` must have `&mut potato::SessionCache` as one of the arguments");
255        }
256    }
257
258    let ret_type = root_fn
259        .sig
260        .output
261        .to_token_stream()
262        .to_string()
263        .type_simplify();
264    match &ret_type[..] {
265        "Result<Option<HttpResponse>>" | "Option<HttpResponse>" | "Result<()>" | "()" => {}
266        _ => panic!(
267            "unsupported `preprocess` return type: `{ret_type}`, expected `anyhow::Result<Option<potato::HttpResponse>>`, `Option<potato::HttpResponse>`, `anyhow::Result<()>`, or `()`"
268        ),
269    }
270    (ret_type, has_once_cache, has_session_cache)
271}
272
273fn validate_postprocess_signature(root_fn: &syn::ItemFn) -> (String, bool, bool) {
274    if root_fn.sig.inputs.len() < 2 && root_fn.sig.inputs.len() > 4 {
275        panic!("`postprocess` function must accept two to four arguments");
276    }
277    let mut arg_types = vec![];
278    for arg in root_fn.sig.inputs.iter() {
279        match arg {
280            syn::FnArg::Typed(arg) => {
281                arg_types.push(arg.ty.to_token_stream().to_string().type_simplify())
282            }
283            _ => panic!("`postprocess` function does not support receiver argument"),
284        }
285    }
286    if arg_types[0] != "& mut HttpRequest" {
287        panic!(
288            "`postprocess` first argument must be `&mut potato::HttpRequest`, got `{}`",
289            arg_types[0]
290        );
291    }
292    if arg_types[1] != "& mut HttpResponse" {
293        panic!(
294            "`postprocess` second argument must be `&mut potato::HttpResponse`, got `{}`",
295            arg_types[1]
296        );
297    }
298
299    let remaining_args = &arg_types[2..];
300    let has_once_cache = remaining_args.iter().any(|t| t == "& mut OnceCache");
301    let has_session_cache = remaining_args.iter().any(|t| t == "& mut SessionCache");
302
303    if arg_types.len() == 3 && !has_once_cache && !has_session_cache {
304        panic!(
305            "`postprocess` third argument must be `&mut potato::OnceCache` or `&mut potato::SessionCache`, got `{}`",
306            arg_types[2]
307        );
308    }
309    if arg_types.len() == 4 && (!has_once_cache || !has_session_cache) {
310        panic!(
311            "`postprocess` with 4 arguments must have both `&mut potato::OnceCache` and `&mut potato::SessionCache`"
312        );
313    }
314
315    let ret_type = root_fn
316        .sig
317        .output
318        .to_token_stream()
319        .to_string()
320        .type_simplify();
321    match &ret_type[..] {
322        "Result<()>" | "()" => {}
323        _ => panic!(
324            "unsupported `postprocess` return type: `{ret_type}`, expected `anyhow::Result<()>` or `()`"
325        ),
326    }
327    (ret_type, has_once_cache, has_session_cache)
328}
329
330fn preprocess_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
331    if !attr.is_empty() {
332        return input;
333    }
334    let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
335    let fn_name = root_fn.sig.ident.clone();
336    let wrap_name = format_ident!("__potato_preprocess_adapter_{}", fn_name);
337    let wrap_name_inner = format_ident!("__potato_preprocess_adapter_inner_{}", fn_name);
338    let is_async = root_fn.sig.asyncness.is_some();
339    let (ret_type, has_once_cache, has_session_cache) = validate_preprocess_signature(&root_fn);
340
341    // 根据是否需要缓存生成不同的函数签名
342    let wrap_signature = match (has_once_cache, has_session_cache) {
343        (true, true) => quote! {
344            async fn #wrap_name_inner(
345                req: &mut potato::HttpRequest,
346                once_cache: &mut potato::OnceCache,
347                session_cache: &mut potato::SessionCache,
348            ) -> anyhow::Result<Option<potato::HttpResponse>>
349        },
350        (true, false) => quote! {
351            async fn #wrap_name_inner(
352                req: &mut potato::HttpRequest,
353                once_cache: &mut potato::OnceCache,
354            ) -> anyhow::Result<Option<potato::HttpResponse>>
355        },
356        (false, true) => quote! {
357            async fn #wrap_name_inner(
358                req: &mut potato::HttpRequest,
359                session_cache: &mut potato::SessionCache,
360            ) -> anyhow::Result<Option<potato::HttpResponse>>
361        },
362        (false, false) => quote! {
363            async fn #wrap_name_inner(
364                req: &mut potato::HttpRequest,
365            ) -> anyhow::Result<Option<potato::HttpResponse>>
366        },
367    };
368
369    // 根据实际使用情况调用函数
370    let call_body = if is_async {
371        match &ret_type[..] {
372            "Result<Option<HttpResponse>>" => match (has_once_cache, has_session_cache) {
373                (true, true) => {
374                    quote! { #fn_name(req, once_cache, session_cache).await }
375                }
376                (true, false) => quote! { #fn_name(req, once_cache).await },
377                (false, true) => quote! { #fn_name(req, session_cache).await },
378                (false, false) => quote! { #fn_name(req).await },
379            },
380            "Option<HttpResponse>" => match (has_once_cache, has_session_cache) {
381                (true, true) => quote! { Ok(#fn_name(req, once_cache, session_cache).await) },
382                (true, false) => quote! { Ok(#fn_name(req, once_cache).await) },
383                (false, true) => quote! { Ok(#fn_name(req, session_cache).await) },
384                (false, false) => quote! { Ok(#fn_name(req).await) },
385            },
386            "Result<()>" => match (has_once_cache, has_session_cache) {
387                (true, true) => {
388                    quote! { #fn_name(req, once_cache, session_cache).await.map(|_| None) }
389                }
390                (true, false) => quote! { #fn_name(req, once_cache).await.map(|_| None) },
391                (false, true) => quote! { #fn_name(req, session_cache).await.map(|_| None) },
392                (false, false) => quote! { #fn_name(req).await.map(|_| None) },
393            },
394            "()" => match (has_once_cache, has_session_cache) {
395                (true, true) => quote! { #fn_name(req, once_cache, session_cache).await; Ok(None) },
396                (true, false) => quote! { #fn_name(req, once_cache).await; Ok(None) },
397                (false, true) => quote! { #fn_name(req, session_cache).await; Ok(None) },
398                (false, false) => quote! { #fn_name(req).await; Ok(None) },
399            },
400            _ => unreachable!(),
401        }
402    } else {
403        match &ret_type[..] {
404            "Result<Option<HttpResponse>>" => match (has_once_cache, has_session_cache) {
405                (true, true) => quote! { #fn_name(req, once_cache, session_cache) },
406                (true, false) => quote! { #fn_name(req, once_cache) },
407                (false, true) => quote! { #fn_name(req, session_cache) },
408                (false, false) => quote! { #fn_name(req) },
409            },
410            "Option<HttpResponse>" => match (has_once_cache, has_session_cache) {
411                (true, true) => quote! { Ok(#fn_name(req, once_cache, session_cache)) },
412                (true, false) => quote! { Ok(#fn_name(req, once_cache)) },
413                (false, true) => quote! { Ok(#fn_name(req, session_cache)) },
414                (false, false) => quote! { Ok(#fn_name(req)) },
415            },
416            "Result<()>" => match (has_once_cache, has_session_cache) {
417                (true, true) => quote! { #fn_name(req, once_cache, session_cache).map(|_| None) },
418                (true, false) => quote! { #fn_name(req, once_cache).map(|_| None) },
419                (false, true) => quote! { #fn_name(req, session_cache).map(|_| None) },
420                (false, false) => quote! { #fn_name(req).map(|_| None) },
421            },
422            "()" => match (has_once_cache, has_session_cache) {
423                (true, true) => quote! { #fn_name(req, once_cache, session_cache); Ok(None) },
424                (true, false) => quote! { #fn_name(req, once_cache); Ok(None) },
425                (false, true) => quote! { #fn_name(req, session_cache); Ok(None) },
426                (false, false) => quote! { #fn_name(req); Ok(None) },
427            },
428            _ => unreachable!(),
429        }
430    };
431
432    // 生成wrapper函数,根据cache需求调用inner函数
433    let wrapper_body = match (has_once_cache, has_session_cache) {
434        (true, true) => quote! {
435            #wrap_name_inner(
436                req,
437                once_cache.expect("OnceCache required but not provided"),
438                session_cache.expect("SessionCache required but not provided"),
439            ).await
440        },
441        (true, false) => quote! {
442            #wrap_name_inner(
443                req,
444                once_cache.expect("OnceCache required but not provided"),
445            ).await
446        },
447        (false, true) => quote! {
448            #wrap_name_inner(
449                req,
450                session_cache.expect("SessionCache required but not provided"),
451            ).await
452        },
453        (false, false) => quote! {
454            #wrap_name_inner(req).await
455        },
456    };
457
458    quote! {
459        #root_fn
460
461        #[doc(hidden)]
462        #wrap_signature {
463            #call_body
464        }
465
466        #[doc(hidden)]
467        pub async fn #wrap_name(
468            req: &mut potato::HttpRequest,
469            once_cache: Option<&mut potato::OnceCache>,
470            session_cache: Option<&mut potato::SessionCache>,
471        ) -> anyhow::Result<Option<potato::HttpResponse>> {
472            #wrapper_body
473        }
474    }
475    .into()
476}
477
478fn postprocess_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
479    if !attr.is_empty() {
480        return input;
481    }
482    let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
483    let fn_name = root_fn.sig.ident.clone();
484    let wrap_name = format_ident!("__potato_postprocess_adapter_{}", fn_name);
485    let wrap_name_inner = format_ident!("__potato_postprocess_adapter_inner_{}", fn_name);
486    let is_async = root_fn.sig.asyncness.is_some();
487    let (ret_type, has_once_cache, has_session_cache) = validate_postprocess_signature(&root_fn);
488
489    // 根据是否需要缓存生成不同的函数签名
490    let wrap_signature = match (has_once_cache, has_session_cache) {
491        (true, true) => quote! {
492            async fn #wrap_name_inner(
493                req: &mut potato::HttpRequest,
494                res: &mut potato::HttpResponse,
495                once_cache: &mut potato::OnceCache,
496                session_cache: &mut potato::SessionCache,
497            ) -> anyhow::Result<()>
498        },
499        (true, false) => quote! {
500            async fn #wrap_name_inner(
501                req: &mut potato::HttpRequest,
502                res: &mut potato::HttpResponse,
503                once_cache: &mut potato::OnceCache,
504            ) -> anyhow::Result<()>
505        },
506        (false, true) => quote! {
507            async fn #wrap_name_inner(
508                req: &mut potato::HttpRequest,
509                res: &mut potato::HttpResponse,
510                session_cache: &mut potato::SessionCache,
511            ) -> anyhow::Result<()>
512        },
513        (false, false) => quote! {
514            async fn #wrap_name_inner(
515                req: &mut potato::HttpRequest,
516                res: &mut potato::HttpResponse,
517            ) -> anyhow::Result<()>
518        },
519    };
520
521    // 根据实际使用情况调用函数
522    let call_body = if is_async {
523        match &ret_type[..] {
524            "Result<()>" => {
525                if has_once_cache && has_session_cache {
526                    quote! {
527                        #fn_name(req, res, once_cache, session_cache).await
528                    }
529                } else if has_once_cache {
530                    quote! {
531                        #fn_name(req, res, once_cache).await
532                    }
533                } else if has_session_cache {
534                    quote! {
535                        #fn_name(req, res, session_cache).await
536                    }
537                } else {
538                    quote! {
539                        #fn_name(req, res).await
540                    }
541                }
542            }
543            "()" => {
544                if has_once_cache && has_session_cache {
545                    quote! {
546                        #fn_name(req, res, once_cache, session_cache).await;
547                        Ok(())
548                    }
549                } else if has_once_cache {
550                    quote! {
551                        #fn_name(req, res, once_cache).await;
552                        Ok(())
553                    }
554                } else if has_session_cache {
555                    quote! {
556                        #fn_name(req, res, session_cache).await;
557                        Ok(())
558                    }
559                } else {
560                    quote! {
561                        #fn_name(req, res).await;
562                        Ok(())
563                    }
564                }
565            }
566            _ => unreachable!(),
567        }
568    } else {
569        match &ret_type[..] {
570            "Result<()>" => {
571                if has_once_cache && has_session_cache {
572                    quote! {
573                        #fn_name(req, res, once_cache, session_cache)
574                    }
575                } else if has_once_cache {
576                    quote! {
577                        #fn_name(req, res, once_cache)
578                    }
579                } else if has_session_cache {
580                    quote! {
581                        #fn_name(req, res, session_cache)
582                    }
583                } else {
584                    quote! {
585                        #fn_name(req, res)
586                    }
587                }
588            }
589            "()" => {
590                if has_once_cache && has_session_cache {
591                    quote! {
592                        #fn_name(req, res, once_cache, session_cache);
593                        Ok(())
594                    }
595                } else if has_once_cache {
596                    quote! {
597                        #fn_name(req, res, once_cache);
598                        Ok(())
599                    }
600                } else if has_session_cache {
601                    quote! {
602                        #fn_name(req, res, session_cache);
603                        Ok(())
604                    }
605                } else {
606                    quote! {
607                        #fn_name(req, res);
608                        Ok(())
609                    }
610                }
611            }
612            _ => unreachable!(),
613        }
614    };
615
616    // 生成wrapper函数,根据cache需求调用inner函数
617    let wrapper_body = match (has_once_cache, has_session_cache) {
618        (true, true) => quote! {
619            #wrap_name_inner(
620                req,
621                res,
622                once_cache.expect("OnceCache required but not provided"),
623                session_cache.expect("SessionCache required but not provided"),
624            ).await
625        },
626        (true, false) => quote! {
627            #wrap_name_inner(
628                req,
629                res,
630                once_cache.expect("OnceCache required but not provided"),
631            ).await
632        },
633        (false, true) => quote! {
634            #wrap_name_inner(
635                req,
636                res,
637                session_cache.expect("SessionCache required but not provided"),
638            ).await
639        },
640        (false, false) => quote! {
641            #wrap_name_inner(req, res).await
642        },
643    };
644
645    quote! {
646        #root_fn
647
648        #[doc(hidden)]
649        #wrap_signature {
650            #call_body
651        }
652
653        #[doc(hidden)]
654        pub async fn #wrap_name(
655            req: &mut potato::HttpRequest,
656            res: &mut potato::HttpResponse,
657            once_cache: Option<&mut potato::OnceCache>,
658            session_cache: Option<&mut potato::SessionCache>,
659        ) -> anyhow::Result<()> {
660            #wrapper_body
661        }
662    }
663    .into()
664}
665
666fn http_handler_macro(attr: TokenStream, input: TokenStream, req_name: &str) -> TokenStream {
667    let req_name = Ident::new(req_name, Span::call_site());
668
669    // 解析函数,检查是否有 receiver(&self / &mut self)
670    let root_fn_for_check = syn::parse::<syn::ItemFn>(input.clone());
671    let has_receiver = if let Ok(ref func) = root_fn_for_check {
672        func.sig
673            .inputs
674            .iter()
675            .any(|arg| matches!(arg, syn::FnArg::Receiver(_)))
676    } else {
677        false
678    };
679
680    let (route_path, default_headers) = {
681        let mut oroute_path = syn::parse::<syn::LitStr>(attr.clone())
682            .ok()
683            .map(|path| path.value());
684        let mut default_headers: Vec<(String, String)> = Vec::new();
685        //
686        if oroute_path.is_none() {
687            let http_parser = syn::meta::parser(|meta| {
688                if meta.path.is_ident("path") {
689                    if let Ok(arg) = meta.value() {
690                        if let Ok(route_path) = arg.parse::<syn::LitStr>() {
691                            let route_path = route_path.value();
692                            oroute_path = Some(route_path);
693                        }
694                    }
695                    Ok(())
696                } else if meta.path.is_ident("header") {
697                    // 解析 header(key = value) 格式
698                    let content;
699                    syn::parenthesized!(content in meta.input);
700                    let key: Ident = content.parse()?;
701                    let _: syn::Token![=] = content.parse()?;
702                    let value: syn::LitStr = content.parse()?;
703                    default_headers.push((key.to_string(), value.value()));
704                    Ok(())
705                } else {
706                    Err(meta.error("unsupported annotation property"))
707                }
708            });
709            syn::parse_macro_input!(attr with http_parser);
710        }
711
712        // 如果没有提供 path 且有 receiver,可能是 controller 方法
713        if oroute_path.is_none() && has_receiver {
714            // 这将在后续代码中处理,先设置为空
715        } else if oroute_path.is_none() {
716            panic!("`path` argument is required for non-controller methods");
717        }
718
719        let route_path = oroute_path.unwrap_or_default();
720
721        // 如果是 controller 方法,需要处理路径拼接
722        let route_path = if has_receiver {
723            if route_path.is_empty() {
724                // 没有指定 path,使用 controller base path(稍后在生成的代码中读取常量)
725                String::new()
726            } else {
727                // 指定了 path,需要拼接到 controller base path
728                // 这里先标记,稍后在生成的代码中处理
729                route_path
730            }
731        } else {
732            if route_path.is_empty() {
733                panic!("`path` argument is required for non-controller methods");
734            }
735            route_path
736        };
737
738        if !route_path.is_empty() && !route_path.starts_with('/') {
739            panic!("route path must start with '/'");
740        }
741        (route_path, default_headers)
742    };
743
744    // 解析函数上的 #[potato::header(...)] 标注
745    let mut root_fn = syn::parse_macro_input!(input as syn::ItemFn);
746    let mut fn_headers: Vec<(String, String)> = Vec::new();
747    let mut cors_config: Option<CorsAttrConfig> = None;
748    let mut max_concurrency: Option<usize> = None;
749    let mut remaining_attrs = Vec::new();
750
751    for attr in root_fn.attrs.iter() {
752        // 检查是否是 header 标注(支持 #[header(...)] 和 #[potato::header(...)] 两种形式)
753        let is_header_attr = attr.path().is_ident("header")
754            || (attr.path().segments.len() == 2
755                && attr
756                    .path()
757                    .segments
758                    .iter()
759                    .next()
760                    .map(|s| s.ident.to_string())
761                    == Some("potato".to_string())
762                && attr
763                    .path()
764                    .segments
765                    .iter()
766                    .last()
767                    .map(|s| s.ident.to_string())
768                    == Some("header".to_string()));
769
770        if is_header_attr {
771            if let syn::Meta::List(meta_list) = &attr.meta {
772                // 解析 header(Cache_Control = "no-store, no-cache, max-age=0") 或 header(Custom("key") = "value")
773                if let Ok((key, value)) = parse_header_attr(&meta_list.tokens) {
774                    fn_headers.push((key, value));
775                }
776            }
777            continue;
778        }
779
780        // 检查是否是 cors 标注
781        let is_cors_attr = attr.path().is_ident("cors")
782            || (attr.path().segments.len() == 2
783                && attr
784                    .path()
785                    .segments
786                    .iter()
787                    .next()
788                    .map(|s| s.ident.to_string())
789                    == Some("potato".to_string())
790                && attr
791                    .path()
792                    .segments
793                    .iter()
794                    .last()
795                    .map(|s| s.ident.to_string())
796                    == Some("cors".to_string()));
797
798        if is_cors_attr {
799            if let syn::Meta::List(meta_list) = &attr.meta {
800                cors_config = Some(parse_cors_attr(&meta_list.tokens));
801            } else {
802                // 无参数时使用最小限制配置
803                cors_config = Some(CorsAttrConfig {
804                    origin: None,
805                    methods: None,
806                    headers: None,
807                    max_age: None,
808                    credentials: false,
809                    expose_headers: None,
810                });
811            }
812            continue;
813        }
814
815        // 检查是否是 max_concurrency 标注
816        let is_max_concurrency_attr = attr.path().is_ident("max_concurrency")
817            || (attr.path().segments.len() == 2
818                && attr
819                    .path()
820                    .segments
821                    .iter()
822                    .next()
823                    .map(|s| s.ident.to_string())
824                    == Some("potato".to_string())
825                && attr
826                    .path()
827                    .segments
828                    .iter()
829                    .last()
830                    .map(|s| s.ident.to_string())
831                    == Some("max_concurrency".to_string()));
832
833        if is_max_concurrency_attr {
834            if let syn::Meta::List(meta_list) = &attr.meta {
835                let tokens = &meta_list.tokens;
836                // 直接解析为数字
837                if let Ok(lit_int) = syn::parse2::<syn::LitInt>(tokens.clone()) {
838                    if let Ok(val) = lit_int.base10_parse::<usize>() {
839                        if val == 0 {
840                            panic!("max_concurrency must be greater than 0");
841                        }
842                        max_concurrency = Some(val);
843                    } else {
844                        panic!("invalid max_concurrency value");
845                    }
846                } else {
847                    panic!(
848                        "max_concurrency requires a numeric value, e.g., #[max_concurrency(10)]"
849                    );
850                }
851            } else if let syn::Meta::NameValue(name_value) = &attr.meta {
852                if let syn::Expr::Lit(expr_lit) = &name_value.value {
853                    if let syn::Lit::Int(lit_int) = &expr_lit.lit {
854                        if let Ok(val) = lit_int.base10_parse::<usize>() {
855                            if val == 0 {
856                                panic!("max_concurrency must be greater than 0");
857                            }
858                            max_concurrency = Some(val);
859                        } else {
860                            panic!("invalid max_concurrency value");
861                        }
862                    } else {
863                        panic!("max_concurrency requires a numeric value");
864                    }
865                } else {
866                    panic!("max_concurrency requires a numeric value");
867                }
868            } else {
869                panic!("max_concurrency requires a numeric value, e.g., #[max_concurrency(10)]");
870            }
871            continue;
872        }
873
874        remaining_attrs.push(attr.clone());
875    }
876
877    // 合并默认headers和函数headers
878    let mut all_headers = default_headers;
879    all_headers.extend(fn_headers);
880
881    root_fn.attrs = remaining_attrs;
882    let (preprocess_fns, postprocess_fns) = collect_handler_hooks(&mut root_fn);
883
884    // 检测handler自身是否需要缓存
885    let handler_has_once_cache = root_fn.sig.inputs.iter().any(|arg| {
886        if let syn::FnArg::Typed(arg) = arg {
887            arg.ty.to_token_stream().to_string().type_simplify() == "& mut OnceCache"
888        } else {
889            false
890        }
891    });
892    let handler_has_session_cache = root_fn.sig.inputs.iter().any(|arg| {
893        if let syn::FnArg::Typed(arg) = arg {
894            arg.ty.to_token_stream().to_string().type_simplify() == "& mut SessionCache"
895        } else {
896            false
897        }
898    });
899
900    // 修复:只有当 handler 本身需要缓存时,才设置 need_session_cache 和 need_once_cache
901    // preprocess/postprocess 钩子如果需要缓存,它们可以通过参数声明
902    // 但如果 handler 不需要缓存,我们不应该强制要求 Authorization header
903    // 这样可以避免给不需要认证的 handler 添加不必要的认证要求
904    let need_once_cache = handler_has_once_cache;
905    let need_session_cache = handler_has_session_cache;
906
907    let preprocess_adapters: Vec<Ident> = preprocess_fns
908        .iter()
909        .map(|name| format_ident!("__potato_preprocess_adapter_{}", name))
910        .collect();
911    let postprocess_adapters: Vec<Ident> = postprocess_fns
912        .iter()
913        .map(|name| format_ident!("__potato_postprocess_adapter_{}", name))
914        .collect();
915    let doc_show = {
916        let mut doc_show = true;
917        for attr in root_fn.attrs.iter() {
918            if attr.meta.path().get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
919                if let Ok(meta_list) = attr.meta.require_list() {
920                    if meta_list.tokens.to_string() == "hidden" {
921                        doc_show = false;
922                        break;
923                    }
924                }
925            }
926        }
927        doc_show
928    };
929    let doc_auth = need_session_cache;
930    let doc_summary = {
931        let mut docs = vec![];
932        for attr in root_fn.attrs.iter() {
933            if let Ok(attr) = attr.meta.require_name_value() {
934                if attr.path.get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
935                    let mut doc = attr.value.to_token_stream().to_string();
936                    if doc.starts_with('\"') {
937                        doc.remove(0);
938                        doc.pop();
939                    }
940                    docs.push(doc);
941                }
942            }
943        }
944        if docs.iter().all(|d| d.starts_with(' ')) {
945            for doc in docs.iter_mut() {
946                doc.remove(0);
947            }
948        }
949        docs.join("\n")
950    };
951    let doc_desp = "";
952    let fn_name = root_fn.sig.ident.clone();
953    let is_async = root_fn.sig.asyncness.is_some();
954
955    // 检测是否有 receiver(&self / &mut self)- controller 方法
956    let has_receiver = root_fn
957        .sig
958        .inputs
959        .iter()
960        .any(|arg| matches!(arg, syn::FnArg::Receiver(_)));
961
962    // 生成最终路径(如果是 controller 方法,需要拼接)
963    // 注意:由于路由注册需要编译期常量,路径拼接必须在宏展开时完成
964    // 但 controller 宏和 http_get 宏是独立展开的,无法直接共享信息
965    // 因此这里采用简化方案:直接使用 route_path,路径拼接由用户保证正确
966    let final_path = if has_receiver {
967        // Controller 方法:如果 route_path 为空,说明用户希望使用 controller 的 base path
968        // 但这里无法获取 base path,所以要求用户必须指定完整路径或相对路径
969        if route_path.is_empty() {
970            // 暂时使用一个占位符,实际应该在 controller 宏中处理
971            // 这里我们先要求用户必须提供路径
972            panic!("Controller methods must specify a path (e.g., #[potato::http_get(\"/\")])");
973        } else {
974            route_path
975        }
976    } else {
977        if route_path.is_empty() {
978            panic!("`path` argument is required for non-controller methods");
979        }
980        route_path
981    };
982
983    let final_path_expr = quote! { #final_path };
984
985    // 生成 tag 表达式
986    let tag_expr = if has_receiver {
987        quote! { __POTATO_CONTROLLER_NAME }
988    } else {
989        quote! { "" }
990    };
991
992    let wrap_func_name = random_ident();
993    let mut args = vec![];
994    let mut arg_names = vec![];
995    let mut arg_types = vec![];
996    let mut doc_args = vec![];
997    for arg in root_fn.sig.inputs.iter() {
998        // 支持 receiver 参数(&self / &mut self)- controller 方法
999        if let syn::FnArg::Receiver(_receiver) = arg {
1000            // 跳过 receiver,不生成参数绑定代码
1001            // controller 实例将在包装函数中创建
1002            continue;
1003        }
1004
1005        if let syn::FnArg::Typed(arg) = arg {
1006            let arg_type_str = arg
1007                .ty
1008                .as_ref()
1009                .to_token_stream()
1010                .to_string()
1011                .type_simplify();
1012            let arg_name_str = arg.pat.to_token_stream().to_string();
1013            let arg_value = match &arg_type_str[..] {
1014                "& mut HttpRequest" => quote! { req },
1015                "& mut OnceCache" => {
1016                    quote! { __potato_once_cache.as_mut().expect("OnceCache not available") }
1017                }
1018                "& mut SessionCache" => {
1019                    quote! { __potato_session_cache.as_mut().expect("SessionCache not available") }
1020                }
1021                "PostFile" => {
1022                    doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
1023                    quote! {
1024                        match req.body_files.get(&potato::utils::refstr::LocalHipStr<'static>::from_str(#arg_name_str)).cloned() {
1025                            Some(file) => file,
1026                            None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
1027                        }
1028                    }
1029                }
1030                arg_type_str if ARG_TYPES.contains(arg_type_str) => {
1031                    doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
1032                    let mut arg_value = quote! {
1033                        match req.body_pairs
1034                            .get(&potato::hipstr::LocalHipStr::from(#arg_name_str))
1035                            .map(|p| p.to_string()) {
1036                            Some(val) => val,
1037                            None => match req.url_query
1038                                .get(&potato::hipstr::LocalHipStr::from(#arg_name_str))
1039                                .map(|p| p.as_str().to_string()) {
1040                                Some(val) => val,
1041                                None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
1042                            },
1043                        }
1044                    };
1045                    if arg_type_str != "String" {
1046                        arg_value = quote! {
1047                            match #arg_value.parse() {
1048                                Ok(val) => val,
1049                                Err(err) => return potato::HttpResponse::error(format!("arg[{}] is not {} type", #arg_name_str, #arg_type_str)),
1050                            }
1051                        }
1052                    }
1053                    arg_value
1054                }
1055                _ => panic!("unsupported arg type: [{arg_type_str}]"),
1056            };
1057            args.push(arg_value);
1058            arg_names.push(random_ident());
1059            // 保存参数类型信息,用于后续生成 call_expr
1060            arg_types.push(arg_type_str);
1061        }
1062    }
1063    let wrap_func_name2 = random_ident();
1064    let ret_type = root_fn
1065        .sig
1066        .output
1067        .to_token_stream()
1068        .to_string()
1069        .type_simplify();
1070
1071    // 如果有 receiver,需要生成 __potato_create_controller 函数
1072    // 通过检查是否存在 controller 常量来判断
1073    let _controller_create_fn = if has_receiver {
1074        quote! {
1075            // 这个函数应该由 controller 宏生成,这里只是引用
1076            // 如果编译出错,说明没有正确使用 #[potato::controller]
1077        }
1078    } else {
1079        quote! {}
1080    };
1081
1082    // 为每个参数生成调用代码
1083    let call_args: Vec<_> = args
1084        .iter()
1085        .enumerate()
1086        .map(|(i, _arg)| {
1087            let arg_name = &arg_names[i];
1088            let arg_type = &arg_types[i];
1089            // 对于 HttpRequest,直接使用 req,不要通过中间变量
1090            if arg_type == "& mut HttpRequest" {
1091                quote! { req }
1092            } else {
1093                quote! { #arg_name }
1094            }
1095        })
1096        .collect();
1097
1098    let call_expr = if has_receiver {
1099        // Controller 方法:直接调用方法(暂不支持字段注入)
1100        // 注意:当前版本不支持 controller 字段,方法应该是静态方法
1101        // 如果要支持字段,需要在包装函数中实例化 controller
1102        match args.len() {
1103            0 => quote! { #fn_name() },
1104            1 => {
1105                let arg_name = &arg_names[0];
1106                let arg = &args[0];
1107                let arg_type = &arg_types[0];
1108                if arg_type == "& mut HttpRequest" {
1109                    quote! { #fn_name(req) }
1110                } else {
1111                    quote! {{
1112                        let #arg_name = #arg;
1113                        #fn_name(#arg_name)
1114                    }}
1115                }
1116            }
1117            _ => {
1118                let let_bindings: Vec<_> = arg_types
1119                    .iter()
1120                    .zip(arg_names.iter())
1121                    .zip(args.iter())
1122                    .filter(|((arg_type, _), _)| *arg_type != "& mut HttpRequest")
1123                    .map(|((_, arg_name), arg)| quote! { let #arg_name = #arg; })
1124                    .collect();
1125
1126                quote! {{
1127                    #(#let_bindings)*
1128                    #fn_name(#(#call_args),*)
1129                }}
1130            }
1131        }
1132    } else {
1133        // 普通方法:直接调用函数
1134        match args.len() {
1135            0 => quote! { #fn_name() },
1136            1 => {
1137                let arg_name = &arg_names[0];
1138                let arg = &args[0];
1139                let arg_type = &arg_types[0];
1140                // 检查是否是 HttpRequest 类型
1141                if arg_type == "& mut HttpRequest" {
1142                    quote! { #fn_name(req) }
1143                } else {
1144                    quote! {{
1145                        let #arg_name = #arg;
1146                        #fn_name(#arg_name)
1147                    }}
1148                }
1149            }
1150            _ => {
1151                // 只为非 HttpRequest 类型的参数创建中间变量
1152                let let_bindings: Vec<_> = arg_types
1153                    .iter()
1154                    .zip(arg_names.iter())
1155                    .zip(args.iter())
1156                    .filter(|((arg_type, _), _)| *arg_type != "& mut HttpRequest")
1157                    .map(|((_, arg_name), arg)| quote! { let #arg_name = #arg; })
1158                    .collect();
1159
1160                quote! {{
1161                    #(#let_bindings)*
1162                    #fn_name(#(#call_args),*)
1163                }}
1164            }
1165        }
1166    };
1167    let handler_wrap_func_body = if is_async {
1168        match &ret_type[..] {
1169            "Result<()>" => quote! {
1170                match #call_expr.await {
1171                    Ok(_) => Ok(potato::HttpResponse::text("ok")),
1172                    Err(err) => Err(err),
1173                }
1174            },
1175            "Result<HttpResponse>" | "anyhow::Result<HttpResponse>" => quote! {
1176                match #call_expr.await {
1177                    Ok(ret) => Ok(ret),
1178                    Err(err) => Err(err),
1179                }
1180            },
1181            "Result<String>" | "anyhow::Result<String>" => quote! {
1182                match #call_expr.await {
1183                    Ok(ret) => Ok(potato::HttpResponse::html(ret)),
1184                    Err(err) => Err(err),
1185                }
1186            },
1187            "Result<& 'static str>" | "anyhow::Result<& 'static str>" => quote! {
1188                match #call_expr.await {
1189                    Ok(ret) => Ok(potato::HttpResponse::html(ret)),
1190                    Err(err) => Err(err),
1191                }
1192            },
1193            "()" => quote! {
1194                #call_expr.await;
1195                Ok(potato::HttpResponse::text("ok"))
1196            },
1197            "HttpResponse" => quote! {
1198                Ok(#call_expr.await)
1199            },
1200            "String" => quote! {
1201                Ok(potato::HttpResponse::html(#call_expr.await))
1202            },
1203            "& 'static str" => quote! {
1204                Ok(potato::HttpResponse::html(#call_expr.await))
1205            },
1206            _ => panic!("unsupported ret type: {ret_type}"),
1207        }
1208    } else {
1209        match &ret_type[..] {
1210            "Result<()>" => quote! {
1211                match #call_expr {
1212                    Ok(_) => Ok(potato::HttpResponse::text("ok")),
1213                    Err(err) => Err(err),
1214                }
1215            },
1216            "Result<HttpResponse>" | "anyhow::Result<HttpResponse>" => quote! {
1217                match #call_expr {
1218                    Ok(ret) => Ok(ret),
1219                    Err(err) => Err(err),
1220                }
1221            },
1222            "Result<String>" | "anyhow::Result<String>" => quote! {
1223                match #call_expr {
1224                    Ok(ret) => Ok(potato::HttpResponse::html(ret)),
1225                    Err(err) => Err(err),
1226                }
1227            },
1228            "Result<& 'static str>" | "anyhow::Result<& 'static str>" => quote! {
1229                match #call_expr {
1230                    Ok(ret) => Ok(potato::HttpResponse::html(ret)),
1231                    Err(err) => Err(err),
1232                }
1233            },
1234            "()" => quote! {
1235                #call_expr;
1236                Ok(potato::HttpResponse::text("ok"))
1237            },
1238            "HttpResponse" => quote! {
1239                Ok(#call_expr)
1240            },
1241            "String" => quote! {
1242                Ok(potato::HttpResponse::html(#call_expr))
1243            },
1244            "& 'static str" => quote! {
1245                Ok(potato::HttpResponse::html(#call_expr))
1246            },
1247            _ => panic!("unsupported ret type: {ret_type}"),
1248        }
1249    };
1250    let doc_args = serde_json::to_string(&doc_args).unwrap();
1251
1252    // 生成添加headers的代码
1253    let add_headers_code = if all_headers.is_empty() {
1254        quote! {}
1255    } else {
1256        let header_statements = all_headers.iter().map(|(key, value)| {
1257            // 将下划线转换为HTTP标准命名 (例如 Cache_Control -> Cache-Control)
1258            let http_key = key.replace("_", "-");
1259            quote! {
1260                __potato_response.add_header(
1261                    std::borrow::Cow::Borrowed(#http_key),
1262                    std::borrow::Cow::Borrowed(#value)
1263                );
1264            }
1265        });
1266        quote! {
1267            #(#header_statements)*
1268        }
1269    };
1270
1271    // 如果存在CORS配置,生成CORS headers注入代码
1272    let cors_headers_code = if let Some(cors) = &cors_config {
1273        let mut statements = vec![];
1274
1275        // origin: 默认"*"
1276        let origin_val = cors.origin.as_deref().unwrap_or("*");
1277        statements.push(quote! {
1278            __potato_response.add_header(
1279                "Access-Control-Allow-Origin".into(),
1280                #origin_val.into()
1281            );
1282        });
1283
1284        // methods: 仅在用户指定时添加,否则由OPTIONS请求自动计算
1285        if let Some(ref methods) = cors.methods {
1286            let mut methods_list: Vec<&str> = methods.split(',').map(|s| s.trim()).collect();
1287            if !methods_list.contains(&"HEAD") {
1288                methods_list.push("HEAD");
1289            }
1290            if !methods_list.contains(&"OPTIONS") {
1291                methods_list.push("OPTIONS");
1292            }
1293            let methods_str = methods_list.join(",");
1294            statements.push(quote! {
1295                __potato_response.add_header(
1296                    "Access-Control-Allow-Methods".into(),
1297                    #methods_str.into()
1298                );
1299            });
1300        }
1301
1302        // headers: 默认"*"
1303        let headers_val = cors.headers.as_deref().unwrap_or("*");
1304        statements.push(quote! {
1305            __potato_response.add_header(
1306                "Access-Control-Allow-Headers".into(),
1307                #headers_val.into()
1308            );
1309        });
1310
1311        // max_age: 默认"86400"
1312        if let Some(ref max_age) = cors.max_age {
1313            statements.push(quote! {
1314                __potato_response.add_header(
1315                    "Access-Control-Max-Age".into(),
1316                    #max_age.into()
1317                );
1318            });
1319        } else {
1320            statements.push(quote! {
1321                __potato_response.add_header(
1322                    "Access-Control-Max-Age".into(),
1323                    "86400".into()
1324                );
1325            });
1326        }
1327
1328        if cors.credentials {
1329            statements.push(quote! {
1330                __potato_response.add_header(
1331                    "Access-Control-Allow-Credentials".into(),
1332                    "true".into()
1333                );
1334            });
1335        }
1336
1337        if let Some(ref expose_headers) = cors.expose_headers {
1338            statements.push(quote! {
1339                __potato_response.add_header(
1340                    "Access-Control-Expose-Headers".into(),
1341                    #expose_headers.into()
1342                );
1343            });
1344        }
1345
1346        quote! { #(#statements)* }
1347    } else {
1348        quote! {}
1349    };
1350
1351    // 如果存在CORS配置且是PUT/POST/DELETE,自动生成HEAD handler
1352    let auto_head_handler = if cors_config.is_some()
1353        && (req_name == "POST" || req_name == "PUT" || req_name == "DELETE")
1354    {
1355        let head_wrap_name = format_ident!("__potato_cors_head_{}", fn_name);
1356        Some(quote! {
1357            #[doc(hidden)]
1358            fn #head_wrap_name(req: &mut potato::HttpRequest) -> potato::HttpResponse {
1359                // HEAD请求直接返回空响应,不执行原handler
1360                // CORS headers会通过postprocess机制自动添加
1361                potato::HttpResponse::html("")
1362            }
1363        })
1364    } else {
1365        None
1366    };
1367
1368    // 如果指定了max_concurrency,生成静态信号量
1369    let semaphore_static = if let Some(max_conn) = max_concurrency {
1370        let semaphore_name =
1371            format_ident!("__POTATO_SEMAPHORE_{}", fn_name.to_string().to_uppercase());
1372        Some(quote! {
1373            #[doc(hidden)]
1374            #[allow(non_upper_case_globals)]
1375            static #semaphore_name: std::sync::LazyLock<tokio::sync::Semaphore> =
1376                std::sync::LazyLock::new(|| tokio::sync::Semaphore::new(#max_conn));
1377        })
1378    } else {
1379        None
1380    };
1381
1382    let wrap_func_body = if is_async {
1383        if max_concurrency.is_some() {
1384            let semaphore_name =
1385                format_ident!("__POTATO_SEMAPHORE_{}", fn_name.to_string().to_uppercase());
1386            quote! {
1387                let __potato_permit = #semaphore_name.acquire().await;
1388
1389                // 获取自定义错误处理器
1390                let __potato_error_handler: Option<potato::ErrorHandler> = {
1391                    let mut handler = None;
1392                    for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1393                        handler = Some(flag.handler.clone());
1394                        break;
1395                    }
1396                    handler
1397                };
1398
1399                // 按需创建缓存对象
1400                let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1401                    Some(potato::OnceCache::new())
1402                } else {
1403                    None
1404                };
1405                let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1406                    // 从 Authorization header 中提取 Bearer token 并加载 session
1407                    if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1408                        let header_value = h.as_str();
1409                        if header_value.starts_with("Bearer ") {
1410                            potato::SessionCache::from_token(&header_value[7..]).await.ok()
1411                        } else {
1412                            None
1413                        }
1414                    } else {
1415                        None
1416                    }
1417                } else {
1418                    None
1419                };
1420
1421                // 如果 handler 需要 SessionCache 但没有提供 Authorization header,返回 401
1422                if #need_session_cache && __potato_session_cache.is_none() {
1423                    let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1424                    __potato_resp.http_code = 401;
1425                    return __potato_resp;
1426                }
1427
1428                // 自动解析请求中的Cookie
1429                if let Some(ref mut session_cache) = __potato_session_cache {
1430                    if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1431                        session_cache.parse_request_cookies(cookie_header.as_str());
1432                    }
1433                }
1434
1435                let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1436                #(
1437                    if __potato_pre_response.is_none() {
1438                        __potato_pre_response = match #preprocess_adapters(
1439                            req,
1440                            __potato_once_cache.as_mut(),
1441                            __potato_session_cache.as_mut(),
1442                        ).await {
1443                            Ok(Some(ret)) => Some(ret),
1444                            Ok(None) => None,
1445                            Err(err) => {
1446                                let handler = &__potato_error_handler;
1447                                Some(match handler {
1448                                    Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1449                                    Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1450                                    None => potato::HttpResponse::error(format!("{err:?}")),
1451                                })
1452                            }
1453                        };
1454                    }
1455                )*
1456
1457                let mut __potato_response = match __potato_pre_response {
1458                    Some(ret) => ret,
1459                    None => match #handler_wrap_func_body {
1460                        Ok(resp) => resp,
1461                        Err(err) => {
1462                            let handler = &__potato_error_handler;
1463                            match handler {
1464                                Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1465                                Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1466                                None => potato::HttpResponse::error(format!("{err:?}")),
1467                            }
1468                        }
1469                    },
1470                };
1471
1472                #(
1473                    if let Err(err) = #postprocess_adapters(
1474                        req,
1475                        &mut __potato_response,
1476                        __potato_once_cache.as_mut(),
1477                        __potato_session_cache.as_mut(),
1478                    ).await {
1479                        drop(__potato_permit);
1480                        let handler = &__potato_error_handler;
1481                        return match handler {
1482                            Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1483                            Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1484                            None => potato::HttpResponse::error(format!("{err:?}")),
1485                        };
1486                    }
1487                )*
1488
1489                #add_headers_code
1490                #cors_headers_code
1491
1492                // 自动应用SessionCache中的cookies到响应
1493                if let Some(ref session_cache) = __potato_session_cache {
1494                    session_cache.apply_cookies(&mut __potato_response);
1495                }
1496
1497                drop(__potato_permit);
1498                __potato_response
1499            }
1500        } else {
1501            quote! {
1502                // 获取自定义错误处理器
1503                let __potato_error_handler: Option<potato::ErrorHandler> = {
1504                    let mut handler = None;
1505                    for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1506                        handler = Some(flag.handler.clone());
1507                        break;
1508                    }
1509                    handler
1510                };
1511
1512                // 按需创建缓存对象
1513                let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1514                    Some(potato::OnceCache::new())
1515                } else {
1516                    None
1517                };
1518                let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1519                    // 从 Authorization header 中提取 Bearer token 并加载 session
1520                    if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1521                        let header_value = h.as_str();
1522                        if header_value.starts_with("Bearer ") {
1523                            potato::SessionCache::from_token(&header_value[7..]).await.ok()
1524                        } else {
1525                            None
1526                        }
1527                    } else {
1528                        None
1529                    }
1530                } else {
1531                    None
1532                };
1533
1534                // 如果 handler 需要 SessionCache 但没有提供 Authorization header,返回 401
1535                if #need_session_cache && __potato_session_cache.is_none() {
1536                    let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1537                    __potato_resp.http_code = 401;
1538                    return __potato_resp;
1539                }
1540
1541                // 自动解析请求中的Cookie
1542                if let Some(ref mut session_cache) = __potato_session_cache {
1543                    if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1544                        session_cache.parse_request_cookies(cookie_header.as_str());
1545                    }
1546                }
1547
1548                let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1549                #(
1550                    if __potato_pre_response.is_none() {
1551                        __potato_pre_response = match #preprocess_adapters(
1552                            req,
1553                            __potato_once_cache.as_mut(),
1554                            __potato_session_cache.as_mut(),
1555                        ).await {
1556                            Ok(Some(ret)) => Some(ret),
1557                            Ok(None) => None,
1558                            Err(err) => {
1559                                let handler = &__potato_error_handler;
1560                                Some(match handler {
1561                                    Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1562                                    Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1563                                    None => potato::HttpResponse::error(format!("{err:?}")),
1564                                })
1565                            }
1566                        };
1567                    }
1568                )*
1569
1570                let mut __potato_response = match __potato_pre_response {
1571                    Some(ret) => ret,
1572                    None => match #handler_wrap_func_body {
1573                        Ok(resp) => resp,
1574                        Err(err) => {
1575                            let handler = &__potato_error_handler;
1576                            match handler {
1577                                Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1578                                Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1579                                None => potato::HttpResponse::error(format!("{err:?}")),
1580                            }
1581                        }
1582                    },
1583                };
1584
1585                #(
1586                    if let Err(err) = #postprocess_adapters(
1587                        req,
1588                        &mut __potato_response,
1589                        __potato_once_cache.as_mut(),
1590                        __potato_session_cache.as_mut(),
1591                    ).await {
1592                        let handler = &__potato_error_handler;
1593                        return match handler {
1594                            Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1595                            Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1596                            None => potato::HttpResponse::error(format!("{err:?}")),
1597                        };
1598                    }
1599                )*
1600
1601                #add_headers_code
1602                #cors_headers_code
1603
1604                __potato_response
1605            }
1606        }
1607    } else {
1608        if max_concurrency.is_some() {
1609            let semaphore_name =
1610                format_ident!("__POTATO_SEMAPHORE_{}", fn_name.to_string().to_uppercase());
1611            quote! {
1612                let __potato_permit = #semaphore_name.acquire().await;
1613
1614                // 获取自定义错误处理器
1615                let __potato_error_handler: Option<potato::ErrorHandler> = {
1616                    let mut handler = None;
1617                    for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1618                        handler = Some(flag.handler.clone());
1619                        break;
1620                    }
1621                    handler
1622                };
1623
1624                // 按需创建缓存对象
1625                let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1626                    Some(potato::OnceCache::new())
1627                } else {
1628                    None
1629                };
1630                let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1631                    // 从 Authorization header 中提取 Bearer token 并加载 session
1632                    if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1633                        let header_value = h.as_str();
1634                        if header_value.starts_with("Bearer ") {
1635                            potato::SessionCache::from_token(&header_value[7..]).await.ok()
1636                        } else {
1637                            None
1638                        }
1639                    } else {
1640                        None
1641                    }
1642                } else {
1643                    None
1644                };
1645
1646                // 如果 handler 需要 SessionCache 但没有提供 Authorization header,返回 401
1647                if #need_session_cache && __potato_session_cache.is_none() {
1648                    let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1649                    __potato_resp.http_code = 401;
1650                    return __potato_resp;
1651                }
1652
1653                // 自动解析请求中的Cookie
1654                if let Some(ref mut session_cache) = __potato_session_cache {
1655                    if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1656                        session_cache.parse_request_cookies(cookie_header.as_str());
1657                    }
1658                }
1659
1660                let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1661                #(
1662                    if __potato_pre_response.is_none() {
1663                        __potato_pre_response = match #preprocess_adapters(
1664                            req,
1665                            __potato_once_cache.as_mut(),
1666                            __potato_session_cache.as_mut(),
1667                        ).await {
1668                            Ok(Some(ret)) => Some(ret),
1669                            Ok(None) => None,
1670                            Err(err) => {
1671                                let handler = &__potato_error_handler;
1672                                Some(match handler {
1673                                    Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1674                                    Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1675                                    None => potato::HttpResponse::error(format!("{err:?}")),
1676                                })
1677                            }
1678                        };
1679                    }
1680                )*
1681
1682                let mut __potato_response = match __potato_pre_response {
1683                    Some(ret) => ret,
1684                    None => match #handler_wrap_func_body {
1685                        Ok(resp) => resp,
1686                        Err(err) => {
1687                            let handler = &__potato_error_handler;
1688                            match handler {
1689                                Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1690                                Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1691                                None => potato::HttpResponse::error(format!("{err:?}")),
1692                            }
1693                        }
1694                    },
1695                };
1696
1697                #(
1698                    if let Err(err) = #postprocess_adapters(
1699                        req,
1700                        &mut __potato_response,
1701                        __potato_once_cache.as_mut(),
1702                        __potato_session_cache.as_mut(),
1703                    ).await {
1704                        drop(__potato_permit);
1705                        let handler = &__potato_error_handler;
1706                        return match handler {
1707                            Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1708                            Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1709                            None => potato::HttpResponse::error(format!("{err:?}")),
1710                        };
1711                    }
1712                )*
1713
1714                #add_headers_code
1715                #cors_headers_code
1716
1717                // 自动应用SessionCache中的cookies到响应
1718                if let Some(ref session_cache) = __potato_session_cache {
1719                    session_cache.apply_cookies(&mut __potato_response);
1720                }
1721
1722                drop(__potato_permit);
1723                __potato_response
1724            }
1725        } else {
1726            quote! {
1727                // 获取自定义错误处理器
1728                let __potato_error_handler: Option<potato::ErrorHandler> = {
1729                    let mut handler = None;
1730                    for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1731                        handler = Some(flag.handler.clone());
1732                        break;
1733                    }
1734                    handler
1735                };
1736
1737                // 按需创建缓存对象
1738                let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1739                    Some(potato::OnceCache::new())
1740                } else {
1741                    None
1742                };
1743                let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1744                    // 从 Authorization header 中提取 Bearer token 并加载 session
1745                    if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1746                        let header_value = h.as_str();
1747                        if header_value.starts_with("Bearer ") {
1748                            potato::SessionCache::from_token(&header_value[7..]).await.ok()
1749                        } else {
1750                            None
1751                        }
1752                    } else {
1753                        None
1754                    }
1755                } else {
1756                    None
1757                };
1758
1759                // 如果 handler 需要 SessionCache 但没有提供 Authorization header,返回 401
1760                if #need_session_cache && __potato_session_cache.is_none() {
1761                    let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1762                    __potato_resp.http_code = 401;
1763                    return __potato_resp;
1764                }
1765
1766                // 自动解析请求中的Cookie
1767                if let Some(ref mut session_cache) = __potato_session_cache {
1768                    if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1769                        session_cache.parse_request_cookies(cookie_header.as_str());
1770                    }
1771                }
1772
1773                let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1774                #(
1775                    if __potato_pre_response.is_none() {
1776                        __potato_pre_response = match #preprocess_adapters(
1777                            req,
1778                            __potato_once_cache.as_mut(),
1779                            __potato_session_cache.as_mut(),
1780                        ).await {
1781                            Ok(Some(ret)) => Some(ret),
1782                            Ok(None) => None,
1783                            Err(err) => {
1784                                let handler = &__potato_error_handler;
1785                                Some(match handler {
1786                                    Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1787                                    Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1788                                    None => potato::HttpResponse::error(format!("{err:?}")),
1789                                })
1790                            }
1791                        };
1792                    }
1793                )*
1794
1795                let mut __potato_response = match __potato_pre_response {
1796                    Some(ret) => ret,
1797                    None => match #handler_wrap_func_body {
1798                        Ok(resp) => resp,
1799                        Err(err) => {
1800                            let handler = &__potato_error_handler;
1801                            match handler {
1802                                Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1803                                Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1804                                None => potato::HttpResponse::error(format!("{err:?}")),
1805                            }
1806                        }
1807                    },
1808                };
1809
1810                #(
1811                    if let Err(err) = #postprocess_adapters(
1812                        req,
1813                        &mut __potato_response,
1814                        __potato_once_cache.as_mut(),
1815                        __potato_session_cache.as_mut(),
1816                    ).await {
1817                        let handler = &__potato_error_handler;
1818                        return match handler {
1819                            Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1820                            Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1821                            None => potato::HttpResponse::error(format!("{err:?}")),
1822                        };
1823                    }
1824                )*
1825
1826                #add_headers_code
1827                #cors_headers_code
1828
1829                // 自动应用SessionCache中的cookies到响应
1830                if let Some(ref session_cache) = __potato_session_cache {
1831                    session_cache.apply_cookies(&mut __potato_response);
1832                }
1833
1834                __potato_response
1835            }
1836        }
1837    };
1838
1839    if is_async {
1840        quote! {
1841            #root_fn
1842
1843            #auto_head_handler
1844
1845            #semaphore_static
1846
1847            #[doc(hidden)]
1848            async fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
1849                #wrap_func_body
1850            }
1851
1852            #[doc(hidden)]
1853            fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
1854                Box::pin(#wrap_func_name2(req))
1855            }
1856
1857            potato::inventory::submit!{potato::RequestHandlerFlag::new(
1858                potato::HttpMethod::#req_name,
1859                #final_path_expr,
1860                potato::HttpHandler::Async(#wrap_func_name),
1861                potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args, #tag_expr)
1862            )}
1863        }
1864        .into()
1865    } else {
1866        quote! {
1867            #root_fn
1868
1869            #auto_head_handler
1870
1871            #semaphore_static
1872
1873            #[doc(hidden)]
1874            async fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
1875                #wrap_func_body
1876            }
1877
1878            #[doc(hidden)]
1879            fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
1880                Box::pin(#wrap_func_name2(req))
1881            }
1882
1883            potato::inventory::submit!{potato::RequestHandlerFlag::new(
1884                potato::HttpMethod::#req_name,
1885                #final_path_expr,
1886                potato::HttpHandler::Async(#wrap_func_name),
1887                potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args, #tag_expr)
1888            )}
1889        }
1890        .into()
1891    }
1892    //}.to_string();
1893    //panic!("{content}");
1894    //todo!()
1895}
1896
1897#[proc_macro_attribute]
1898pub fn http_get(attr: TokenStream, input: TokenStream) -> TokenStream {
1899    http_handler_macro(attr, input, "GET")
1900}
1901
1902#[proc_macro_attribute]
1903pub fn http_post(attr: TokenStream, input: TokenStream) -> TokenStream {
1904    http_handler_macro(attr, input, "POST")
1905}
1906
1907#[proc_macro_attribute]
1908pub fn http_put(attr: TokenStream, input: TokenStream) -> TokenStream {
1909    http_handler_macro(attr, input, "PUT")
1910}
1911
1912#[proc_macro_attribute]
1913pub fn http_delete(attr: TokenStream, input: TokenStream) -> TokenStream {
1914    http_handler_macro(attr, input, "DELETE")
1915}
1916
1917#[proc_macro_attribute]
1918pub fn http_options(attr: TokenStream, input: TokenStream) -> TokenStream {
1919    http_handler_macro(attr, input, "OPTIONS")
1920}
1921
1922#[proc_macro_attribute]
1923pub fn http_head(attr: TokenStream, input: TokenStream) -> TokenStream {
1924    http_handler_macro(attr, input, "HEAD")
1925}
1926
1927/// Controller 属性宏 - 定义控制器结构体
1928///
1929/// # 功能
1930/// - 为结构体的 impl 块中的所有方法提供统一的路由前缀
1931/// - 支持 preprocess/postprocess 中间件继承
1932/// - 自动为 Swagger 文档分组(tag 为结构体名称)
1933///
1934/// # 结构体字段限制
1935/// 只能包含以下类型的字段(0个或多个):
1936/// - `&'a potato::OnceCache`
1937/// - `&'a potato::SessionCache`
1938///
1939/// # 示例
1940/// ```rust,ignore
1941/// #[potato::controller("/api/users")]
1942/// pub struct UsersController<'a> {
1943///     pub once_cache: &'a potato::OnceCache,
1944///     pub sess_cache: &'a potato::SessionCache,
1945/// }
1946///
1947/// #[potato::preprocess(my_preprocess)]
1948/// impl<'a> UsersController<'a> {
1949///     #[potato::http_get] // 地址为 "/api/users"
1950///     pub async fn get(&self) -> anyhow::Result<&'static str> {
1951///         Ok("get users data")
1952///     }
1953///
1954///     #[potato::http_get("/any")] // 地址为 "/api/users/any"
1955///     pub async fn get_any(&self) -> anyhow::Result<&'static str> {
1956///         Ok("get any data")
1957///     }
1958/// }
1959/// ```
1960#[proc_macro_attribute]
1961pub fn controller(attr: TokenStream, input: TokenStream) -> TokenStream {
1962    controller_macro(attr, input)
1963}
1964
1965fn controller_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
1966    // 尝试解析为 impl 块
1967    let input_clone = input.clone();
1968    if let Ok(item_impl) = syn::parse::<syn::ItemImpl>(input_clone) {
1969        // 这是 impl 块,需要提取方法并生成路由注册
1970        return controller_impl_macro(attr, item_impl);
1971    }
1972
1973    // 否则解析为结构体
1974    let item_struct = syn::parse_macro_input!(input as syn::ItemStruct);
1975
1976    // 解析 base path(结构体上的 controller 可以没有 path,由 impl 块指定)
1977    let base_path = if attr.is_empty() {
1978        // 结构体上没有 path,不生成常量,由 impl 块上的 controller 指定
1979        quote! {}
1980    } else {
1981        let attr_str = attr.to_string();
1982        let base_path = attr_str.trim_matches('"').to_string();
1983        quote! {
1984            #[doc(hidden)]
1985            const __POTATO_CONTROLLER_BASE_PATH: &str = #base_path;
1986        }
1987    };
1988
1989    // 验证结构体字段并获取字段信息
1990    let (has_once_cache, has_session_cache) = validate_controller_struct(&item_struct);
1991    let struct_name = &item_struct.ident;
1992    let struct_name_str = struct_name.to_string();
1993
1994    // 生成结构体定义、常量和 inventory 提交
1995    // 同时生成隐藏的 controller 创建辅助函数
1996    let controller_creation_fn = if has_session_cache {
1997        // 结构体有 SessionCache 字段,生成包含鉴权的创建函数
1998        // 直接创建并返回 Box<Self>
1999        quote! {
2000            #[doc(hidden)]
2001            #[allow(dead_code)]
2002            async fn __potato_create_controller(req: &potato::HttpRequest) -> Result<Box<Self>, potato::HttpResponse> {
2003                // 在堆上分配缓存
2004                let once_cache = Box::leak(Box::new(potato::OnceCache::new()));
2005
2006                // 从 Authorization header 中提取 Bearer token 并加载 session
2007                let session_cache = {
2008                    if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
2009                        let header_value = h.as_str();
2010                        if header_value.starts_with("Bearer ") {
2011                            potato::SessionCache::from_token(&header_value[7..]).await.ok()
2012                        } else {
2013                            None
2014                        }
2015                    } else {
2016                        None
2017                    }
2018                };
2019
2020                let session_cache = match session_cache {
2021                    Some(cache) => cache,
2022                    None => {
2023                        let mut resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
2024                        resp.http_code = 401;
2025                        return Err(resp);
2026                    }
2027                };
2028                let session_cache = Box::leak(Box::new(session_cache));
2029
2030                // 创建 controller 实例
2031                let controller = Self {
2032                    once_cache,
2033                    sess_cache: session_cache,
2034                };
2035
2036                Ok(Box::new(controller))
2037            }
2038        }
2039    } else {
2040        // 结构体没有 SessionCache 字段,生成不包含鉴权的创建函数
2041        // 创建临时的 SessionCache(但不使用),返回 Box<Self>
2042        quote! {
2043            #[doc(hidden)]
2044            #[allow(dead_code)]
2045            async fn __potato_create_controller(_req: &potato::HttpRequest) -> Result<Box<Self>, potato::HttpResponse> {
2046                // 在堆上分配缓存
2047                let once_cache = Box::leak(Box::new(potato::OnceCache::new()));
2048
2049                // 创建临时的 SessionCache(不需要鉴权,也不使用)
2050                let _temp_session_cache = Box::leak(Box::new(potato::SessionCache::new()));
2051
2052                // 创建 controller 实例(不包含 sess_cache 字段)
2053                let controller = Self {
2054                    once_cache,
2055                };
2056
2057                Ok(Box::new(controller))
2058            }
2059        }
2060    };
2061
2062    // 提取结构体的泛型参数(包括生命周期)
2063    let struct_generics = &item_struct.generics;
2064    let (impl_generics, type_generics, where_clause) = struct_generics.split_for_impl();
2065
2066    let output = quote! {
2067        #item_struct
2068
2069        #base_path
2070
2071        #[doc(hidden)]
2072        const __POTATO_CONTROLLER_NAME: &str = #struct_name_str;
2073
2074        // 提交 Controller 结构体字段信息到 inventory
2075        potato::inventory::submit! {
2076            potato::ControllerStructFlag::new(
2077                #struct_name_str,
2078                potato::ControllerStructFieldInfo {
2079                    has_once_cache: #has_once_cache,
2080                    has_session_cache: #has_session_cache,
2081                }
2082            )
2083        }
2084
2085        // 生成隐藏的 controller 创建辅助函数
2086        impl #impl_generics #struct_name #type_generics #where_clause {
2087            #controller_creation_fn
2088        }
2089    };
2090
2091    output.into()
2092}
2093
2094/// 处理 impl 块的 controller 宏
2095fn controller_impl_macro(attr: TokenStream, item_impl: syn::ItemImpl) -> TokenStream {
2096    // 解析 base path(如果 attr 为空,则从常量读取)
2097    let base_path_str = if attr.is_empty() {
2098        // 从结构体上的 controller 宏生成的常量读取
2099        // 这种情况暂不支持,因为宏展开时无法读取常量值
2100        None
2101    } else {
2102        let attr_str = attr.to_string();
2103        Some(attr_str.trim_matches('"').to_string())
2104    };
2105
2106    // 从 impl 块中提取类型名称
2107    let self_type = &item_impl.self_ty;
2108
2109    // 提取不带生命周期参数的类型名称(用于实例化)
2110    let self_type_name = match &*item_impl.self_ty {
2111        syn::Type::Path(type_path) => {
2112            // 获取路径的最后一段(类型名称),不包含泛型参数
2113            if let Some(segment) = type_path.path.segments.last() {
2114                let ident = &segment.ident;
2115                quote! { #ident }
2116            } else {
2117                quote! { #self_type }
2118            }
2119        }
2120        _ => quote! { #self_type },
2121    };
2122
2123    // 提取不带生命周期参数的类型名称字符串(用于 Swagger tag)
2124    let self_type_tag = match &*item_impl.self_ty {
2125        syn::Type::Path(type_path) => {
2126            // 获取路径的最后一段(类型名称),不包含泛型参数
2127            if let Some(segment) = type_path.path.segments.last() {
2128                segment.ident.to_string()
2129            } else {
2130                self_type.to_token_stream().to_string()
2131            }
2132        }
2133        _ => self_type.to_token_stream().to_string(),
2134    };
2135
2136    // 创建清理后的 impl 块(移除方法上的 http_* 标注)
2137    let mut cleaned_items = Vec::new();
2138    let mut generated_code = Vec::new();
2139
2140    for item in &item_impl.items {
2141        if let syn::ImplItem::Fn(method) = item {
2142            // 检查是否有 http_* 标注
2143            let has_http_attr = method.attrs.iter().any(|attr| {
2144                let attr_name = attr.path().to_token_stream().to_string();
2145                attr_name.contains("http_get")
2146                    || attr_name.contains("http_post")
2147                    || attr_name.contains("http_put")
2148                    || attr_name.contains("http_delete")
2149                    || attr_name.contains("http_head")
2150                    || attr_name.contains("http_patch")
2151                    || attr_name.contains("http_options")
2152            });
2153
2154            if has_http_attr {
2155                // 有 http_* 标注,创建清理后的方法(移除 http_* 标注)
2156                let mut cleaned_method = method.clone();
2157                cleaned_method.attrs = method
2158                    .attrs
2159                    .iter()
2160                    .filter(|attr| {
2161                        let attr_name = attr.path().to_token_stream().to_string();
2162                        !attr_name.contains("http_get")
2163                            && !attr_name.contains("http_post")
2164                            && !attr_name.contains("http_put")
2165                            && !attr_name.contains("http_delete")
2166                            && !attr_name.contains("http_head")
2167                            && !attr_name.contains("http_patch")
2168                            && !attr_name.contains("http_options")
2169                    })
2170                    .cloned()
2171                    .collect();
2172
2173                cleaned_items.push(syn::ImplItem::Fn(cleaned_method));
2174
2175                // 为每个 http_* 标注生成包装函数和路由注册
2176                for attr in &method.attrs {
2177                    let attr_name = attr.path().to_token_stream().to_string();
2178                    if attr_name.contains("http_get")
2179                        || attr_name.contains("http_post")
2180                        || attr_name.contains("http_put")
2181                        || attr_name.contains("http_delete")
2182                        || attr_name.contains("http_head")
2183                        || attr_name.contains("http_patch")
2184                        || attr_name.contains("http_options")
2185                    {
2186                        // 提取 HTTP 方法名
2187                        let http_method = if attr_name.contains("http_get") {
2188                            "GET"
2189                        } else if attr_name.contains("http_post") {
2190                            "POST"
2191                        } else if attr_name.contains("http_put") {
2192                            "PUT"
2193                        } else if attr_name.contains("http_delete") {
2194                            "DELETE"
2195                        } else if attr_name.contains("http_head") {
2196                            "HEAD"
2197                        } else if attr_name.contains("http_patch") {
2198                            "PATCH"
2199                        } else {
2200                            "OPTIONS"
2201                        };
2202
2203                        // 提取 path 参数
2204                        let method_path = match &attr.meta {
2205                            syn::Meta::List(list) => {
2206                                if let Ok(lit_str) =
2207                                    syn::parse::<syn::LitStr>(list.tokens.clone().into())
2208                                {
2209                                    lit_str.value()
2210                                } else {
2211                                    String::new()
2212                                }
2213                            }
2214                            _ => String::new(),
2215                        };
2216
2217                        // 拼接路径(在宏展开时完成)
2218                        let final_path = if let Some(ref base_path) = base_path_str {
2219                            // base path 已知,直接在宏展开时拼接
2220                            if method_path.is_empty() {
2221                                base_path.clone()
2222                            } else {
2223                                format!("{}{}", base_path, method_path)
2224                            }
2225                        } else {
2226                            // base path 未知(从常量读取),这种情况暂不支持
2227                            panic!("impl block controller must specify a base path, e.g., #[potato::controller(\"/api/users\")]");
2228                        };
2229
2230                        // 将 final_path 转换为 LitStr
2231                        let final_path_lit =
2232                            syn::LitStr::new(&final_path, proc_macro2::Span::call_site());
2233
2234                        // 生成包装函数名
2235                        let fn_name = &method.sig.ident;
2236                        let wrapper_fn_name = quote::format_ident!("__potato_ctrl_{}", fn_name);
2237                        let is_async = method.sig.asyncness.is_some();
2238
2239                        // 检测方法是否有 receiver,以及是否是 mutable
2240                        let (has_receiver, _is_mut_receiver) = method
2241                            .sig
2242                            .inputs
2243                            .iter()
2244                            .filter_map(|arg| {
2245                                if let syn::FnArg::Receiver(recv) = arg {
2246                                    Some((true, recv.mutability.is_some()))
2247                                } else {
2248                                    None
2249                                }
2250                            })
2251                            .next()
2252                            .unwrap_or((false, false));
2253
2254                        // 提取非 receiver 参数
2255                        let other_params: Vec<_> = method
2256                            .sig
2257                            .inputs
2258                            .iter()
2259                            .filter_map(|arg| {
2260                                if let syn::FnArg::Typed(pat_type) = arg {
2261                                    Some(pat_type.clone())
2262                                } else {
2263                                    None
2264                                }
2265                            })
2266                            .collect();
2267
2268                        // 检测非 receiver 参数中是否包含 SessionCache
2269                        let method_has_session_cache = other_params.iter().any(|pat_type| {
2270                            pat_type.ty.to_token_stream().to_string().type_simplify()
2271                                == "& mut SessionCache"
2272                        });
2273
2274                        // 如果方法有 receiver(&self 或 &mut self),或者参数包含 SessionCache,则需要鉴权
2275                        let doc_auth = has_receiver || method_has_session_cache;
2276
2277                        // 提取参数名
2278                        let param_names: Vec<_> = other_params
2279                            .iter()
2280                            .filter_map(|pat_type| {
2281                                if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
2282                                    Some(pat_ident.ident.clone())
2283                                } else {
2284                                    None
2285                                }
2286                            })
2287                            .collect();
2288
2289                        // 生成方法调用
2290                        let method_call = if has_receiver {
2291                            // 有 receiver,需要实例化 controller
2292                            // 使用结构体生成的 __potato_create_controller 函数
2293                            // 返回 Box<Self>
2294
2295                            if param_names.is_empty() {
2296                                if is_async {
2297                                    quote! {
2298                                        {
2299                                            let mut controller = match #self_type_name::__potato_create_controller(req).await {
2300                                                Ok(boxed) => boxed,
2301                                                Err(resp) => return resp,
2302                                            };
2303                                            controller.#fn_name().await
2304                                        }
2305                                    }
2306                                } else {
2307                                    quote! {
2308                                        {
2309                                            let mut controller = match #self_type_name::__potato_create_controller(req).await {
2310                                                Ok(boxed) => boxed,
2311                                                Err(resp) => return resp,
2312                                            };
2313                                            controller.#fn_name()
2314                                        }
2315                                    }
2316                                }
2317                            } else {
2318                                if is_async {
2319                                    quote! {
2320                                        {
2321                                            let mut controller = match #self_type_name::__potato_create_controller(req).await {
2322                                                Ok(boxed) => boxed,
2323                                                Err(resp) => return resp,
2324                                            };
2325                                            controller.#fn_name(#(#param_names),*).await
2326                                        }
2327                                    }
2328                                } else {
2329                                    quote! {
2330                                        {
2331                                            let mut controller = match #self_type_name::__potato_create_controller(req).await {
2332                                                Ok(boxed) => boxed,
2333                                                Err(resp) => return resp,
2334                                            };
2335                                            controller.#fn_name(#(#param_names),*)
2336                                        }
2337                                    }
2338                                }
2339                            }
2340                        } else {
2341                            // 没有 receiver,直接调用关联函数
2342                            // 但仍需要处理参数(如 SessionCache)
2343
2344                            if param_names.is_empty() {
2345                                if is_async {
2346                                    quote! { #self_type_name::#fn_name().await }
2347                                } else {
2348                                    quote! { #self_type_name::#fn_name() }
2349                                }
2350                            } else {
2351                                // 有参数,需要根据参数类型生成绑定代码
2352                                // 生成参数绑定
2353                                let mut param_bindings = Vec::new();
2354                                for (i, param) in other_params.iter().enumerate() {
2355                                    let param_type_str =
2356                                        param.ty.to_token_stream().to_string().type_simplify();
2357                                    let param_name = &param_names[i];
2358
2359                                    match &param_type_str[..] {
2360                                        "& mut OnceCache" => {
2361                                            param_bindings.push(quote! {
2362                                                let #param_name = &mut __potato_once_cache;
2363                                            });
2364                                        }
2365                                        "& mut SessionCache" => {
2366                                            param_bindings.push(quote! {
2367                                                let #param_name = &mut __potato_session_cache;
2368                                            });
2369                                        }
2370                                        _ => {
2371                                            // 其他参数类型暂不支持
2372                                        }
2373                                    }
2374                                }
2375
2376                                // 生成 SessionCache 加载逻辑(从 Authorization header)
2377                                let needs_session_cache = other_params.iter().any(|p| {
2378                                    p.ty.to_token_stream().to_string().type_simplify()
2379                                        == "& mut SessionCache"
2380                                });
2381
2382                                let session_cache_init = if needs_session_cache {
2383                                    quote! {
2384                                        {
2385                                            // 从 Authorization header 中提取 Bearer token 并加载 session
2386                                            if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
2387                                                let header_value = h.as_str();
2388                                                if header_value.starts_with("Bearer ") {
2389                                                    potato::SessionCache::from_token(&header_value[7..]).await.ok()
2390                                                } else {
2391                                                    None
2392                                                }
2393                                            } else {
2394                                                None
2395                                            }
2396                                        }
2397                                    }
2398                                } else {
2399                                    quote! { None }
2400                                };
2401
2402                                if is_async {
2403                                    quote! {
2404                                        {
2405                                            let mut __potato_once_cache = potato::OnceCache::new();
2406                                            let mut __potato_session_cache = #session_cache_init.unwrap_or_else(|| potato::SessionCache::new());
2407                                            #(#param_bindings)*
2408                                            #self_type_name::#fn_name(#(#param_names),*).await
2409                                        }
2410                                    }
2411                                } else {
2412                                    quote! {
2413                                        {
2414                                            let mut __potato_once_cache = potato::OnceCache::new();
2415                                            let mut __potato_session_cache = #session_cache_init.unwrap_or_else(|| potato::SessionCache::new());
2416                                            #(#param_bindings)*
2417                                            #self_type_name::#fn_name(#(#param_names),*)
2418                                        }
2419                                    }
2420                                }
2421                            }
2422                        };
2423
2424                        // 生成包装函数 - 简化版,直接返回文本
2425                        let wrapper_fn = if is_async {
2426                            quote! {
2427                                #[doc(hidden)]
2428                                fn #wrapper_fn_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
2429                                    Box::pin(async move {
2430                                        match #method_call {
2431                                            Ok(resp) => potato::HttpResponse::text(resp.to_string()),
2432                                            Err(err) => potato::HttpResponse::error(err.to_string()),
2433                                        }
2434                                    })
2435                                }
2436                            }
2437                        } else {
2438                            quote! {
2439                                #[doc(hidden)]
2440                                fn #wrapper_fn_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
2441                                    Box::pin(async move {
2442                                        match #method_call {
2443                                            Ok(resp) => potato::HttpResponse::text(resp.to_string()),
2444                                            Err(err) => potato::HttpResponse::error(err.to_string()),
2445                                        }
2446                                    })
2447                                }
2448                            }
2449                        };
2450
2451                        generated_code.push(wrapper_fn);
2452
2453                        // 生成路由注册
2454                        let http_method_ident = quote::format_ident!("{}", http_method);
2455
2456                        let route_register = quote! {
2457                            potato::inventory::submit! {
2458                                potato::RequestHandlerFlag::new(
2459                                    potato::HttpMethod::#http_method_ident,
2460                                    #final_path_lit,
2461                                    potato::HttpHandler::Async(#wrapper_fn_name),
2462                                    potato::RequestHandlerFlagDoc::new(true, #doc_auth, "", "", "", #self_type_tag)
2463                                )
2464                            }
2465                        };
2466
2467                        generated_code.push(route_register);
2468                    }
2469                }
2470            } else {
2471                // 没有 http_* 标注,保留原方法
2472                cleaned_items.push(item.clone());
2473            }
2474        } else {
2475            // 非方法项,直接保留
2476            cleaned_items.push(item.clone());
2477        }
2478    }
2479
2480    // 生成清理后的 impl 块
2481    let mut cleaned_impl = item_impl.clone();
2482    cleaned_impl.items = cleaned_items;
2483
2484    // 生成最终代码:清理后的 impl 块 + 生成的独立函数和路由注册
2485    let output = quote! {
2486        #cleaned_impl
2487
2488        #(#generated_code)*
2489    };
2490
2491    output.into()
2492}
2493
2494#[proc_macro_attribute]
2495pub fn preprocess(attr: TokenStream, input: TokenStream) -> TokenStream {
2496    preprocess_macro(attr, input)
2497}
2498
2499#[proc_macro_attribute]
2500pub fn postprocess(attr: TokenStream, input: TokenStream) -> TokenStream {
2501    postprocess_macro(attr, input)
2502}
2503
2504/// handle_error 属性宏 - 定义全局错误处理函数
2505///
2506/// # 签名要求
2507/// - 参数1: `req: &mut HttpRequest`
2508/// - 参数2: `err: anyhow::Error`
2509/// - 返回: `HttpResponse`
2510/// - 支持 async fn 和普通 fn
2511///
2512/// # 示例
2513/// ```rust,ignore
2514/// #[potato::handle_error]
2515/// async fn handle_error(req: &mut HttpRequest, err: anyhow::Error) -> HttpResponse {
2516///     HttpResponse::json(serde_json::json!({
2517///         "error": format!("{}", err)
2518///     }))
2519/// }
2520/// ```
2521fn handle_error_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
2522    if !attr.is_empty() {
2523        return input;
2524    }
2525
2526    let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
2527    let fn_name = root_fn.sig.ident.clone();
2528    let is_async = root_fn.sig.asyncness.is_some();
2529
2530    // 验证函数签名
2531    if root_fn.sig.inputs.len() != 2 {
2532        panic!("`handle_error` function must accept exactly two arguments");
2533    }
2534
2535    let mut arg_types = vec![];
2536    for arg in root_fn.sig.inputs.iter() {
2537        match arg {
2538            syn::FnArg::Typed(arg) => {
2539                arg_types.push(arg.ty.to_token_stream().to_string().type_simplify())
2540            }
2541            _ => panic!("`handle_error` function does not support receiver argument"),
2542        }
2543    }
2544
2545    if arg_types[0] != "& mut HttpRequest" {
2546        panic!(
2547            "`handle_error` first argument must be `&mut potato::HttpRequest`, got `{}`",
2548            arg_types[0]
2549        );
2550    }
2551    if arg_types[1] != "anyhow::Error" {
2552        panic!(
2553            "`handle_error` second argument must be `anyhow::Error`, got `{}`",
2554            arg_types[1]
2555        );
2556    }
2557
2558    let ret_type = root_fn
2559        .sig
2560        .output
2561        .to_token_stream()
2562        .to_string()
2563        .type_simplify();
2564    if ret_type != "HttpResponse" {
2565        panic!(
2566            "`handle_error` return type must be `potato::HttpResponse`, got `{}`",
2567            ret_type
2568        );
2569    }
2570
2571    // 生成适配器函数
2572    let wrap_name = format_ident!("__potato_error_handler_adapter_{}", fn_name);
2573    let wrap_name_inner = format_ident!("__potato_error_handler_adapter_inner_{}", fn_name);
2574
2575    // 生成内部函数
2576    let call_body = if is_async {
2577        quote! { #fn_name(req, err).await }
2578    } else {
2579        quote! { #fn_name(req, err) }
2580    };
2581
2582    quote! {
2583        #root_fn
2584
2585        #[doc(hidden)]
2586        async fn #wrap_name_inner(
2587            req: &mut potato::HttpRequest,
2588            err: anyhow::Error,
2589        ) -> potato::HttpResponse {
2590            #call_body
2591        }
2592
2593        #[doc(hidden)]
2594        pub fn #wrap_name(
2595            req: &mut potato::HttpRequest,
2596            err: anyhow::Error,
2597        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
2598            Box::pin(#wrap_name_inner(req, err))
2599        }
2600
2601        potato::inventory::submit! {
2602            potato::ErrorHandlerFlag::new(
2603                potato::ErrorHandler::Async(#wrap_name)
2604            )
2605        }
2606    }
2607    .into()
2608}
2609
2610#[proc_macro_attribute]
2611pub fn handle_error(attr: TokenStream, input: TokenStream) -> TokenStream {
2612    handle_error_macro(attr, input)
2613}
2614
2615/// limit_size 属性宏 - 为 handler 设置独立的请求体大小限制
2616///
2617/// # 参数
2618/// * 单个值: `#[potato::limit_size(1024 * 1024 * 1024)]` - 仅限制 body 为 1GB
2619/// * 命名参数: `#[potato::limit_size(header = 2 * 1024 * 1024, body = 500 * 1024 * 1024)]`
2620///
2621/// # 示例
2622/// ```rust,ignore
2623/// // 限制 body 为 1GB
2624/// #[potato::http_post("/upload")]
2625/// #[potato::limit_size(1024 * 1024 * 1024)]
2626/// async fn large_upload(req: &mut potato::HttpRequest) -> potato::HttpResponse {
2627///     todo!()
2628/// }
2629///
2630/// // 分别限制 header 和 body
2631/// #[potato::http_post("/upload")]
2632/// #[potato::limit_size(header = 2 * 1024 * 1024, body = 500 * 1024 * 1024)]
2633/// async fn medium_upload(req: &mut potato::HttpRequest) -> potato::HttpResponse {
2634///     todo!()
2635/// }
2636/// ```
2637#[proc_macro_attribute]
2638pub fn limit_size(attr: TokenStream, input: TokenStream) -> TokenStream {
2639    limit_size_macro(attr, input)
2640}
2641
2642fn limit_size_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
2643    // 解析参数
2644    let (_max_header, max_body) = {
2645        let attr_tokens: proc_macro2::TokenStream = attr.clone().into();
2646        if attr_tokens.is_empty() {
2647            // 默认值: 不限制 header,使用全局 body 限制
2648            (None, None)
2649        } else {
2650            // 尝试解析为命名参数或单值
2651            let result = syn::parse::Parser::parse2(
2652                |input: syn::parse::ParseStream| -> syn::Result<(Option<syn::Expr>, Option<syn::Expr>)> {
2653                    let mut header_expr = None;
2654                    let mut body_expr = None;
2655
2656                    // 尝试解析命名参数
2657                    while !input.is_empty() {
2658                        let ident: Ident = input.parse()?;
2659                        input.parse::<Token![=]>()?;
2660                        let value: syn::Expr = input.parse()?;
2661
2662                        match ident.to_string().as_str() {
2663                            "header" => header_expr = Some(value),
2664                            "body" => body_expr = Some(value),
2665                            _ => return Err(syn::Error::new(ident.span(), "expected 'header' or 'body'")),
2666                        }
2667
2668                        // 可选的逗号
2669                        if input.peek(Token![,]) {
2670                            input.parse::<Token![,]>()?;
2671                        }
2672                    }
2673
2674                    Ok((header_expr, body_expr))
2675                },
2676                attr_tokens.clone(),
2677            );
2678
2679            match result {
2680                Ok((h, b)) => (h, b),
2681                Err(_) => {
2682                    // 解析失败,尝试作为单值(body 限制)
2683                    if let Ok(expr) = syn::parse2::<syn::Expr>(attr_tokens) {
2684                        (None, Some(expr))
2685                    } else {
2686                        (None, None)
2687                    }
2688                }
2689            }
2690        }
2691    };
2692
2693    let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
2694
2695    // 生成检查代码
2696    let body_check = if let Some(body_expr) = max_body {
2697        quote! {
2698            // 检查 body 大小
2699            let body_len = req.body.len();
2700            if body_len > #body_expr {
2701                let mut res = potato::HttpResponse::text(format!(
2702                    "Payload Too Large: body size {} bytes exceeds limit {} bytes",
2703                    body_len, #body_expr
2704                ));
2705                res.http_code = 413;
2706                return res;
2707            }
2708        }
2709    } else {
2710        quote! {}
2711    };
2712
2713    // 克隆整个函数,然后修改 block
2714    let mut wrapped_fn = root_fn.clone();
2715    let original_block = root_fn.block.as_ref();
2716    let new_block: syn::Block = syn::parse_quote!({
2717        #body_check
2718        #original_block
2719    });
2720    wrapped_fn.block = Box::new(new_block);
2721
2722    quote! {
2723        #wrapped_fn
2724    }
2725    .into()
2726}
2727
2728/// header 属性宏 - 这是一个占位宏,实际解析在 http_handler_macro 中完成
2729/// 这个宏的存在使得 #[potato::header(...)] 语法能够被编译器识别
2730#[proc_macro_attribute]
2731pub fn header(_attr: TokenStream, input: TokenStream) -> TokenStream {
2732    // 直接返回原始函数,不做任何修改
2733    // 实际的 header 解析和处理在 http_get/http_post 等宏中完成
2734    input
2735}
2736
2737/// cors 属性宏 - 这是一个占位宏,实际解析在 http_handler_macro 中完成
2738/// 这个宏的存在使得 #[potato::cors(...)] 语法能够被编译器识别
2739#[proc_macro_attribute]
2740pub fn cors(_attr: TokenStream, input: TokenStream) -> TokenStream {
2741    // 直接返回原始函数,不做任何修改
2742    // 实际的 cors 解析和处理在 http_handler_macro 中完成
2743    input
2744}
2745
2746#[proc_macro]
2747pub fn embed_dir(input: TokenStream) -> TokenStream {
2748    let path = syn::parse_macro_input!(input as syn::LitStr).value();
2749    quote! {{
2750        #[derive(potato::rust_embed::Embed)]
2751        #[folder = #path]
2752        struct Asset;
2753
2754        potato::load_embed::<Asset>()
2755    }}
2756    .into()
2757}
2758
2759#[proc_macro_derive(StandardHeader)]
2760pub fn standard_header_derive(input: TokenStream) -> TokenStream {
2761    let root_enum = syn::parse_macro_input!(input as syn::ItemEnum);
2762    let enum_name = root_enum.ident;
2763    let mut try_from_str_items = vec![];
2764    let mut to_str_items = vec![];
2765    let mut headers_items = vec![];
2766    let mut headers_apply_items = vec![];
2767    for root_field in root_enum.variants.iter() {
2768        let name = root_field.ident.clone();
2769        if root_field.fields.iter().next().is_some() {
2770            panic!("unsupported enum type");
2771        }
2772        let str_name = name.to_string().replace("_", "-");
2773        let len = str_name.len();
2774        try_from_str_items
2775            .push(quote! { #len if value.eq_ignore_ascii_case(#str_name) => Some(Self::#name), });
2776        to_str_items.push(quote! { Self::#name => #str_name, });
2777        headers_items.push(quote! { #name(String), });
2778        headers_apply_items
2779            .push(quote! { Headers::#name(s) => self.set_header(HeaderItem::#name.to_str(), s), });
2780    }
2781    let r = quote! {
2782        impl #enum_name {
2783            pub fn try_from_str(value: &str) -> Option<Self> {
2784                match value.len() {
2785                    #( #try_from_str_items )*
2786                    _ => None,
2787                }
2788            }
2789
2790            pub fn to_str(&self) -> &'static str {
2791                match self {
2792                    #( #to_str_items )*
2793                }
2794            }
2795        }
2796
2797        pub enum Headers {
2798            #( #headers_items )*
2799            Custom((String, String)),
2800        }
2801
2802        impl HttpRequest {
2803            pub fn apply_header(&mut self, header: Headers) {
2804                match header {
2805                    #( #headers_apply_items )*
2806                    Headers::Custom((k, v)) => self.set_header(&k[..], v),
2807                }
2808            }
2809        }
2810    };
2811    r.into()
2812}