Skip to main content

potato_macro/
lib.rs

1mod utils;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span};
5use quote::{format_ident, quote, ToTokens};
6use rand::Rng;
7use serde_json::json;
8use std::{collections::HashSet, sync::LazyLock};
9use syn::Token;
10use utils::StringExt as _;
11
12/// CORS配置结构体(宏内部使用)
13struct CorsAttrConfig {
14    origin: Option<String>,
15    methods: Option<String>,
16    headers: Option<String>,
17    max_age: Option<String>,
18    credentials: bool,
19    expose_headers: Option<String>,
20}
21
22/// 解析CORS属性
23fn parse_cors_attr(tokens: &proc_macro2::TokenStream) -> CorsAttrConfig {
24    let config = CorsAttrConfig {
25        origin: None,
26        methods: None,
27        headers: None,
28        max_age: None,
29        credentials: false,
30        expose_headers: None,
31    };
32
33    if tokens.is_empty() {
34        return config; // 返回最小限制配置(origin="*", headers="*", methods自动计算)
35    }
36
37    // 解析 key = value 格式
38    use syn::parse::Parser;
39
40    fn parse_inner(input: syn::parse::ParseStream) -> syn::Result<CorsAttrConfig> {
41        let mut config = CorsAttrConfig {
42            origin: None,
43            methods: None,
44            headers: None,
45            max_age: None,
46            credentials: false,
47            expose_headers: None,
48        };
49
50        let vars =
51            syn::punctuated::Punctuated::<syn::MetaNameValue, Token![,]>::parse_terminated(input)?;
52        for meta in vars {
53            let key = meta
54                .path
55                .get_ident()
56                .map(|i| i.to_string())
57                .unwrap_or_default();
58            if let syn::Expr::Lit(expr_lit) = &meta.value {
59                match &expr_lit.lit {
60                    syn::Lit::Str(s) => {
61                        let val = s.value();
62                        match key.as_str() {
63                            "origin" => config.origin = Some(val),
64                            "methods" => config.methods = Some(val),
65                            "headers" => config.headers = Some(val),
66                            "max_age" => config.max_age = Some(val),
67                            "expose_headers" => config.expose_headers = Some(val),
68                            _ => {}
69                        }
70                    }
71                    syn::Lit::Bool(b) => {
72                        if key == "credentials" {
73                            config.credentials = b.value();
74                        }
75                    }
76                    _ => {}
77                }
78            }
79        }
80        Ok(config)
81    }
82
83    match parse_inner.parse2(tokens.clone()) {
84        Ok(cfg) => cfg,
85        Err(e) => panic!("Failed to parse cors attributes: {e}"),
86    }
87}
88
89static ARG_TYPES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
90    [
91        "String", "bool", "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
92        "f32", "f64",
93    ]
94    .into_iter()
95    .collect()
96});
97
98/// Controller 字段类型
99/// 验证 Controller 结构体字段
100fn validate_controller_struct(item_struct: &syn::ItemStruct) -> (bool, bool) {
101    let mut has_once_cache = false;
102    let mut has_session_cache = false;
103
104    if let syn::Fields::Named(fields_named) = &item_struct.fields {
105        for field in &fields_named.named {
106            let field_type_str = field.ty.to_token_stream().to_string().type_simplify();
107
108            // 验证类型必须是 &OnceCache 或 &SessionCache(支持生命周期参数)
109            if field_type_str.contains("OnceCache") {
110                has_once_cache = true;
111            } else if field_type_str.contains("SessionCache") {
112                has_session_cache = true;
113            } else {
114                panic!(
115                    "Controller field must be &OnceCache or &SessionCache, got: {}",
116                    field_type_str
117                );
118            }
119        }
120    }
121
122    (has_once_cache, has_session_cache)
123}
124
125/// 解析 header 标注的 tokens,返回 (key, value)
126fn parse_header_attr(tokens: &proc_macro2::TokenStream) -> Result<(String, String), syn::Error> {
127    use syn::parse::Parser;
128
129    let parser = |input: syn::parse::ParseStream| {
130        // 支持两种格式:
131        // 1. Key = "value" (标准 header)
132        // 2. Custom("key") = "value" (自定义 header)
133        let key_ident: Ident = input.parse()?;
134        let key_name = key_ident.to_string();
135
136        if key_name == "Custom" {
137            // Custom("key") = "value" 格式
138            let content;
139            syn::parenthesized!(content in input);
140            let key_lit: syn::LitStr = content.parse()?;
141            let key = key_lit.value();
142            let _: Token![=] = input.parse()?;
143            let value: syn::LitStr = input.parse()?;
144            Ok((key, value.value()))
145        } else {
146            // Key = "value" 格式
147            let _: Token![=] = input.parse()?;
148            let value: syn::LitStr = input.parse()?;
149            Ok((key_name, value.value()))
150        }
151    };
152
153    parser.parse2(tokens.clone())
154}
155
156fn random_ident() -> Ident {
157    let mut rng = rand::thread_rng();
158    let value = format!("__potato_id_{}", rng.r#gen::<u64>());
159    Ident::new(&value, Span::call_site())
160}
161
162fn attr_last_ident(attr: &syn::Attribute) -> Option<String> {
163    attr.meta
164        .path()
165        .segments
166        .iter()
167        .last()
168        .map(|segment| segment.ident.to_string())
169}
170
171fn parse_hook_attr_items(attr: &syn::Attribute, attr_name: &str) -> Vec<Ident> {
172    let parser = syn::punctuated::Punctuated::<Ident, syn::Token![,]>::parse_terminated;
173    let idents = attr.parse_args_with(parser).unwrap_or_else(|err| {
174        panic!("invalid `{attr_name}` annotation: {err}");
175    });
176    if idents.is_empty() {
177        panic!("`{attr_name}` annotation requires at least one function name");
178    }
179    idents.into_iter().collect()
180}
181
182fn collect_handler_hooks(root_fn: &mut syn::ItemFn) -> (Vec<Ident>, Vec<Ident>) {
183    enum HookKind {
184        Pre,
185        Post,
186    }
187    let mut hooks = vec![];
188    let mut new_attrs = Vec::with_capacity(root_fn.attrs.len());
189    for attr in root_fn.attrs.iter() {
190        match attr_last_ident(attr).as_deref() {
191            Some("preprocess") => {
192                hooks.extend(
193                    parse_hook_attr_items(attr, "preprocess")
194                        .into_iter()
195                        .map(|item| (HookKind::Pre, item)),
196                );
197            }
198            Some("postprocess") => {
199                hooks.extend(
200                    parse_hook_attr_items(attr, "postprocess")
201                        .into_iter()
202                        .map(|item| (HookKind::Post, item)),
203                );
204            }
205            _ => new_attrs.push(attr.clone()),
206        }
207    }
208    root_fn.attrs = new_attrs;
209    let mut preprocess_fns = vec![];
210    let mut postprocess_fns = vec![];
211    for (kind, hook) in hooks.into_iter() {
212        match kind {
213            HookKind::Pre => preprocess_fns.push(hook),
214            HookKind::Post => postprocess_fns.push(hook),
215        }
216    }
217    (preprocess_fns, postprocess_fns)
218}
219
220fn validate_preprocess_signature(root_fn: &syn::ItemFn) -> (String, bool, bool) {
221    if root_fn.sig.inputs.is_empty() || root_fn.sig.inputs.len() > 3 {
222        panic!("`preprocess` function must accept one to three arguments");
223    }
224    let mut arg_types = vec![];
225    for arg in root_fn.sig.inputs.iter() {
226        match arg {
227            syn::FnArg::Typed(arg) => {
228                arg_types.push(arg.ty.to_token_stream().to_string().type_simplify())
229            }
230            _ => panic!("`preprocess` function does not support receiver argument"),
231        }
232    }
233    if arg_types[0] != "& mut HttpRequest" {
234        panic!(
235            "`preprocess` first argument type must be `&mut potato::HttpRequest`, got `{}`",
236            arg_types[0]
237        );
238    }
239
240    let has_once_cache = arg_types.iter().any(|t| t == "& mut OnceCache");
241    let has_session_cache = arg_types.iter().any(|t| t == "& mut SessionCache");
242
243    if arg_types.len() == 2 && !has_once_cache && !has_session_cache {
244        panic!(
245            "`preprocess` second argument type must be `&mut potato::OnceCache` or `&mut potato::SessionCache`, got `{}`",
246            arg_types[1]
247        );
248    }
249    if arg_types.len() == 3 {
250        if !has_once_cache {
251            panic!("`preprocess` must have `&mut potato::OnceCache` as one of the arguments");
252        }
253        if !has_session_cache {
254            panic!("`preprocess` must have `&mut potato::SessionCache` as one of the arguments");
255        }
256    }
257
258    let ret_type = root_fn
259        .sig
260        .output
261        .to_token_stream()
262        .to_string()
263        .type_simplify();
264    match &ret_type[..] {
265        "Result<Option<HttpResponse>>" | "Option<HttpResponse>" | "Result<()>" | "()" => {}
266        _ => panic!(
267            "unsupported `preprocess` return type: `{ret_type}`, expected `anyhow::Result<Option<potato::HttpResponse>>`, `Option<potato::HttpResponse>`, `anyhow::Result<()>`, or `()`"
268        ),
269    }
270    (ret_type, has_once_cache, has_session_cache)
271}
272
273fn validate_postprocess_signature(root_fn: &syn::ItemFn) -> (String, bool, bool) {
274    if root_fn.sig.inputs.len() < 2 && root_fn.sig.inputs.len() > 4 {
275        panic!("`postprocess` function must accept two to four arguments");
276    }
277    let mut arg_types = vec![];
278    for arg in root_fn.sig.inputs.iter() {
279        match arg {
280            syn::FnArg::Typed(arg) => {
281                arg_types.push(arg.ty.to_token_stream().to_string().type_simplify())
282            }
283            _ => panic!("`postprocess` function does not support receiver argument"),
284        }
285    }
286    if arg_types[0] != "& mut HttpRequest" {
287        panic!(
288            "`postprocess` first argument must be `&mut potato::HttpRequest`, got `{}`",
289            arg_types[0]
290        );
291    }
292    if arg_types[1] != "& mut HttpResponse" {
293        panic!(
294            "`postprocess` second argument must be `&mut potato::HttpResponse`, got `{}`",
295            arg_types[1]
296        );
297    }
298
299    let remaining_args = &arg_types[2..];
300    let has_once_cache = remaining_args.iter().any(|t| t == "& mut OnceCache");
301    let has_session_cache = remaining_args.iter().any(|t| t == "& mut SessionCache");
302
303    if arg_types.len() == 3 && !has_once_cache && !has_session_cache {
304        panic!(
305            "`postprocess` third argument must be `&mut potato::OnceCache` or `&mut potato::SessionCache`, got `{}`",
306            arg_types[2]
307        );
308    }
309    if arg_types.len() == 4 && (!has_once_cache || !has_session_cache) {
310        panic!(
311            "`postprocess` with 4 arguments must have both `&mut potato::OnceCache` and `&mut potato::SessionCache`"
312        );
313    }
314
315    let ret_type = root_fn
316        .sig
317        .output
318        .to_token_stream()
319        .to_string()
320        .type_simplify();
321    match &ret_type[..] {
322        "Result<()>" | "()" => {}
323        _ => panic!(
324            "unsupported `postprocess` return type: `{ret_type}`, expected `anyhow::Result<()>` or `()`"
325        ),
326    }
327    (ret_type, has_once_cache, has_session_cache)
328}
329
330fn preprocess_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
331    if !attr.is_empty() {
332        return input;
333    }
334    let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
335    let fn_name = root_fn.sig.ident.clone();
336    let wrap_name = format_ident!("__potato_preprocess_adapter_{}", fn_name);
337    let wrap_name_inner = format_ident!("__potato_preprocess_adapter_inner_{}", fn_name);
338    let is_async = root_fn.sig.asyncness.is_some();
339    let (ret_type, has_once_cache, has_session_cache) = validate_preprocess_signature(&root_fn);
340
341    // 根据是否需要缓存生成不同的函数签名
342    let wrap_signature = match (has_once_cache, has_session_cache) {
343        (true, true) => quote! {
344            async fn #wrap_name_inner(
345                req: &mut potato::HttpRequest,
346                once_cache: &mut potato::OnceCache,
347                session_cache: &mut potato::SessionCache,
348            ) -> anyhow::Result<Option<potato::HttpResponse>>
349        },
350        (true, false) => quote! {
351            async fn #wrap_name_inner(
352                req: &mut potato::HttpRequest,
353                once_cache: &mut potato::OnceCache,
354            ) -> anyhow::Result<Option<potato::HttpResponse>>
355        },
356        (false, true) => quote! {
357            async fn #wrap_name_inner(
358                req: &mut potato::HttpRequest,
359                session_cache: &mut potato::SessionCache,
360            ) -> anyhow::Result<Option<potato::HttpResponse>>
361        },
362        (false, false) => quote! {
363            async fn #wrap_name_inner(
364                req: &mut potato::HttpRequest,
365            ) -> anyhow::Result<Option<potato::HttpResponse>>
366        },
367    };
368
369    // 根据实际使用情况调用函数
370    let call_body = if is_async {
371        match &ret_type[..] {
372            "Result<Option<HttpResponse>>" => match (has_once_cache, has_session_cache) {
373                (true, true) => {
374                    quote! { #fn_name(req, once_cache, session_cache).await }
375                }
376                (true, false) => quote! { #fn_name(req, once_cache).await },
377                (false, true) => quote! { #fn_name(req, session_cache).await },
378                (false, false) => quote! { #fn_name(req).await },
379            },
380            "Option<HttpResponse>" => match (has_once_cache, has_session_cache) {
381                (true, true) => quote! { Ok(#fn_name(req, once_cache, session_cache).await) },
382                (true, false) => quote! { Ok(#fn_name(req, once_cache).await) },
383                (false, true) => quote! { Ok(#fn_name(req, session_cache).await) },
384                (false, false) => quote! { Ok(#fn_name(req).await) },
385            },
386            "Result<()>" => match (has_once_cache, has_session_cache) {
387                (true, true) => {
388                    quote! { #fn_name(req, once_cache, session_cache).await.map(|_| None) }
389                }
390                (true, false) => quote! { #fn_name(req, once_cache).await.map(|_| None) },
391                (false, true) => quote! { #fn_name(req, session_cache).await.map(|_| None) },
392                (false, false) => quote! { #fn_name(req).await.map(|_| None) },
393            },
394            "()" => match (has_once_cache, has_session_cache) {
395                (true, true) => quote! { #fn_name(req, once_cache, session_cache).await; Ok(None) },
396                (true, false) => quote! { #fn_name(req, once_cache).await; Ok(None) },
397                (false, true) => quote! { #fn_name(req, session_cache).await; Ok(None) },
398                (false, false) => quote! { #fn_name(req).await; Ok(None) },
399            },
400            _ => unreachable!(),
401        }
402    } else {
403        match &ret_type[..] {
404            "Result<Option<HttpResponse>>" => match (has_once_cache, has_session_cache) {
405                (true, true) => quote! { #fn_name(req, once_cache, session_cache) },
406                (true, false) => quote! { #fn_name(req, once_cache) },
407                (false, true) => quote! { #fn_name(req, session_cache) },
408                (false, false) => quote! { #fn_name(req) },
409            },
410            "Option<HttpResponse>" => match (has_once_cache, has_session_cache) {
411                (true, true) => quote! { Ok(#fn_name(req, once_cache, session_cache)) },
412                (true, false) => quote! { Ok(#fn_name(req, once_cache)) },
413                (false, true) => quote! { Ok(#fn_name(req, session_cache)) },
414                (false, false) => quote! { Ok(#fn_name(req)) },
415            },
416            "Result<()>" => match (has_once_cache, has_session_cache) {
417                (true, true) => quote! { #fn_name(req, once_cache, session_cache).map(|_| None) },
418                (true, false) => quote! { #fn_name(req, once_cache).map(|_| None) },
419                (false, true) => quote! { #fn_name(req, session_cache).map(|_| None) },
420                (false, false) => quote! { #fn_name(req).map(|_| None) },
421            },
422            "()" => match (has_once_cache, has_session_cache) {
423                (true, true) => quote! { #fn_name(req, once_cache, session_cache); Ok(None) },
424                (true, false) => quote! { #fn_name(req, once_cache); Ok(None) },
425                (false, true) => quote! { #fn_name(req, session_cache); Ok(None) },
426                (false, false) => quote! { #fn_name(req); Ok(None) },
427            },
428            _ => unreachable!(),
429        }
430    };
431
432    // 生成wrapper函数,根据cache需求调用inner函数
433    let wrapper_body = match (has_once_cache, has_session_cache) {
434        (true, true) => quote! {
435            #wrap_name_inner(
436                req,
437                once_cache.expect("OnceCache required but not provided"),
438                session_cache.expect("SessionCache required but not provided"),
439            ).await
440        },
441        (true, false) => quote! {
442            #wrap_name_inner(
443                req,
444                once_cache.expect("OnceCache required but not provided"),
445            ).await
446        },
447        (false, true) => quote! {
448            #wrap_name_inner(
449                req,
450                session_cache.expect("SessionCache required but not provided"),
451            ).await
452        },
453        (false, false) => quote! {
454            #wrap_name_inner(req).await
455        },
456    };
457
458    quote! {
459        #root_fn
460
461        #[doc(hidden)]
462        #wrap_signature {
463            #call_body
464        }
465
466        #[doc(hidden)]
467        pub fn #wrap_name<'a>(
468            req: &'a mut potato::HttpRequest,
469            once_cache: Option<&'a mut potato::OnceCache>,
470            session_cache: Option<&'a mut potato::SessionCache>,
471        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<Option<potato::HttpResponse>>> + Send + 'a>> {
472            Box::pin(async move {
473                #wrapper_body
474            })
475        }
476    }
477    .into()
478}
479
480fn postprocess_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
481    if !attr.is_empty() {
482        return input;
483    }
484    let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
485    let fn_name = root_fn.sig.ident.clone();
486    let wrap_name = format_ident!("__potato_postprocess_adapter_{}", fn_name);
487    let wrap_name_inner = format_ident!("__potato_postprocess_adapter_inner_{}", fn_name);
488    let is_async = root_fn.sig.asyncness.is_some();
489    let (ret_type, has_once_cache, has_session_cache) = validate_postprocess_signature(&root_fn);
490
491    // 根据是否需要缓存生成不同的函数签名
492    let wrap_signature = match (has_once_cache, has_session_cache) {
493        (true, true) => quote! {
494            async fn #wrap_name_inner(
495                req: &mut potato::HttpRequest,
496                res: &mut potato::HttpResponse,
497                once_cache: &mut potato::OnceCache,
498                session_cache: &mut potato::SessionCache,
499            ) -> anyhow::Result<()>
500        },
501        (true, false) => quote! {
502            async fn #wrap_name_inner(
503                req: &mut potato::HttpRequest,
504                res: &mut potato::HttpResponse,
505                once_cache: &mut potato::OnceCache,
506            ) -> anyhow::Result<()>
507        },
508        (false, true) => quote! {
509            async fn #wrap_name_inner(
510                req: &mut potato::HttpRequest,
511                res: &mut potato::HttpResponse,
512                session_cache: &mut potato::SessionCache,
513            ) -> anyhow::Result<()>
514        },
515        (false, false) => quote! {
516            async fn #wrap_name_inner(
517                req: &mut potato::HttpRequest,
518                res: &mut potato::HttpResponse,
519            ) -> anyhow::Result<()>
520        },
521    };
522
523    // 根据实际使用情况调用函数
524    let call_body = if is_async {
525        match &ret_type[..] {
526            "Result<()>" => {
527                if has_once_cache && has_session_cache {
528                    quote! {
529                        #fn_name(req, res, once_cache, session_cache).await
530                    }
531                } else if has_once_cache {
532                    quote! {
533                        #fn_name(req, res, once_cache).await
534                    }
535                } else if has_session_cache {
536                    quote! {
537                        #fn_name(req, res, session_cache).await
538                    }
539                } else {
540                    quote! {
541                        #fn_name(req, res).await
542                    }
543                }
544            }
545            "()" => {
546                if has_once_cache && has_session_cache {
547                    quote! {
548                        #fn_name(req, res, once_cache, session_cache).await;
549                        Ok(())
550                    }
551                } else if has_once_cache {
552                    quote! {
553                        #fn_name(req, res, once_cache).await;
554                        Ok(())
555                    }
556                } else if has_session_cache {
557                    quote! {
558                        #fn_name(req, res, session_cache).await;
559                        Ok(())
560                    }
561                } else {
562                    quote! {
563                        #fn_name(req, res).await;
564                        Ok(())
565                    }
566                }
567            }
568            _ => unreachable!(),
569        }
570    } else {
571        match &ret_type[..] {
572            "Result<()>" => {
573                if has_once_cache && has_session_cache {
574                    quote! {
575                        #fn_name(req, res, once_cache, session_cache)
576                    }
577                } else if has_once_cache {
578                    quote! {
579                        #fn_name(req, res, once_cache)
580                    }
581                } else if has_session_cache {
582                    quote! {
583                        #fn_name(req, res, session_cache)
584                    }
585                } else {
586                    quote! {
587                        #fn_name(req, res)
588                    }
589                }
590            }
591            "()" => {
592                if has_once_cache && has_session_cache {
593                    quote! {
594                        #fn_name(req, res, once_cache, session_cache);
595                        Ok(())
596                    }
597                } else if has_once_cache {
598                    quote! {
599                        #fn_name(req, res, once_cache);
600                        Ok(())
601                    }
602                } else if has_session_cache {
603                    quote! {
604                        #fn_name(req, res, session_cache);
605                        Ok(())
606                    }
607                } else {
608                    quote! {
609                        #fn_name(req, res);
610                        Ok(())
611                    }
612                }
613            }
614            _ => unreachable!(),
615        }
616    };
617
618    // 生成wrapper函数,根据cache需求调用inner函数
619    let wrapper_body = match (has_once_cache, has_session_cache) {
620        (true, true) => quote! {
621            #wrap_name_inner(
622                req,
623                res,
624                once_cache.expect("OnceCache required but not provided"),
625                session_cache.expect("SessionCache required but not provided"),
626            ).await
627        },
628        (true, false) => quote! {
629            #wrap_name_inner(
630                req,
631                res,
632                once_cache.expect("OnceCache required but not provided"),
633            ).await
634        },
635        (false, true) => quote! {
636            #wrap_name_inner(
637                req,
638                res,
639                session_cache.expect("SessionCache required but not provided"),
640            ).await
641        },
642        (false, false) => quote! {
643            #wrap_name_inner(req, res).await
644        },
645    };
646
647    quote! {
648        #root_fn
649
650        #[doc(hidden)]
651        #wrap_signature {
652            #call_body
653        }
654
655        #[doc(hidden)]
656        pub fn #wrap_name<'a>(
657            req: &'a mut potato::HttpRequest,
658            res: &'a mut potato::HttpResponse,
659            once_cache: Option<&'a mut potato::OnceCache>,
660            session_cache: Option<&'a mut potato::SessionCache>,
661        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = anyhow::Result<()>> + Send + 'a>> {
662            Box::pin(async move {
663                #wrapper_body
664            })
665        }
666    }
667    .into()
668}
669
670fn http_handler_macro(attr: TokenStream, input: TokenStream, req_name: &str) -> TokenStream {
671    let req_name = Ident::new(req_name, Span::call_site());
672
673    // 解析函数,检查是否有 receiver(&self / &mut self)
674    let root_fn_for_check = syn::parse::<syn::ItemFn>(input.clone());
675    let has_receiver = if let Ok(ref func) = root_fn_for_check {
676        func.sig
677            .inputs
678            .iter()
679            .any(|arg| matches!(arg, syn::FnArg::Receiver(_)))
680    } else {
681        false
682    };
683
684    let (route_path, is_send) = {
685        let mut oroute_path: Option<String> = None;
686        let mut is_send = true; // 默认使用Send
687
688        // 首先尝试解析为逗号分隔的列表:path_str, Send 或 path_str
689        let attr_stream: proc_macro2::TokenStream = attr.into();
690        let tokens: Vec<_> = attr_stream.into_iter().collect();
691        let mut i = 0;
692        while i < tokens.len() {
693            let token = &tokens[i];
694            // 尝试解析为字符串字面量(路径)
695            if let proc_macro2::TokenTree::Literal(lit) = token {
696                let lit_str = lit.to_string();
697                // 移除引号
698                if lit_str.starts_with('"') && lit_str.ends_with('"') {
699                    let path = &lit_str[1..lit_str.len() - 1];
700                    oroute_path = Some(path.to_string());
701                }
702                i += 1;
703            } else if let proc_macro2::TokenTree::Ident(ident) = token {
704                if ident.to_string() == "Send" {
705                    is_send = true;
706                    i += 1;
707                } else if ident.to_string() == "path" {
708                    // 处理 path = "/api/path" 格式
709                    if i + 2 < tokens.len() {
710                        if let proc_macro2::TokenTree::Punct(punct) = &tokens[i + 1] {
711                            if punct.as_char() == '=' {
712                                if let proc_macro2::TokenTree::Literal(lit) = &tokens[i + 2] {
713                                    let lit_str = lit.to_string();
714                                    if lit_str.starts_with('"') && lit_str.ends_with('"') {
715                                        let path = &lit_str[1..lit_str.len() - 1];
716                                        oroute_path = Some(path.to_string());
717                                    }
718                                    i += 3;
719                                    continue;
720                                }
721                            }
722                        }
723                    }
724                    i += 1;
725                } else {
726                    i += 1;
727                }
728            } else {
729                i += 1;
730            }
731        }
732
733        // 如果没有找到路径且有 receiver,可能是 controller 方法
734        if oroute_path.is_none() && has_receiver {
735            // 这将在后续代码中处理,先设置为空
736        } else if oroute_path.is_none() {
737            panic!("`path` argument is required for non-controller methods");
738        }
739
740        let route_path = oroute_path.unwrap_or_default();
741
742        // 如果是 controller 方法,需要处理路径拼接
743        let route_path = if has_receiver {
744            if route_path.is_empty() {
745                // 没有指定 path,使用 controller base path(稍后在生成的代码中读取常量)
746                String::new()
747            } else {
748                // 指定了 path,需要拼接到 controller base path
749                // 这里先标记,稍后在生成的代码中处理
750                route_path
751            }
752        } else {
753            if route_path.is_empty() {
754                panic!("`path` argument is required for non-controller methods");
755            }
756            route_path
757        };
758
759        if !route_path.is_empty() && !route_path.starts_with('/') {
760            panic!("route path must start with '/'");
761        }
762        (route_path, is_send)
763    };
764
765    // 解析函数上的 #[potato::header(...)] 标注
766    let mut root_fn = syn::parse_macro_input!(input as syn::ItemFn);
767    let mut fn_headers: Vec<(String, String)> = Vec::new();
768    let mut cors_config: Option<CorsAttrConfig> = None;
769    let mut max_concurrency: Option<usize> = None;
770    let mut remaining_attrs = Vec::new();
771
772    for attr in root_fn.attrs.iter() {
773        // 检查是否是 header 标注(支持 #[header(...)] 和 #[potato::header(...)] 两种形式)
774        let is_header_attr = attr.path().is_ident("header")
775            || (attr.path().segments.len() == 2
776                && attr
777                    .path()
778                    .segments
779                    .iter()
780                    .next()
781                    .map(|s| s.ident.to_string())
782                    == Some("potato".to_string())
783                && attr
784                    .path()
785                    .segments
786                    .iter()
787                    .last()
788                    .map(|s| s.ident.to_string())
789                    == Some("header".to_string()));
790
791        if is_header_attr {
792            if let syn::Meta::List(meta_list) = &attr.meta {
793                // 解析 header(Cache_Control = "no-store, no-cache, max-age=0") 或 header(Custom("key") = "value")
794                if let Ok((key, value)) = parse_header_attr(&meta_list.tokens) {
795                    fn_headers.push((key, value));
796                }
797            }
798            continue;
799        }
800
801        // 检查是否是 cors 标注
802        let is_cors_attr = attr.path().is_ident("cors")
803            || (attr.path().segments.len() == 2
804                && attr
805                    .path()
806                    .segments
807                    .iter()
808                    .next()
809                    .map(|s| s.ident.to_string())
810                    == Some("potato".to_string())
811                && attr
812                    .path()
813                    .segments
814                    .iter()
815                    .last()
816                    .map(|s| s.ident.to_string())
817                    == Some("cors".to_string()));
818
819        if is_cors_attr {
820            if let syn::Meta::List(meta_list) = &attr.meta {
821                cors_config = Some(parse_cors_attr(&meta_list.tokens));
822            } else {
823                // 无参数时使用最小限制配置
824                cors_config = Some(CorsAttrConfig {
825                    origin: None,
826                    methods: None,
827                    headers: None,
828                    max_age: None,
829                    credentials: false,
830                    expose_headers: None,
831                });
832            }
833            continue;
834        }
835
836        // 检查是否是 max_concurrency 标注
837        let is_max_concurrency_attr = attr.path().is_ident("max_concurrency")
838            || (attr.path().segments.len() == 2
839                && attr
840                    .path()
841                    .segments
842                    .iter()
843                    .next()
844                    .map(|s| s.ident.to_string())
845                    == Some("potato".to_string())
846                && attr
847                    .path()
848                    .segments
849                    .iter()
850                    .last()
851                    .map(|s| s.ident.to_string())
852                    == Some("max_concurrency".to_string()));
853
854        if is_max_concurrency_attr {
855            if let syn::Meta::List(meta_list) = &attr.meta {
856                let tokens = &meta_list.tokens;
857                // 直接解析为数字
858                if let Ok(lit_int) = syn::parse2::<syn::LitInt>(tokens.clone()) {
859                    if let Ok(val) = lit_int.base10_parse::<usize>() {
860                        if val == 0 {
861                            panic!("max_concurrency must be greater than 0");
862                        }
863                        max_concurrency = Some(val);
864                    } else {
865                        panic!("invalid max_concurrency value");
866                    }
867                } else {
868                    panic!(
869                        "max_concurrency requires a numeric value, e.g., #[max_concurrency(10)]"
870                    );
871                }
872            } else if let syn::Meta::NameValue(name_value) = &attr.meta {
873                if let syn::Expr::Lit(expr_lit) = &name_value.value {
874                    if let syn::Lit::Int(lit_int) = &expr_lit.lit {
875                        if let Ok(val) = lit_int.base10_parse::<usize>() {
876                            if val == 0 {
877                                panic!("max_concurrency must be greater than 0");
878                            }
879                            max_concurrency = Some(val);
880                        } else {
881                            panic!("invalid max_concurrency value");
882                        }
883                    } else {
884                        panic!("max_concurrency requires a numeric value");
885                    }
886                } else {
887                    panic!("max_concurrency requires a numeric value");
888                }
889            } else {
890                panic!("max_concurrency requires a numeric value, e.g., #[max_concurrency(10)]");
891            }
892            continue;
893        }
894
895        remaining_attrs.push(attr.clone());
896    }
897
898    let all_headers = fn_headers;
899
900    root_fn.attrs = remaining_attrs;
901    let (preprocess_fns, postprocess_fns) = collect_handler_hooks(&mut root_fn);
902
903    // 检测handler自身是否需要缓存
904    let handler_has_once_cache = root_fn.sig.inputs.iter().any(|arg| {
905        if let syn::FnArg::Typed(arg) = arg {
906            arg.ty.to_token_stream().to_string().type_simplify() == "& mut OnceCache"
907        } else {
908            false
909        }
910    });
911    let handler_has_session_cache = root_fn.sig.inputs.iter().any(|arg| {
912        if let syn::FnArg::Typed(arg) = arg {
913            arg.ty.to_token_stream().to_string().type_simplify() == "& mut SessionCache"
914        } else {
915            false
916        }
917    });
918
919    // 修复:只有当 handler 本身需要缓存时,才设置 need_session_cache 和 need_once_cache
920    // preprocess/postprocess 钩子如果需要缓存,它们可以通过参数声明
921    // 但如果 handler 不需要缓存,我们不应该强制要求 Authorization header
922    // 这样可以避免给不需要认证的 handler 添加不必要的认证要求
923    let need_once_cache = handler_has_once_cache;
924    let need_session_cache = handler_has_session_cache;
925
926    let preprocess_adapters: Vec<Ident> = preprocess_fns
927        .iter()
928        .map(|name| format_ident!("__potato_preprocess_adapter_{}", name))
929        .collect();
930    let postprocess_adapters: Vec<Ident> = postprocess_fns
931        .iter()
932        .map(|name| format_ident!("__potato_postprocess_adapter_{}", name))
933        .collect();
934    let doc_show = {
935        let mut doc_show = true;
936        for attr in root_fn.attrs.iter() {
937            if attr.meta.path().get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
938                if let Ok(meta_list) = attr.meta.require_list() {
939                    if meta_list.tokens.to_string() == "hidden" {
940                        doc_show = false;
941                        break;
942                    }
943                }
944            }
945        }
946        doc_show
947    };
948    let doc_auth = need_session_cache;
949    let doc_summary = {
950        let mut docs = vec![];
951        for attr in root_fn.attrs.iter() {
952            if let Ok(attr) = attr.meta.require_name_value() {
953                if attr.path.get_ident().map(|p| p.to_string()) == Some("doc".to_string()) {
954                    let mut doc = attr.value.to_token_stream().to_string();
955                    if doc.starts_with('\"') {
956                        doc.remove(0);
957                        doc.pop();
958                    }
959                    docs.push(doc);
960                }
961            }
962        }
963        if docs.iter().all(|d| d.starts_with(' ')) {
964            for doc in docs.iter_mut() {
965                doc.remove(0);
966            }
967        }
968        docs.join("\n")
969    };
970    let doc_desp = "";
971    let fn_name = root_fn.sig.ident.clone();
972    let is_async = root_fn.sig.asyncness.is_some();
973
974    // 检测是否有 receiver(&self / &mut self)- controller 方法
975    let has_receiver = root_fn
976        .sig
977        .inputs
978        .iter()
979        .any(|arg| matches!(arg, syn::FnArg::Receiver(_)));
980
981    // 生成最终路径(如果是 controller 方法,需要拼接)
982    // 注意:由于路由注册需要编译期常量,路径拼接必须在宏展开时完成
983    // 但 controller 宏和 http_get 宏是独立展开的,无法直接共享信息
984    // 因此这里采用简化方案:直接使用 route_path,路径拼接由用户保证正确
985    let final_path = if has_receiver {
986        // Controller 方法:如果 route_path 为空,说明用户希望使用 controller 的 base path
987        // 但这里无法获取 base path,所以要求用户必须指定完整路径或相对路径
988        if route_path.is_empty() {
989            // 暂时使用一个占位符,实际应该在 controller 宏中处理
990            // 这里我们先要求用户必须提供路径
991            panic!("Controller methods must specify a path (e.g., #[potato::http_get(\"/\")])");
992        } else {
993            route_path
994        }
995    } else {
996        if route_path.is_empty() {
997            panic!("`path` argument is required for non-controller methods");
998        }
999        route_path
1000    };
1001
1002    let final_path_expr = quote! { #final_path };
1003
1004    // 生成 tag 表达式
1005    let tag_expr = if has_receiver {
1006        // 使用类型名来生成唯一的常量名
1007        // 这个将在 impl controller 宏中被替换为实际的常量引用
1008        quote! { "" }
1009    } else {
1010        quote! { "" }
1011    };
1012
1013    let wrap_func_name = random_ident();
1014    let mut args = vec![];
1015    let mut arg_names = vec![];
1016    let mut arg_types = vec![];
1017    let mut doc_args = vec![];
1018    for arg in root_fn.sig.inputs.iter() {
1019        // 支持 receiver 参数(&self / &mut self)- controller 方法
1020        if let syn::FnArg::Receiver(_receiver) = arg {
1021            // 跳过 receiver,不生成参数绑定代码
1022            // controller 实例将在包装函数中创建
1023            continue;
1024        }
1025
1026        if let syn::FnArg::Typed(arg) = arg {
1027            let arg_type_str = arg
1028                .ty
1029                .as_ref()
1030                .to_token_stream()
1031                .to_string()
1032                .type_simplify();
1033            let arg_name_str = arg.pat.to_token_stream().to_string();
1034            let arg_value = match &arg_type_str[..] {
1035                "& mut HttpRequest" => quote! { req },
1036                "& mut OnceCache" => {
1037                    quote! { __potato_once_cache.as_mut().expect("OnceCache not available") }
1038                }
1039                "& mut SessionCache" => {
1040                    quote! { __potato_session_cache.as_mut().expect("SessionCache not available") }
1041                }
1042                "PostFile" => {
1043                    doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
1044                    quote! {
1045                        match req.body_files.get(&potato::utils::refstr::LocalHipStr<'static>::from_str(#arg_name_str)).cloned() {
1046                            Some(file) => file,
1047                            None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
1048                        }
1049                    }
1050                }
1051                arg_type_str if ARG_TYPES.contains(arg_type_str) => {
1052                    doc_args.push(json!({ "name": arg_name_str, "type": arg_type_str }));
1053                    let mut arg_value = quote! {
1054                        match req.body_pairs
1055                            .get(&potato::hipstr::LocalHipStr::from(#arg_name_str))
1056                            .map(|p| p.to_string()) {
1057                            Some(val) => val,
1058                            None => match req.url_query
1059                                .get(&potato::hipstr::LocalHipStr::from(#arg_name_str))
1060                                .map(|p| p.as_str().to_string()) {
1061                                Some(val) => val,
1062                                None => return potato::HttpResponse::error(format!("miss arg: {}", #arg_name_str)),
1063                            },
1064                        }
1065                    };
1066                    if arg_type_str != "String" {
1067                        arg_value = quote! {
1068                            match #arg_value.parse() {
1069                                Ok(val) => val,
1070                                Err(err) => return potato::HttpResponse::error(format!("arg[{}] is not {} type", #arg_name_str, #arg_type_str)),
1071                            }
1072                        }
1073                    }
1074                    arg_value
1075                }
1076                _ => panic!("unsupported arg type: [{arg_type_str}]"),
1077            };
1078            args.push(arg_value);
1079            arg_names.push(random_ident());
1080            // 保存参数类型信息,用于后续生成 call_expr
1081            arg_types.push(arg_type_str);
1082        }
1083    }
1084    let wrap_func_name2 = random_ident();
1085    let ret_type = root_fn
1086        .sig
1087        .output
1088        .to_token_stream()
1089        .to_string()
1090        .type_simplify();
1091
1092    // 如果有 receiver,需要生成 __potato_create_controller 函数
1093    // 通过检查是否存在 controller 常量来判断
1094    let _controller_create_fn = if has_receiver {
1095        quote! {
1096            // 这个函数应该由 controller 宏生成,这里只是引用
1097            // 如果编译出错,说明没有正确使用 #[potato::controller]
1098        }
1099    } else {
1100        quote! {}
1101    };
1102
1103    // 为每个参数生成调用代码
1104    let call_args: Vec<_> = args
1105        .iter()
1106        .enumerate()
1107        .map(|(i, _arg)| {
1108            let arg_name = &arg_names[i];
1109            let arg_type = &arg_types[i];
1110            // 对于 HttpRequest,直接使用 req,不要通过中间变量
1111            if arg_type == "& mut HttpRequest" {
1112                quote! { req }
1113            } else {
1114                quote! { #arg_name }
1115            }
1116        })
1117        .collect();
1118
1119    let call_expr = if has_receiver {
1120        // Controller 方法:直接调用方法(暂不支持字段注入)
1121        // 注意:当前版本不支持 controller 字段,方法应该是静态方法
1122        // 如果要支持字段,需要在包装函数中实例化 controller
1123        match args.len() {
1124            0 => quote! { #fn_name() },
1125            1 => {
1126                let arg_name = &arg_names[0];
1127                let arg = &args[0];
1128                let arg_type = &arg_types[0];
1129                if arg_type == "& mut HttpRequest" {
1130                    quote! { #fn_name(req) }
1131                } else {
1132                    quote! {{
1133                        let #arg_name = #arg;
1134                        #fn_name(#arg_name)
1135                    }}
1136                }
1137            }
1138            _ => {
1139                let let_bindings: Vec<_> = arg_types
1140                    .iter()
1141                    .zip(arg_names.iter())
1142                    .zip(args.iter())
1143                    .filter(|((arg_type, _), _)| *arg_type != "& mut HttpRequest")
1144                    .map(|((_, arg_name), arg)| quote! { let #arg_name = #arg; })
1145                    .collect();
1146
1147                quote! {{
1148                    #(#let_bindings)*
1149                    #fn_name(#(#call_args),*)
1150                }}
1151            }
1152        }
1153    } else {
1154        // 普通方法:直接调用函数
1155        match args.len() {
1156            0 => quote! { #fn_name() },
1157            1 => {
1158                let arg_name = &arg_names[0];
1159                let arg = &args[0];
1160                let arg_type = &arg_types[0];
1161                // 检查是否是 HttpRequest 类型
1162                if arg_type == "& mut HttpRequest" {
1163                    quote! { #fn_name(req) }
1164                } else {
1165                    quote! {{
1166                        let #arg_name = #arg;
1167                        #fn_name(#arg_name)
1168                    }}
1169                }
1170            }
1171            _ => {
1172                // 只为非 HttpRequest 类型的参数创建中间变量
1173                let let_bindings: Vec<_> = arg_types
1174                    .iter()
1175                    .zip(arg_names.iter())
1176                    .zip(args.iter())
1177                    .filter(|((arg_type, _), _)| *arg_type != "& mut HttpRequest")
1178                    .map(|((_, arg_name), arg)| quote! { let #arg_name = #arg; })
1179                    .collect();
1180
1181                quote! {{
1182                    #(#let_bindings)*
1183                    #fn_name(#(#call_args),*)
1184                }}
1185            }
1186        }
1187    };
1188    let handler_wrap_func_body = generate_response_handler(call_expr, &ret_type, is_async);
1189    let doc_args = serde_json::to_string(&doc_args).unwrap();
1190
1191    // 生成添加headers的代码
1192    let add_headers_code = if all_headers.is_empty() {
1193        quote! {}
1194    } else {
1195        let header_statements = all_headers.iter().map(|(key, value)| {
1196            // 将下划线转换为HTTP标准命名 (例如 Cache_Control -> Cache-Control)
1197            let http_key = key.replace("_", "-");
1198            quote! {
1199                __potato_response.add_header(
1200                    std::borrow::Cow::Borrowed(#http_key),
1201                    std::borrow::Cow::Borrowed(#value)
1202                );
1203            }
1204        });
1205        quote! {
1206            #(#header_statements)*
1207        }
1208    };
1209
1210    // 如果存在CORS配置,生成CORS headers注入代码
1211    let cors_headers_code = if let Some(cors) = &cors_config {
1212        let mut statements = vec![];
1213
1214        // origin: 默认"*"
1215        let origin_val = cors.origin.as_deref().unwrap_or("*");
1216        statements.push(quote! {
1217            __potato_response.add_header(
1218                "Access-Control-Allow-Origin".into(),
1219                #origin_val.into()
1220            );
1221        });
1222
1223        // methods: 仅在用户指定时添加,否则由OPTIONS请求自动计算
1224        if let Some(ref methods) = cors.methods {
1225            let mut methods_list: Vec<&str> = methods.split(',').map(|s| s.trim()).collect();
1226            if !methods_list.contains(&"HEAD") {
1227                methods_list.push("HEAD");
1228            }
1229            if !methods_list.contains(&"OPTIONS") {
1230                methods_list.push("OPTIONS");
1231            }
1232            let methods_str = methods_list.join(",");
1233            statements.push(quote! {
1234                __potato_response.add_header(
1235                    "Access-Control-Allow-Methods".into(),
1236                    #methods_str.into()
1237                );
1238            });
1239        }
1240
1241        // headers: 默认"*"
1242        let headers_val = cors.headers.as_deref().unwrap_or("*");
1243        statements.push(quote! {
1244            __potato_response.add_header(
1245                "Access-Control-Allow-Headers".into(),
1246                #headers_val.into()
1247            );
1248        });
1249
1250        // max_age: 默认"86400"
1251        if let Some(ref max_age) = cors.max_age {
1252            statements.push(quote! {
1253                __potato_response.add_header(
1254                    "Access-Control-Max-Age".into(),
1255                    #max_age.into()
1256                );
1257            });
1258        } else {
1259            statements.push(quote! {
1260                __potato_response.add_header(
1261                    "Access-Control-Max-Age".into(),
1262                    "86400".into()
1263                );
1264            });
1265        }
1266
1267        if cors.credentials {
1268            statements.push(quote! {
1269                __potato_response.add_header(
1270                    "Access-Control-Allow-Credentials".into(),
1271                    "true".into()
1272                );
1273            });
1274        }
1275
1276        if let Some(ref expose_headers) = cors.expose_headers {
1277            statements.push(quote! {
1278                __potato_response.add_header(
1279                    "Access-Control-Expose-Headers".into(),
1280                    #expose_headers.into()
1281                );
1282            });
1283        }
1284
1285        quote! { #(#statements)* }
1286    } else {
1287        quote! {}
1288    };
1289
1290    // 如果存在CORS配置且是PUT/POST/DELETE,自动生成HEAD handler
1291    let auto_head_handler = if cors_config.is_some()
1292        && (req_name == "POST" || req_name == "PUT" || req_name == "DELETE")
1293    {
1294        let head_wrap_name = format_ident!("__potato_cors_head_{}", fn_name);
1295        Some(quote! {
1296            #[doc(hidden)]
1297            fn #head_wrap_name(req: &mut potato::HttpRequest) -> potato::HttpResponse {
1298                // HEAD请求直接返回空响应,不执行原handler
1299                // CORS headers会通过postprocess机制自动添加
1300                potato::HttpResponse::html("")
1301            }
1302        })
1303    } else {
1304        None
1305    };
1306
1307    // 如果指定了max_concurrency,生成静态信号量
1308    let semaphore_static = if let Some(max_conn) = max_concurrency {
1309        let semaphore_name =
1310            format_ident!("__POTATO_SEMAPHORE_{}", fn_name.to_string().to_uppercase());
1311        Some(quote! {
1312            #[doc(hidden)]
1313            #[allow(non_upper_case_globals)]
1314            static #semaphore_name: std::sync::LazyLock<tokio::sync::Semaphore> =
1315                std::sync::LazyLock::new(|| tokio::sync::Semaphore::new(#max_conn));
1316        })
1317    } else {
1318        None
1319    };
1320
1321    let wrap_func_body = if is_async {
1322        if max_concurrency.is_some() {
1323            let semaphore_name =
1324                format_ident!("__POTATO_SEMAPHORE_{}", fn_name.to_string().to_uppercase());
1325            quote! {
1326                let __potato_permit = #semaphore_name.acquire().await;
1327
1328                // 获取自定义错误处理器
1329                let __potato_error_handler: Option<potato::ErrorHandler> = {
1330                    let mut handler = None;
1331                    for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1332                        handler = Some(flag.handler.clone());
1333                        break;
1334                    }
1335                    handler
1336                };
1337
1338                // 按需创建缓存对象
1339                let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1340                    Some(potato::OnceCache::new())
1341                } else {
1342                    None
1343                };
1344                let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1345                    // 从 Authorization header 中提取 Bearer token 并加载 session
1346                    if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1347                        let header_value = h.as_str();
1348                        if header_value.starts_with("Bearer ") {
1349                            potato::SessionCache::from_token(&header_value[7..]).await.ok()
1350                        } else {
1351                            None
1352                        }
1353                    } else {
1354                        None
1355                    }
1356                } else {
1357                    None
1358                };
1359
1360                // 如果 handler 需要 SessionCache 但没有提供 Authorization header,返回 401
1361                if #need_session_cache && __potato_session_cache.is_none() {
1362                    let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1363                    __potato_resp.http_code = 401;
1364                    return __potato_resp;
1365                }
1366
1367                // 自动解析请求中的Cookie
1368                if let Some(ref mut session_cache) = __potato_session_cache {
1369                    if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1370                        session_cache.parse_request_cookies(cookie_header.as_str());
1371                    }
1372                }
1373
1374                let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1375                #(
1376                    if __potato_pre_response.is_none() {
1377                        __potato_pre_response = match #preprocess_adapters(
1378                            req,
1379                            __potato_once_cache.as_mut(),
1380                            __potato_session_cache.as_mut(),
1381                        ).await {
1382                            Ok(Some(ret)) => Some(ret),
1383                            Ok(None) => None,
1384                            Err(err) => {
1385                                let handler = &__potato_error_handler;
1386                                Some(match handler {
1387                                    Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1388                                    Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1389                                    None => potato::HttpResponse::error(format!("{err:?}")),
1390                                })
1391                            }
1392                        };
1393                    }
1394                )*
1395
1396                let mut __potato_response = match __potato_pre_response {
1397                    Some(ret) => ret,
1398                    None => match #handler_wrap_func_body {
1399                        Ok(resp) => resp,
1400                        Err(err) => {
1401                            let handler = &__potato_error_handler;
1402                            match handler {
1403                                Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1404                                Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1405                                None => potato::HttpResponse::error(format!("{err:?}")),
1406                            }
1407                        }
1408                    },
1409                };
1410
1411                #(
1412                    if let Err(err) = #postprocess_adapters(
1413                        req,
1414                        &mut __potato_response,
1415                        __potato_once_cache.as_mut(),
1416                        __potato_session_cache.as_mut(),
1417                    ).await {
1418                        drop(__potato_permit);
1419                        let handler = &__potato_error_handler;
1420                        return match handler {
1421                            Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1422                            Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1423                            None => potato::HttpResponse::error(format!("{err:?}")),
1424                        };
1425                    }
1426                )*
1427
1428                #add_headers_code
1429                #cors_headers_code
1430
1431                // 自动应用SessionCache中的cookies到响应
1432                if let Some(ref session_cache) = __potato_session_cache {
1433                    session_cache.apply_cookies(&mut __potato_response);
1434                }
1435
1436                drop(__potato_permit);
1437                __potato_response
1438            }
1439        } else {
1440            quote! {
1441                // 获取自定义错误处理器
1442                let __potato_error_handler: Option<potato::ErrorHandler> = {
1443                    let mut handler = None;
1444                    for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1445                        handler = Some(flag.handler.clone());
1446                        break;
1447                    }
1448                    handler
1449                };
1450
1451                // 按需创建缓存对象
1452                let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1453                    Some(potato::OnceCache::new())
1454                } else {
1455                    None
1456                };
1457                let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1458                    // 从 Authorization header 中提取 Bearer token 并加载 session
1459                    if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1460                        let header_value = h.as_str();
1461                        if header_value.starts_with("Bearer ") {
1462                            potato::SessionCache::from_token(&header_value[7..]).await.ok()
1463                        } else {
1464                            None
1465                        }
1466                    } else {
1467                        None
1468                    }
1469                } else {
1470                    None
1471                };
1472
1473                // 如果 handler 需要 SessionCache 但没有提供 Authorization header,返回 401
1474                if #need_session_cache && __potato_session_cache.is_none() {
1475                    let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1476                    __potato_resp.http_code = 401;
1477                    return __potato_resp;
1478                }
1479
1480                // 自动解析请求中的Cookie
1481                if let Some(ref mut session_cache) = __potato_session_cache {
1482                    if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1483                        session_cache.parse_request_cookies(cookie_header.as_str());
1484                    }
1485                }
1486
1487                let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1488                #(
1489                    if __potato_pre_response.is_none() {
1490                        __potato_pre_response = match #preprocess_adapters(
1491                            req,
1492                            __potato_once_cache.as_mut(),
1493                            __potato_session_cache.as_mut(),
1494                        ).await {
1495                            Ok(Some(ret)) => Some(ret),
1496                            Ok(None) => None,
1497                            Err(err) => {
1498                                let handler = &__potato_error_handler;
1499                                Some(match handler {
1500                                    Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1501                                    Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1502                                    None => potato::HttpResponse::error(format!("{err:?}")),
1503                                })
1504                            }
1505                        };
1506                    }
1507                )*
1508
1509                let mut __potato_response = match __potato_pre_response {
1510                    Some(ret) => ret,
1511                    None => match #handler_wrap_func_body {
1512                        Ok(resp) => resp,
1513                        Err(err) => {
1514                            let handler = &__potato_error_handler;
1515                            match handler {
1516                                Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1517                                Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1518                                None => potato::HttpResponse::error(format!("{err:?}")),
1519                            }
1520                        }
1521                    },
1522                };
1523
1524                #(
1525                    if let Err(err) = #postprocess_adapters(
1526                        req,
1527                        &mut __potato_response,
1528                        __potato_once_cache.as_mut(),
1529                        __potato_session_cache.as_mut(),
1530                    ).await {
1531                        let handler = &__potato_error_handler;
1532                        return match handler {
1533                            Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1534                            Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1535                            None => potato::HttpResponse::error(format!("{err:?}")),
1536                        };
1537                    }
1538                )*
1539
1540                #add_headers_code
1541                #cors_headers_code
1542
1543                __potato_response
1544            }
1545        }
1546    } else {
1547        if max_concurrency.is_some() {
1548            let semaphore_name =
1549                format_ident!("__POTATO_SEMAPHORE_{}", fn_name.to_string().to_uppercase());
1550            quote! {
1551                let __potato_permit = #semaphore_name.acquire().await;
1552
1553                // 获取自定义错误处理器
1554                let __potato_error_handler: Option<potato::ErrorHandler> = {
1555                    let mut handler = None;
1556                    for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1557                        handler = Some(flag.handler.clone());
1558                        break;
1559                    }
1560                    handler
1561                };
1562
1563                // 按需创建缓存对象
1564                let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1565                    Some(potato::OnceCache::new())
1566                } else {
1567                    None
1568                };
1569                let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1570                    // 从 Authorization header 中提取 Bearer token 并加载 session
1571                    if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1572                        let header_value = h.as_str();
1573                        if header_value.starts_with("Bearer ") {
1574                            potato::SessionCache::from_token(&header_value[7..]).await.ok()
1575                        } else {
1576                            None
1577                        }
1578                    } else {
1579                        None
1580                    }
1581                } else {
1582                    None
1583                };
1584
1585                // 如果 handler 需要 SessionCache 但没有提供 Authorization header,返回 401
1586                if #need_session_cache && __potato_session_cache.is_none() {
1587                    let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1588                    __potato_resp.http_code = 401;
1589                    return __potato_resp;
1590                }
1591
1592                // 自动解析请求中的Cookie
1593                if let Some(ref mut session_cache) = __potato_session_cache {
1594                    if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1595                        session_cache.parse_request_cookies(cookie_header.as_str());
1596                    }
1597                }
1598
1599                let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1600                #(
1601                    if __potato_pre_response.is_none() {
1602                        __potato_pre_response = match #preprocess_adapters(
1603                            req,
1604                            __potato_once_cache.as_mut(),
1605                            __potato_session_cache.as_mut(),
1606                        ).await {
1607                            Ok(Some(ret)) => Some(ret),
1608                            Ok(None) => None,
1609                            Err(err) => {
1610                                let handler = &__potato_error_handler;
1611                                Some(match handler {
1612                                    Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1613                                    Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1614                                    None => potato::HttpResponse::error(format!("{err:?}")),
1615                                })
1616                            }
1617                        };
1618                    }
1619                )*
1620
1621                let mut __potato_response = match __potato_pre_response {
1622                    Some(ret) => ret,
1623                    None => match #handler_wrap_func_body {
1624                        Ok(resp) => resp,
1625                        Err(err) => {
1626                            let handler = &__potato_error_handler;
1627                            match handler {
1628                                Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1629                                Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1630                                None => potato::HttpResponse::error(format!("{err:?}")),
1631                            }
1632                        }
1633                    },
1634                };
1635
1636                #(
1637                    if let Err(err) = #postprocess_adapters(
1638                        req,
1639                        &mut __potato_response,
1640                        __potato_once_cache.as_mut(),
1641                        __potato_session_cache.as_mut(),
1642                    ).await {
1643                        drop(__potato_permit);
1644                        let handler = &__potato_error_handler;
1645                        return match handler {
1646                            Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1647                            Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1648                            None => potato::HttpResponse::error(format!("{err:?}")),
1649                        };
1650                    }
1651                )*
1652
1653                #add_headers_code
1654                #cors_headers_code
1655
1656                // 自动应用SessionCache中的cookies到响应
1657                if let Some(ref session_cache) = __potato_session_cache {
1658                    session_cache.apply_cookies(&mut __potato_response);
1659                }
1660
1661                drop(__potato_permit);
1662                __potato_response
1663            }
1664        } else {
1665            quote! {
1666                // 获取自定义错误处理器
1667                let __potato_error_handler: Option<potato::ErrorHandler> = {
1668                    let mut handler = None;
1669                    for flag in potato::inventory::iter::<potato::ErrorHandlerFlag> {
1670                        handler = Some(flag.handler.clone());
1671                        break;
1672                    }
1673                    handler
1674                };
1675
1676                // 按需创建缓存对象
1677                let mut __potato_once_cache: Option<potato::OnceCache> = if #need_once_cache {
1678                    Some(potato::OnceCache::new())
1679                } else {
1680                    None
1681                };
1682                let mut __potato_session_cache: Option<potato::SessionCache> = if #need_session_cache {
1683                    // 从 Authorization header 中提取 Bearer token 并加载 session
1684                    if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
1685                        let header_value = h.as_str();
1686                        if header_value.starts_with("Bearer ") {
1687                            potato::SessionCache::from_token(&header_value[7..]).await.ok()
1688                        } else {
1689                            None
1690                        }
1691                    } else {
1692                        None
1693                    }
1694                } else {
1695                    None
1696                };
1697
1698                // 如果 handler 需要 SessionCache 但没有提供 Authorization header,返回 401
1699                if #need_session_cache && __potato_session_cache.is_none() {
1700                    let mut __potato_resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
1701                    __potato_resp.http_code = 401;
1702                    return __potato_resp;
1703                }
1704
1705                // 自动解析请求中的Cookie
1706                if let Some(ref mut session_cache) = __potato_session_cache {
1707                    if let Some(cookie_header) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Cookie")) {
1708                        session_cache.parse_request_cookies(cookie_header.as_str());
1709                    }
1710                }
1711
1712                let mut __potato_pre_response: Option<potato::HttpResponse> = None;
1713                #(
1714                    if __potato_pre_response.is_none() {
1715                        __potato_pre_response = match #preprocess_adapters(
1716                            req,
1717                            __potato_once_cache.as_mut(),
1718                            __potato_session_cache.as_mut(),
1719                        ).await {
1720                            Ok(Some(ret)) => Some(ret),
1721                            Ok(None) => None,
1722                            Err(err) => {
1723                                let handler = &__potato_error_handler;
1724                                Some(match handler {
1725                                    Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1726                                    Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1727                                    None => potato::HttpResponse::error(format!("{err:?}")),
1728                                })
1729                            }
1730                        };
1731                    }
1732                )*
1733
1734                let mut __potato_response = match __potato_pre_response {
1735                    Some(ret) => ret,
1736                    None => match #handler_wrap_func_body {
1737                        Ok(resp) => resp,
1738                        Err(err) => {
1739                            let handler = &__potato_error_handler;
1740                            match handler {
1741                                Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1742                                Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1743                                None => potato::HttpResponse::error(format!("{err:?}")),
1744                            }
1745                        }
1746                    },
1747                };
1748
1749                #(
1750                    if let Err(err) = #postprocess_adapters(
1751                        req,
1752                        &mut __potato_response,
1753                        __potato_once_cache.as_mut(),
1754                        __potato_session_cache.as_mut(),
1755                    ).await {
1756                        let handler = &__potato_error_handler;
1757                        return match handler {
1758                            Some(potato::ErrorHandler::Async(h)) => h(req, err).await,
1759                            Some(potato::ErrorHandler::Sync(h)) => h(req, err),
1760                            None => potato::HttpResponse::error(format!("{err:?}")),
1761                        };
1762                    }
1763                )*
1764
1765                #add_headers_code
1766                #cors_headers_code
1767
1768                // 自动应用SessionCache中的cookies到响应
1769                if let Some(ref session_cache) = __potato_session_cache {
1770                    session_cache.apply_cookies(&mut __potato_response);
1771                }
1772
1773                __potato_response
1774            }
1775        }
1776    };
1777
1778    if is_async {
1779        let (wrapper_sig, handler_variant) = if is_send {
1780            (
1781                quote! { fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> },
1782                quote! { potato::HttpHandler::Async },
1783            )
1784        } else {
1785            (
1786                quote! { fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + '_>> },
1787                quote! { potato::HttpHandler::AsyncNoSend },
1788            )
1789        };
1790        quote! {
1791            #root_fn
1792
1793            #auto_head_handler
1794
1795            #semaphore_static
1796
1797            #[doc(hidden)]
1798            async fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
1799                #wrap_func_body
1800            }
1801
1802            #[doc(hidden)]
1803            #wrapper_sig {
1804                Box::pin(#wrap_func_name2(req))
1805            }
1806
1807            potato::inventory::submit!{potato::RequestHandlerFlag::new(
1808                potato::HttpMethod::#req_name,
1809                #final_path_expr,
1810                #handler_variant(#wrap_func_name),
1811                potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args, #tag_expr)
1812            )}
1813        }
1814        .into()
1815    } else {
1816        let (wrapper_sig, handler_variant) = if is_send {
1817            (
1818                quote! { fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> },
1819                quote! { potato::HttpHandler::Async },
1820            )
1821        } else {
1822            (
1823                quote! { fn #wrap_func_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + '_>> },
1824                quote! { potato::HttpHandler::AsyncNoSend },
1825            )
1826        };
1827        quote! {
1828            #root_fn
1829
1830            #auto_head_handler
1831
1832            #semaphore_static
1833
1834            #[doc(hidden)]
1835            async fn #wrap_func_name2(req: &mut potato::HttpRequest) -> potato::HttpResponse {
1836                #wrap_func_body
1837            }
1838
1839            #[doc(hidden)]
1840            #wrapper_sig {
1841                Box::pin(#wrap_func_name2(req))
1842            }
1843
1844            potato::inventory::submit!{potato::RequestHandlerFlag::new(
1845                potato::HttpMethod::#req_name,
1846                #final_path_expr,
1847                #handler_variant(#wrap_func_name),
1848                potato::RequestHandlerFlagDoc::new(#doc_show, #doc_auth, #doc_summary, #doc_desp, #doc_args, #tag_expr)
1849            )}
1850        }
1851        .into()
1852    }
1853    //}.to_string();
1854    //panic!("{content}");
1855    //todo!()
1856}
1857
1858/// 生成返回值处理代码(统一用于全局 handler 和 controller)
1859///
1860/// # 参数
1861/// - `call_expr`: 调用表达式(函数调用)
1862/// - `ret_type`: 返回值类型字符串
1863/// - `is_async`: 是否是异步函数(true 时会自动添加 .await)
1864fn generate_response_handler(
1865    call_expr: proc_macro2::TokenStream,
1866    ret_type: &str,
1867    is_async: bool,
1868) -> proc_macro2::TokenStream {
1869    let call_with_await = if is_async {
1870        quote! { #call_expr.await }
1871    } else {
1872        quote! { #call_expr }
1873    };
1874
1875    match ret_type {
1876        // Result 类型
1877        "Result<()>" | "anyhow::Result<()>" => quote! {
1878            match #call_with_await {
1879                Ok(_) => Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::text("ok")),
1880                Err(err) => Err(err),
1881            }
1882        },
1883        "Result<HttpResponse>" | "anyhow::Result<HttpResponse>" => quote! {
1884            match #call_with_await {
1885                Ok(ret) => Ok(ret),
1886                Err(err) => Err(err),
1887            }
1888        },
1889        "Result<String>" | "anyhow::Result<String>" => quote! {
1890            match #call_with_await {
1891                Ok(ret) => Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::html(ret)),
1892                Err(err) => Err(err),
1893            }
1894        },
1895        "Result<& 'static str>" | "anyhow::Result<& 'static str>" => quote! {
1896            match #call_with_await {
1897                Ok(ret) => Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::html(ret)),
1898                Err(err) => Err(err),
1899            }
1900        },
1901        "Result<serde_json::Value>" | "anyhow::Result<serde_json::Value>" => quote! {
1902            match #call_with_await {
1903                Ok(ret) => Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::json(serde_json::to_string(&ret).unwrap_or_else(|_| "{}".to_string()))),
1904                Err(err) => Err(err),
1905            }
1906        },
1907        // 非 Result 类型
1908        "()" => quote! {
1909            {
1910                #call_with_await;
1911                Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::text("ok"))
1912            }
1913        },
1914        "HttpResponse" => quote! {
1915            Ok::<potato::HttpResponse, anyhow::Error>(#call_with_await)
1916        },
1917        "String" => quote! {
1918            Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::html(#call_with_await))
1919        },
1920        "& 'static str" => quote! {
1921            Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::html(#call_with_await))
1922        },
1923        "serde_json::Value" => quote! {
1924            Ok::<potato::HttpResponse, anyhow::Error>(potato::HttpResponse::json(serde_json::to_string(&#call_with_await).unwrap_or_else(|_| "{}".to_string())))
1925        },
1926        // 不支持的类型
1927        _ => panic!("unsupported ret type: {}", ret_type),
1928    }
1929}
1930
1931#[proc_macro_attribute]
1932pub fn http_get(attr: TokenStream, input: TokenStream) -> TokenStream {
1933    http_handler_macro(attr, input, "GET")
1934}
1935
1936#[proc_macro_attribute]
1937pub fn http_post(attr: TokenStream, input: TokenStream) -> TokenStream {
1938    http_handler_macro(attr, input, "POST")
1939}
1940
1941#[proc_macro_attribute]
1942pub fn http_put(attr: TokenStream, input: TokenStream) -> TokenStream {
1943    http_handler_macro(attr, input, "PUT")
1944}
1945
1946#[proc_macro_attribute]
1947pub fn http_delete(attr: TokenStream, input: TokenStream) -> TokenStream {
1948    http_handler_macro(attr, input, "DELETE")
1949}
1950
1951#[proc_macro_attribute]
1952pub fn http_options(attr: TokenStream, input: TokenStream) -> TokenStream {
1953    http_handler_macro(attr, input, "OPTIONS")
1954}
1955
1956#[proc_macro_attribute]
1957pub fn http_head(attr: TokenStream, input: TokenStream) -> TokenStream {
1958    http_handler_macro(attr, input, "HEAD")
1959}
1960
1961/// Controller 属性宏 - 定义控制器结构体
1962///
1963/// # 功能
1964/// - 为结构体的 impl 块中的所有方法提供统一的路由前缀
1965/// - 支持 preprocess/postprocess 中间件继承
1966/// - 自动为 Swagger 文档分组(tag 为结构体名称)
1967///
1968/// # 结构体字段限制
1969/// 只能包含以下类型的字段(0个或多个):
1970/// - `&'a potato::OnceCache`
1971/// - `&'a potato::SessionCache`
1972///
1973/// # 示例
1974/// ```rust,ignore
1975/// #[potato::controller("/api/users")]
1976/// pub struct UsersController<'a> {
1977///     pub once_cache: &'a potato::OnceCache,
1978///     pub sess_cache: &'a potato::SessionCache,
1979/// }
1980///
1981/// #[potato::preprocess(my_preprocess)]
1982/// impl<'a> UsersController<'a> {
1983///     #[potato::http_get] // 地址为 "/api/users"
1984///     pub async fn get(&self) -> anyhow::Result<&'static str> {
1985///         Ok("get users data")
1986///     }
1987///
1988///     #[potato::http_get("/any")] // 地址为 "/api/users/any"
1989///     pub async fn get_any(&self) -> anyhow::Result<&'static str> {
1990///         Ok("get any data")
1991///     }
1992/// }
1993/// ```
1994#[proc_macro_attribute]
1995pub fn controller(attr: TokenStream, input: TokenStream) -> TokenStream {
1996    controller_macro(attr, input)
1997}
1998
1999fn controller_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
2000    // 尝试解析为 impl 块
2001    let input_clone = input.clone();
2002    if let Ok(item_impl) = syn::parse::<syn::ItemImpl>(input_clone) {
2003        // 这是 impl 块,需要提取方法并生成路由注册
2004        return controller_impl_macro(attr, item_impl);
2005    }
2006
2007    // 否则解析为结构体
2008    let item_struct = syn::parse_macro_input!(input as syn::ItemStruct);
2009
2010    // 解析 base path(结构体上的 controller 可以没有 path,由 impl 块指定)
2011    let base_path = if attr.is_empty() {
2012        // 结构体上没有 path,不生成常量,由 impl 块上的 controller 指定
2013        quote! {}
2014    } else {
2015        let attr_str = attr.to_string();
2016        let base_path = attr_str.trim_matches('"').to_string();
2017        quote! {
2018            #[doc(hidden)]
2019            const __POTATO_CONTROLLER_BASE_PATH: &str = #base_path;
2020        }
2021    };
2022
2023    // 验证结构体字段并获取字段信息
2024    let (has_once_cache, has_session_cache) = validate_controller_struct(&item_struct);
2025    let struct_name = &item_struct.ident;
2026    let struct_name_str = struct_name.to_string();
2027
2028    // 生成结构体定义、常量和 inventory 提交
2029    // 同时生成隐藏的 controller 创建辅助函数
2030    let controller_creation_fn = if has_session_cache {
2031        // 结构体有 SessionCache 字段,生成包含鉴权的创建函数
2032        // 直接创建并返回 Box<Self>
2033        quote! {
2034            #[doc(hidden)]
2035            #[allow(dead_code)]
2036            async fn __potato_create_controller(req: &potato::HttpRequest) -> Result<Box<Self>, potato::HttpResponse> {
2037                // 在堆上分配缓存
2038                let once_cache = Box::leak(Box::new(potato::OnceCache::new()));
2039
2040                // 从 Authorization header 中提取 Bearer token 并加载 session
2041                let session_cache = {
2042                    if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
2043                        let header_value = h.as_str();
2044                        if header_value.starts_with("Bearer ") {
2045                            potato::SessionCache::from_token(&header_value[7..]).await.ok()
2046                        } else {
2047                            None
2048                        }
2049                    } else {
2050                        None
2051                    }
2052                };
2053
2054                let session_cache = match session_cache {
2055                    Some(cache) => cache,
2056                    None => {
2057                        let mut resp = potato::HttpResponse::text("Unauthorized: Missing or invalid Authorization header");
2058                        resp.http_code = 401;
2059                        return Err(resp);
2060                    }
2061                };
2062                let session_cache = Box::leak(Box::new(session_cache));
2063
2064                // 创建 controller 实例
2065                let controller = Self {
2066                    once_cache,
2067                    sess_cache: session_cache,
2068                };
2069
2070                Ok(Box::new(controller))
2071            }
2072        }
2073    } else {
2074        // 结构体没有 SessionCache 字段,生成不包含鉴权的创建函数
2075        // 创建临时的 SessionCache(但不使用),返回 Box<Self>
2076        quote! {
2077            #[doc(hidden)]
2078            #[allow(dead_code)]
2079            async fn __potato_create_controller(_req: &potato::HttpRequest) -> Result<Box<Self>, potato::HttpResponse> {
2080                // 在堆上分配缓存
2081                let once_cache = Box::leak(Box::new(potato::OnceCache::new()));
2082
2083                // 创建临时的 SessionCache(不需要鉴权,也不使用)
2084                let _temp_session_cache = Box::leak(Box::new(potato::SessionCache::new()));
2085
2086                // 创建 controller 实例(不包含 sess_cache 字段)
2087                let controller = Self {
2088                    once_cache,
2089                };
2090
2091                Ok(Box::new(controller))
2092            }
2093        }
2094    };
2095
2096    // 提取结构体的泛型参数(包括生命周期)
2097    let struct_generics = &item_struct.generics;
2098    let (impl_generics, type_generics, where_clause) = struct_generics.split_for_impl();
2099
2100    // 生成唯一的常量名
2101    let controller_name_const =
2102        quote::format_ident!("__POTATO_CONTROLLER_NAME_{}", struct_name_str);
2103
2104    let output = quote! {
2105        #item_struct
2106
2107        #base_path
2108
2109        #[doc(hidden)]
2110        const #controller_name_const: &str = #struct_name_str;
2111
2112        // 提交 Controller 结构体字段信息到 inventory
2113        potato::inventory::submit! {
2114            potato::ControllerStructFlag::new(
2115                #struct_name_str,
2116                potato::ControllerStructFieldInfo {
2117                    has_once_cache: #has_once_cache,
2118                    has_session_cache: #has_session_cache,
2119                }
2120            )
2121        }
2122
2123        // 生成隐藏的 controller 创建辅助函数
2124        impl #impl_generics #struct_name #type_generics #where_clause {
2125            #controller_creation_fn
2126        }
2127    };
2128
2129    output.into()
2130}
2131
2132/// 处理 impl 块的 controller 宏
2133fn controller_impl_macro(attr: TokenStream, item_impl: syn::ItemImpl) -> TokenStream {
2134    // 解析 base path(如果 attr 为空,则从常量读取)
2135    let base_path_str = if attr.is_empty() {
2136        // 从结构体上的 controller 宏生成的常量读取
2137        // 这种情况暂不支持,因为宏展开时无法读取常量值
2138        None
2139    } else {
2140        let attr_str = attr.to_string();
2141        Some(attr_str.trim_matches('"').to_string())
2142    };
2143
2144    // 从 impl 块中提取类型名称
2145    let self_type = &item_impl.self_ty;
2146
2147    // 提取不带生命周期参数的类型名称(用于实例化)
2148    let self_type_name = match &*item_impl.self_ty {
2149        syn::Type::Path(type_path) => {
2150            // 获取路径的最后一段(类型名称),不包含泛型参数
2151            if let Some(segment) = type_path.path.segments.last() {
2152                let ident = &segment.ident;
2153                quote! { #ident }
2154            } else {
2155                quote! { #self_type }
2156            }
2157        }
2158        _ => quote! { #self_type },
2159    };
2160
2161    // 提取不带生命周期参数的类型名称字符串(用于 Swagger tag)
2162    let self_type_tag = match &*item_impl.self_ty {
2163        syn::Type::Path(type_path) => {
2164            // 获取路径的最后一段(类型名称),不包含泛型参数
2165            if let Some(segment) = type_path.path.segments.last() {
2166                segment.ident.to_string()
2167            } else {
2168                self_type.to_token_stream().to_string()
2169            }
2170        }
2171        _ => self_type.to_token_stream().to_string(),
2172    };
2173
2174    // 创建清理后的 impl 块(移除方法上的 http_* 标注)
2175    let mut cleaned_items = Vec::new();
2176    let mut generated_code = Vec::new();
2177
2178    for item in &item_impl.items {
2179        if let syn::ImplItem::Fn(method) = item {
2180            // 检查是否有 http_* 标注
2181            let has_http_attr = method.attrs.iter().any(|attr| {
2182                let attr_name = attr.path().to_token_stream().to_string();
2183                attr_name.contains("http_get")
2184                    || attr_name.contains("http_post")
2185                    || attr_name.contains("http_put")
2186                    || attr_name.contains("http_delete")
2187                    || attr_name.contains("http_head")
2188                    || attr_name.contains("http_patch")
2189                    || attr_name.contains("http_options")
2190            });
2191
2192            if has_http_attr {
2193                // 有 http_* 标注,创建清理后的方法(移除 http_* 标注)
2194                let mut cleaned_method = method.clone();
2195                cleaned_method.attrs = method
2196                    .attrs
2197                    .iter()
2198                    .filter(|attr| {
2199                        let attr_name = attr.path().to_token_stream().to_string();
2200                        !attr_name.contains("http_get")
2201                            && !attr_name.contains("http_post")
2202                            && !attr_name.contains("http_put")
2203                            && !attr_name.contains("http_delete")
2204                            && !attr_name.contains("http_head")
2205                            && !attr_name.contains("http_patch")
2206                            && !attr_name.contains("http_options")
2207                    })
2208                    .cloned()
2209                    .collect();
2210
2211                cleaned_items.push(syn::ImplItem::Fn(cleaned_method));
2212
2213                // 为每个 http_* 标注生成包装函数和路由注册
2214                for attr in &method.attrs {
2215                    let attr_name = attr.path().to_token_stream().to_string();
2216                    if attr_name.contains("http_get")
2217                        || attr_name.contains("http_post")
2218                        || attr_name.contains("http_put")
2219                        || attr_name.contains("http_delete")
2220                        || attr_name.contains("http_head")
2221                        || attr_name.contains("http_patch")
2222                        || attr_name.contains("http_options")
2223                    {
2224                        // 提取 HTTP 方法名
2225                        let http_method = if attr_name.contains("http_get") {
2226                            "GET"
2227                        } else if attr_name.contains("http_post") {
2228                            "POST"
2229                        } else if attr_name.contains("http_put") {
2230                            "PUT"
2231                        } else if attr_name.contains("http_delete") {
2232                            "DELETE"
2233                        } else if attr_name.contains("http_head") {
2234                            "HEAD"
2235                        } else if attr_name.contains("http_patch") {
2236                            "PATCH"
2237                        } else {
2238                            "OPTIONS"
2239                        };
2240
2241                        // 提取 path 参数和 Send 标志
2242                        let (method_path, is_send) = match &attr.meta {
2243                            syn::Meta::List(list) => {
2244                                let tokens: Vec<_> = list.tokens.clone().into_iter().collect();
2245                                let mut path = String::new();
2246                                let mut send = true; // 默认使用Send
2247
2248                                let mut i = 0;
2249                                while i < tokens.len() {
2250                                    if let proc_macro2::TokenTree::Literal(lit) = &tokens[i] {
2251                                        let lit_str = lit.to_string();
2252                                        if lit_str.starts_with('"') && lit_str.ends_with('"') {
2253                                            path = lit_str[1..lit_str.len() - 1].to_string();
2254                                        }
2255                                        i += 1;
2256                                    } else if let proc_macro2::TokenTree::Ident(ident) = &tokens[i]
2257                                    {
2258                                        if ident.to_string() == "Send" {
2259                                            send = true;
2260                                        }
2261                                        i += 1;
2262                                    } else {
2263                                        i += 1;
2264                                    }
2265                                }
2266
2267                                (path, send)
2268                            }
2269                            _ => (String::new(), true),
2270                        };
2271
2272                        // 拼接路径(在宏展开时完成)
2273                        let final_path = if let Some(ref base_path) = base_path_str {
2274                            // base path 已知,直接在宏展开时拼接
2275                            if method_path.is_empty() {
2276                                base_path.clone()
2277                            } else {
2278                                format!("{}{}", base_path, method_path)
2279                            }
2280                        } else {
2281                            // base path 未知(从常量读取),这种情况暂不支持
2282                            panic!("impl block controller must specify a base path, e.g., #[potato::controller(\"/api/users\")]");
2283                        };
2284
2285                        // 将 final_path 转换为 LitStr
2286                        let final_path_lit =
2287                            syn::LitStr::new(&final_path, proc_macro2::Span::call_site());
2288
2289                        // 生成包装函数名
2290                        let fn_name = &method.sig.ident;
2291                        let wrapper_fn_name = quote::format_ident!("__potato_ctrl_{}", fn_name);
2292                        let is_async = method.sig.asyncness.is_some();
2293
2294                        // 检测方法是否有 receiver,以及是否是 mutable
2295                        let (has_receiver, _is_mut_receiver) = method
2296                            .sig
2297                            .inputs
2298                            .iter()
2299                            .filter_map(|arg| {
2300                                if let syn::FnArg::Receiver(recv) = arg {
2301                                    Some((true, recv.mutability.is_some()))
2302                                } else {
2303                                    None
2304                                }
2305                            })
2306                            .next()
2307                            .unwrap_or((false, false));
2308
2309                        // 提取非 receiver 参数
2310                        let other_params: Vec<_> = method
2311                            .sig
2312                            .inputs
2313                            .iter()
2314                            .filter_map(|arg| {
2315                                if let syn::FnArg::Typed(pat_type) = arg {
2316                                    Some(pat_type.clone())
2317                                } else {
2318                                    None
2319                                }
2320                            })
2321                            .collect();
2322
2323                        // 检测非 receiver 参数中是否包含 SessionCache
2324                        let method_has_session_cache = other_params.iter().any(|pat_type| {
2325                            pat_type.ty.to_token_stream().to_string().type_simplify()
2326                                == "& mut SessionCache"
2327                        });
2328
2329                        // 如果方法有 receiver(&self 或 &mut self),或者参数包含 SessionCache,则需要鉴权
2330                        let doc_auth = has_receiver || method_has_session_cache;
2331
2332                        // 提取参数名
2333                        let param_names: Vec<_> = other_params
2334                            .iter()
2335                            .filter_map(|pat_type| {
2336                                if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
2337                                    Some(pat_ident.ident.clone())
2338                                } else {
2339                                    None
2340                                }
2341                            })
2342                            .collect();
2343
2344                        // 生成方法调用
2345                        let method_call = if has_receiver {
2346                            // 有 receiver,需要实例化 controller
2347                            // 使用结构体生成的 __potato_create_controller 函数
2348                            // 返回 Box<Self>
2349
2350                            if param_names.is_empty() {
2351                                if is_async {
2352                                    quote! {
2353                                        {
2354                                            let mut controller = match #self_type_name::__potato_create_controller(req).await {
2355                                                Ok(boxed) => boxed,
2356                                                Err(resp) => return resp,
2357                                            };
2358                                            controller.#fn_name().await
2359                                        }
2360                                    }
2361                                } else {
2362                                    quote! {
2363                                        {
2364                                            let mut controller = match #self_type_name::__potato_create_controller(req).await {
2365                                                Ok(boxed) => boxed,
2366                                                Err(resp) => return resp,
2367                                            };
2368                                            controller.#fn_name()
2369                                        }
2370                                    }
2371                                }
2372                            } else {
2373                                if is_async {
2374                                    quote! {
2375                                        {
2376                                            let mut controller = match #self_type_name::__potato_create_controller(req).await {
2377                                                Ok(boxed) => boxed,
2378                                                Err(resp) => return resp,
2379                                            };
2380                                            controller.#fn_name(#(#param_names),*).await
2381                                        }
2382                                    }
2383                                } else {
2384                                    quote! {
2385                                        {
2386                                            let mut controller = match #self_type_name::__potato_create_controller(req).await {
2387                                                Ok(boxed) => boxed,
2388                                                Err(resp) => return resp,
2389                                            };
2390                                            controller.#fn_name(#(#param_names),*)
2391                                        }
2392                                    }
2393                                }
2394                            }
2395                        } else {
2396                            // 没有 receiver,直接调用关联函数
2397                            // 但仍需要处理参数(如 SessionCache)
2398
2399                            if param_names.is_empty() {
2400                                if is_async {
2401                                    quote! { #self_type_name::#fn_name().await }
2402                                } else {
2403                                    quote! { #self_type_name::#fn_name() }
2404                                }
2405                            } else {
2406                                // 有参数,需要根据参数类型生成绑定代码
2407                                // 生成参数绑定
2408                                let mut param_bindings = Vec::new();
2409                                for (i, param) in other_params.iter().enumerate() {
2410                                    let param_type_str =
2411                                        param.ty.to_token_stream().to_string().type_simplify();
2412                                    let param_name = &param_names[i];
2413
2414                                    match &param_type_str[..] {
2415                                        "& mut OnceCache" => {
2416                                            param_bindings.push(quote! {
2417                                                let #param_name = &mut __potato_once_cache;
2418                                            });
2419                                        }
2420                                        "& mut SessionCache" => {
2421                                            param_bindings.push(quote! {
2422                                                let #param_name = &mut __potato_session_cache;
2423                                            });
2424                                        }
2425                                        _ => {
2426                                            // 其他参数类型暂不支持
2427                                        }
2428                                    }
2429                                }
2430
2431                                // 生成 SessionCache 加载逻辑(从 Authorization header)
2432                                let needs_session_cache = other_params.iter().any(|p| {
2433                                    p.ty.to_token_stream().to_string().type_simplify()
2434                                        == "& mut SessionCache"
2435                                });
2436
2437                                let session_cache_init = if needs_session_cache {
2438                                    quote! {
2439                                        {
2440                                            // 从 Authorization header 中提取 Bearer token 并加载 session
2441                                            if let Some(h) = req.headers.get(&potato::utils::refstr::HeaderOrHipStr::from_str("Authorization")) {
2442                                                let header_value = h.as_str();
2443                                                if header_value.starts_with("Bearer ") {
2444                                                    potato::SessionCache::from_token(&header_value[7..]).await.ok()
2445                                                } else {
2446                                                    None
2447                                                }
2448                                            } else {
2449                                                None
2450                                            }
2451                                        }
2452                                    }
2453                                } else {
2454                                    quote! { None }
2455                                };
2456
2457                                if is_async {
2458                                    quote! {
2459                                        {
2460                                            let mut __potato_once_cache = potato::OnceCache::new();
2461                                            let mut __potato_session_cache = #session_cache_init.unwrap_or_else(|| potato::SessionCache::new());
2462                                            #(#param_bindings)*
2463                                            #self_type_name::#fn_name(#(#param_names),*).await
2464                                        }
2465                                    }
2466                                } else {
2467                                    quote! {
2468                                        {
2469                                            let mut __potato_once_cache = potato::OnceCache::new();
2470                                            let mut __potato_session_cache = #session_cache_init.unwrap_or_else(|| potato::SessionCache::new());
2471                                            #(#param_bindings)*
2472                                            #self_type_name::#fn_name(#(#param_names),*)
2473                                        }
2474                                    }
2475                                }
2476                            }
2477                        };
2478
2479                        // 检测返回类型
2480                        let ret_type_str = method
2481                            .sig
2482                            .output
2483                            .to_token_stream()
2484                            .to_string()
2485                            .type_simplify();
2486
2487                        // 使用统一的返回值处理函数(method_call 已经包含了 .await,所以传入 false)
2488                        let resp_handler =
2489                            generate_response_handler(method_call.clone(), &ret_type_str, false);
2490
2491                        // 生成包装函数
2492                        let wrapper_fn = if is_async {
2493                            quote! {
2494                                #[doc(hidden)]
2495                                fn #wrapper_fn_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
2496                                    Box::pin(async move {
2497                                        match #resp_handler {
2498                                            Ok(resp) => resp,
2499                                            Err(err) => potato::HttpResponse::error(err.to_string()),
2500                                        }
2501                                    })
2502                                }
2503                            }
2504                        } else {
2505                            quote! {
2506                                #[doc(hidden)]
2507                                fn #wrapper_fn_name(req: &mut potato::HttpRequest) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
2508                                    Box::pin(async move {
2509                                        match #resp_handler {
2510                                            Ok(resp) => resp,
2511                                            Err(err) => potato::HttpResponse::error(err.to_string()),
2512                                        }
2513                                    })
2514                                }
2515                            }
2516                        };
2517
2518                        generated_code.push(wrapper_fn);
2519
2520                        // 生成路由注册(根据 is_send 标志选择 Async 或 AsyncNoSend)
2521                        let http_method_ident = quote::format_ident!("{}", http_method);
2522
2523                        let route_register = if is_send {
2524                            quote! {
2525                                potato::inventory::submit! {
2526                                    potato::RequestHandlerFlag::new(
2527                                        potato::HttpMethod::#http_method_ident,
2528                                        #final_path_lit,
2529                                        potato::HttpHandler::Async(#wrapper_fn_name),
2530                                        potato::RequestHandlerFlagDoc::new(true, #doc_auth, "", "", "", #self_type_tag)
2531                                    )
2532                                }
2533                            }
2534                        } else {
2535                            quote! {
2536                                potato::inventory::submit! {
2537                                    potato::RequestHandlerFlag::new(
2538                                        potato::HttpMethod::#http_method_ident,
2539                                        #final_path_lit,
2540                                        potato::HttpHandler::AsyncNoSend(#wrapper_fn_name),
2541                                        potato::RequestHandlerFlagDoc::new(true, #doc_auth, "", "", "", #self_type_tag)
2542                                    )
2543                                }
2544                            }
2545                        };
2546
2547                        generated_code.push(route_register);
2548                    }
2549                }
2550            } else {
2551                // 没有 http_* 标注,保留原方法
2552                cleaned_items.push(item.clone());
2553            }
2554        } else {
2555            // 非方法项,直接保留
2556            cleaned_items.push(item.clone());
2557        }
2558    }
2559
2560    // 生成清理后的 impl 块
2561    let mut cleaned_impl = item_impl.clone();
2562    cleaned_impl.items = cleaned_items;
2563
2564    // 生成最终代码:清理后的 impl 块 + 生成的独立函数和路由注册
2565    let output = quote! {
2566        #cleaned_impl
2567
2568        #(#generated_code)*
2569    };
2570
2571    output.into()
2572}
2573
2574#[proc_macro_attribute]
2575pub fn preprocess(attr: TokenStream, input: TokenStream) -> TokenStream {
2576    preprocess_macro(attr, input)
2577}
2578
2579#[proc_macro_attribute]
2580pub fn postprocess(attr: TokenStream, input: TokenStream) -> TokenStream {
2581    postprocess_macro(attr, input)
2582}
2583
2584/// handle_error 属性宏 - 定义全局错误处理函数
2585///
2586/// # 签名要求
2587/// - 参数1: `req: &mut HttpRequest`
2588/// - 参数2: `err: anyhow::Error`
2589/// - 返回: `HttpResponse`
2590/// - 支持 async fn 和普通 fn
2591///
2592/// # 示例
2593/// ```rust,ignore
2594/// #[potato::handle_error]
2595/// async fn handle_error(req: &mut HttpRequest, err: anyhow::Error) -> HttpResponse {
2596///     HttpResponse::json(serde_json::json!({
2597///         "error": format!("{}", err)
2598///     }))
2599/// }
2600/// ```
2601fn handle_error_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
2602    if !attr.is_empty() {
2603        return input;
2604    }
2605
2606    let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
2607    let fn_name = root_fn.sig.ident.clone();
2608    let is_async = root_fn.sig.asyncness.is_some();
2609
2610    // 验证函数签名
2611    if root_fn.sig.inputs.len() != 2 {
2612        panic!("`handle_error` function must accept exactly two arguments");
2613    }
2614
2615    let mut arg_types = vec![];
2616    for arg in root_fn.sig.inputs.iter() {
2617        match arg {
2618            syn::FnArg::Typed(arg) => {
2619                arg_types.push(arg.ty.to_token_stream().to_string().type_simplify())
2620            }
2621            _ => panic!("`handle_error` function does not support receiver argument"),
2622        }
2623    }
2624
2625    if arg_types[0] != "& mut HttpRequest" {
2626        panic!(
2627            "`handle_error` first argument must be `&mut potato::HttpRequest`, got `{}`",
2628            arg_types[0]
2629        );
2630    }
2631    if arg_types[1] != "anyhow::Error" {
2632        panic!(
2633            "`handle_error` second argument must be `anyhow::Error`, got `{}`",
2634            arg_types[1]
2635        );
2636    }
2637
2638    let ret_type = root_fn
2639        .sig
2640        .output
2641        .to_token_stream()
2642        .to_string()
2643        .type_simplify();
2644    if ret_type != "HttpResponse" {
2645        panic!(
2646            "`handle_error` return type must be `potato::HttpResponse`, got `{}`",
2647            ret_type
2648        );
2649    }
2650
2651    // 生成适配器函数
2652    let wrap_name = format_ident!("__potato_error_handler_adapter_{}", fn_name);
2653    let wrap_name_inner = format_ident!("__potato_error_handler_adapter_inner_{}", fn_name);
2654
2655    // 生成内部函数
2656    let call_body = if is_async {
2657        quote! { #fn_name(req, err).await }
2658    } else {
2659        quote! { #fn_name(req, err) }
2660    };
2661
2662    quote! {
2663        #root_fn
2664
2665        #[doc(hidden)]
2666        async fn #wrap_name_inner(
2667            req: &mut potato::HttpRequest,
2668            err: anyhow::Error,
2669        ) -> potato::HttpResponse {
2670            #call_body
2671        }
2672
2673        #[doc(hidden)]
2674        pub fn #wrap_name(
2675            req: &mut potato::HttpRequest,
2676            err: anyhow::Error,
2677        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = potato::HttpResponse> + Send + '_>> {
2678            Box::pin(#wrap_name_inner(req, err))
2679        }
2680
2681        potato::inventory::submit! {
2682            potato::ErrorHandlerFlag::new(
2683                potato::ErrorHandler::Async(#wrap_name)
2684            )
2685        }
2686    }
2687    .into()
2688}
2689
2690#[proc_macro_attribute]
2691pub fn handle_error(attr: TokenStream, input: TokenStream) -> TokenStream {
2692    handle_error_macro(attr, input)
2693}
2694
2695/// limit_size 属性宏 - 为 handler 设置独立的请求体大小限制
2696///
2697/// # 参数
2698/// * 单个值: `#[potato::limit_size(1024 * 1024 * 1024)]` - 仅限制 body 为 1GB
2699/// * 命名参数: `#[potato::limit_size(header = 2 * 1024 * 1024, body = 500 * 1024 * 1024)]`
2700///
2701/// # 示例
2702/// ```rust,ignore
2703/// // 限制 body 为 1GB
2704/// #[potato::http_post("/upload")]
2705/// #[potato::limit_size(1024 * 1024 * 1024)]
2706/// async fn large_upload(req: &mut potato::HttpRequest) -> potato::HttpResponse {
2707///     todo!()
2708/// }
2709///
2710/// // 分别限制 header 和 body
2711/// #[potato::http_post("/upload")]
2712/// #[potato::limit_size(header = 2 * 1024 * 1024, body = 500 * 1024 * 1024)]
2713/// async fn medium_upload(req: &mut potato::HttpRequest) -> potato::HttpResponse {
2714///     todo!()
2715/// }
2716/// ```
2717#[proc_macro_attribute]
2718pub fn limit_size(attr: TokenStream, input: TokenStream) -> TokenStream {
2719    limit_size_macro(attr, input)
2720}
2721
2722fn limit_size_macro(attr: TokenStream, input: TokenStream) -> TokenStream {
2723    // 解析参数
2724    let (_max_header, max_body) = {
2725        let attr_tokens: proc_macro2::TokenStream = attr.clone().into();
2726        if attr_tokens.is_empty() {
2727            // 默认值: 不限制 header,使用全局 body 限制
2728            (None, None)
2729        } else {
2730            // 尝试解析为命名参数或单值
2731            let result = syn::parse::Parser::parse2(
2732                |input: syn::parse::ParseStream| -> syn::Result<(Option<syn::Expr>, Option<syn::Expr>)> {
2733                    let mut header_expr = None;
2734                    let mut body_expr = None;
2735
2736                    // 尝试解析命名参数
2737                    while !input.is_empty() {
2738                        let ident: Ident = input.parse()?;
2739                        input.parse::<Token![=]>()?;
2740                        let value: syn::Expr = input.parse()?;
2741
2742                        match ident.to_string().as_str() {
2743                            "header" => header_expr = Some(value),
2744                            "body" => body_expr = Some(value),
2745                            _ => return Err(syn::Error::new(ident.span(), "expected 'header' or 'body'")),
2746                        }
2747
2748                        // 可选的逗号
2749                        if input.peek(Token![,]) {
2750                            input.parse::<Token![,]>()?;
2751                        }
2752                    }
2753
2754                    Ok((header_expr, body_expr))
2755                },
2756                attr_tokens.clone(),
2757            );
2758
2759            match result {
2760                Ok((h, b)) => (h, b),
2761                Err(_) => {
2762                    // 解析失败,尝试作为单值(body 限制)
2763                    if let Ok(expr) = syn::parse2::<syn::Expr>(attr_tokens) {
2764                        (None, Some(expr))
2765                    } else {
2766                        (None, None)
2767                    }
2768                }
2769            }
2770        }
2771    };
2772
2773    let root_fn = syn::parse_macro_input!(input as syn::ItemFn);
2774
2775    // 生成检查代码
2776    let body_check = if let Some(body_expr) = max_body {
2777        quote! {
2778            // 检查 body 大小
2779            let body_len = req.body.len();
2780            if body_len > #body_expr {
2781                let mut res = potato::HttpResponse::text(format!(
2782                    "Payload Too Large: body size {} bytes exceeds limit {} bytes",
2783                    body_len, #body_expr
2784                ));
2785                res.http_code = 413;
2786                return res;
2787            }
2788        }
2789    } else {
2790        quote! {}
2791    };
2792
2793    // 克隆整个函数,然后修改 block
2794    let mut wrapped_fn = root_fn.clone();
2795    let original_block = root_fn.block.as_ref();
2796    let new_block: syn::Block = syn::parse_quote!({
2797        #body_check
2798        #original_block
2799    });
2800    wrapped_fn.block = Box::new(new_block);
2801
2802    quote! {
2803        #wrapped_fn
2804    }
2805    .into()
2806}
2807
2808/// header 属性宏 - 这是一个占位宏,实际解析在 http_handler_macro 中完成
2809/// 这个宏的存在使得 #[potato::header(...)] 语法能够被编译器识别
2810#[proc_macro_attribute]
2811pub fn header(_attr: TokenStream, input: TokenStream) -> TokenStream {
2812    // 直接返回原始函数,不做任何修改
2813    // 实际的 header 解析和处理在 http_get/http_post 等宏中完成
2814    input
2815}
2816
2817/// cors 属性宏 - 这是一个占位宏,实际解析在 http_handler_macro 中完成
2818/// 这个宏的存在使得 #[potato::cors(...)] 语法能够被编译器识别
2819#[proc_macro_attribute]
2820pub fn cors(_attr: TokenStream, input: TokenStream) -> TokenStream {
2821    // 直接返回原始函数,不做任何修改
2822    // 实际的 cors 解析和处理在 http_handler_macro 中完成
2823    input
2824}
2825
2826#[proc_macro]
2827pub fn embed_dir(input: TokenStream) -> TokenStream {
2828    let path = syn::parse_macro_input!(input as syn::LitStr).value();
2829    quote! {{
2830        #[derive(potato::rust_embed::Embed)]
2831        #[folder = #path]
2832        struct Asset;
2833
2834        potato::load_embed::<Asset>()
2835    }}
2836    .into()
2837}
2838
2839#[proc_macro_derive(StandardHeader)]
2840pub fn standard_header_derive(input: TokenStream) -> TokenStream {
2841    let root_enum = syn::parse_macro_input!(input as syn::ItemEnum);
2842    let enum_name = root_enum.ident;
2843    let mut try_from_str_items = vec![];
2844    let mut to_str_items = vec![];
2845    let mut headers_items = vec![];
2846    let mut headers_apply_items = vec![];
2847    for root_field in root_enum.variants.iter() {
2848        let name = root_field.ident.clone();
2849        if root_field.fields.iter().next().is_some() {
2850            panic!("unsupported enum type");
2851        }
2852        let str_name = name.to_string().replace("_", "-");
2853        let len = str_name.len();
2854        try_from_str_items
2855            .push(quote! { #len if value.eq_ignore_ascii_case(#str_name) => Some(Self::#name), });
2856        to_str_items.push(quote! { Self::#name => #str_name, });
2857        headers_items.push(quote! { #name(String), });
2858        headers_apply_items
2859            .push(quote! { Headers::#name(s) => self.set_header(HeaderItem::#name.to_str(), s), });
2860    }
2861    let r = quote! {
2862        impl #enum_name {
2863            pub fn try_from_str(value: &str) -> Option<Self> {
2864                match value.len() {
2865                    #( #try_from_str_items )*
2866                    _ => None,
2867                }
2868            }
2869
2870            pub fn to_str(&self) -> &'static str {
2871                match self {
2872                    #( #to_str_items )*
2873                }
2874            }
2875        }
2876
2877        pub enum Headers {
2878            #( #headers_items )*
2879            Custom((String, String)),
2880        }
2881
2882        impl HttpRequest {
2883            pub fn apply_header(&mut self, header: Headers) {
2884                match header {
2885                    #( #headers_apply_items )*
2886                    Headers::Custom((k, v)) => self.set_header(&k[..], v),
2887                }
2888            }
2889        }
2890    };
2891    r.into()
2892}