Skip to main content

silent_openapi_macros/
lib.rs

1use convert_case::Casing;
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::Token;
5use syn::punctuated::Punctuated;
6use syn::{
7    Expr, ExprLit, FnArg, ItemFn, Lit, Meta, Result as SynResult, parse::Parse, parse::ParseStream,
8};
9
10fn endpoint_impl(
11    attr: proc_macro2::TokenStream,
12    item: proc_macro2::TokenStream,
13) -> proc_macro2::TokenStream {
14    struct MetaArgs(Punctuated<Meta, Token![,]>);
15    impl Parse for MetaArgs {
16        fn parse(input: ParseStream) -> SynResult<Self> {
17            Ok(MetaArgs(Punctuated::parse_terminated(input)?))
18        }
19    }
20    let MetaArgs(args) = syn::parse2::<MetaArgs>(attr).expect("parse attr");
21    let mut summary_arg: Option<String> = None;
22    let mut description_arg: Option<String> = None;
23    for meta in args {
24        if let Meta::NameValue(nv) = meta {
25            if nv.path.is_ident("summary")
26                && let Expr::Lit(ExprLit {
27                    lit: Lit::Str(s), ..
28                }) = &nv.value
29            {
30                summary_arg = Some(s.value());
31            } else if nv.path.is_ident("description")
32                && let Expr::Lit(ExprLit {
33                    lit: Lit::Str(s), ..
34                }) = &nv.value
35            {
36                description_arg = Some(s.value());
37            }
38        }
39    }
40
41    let input: ItemFn = syn::parse2(item).expect("parse item fn");
42    let vis = &input.vis;
43    let sig = input.sig.clone();
44    let attrs = &input.attrs;
45    let block = &input.block;
46    let name = &sig.ident;
47
48    // 收集文档注释作为默认 summary/description
49    let mut doc_lines: Vec<String> = Vec::new();
50    for a in attrs.iter() {
51        if a.path().is_ident("doc") {
52            let _ = a.parse_nested_meta(|meta| {
53                let lit: syn::LitStr = meta.value()?.parse()?;
54                let v = lit.value();
55                doc_lines.push(v.trim().to_string());
56                Ok(())
57            });
58        }
59    }
60    let (def_summary, def_description) = if !doc_lines.is_empty() {
61        let mut it = doc_lines.into_iter().filter(|s| !s.is_empty());
62        if let Some(first) = it.next() {
63            let rest = it.collect::<Vec<_>>().join("\n");
64            (Some(first), if rest.is_empty() { None } else { Some(rest) })
65        } else {
66            (None, None)
67        }
68    } else {
69        (None, None)
70    };
71
72    let summary = summary_arg.or(def_summary);
73    let description = description_arg.or(def_description);
74
75    // 真实处理函数改名
76    let impl_name = format_ident!("{}_impl", name);
77    // 生成实现函数签名(重命名)
78    let mut impl_sig = sig.clone();
79    impl_sig.ident = impl_name.clone();
80
81    // 端点类型 + 常量(实现与原 `.get(get_xxx)` 风格兼容)
82    let ep_ty = format_ident!(
83        "{}Endpoint",
84        name.to_string().to_case(convert_case::Case::UpperCamel)
85    );
86    let sum_tokens = if let Some(s) = &summary {
87        let lit = syn::LitStr::new(s, proc_macro2::Span::call_site());
88        quote!(Some(#lit))
89    } else {
90        quote!(None)
91    };
92    let desc_tokens = if let Some(s) = &description {
93        let lit = syn::LitStr::new(s, proc_macro2::Span::call_site());
94        quote!(Some(#lit))
95    } else {
96        quote!(None)
97    };
98
99    // 解析返回类型 Ok(T) -> ResponseMeta
100    let ret_meta = {
101        match &sig.output {
102            syn::ReturnType::Type(_, ty) => {
103                if let syn::Type::Path(tp) = ty.as_ref() {
104                    if let Some(seg) = tp.path.segments.last() {
105                        if seg.ident == "Result" || seg.ident == "SilentResult" {
106                            if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
107                                if let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first() {
108                                    match ok_ty {
109                                        syn::Type::Path(tpath) => {
110                                            if let Some(id) = tpath.path.segments.last() {
111                                                if id.ident == "Response" {
112                                                    quote!(None)
113                                                } else if id.ident == "String" {
114                                                    quote!(Some(::silent_openapi::doc::ResponseMeta::TextPlain))
115                                                } else {
116                                                    let tn = id.ident.to_string();
117                                                    quote!(Some(::silent_openapi::doc::ResponseMeta::Json { type_name: #tn }))
118                                                }
119                                            } else {
120                                                quote!(None)
121                                            }
122                                        }
123                                        syn::Type::Reference(r) => {
124                                            if let syn::Type::Path(tp2) = r.elem.as_ref() {
125                                                if let Some(id) = tp2.path.segments.last() {
126                                                    if id.ident == "str" {
127                                                        quote!(Some(::silent_openapi::doc::ResponseMeta::TextPlain))
128                                                    } else {
129                                                        let tn = id.ident.to_string();
130                                                        quote!(Some(::silent_openapi::doc::ResponseMeta::Json { type_name: #tn }))
131                                                    }
132                                                } else {
133                                                    quote!(None)
134                                                }
135                                            } else {
136                                                quote!(None)
137                                            }
138                                        }
139                                        _ => quote!(None),
140                                    }
141                                } else {
142                                    quote!(None)
143                                }
144                            } else {
145                                quote!(None)
146                            }
147                        } else {
148                            quote!(None)
149                        }
150                    } else {
151                        quote!(None)
152                    }
153                } else {
154                    quote!(None)
155                }
156            }
157            _ => quote!(None),
158        }
159    };
160
161    // 为自定义 Ok(T) 注册 ToSchema 完整 schema
162    let ret_schema_register = {
163        match &sig.output {
164            syn::ReturnType::Type(_, ty) => {
165                if let syn::Type::Path(tp) = ty.as_ref() {
166                    if let Some(seg) = tp.path.segments.last() {
167                        if seg.ident == "Result" || seg.ident == "SilentResult" {
168                            if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
169                                if let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first() {
170                                    match ok_ty {
171                                        syn::Type::Path(tpath) => {
172                                            if let Some(id) = tpath.path.segments.last() {
173                                                if id.ident == "Response" || id.ident == "String" {
174                                                    quote!()
175                                                } else {
176                                                    let ty = ok_ty.clone();
177                                                    quote!(::silent_openapi::doc::register_schema_for::<#ty>();)
178                                                }
179                                            } else {
180                                                quote!()
181                                            }
182                                        }
183                                        syn::Type::Reference(r) => {
184                                            if let syn::Type::Path(tp2) = r.elem.as_ref() {
185                                                if let Some(id) = tp2.path.segments.last() {
186                                                    if id.ident == "str" {
187                                                        quote!()
188                                                    } else {
189                                                        let inner = tp2.clone();
190                                                        quote!(::silent_openapi::doc::register_schema_for::<#inner>();)
191                                                    }
192                                                } else {
193                                                    quote!()
194                                                }
195                                            } else {
196                                                quote!()
197                                            }
198                                        }
199                                        _ => quote!(),
200                                    }
201                                } else {
202                                    quote!()
203                                }
204                            } else {
205                                quote!()
206                            }
207                        } else {
208                            quote!()
209                        }
210                    } else {
211                        quote!()
212                    }
213                } else {
214                    quote!()
215                }
216            }
217            _ => quote!(),
218        }
219    };
220
221    // 从提取器类型中生成请求元信息注册代码
222    fn gen_request_meta_register(ty: &syn::Type) -> proc_macro2::TokenStream {
223        if let syn::Type::Path(tp) = ty {
224            if let Some(seg) = tp.path.segments.last() {
225                let ident = seg.ident.to_string();
226                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
227                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
228                        // 获取内部类型名称
229                        let inner_name = if let syn::Type::Path(inner_tp) = inner_ty {
230                            inner_tp
231                                .path
232                                .segments
233                                .last()
234                                .map(|s| s.ident.to_string())
235                                .unwrap_or_default()
236                        } else {
237                            String::new()
238                        };
239
240                        if !inner_name.is_empty() {
241                            match ident.as_str() {
242                                "Json" => {
243                                    let inner = inner_ty.clone();
244                                    return quote! {
245                                        ::silent_openapi::doc::register_request_by_ptr(
246                                            ptr,
247                                            ::silent_openapi::doc::RequestMeta::JsonBody { type_name: #inner_name },
248                                        );
249                                        ::silent_openapi::doc::register_schema_for::<#inner>();
250                                    };
251                                }
252                                "Form" => {
253                                    let inner = inner_ty.clone();
254                                    return quote! {
255                                        ::silent_openapi::doc::register_request_by_ptr(
256                                            ptr,
257                                            ::silent_openapi::doc::RequestMeta::FormBody { type_name: #inner_name },
258                                        );
259                                        ::silent_openapi::doc::register_schema_for::<#inner>();
260                                    };
261                                }
262                                "Query" => {
263                                    let inner = inner_ty.clone();
264                                    return quote! {
265                                        ::silent_openapi::doc::register_request_by_ptr(
266                                            ptr,
267                                            ::silent_openapi::doc::RequestMeta::QueryParams { type_name: #inner_name },
268                                        );
269                                        ::silent_openapi::doc::register_schema_for::<#inner>();
270                                    };
271                                }
272                                _ => {}
273                            }
274                        }
275                    }
276                }
277            }
278        }
279        quote!()
280    }
281
282    // 根据函数参数形态生成 IntoRouteHandler 实现
283    let inputs = sig.inputs.clone().into_iter().collect::<Vec<_>>();
284    let impls = if inputs.len() == 1 {
285        match &inputs[0] {
286            FnArg::Typed(pat_ty) => {
287                let ty = &pat_ty.ty;
288                // 简单规则:类型标识名为 Request 则认为是 Request 形态
289                let is_request = matches!(
290                    &**ty,
291                    syn::Type::Path(tp) if tp.path.segments.last().map(|s| s.ident == "Request").unwrap_or(false)
292                );
293                if is_request {
294                    quote! {
295                        impl ::silent::prelude::IntoRouteHandler<::silent::Request> for #ep_ty {
296                            fn into_handler(self) -> std::sync::Arc<dyn ::silent::Handler> {
297                                let handler = std::sync::Arc::new(::silent::HandlerWrapper::new(#impl_name));
298                                let ptr = std::sync::Arc::as_ptr(&handler) as *const () as usize;
299                                ::silent_openapi::doc::register_doc_by_ptr(
300                                    ptr,
301                                    #sum_tokens,
302                                    #desc_tokens,
303                                );
304                                #ret_schema_register
305                                if let Some(meta) = #ret_meta { ::silent_openapi::doc::register_response_by_ptr(ptr, meta); }
306                                handler
307                            }
308                        }
309                    }
310                } else {
311                    // 单萃取器参数
312                    let req_meta_register = gen_request_meta_register(ty);
313                    quote! {
314                        impl ::silent::prelude::IntoRouteHandler<#ty> for #ep_ty {
315                            fn into_handler(self) -> std::sync::Arc<dyn ::silent::Handler> {
316                                let adapted = ::silent::extractor::handler_from_extractor::<#ty, _, _, _>(#impl_name);
317                                let handler = std::sync::Arc::new(::silent::HandlerWrapper::new(adapted));
318                                let ptr = std::sync::Arc::as_ptr(&handler) as *const () as usize;
319                                ::silent_openapi::doc::register_doc_by_ptr(
320                                    ptr,
321                                    #sum_tokens,
322                                    #desc_tokens,
323                                );
324                                #ret_schema_register
325                                if let Some(meta) = #ret_meta { ::silent_openapi::doc::register_response_by_ptr(ptr, meta); }
326                                #req_meta_register
327                                handler
328                            }
329                        }
330                    }
331                }
332            }
333            _ => quote! {},
334        }
335    } else if inputs.len() == 2 {
336        match (&inputs[0], &inputs[1]) {
337            (FnArg::Typed(first), FnArg::Typed(second)) => {
338                let ty1 = &first.ty;
339                let ty2 = &second.ty;
340                // 期望形态: (Request, Args)
341                let is_request_first = matches!(
342                    &**ty1,
343                    syn::Type::Path(tp) if tp.path.segments.last().map(|s| s.ident == "Request").unwrap_or(false)
344                );
345                if is_request_first {
346                    let req_meta_register = gen_request_meta_register(ty2);
347                    quote! {
348                        impl ::silent::prelude::IntoRouteHandler<(::silent::Request, #ty2)> for #ep_ty {
349                            fn into_handler(self) -> std::sync::Arc<dyn ::silent::Handler> {
350                                let adapted = ::silent::extractor::handler_from_extractor_with_request::<#ty2, _, _, _>(#impl_name);
351                                let handler = std::sync::Arc::new(::silent::HandlerWrapper::new(adapted));
352                                let ptr = std::sync::Arc::as_ptr(&handler) as *const () as usize;
353                                ::silent_openapi::doc::register_doc_by_ptr(
354                                    ptr,
355                                    #sum_tokens,
356                                    #desc_tokens,
357                                );
358                                #ret_schema_register
359                                if let Some(meta) = #ret_meta { ::silent_openapi::doc::register_response_by_ptr(ptr, meta); }
360                                #req_meta_register
361                                handler
362                            }
363                        }
364                    }
365                } else {
366                    quote! {}
367                }
368            }
369            _ => quote! {},
370        }
371    } else {
372        quote! {}
373    };
374
375    let code = quote! {
376        // 原函数体改名为实现函数
377        #(#attrs)*
378        #impl_sig #block
379
380        // 端点类型(零尺寸) + 常量,同名以保留 `.get(get_xxx)` 调用方式
381        pub struct #ep_ty;
382        #[allow(non_upper_case_globals)]
383        #vis const #name: #ep_ty = #ep_ty;
384
385        #impls
386    };
387
388    code
389}
390
391#[proc_macro_attribute]
392pub fn endpoint(attr: TokenStream, item: TokenStream) -> TokenStream {
393    endpoint_impl(attr.into(), item.into()).into()
394}
395
396#[cfg(test)]
397mod tests {
398    use quote::quote;
399
400    fn render(ts: proc_macro2::TokenStream) -> String {
401        ts.to_string()
402    }
403
404    #[test]
405    fn generates_endpoint_type_and_const_for_request_sig() {
406        let attr = quote!(summary = "hello", description = "world");
407        let item = quote!(
408            async fn get_hello(_req: ::silent::Request) -> ::silent::Result<::silent::Response> {
409                unimplemented!()
410            }
411        );
412        let out = super::endpoint_impl(attr, item);
413        let s = render(out);
414        assert!(s.contains("struct GetHelloEndpoint"));
415        assert!(s.contains("const get_hello"));
416    }
417
418    #[test]
419    fn generates_into_route_handler_for_extractor_sig() {
420        let attr = quote!();
421        let item = quote!(
422            async fn get_user(_id: Path<u64>) -> ::silent::Result<::silent::Response> {
423                unimplemented!()
424            }
425        );
426        let out = super::endpoint_impl(attr, item);
427        let s = render(out);
428        // 生成的端点常量与 IntoRouteHandler 实现
429        assert!(s.contains("struct GetUserEndpoint"));
430        assert!(s.contains("const get_user"));
431        assert!(s.contains("IntoRouteHandler"));
432        assert!(s.contains("GetUserEndpoint"));
433    }
434
435    #[test]
436    fn registers_request_meta_for_json_extractor() {
437        let attr = quote!();
438        let item = quote!(
439            async fn create_user(body: Json<UserInput>) -> ::silent::Result<::silent::Response> {
440                unimplemented!()
441            }
442        );
443        let out = super::endpoint_impl(attr, item);
444        let s = render(out);
445        assert!(s.contains("RequestMeta :: JsonBody"));
446        assert!(s.contains("register_request_by_ptr"));
447        assert!(s.contains("register_schema_for"));
448    }
449
450    #[test]
451    fn registers_request_meta_for_query_extractor() {
452        let attr = quote!();
453        let item = quote!(
454            async fn list_users(params: Query<ListParams>) -> ::silent::Result<::silent::Response> {
455                unimplemented!()
456            }
457        );
458        let out = super::endpoint_impl(attr, item);
459        let s = render(out);
460        assert!(s.contains("RequestMeta :: QueryParams"));
461        assert!(s.contains("register_request_by_ptr"));
462    }
463
464    #[test]
465    fn registers_request_meta_for_form_extractor() {
466        let attr = quote!();
467        let item = quote!(
468            async fn submit_form(data: Form<FormData>) -> ::silent::Result<::silent::Response> {
469                unimplemented!()
470            }
471        );
472        let out = super::endpoint_impl(attr, item);
473        let s = render(out);
474        assert!(s.contains("RequestMeta :: FormBody"));
475        assert!(s.contains("register_request_by_ptr"));
476    }
477
478    #[test]
479    fn registers_request_meta_for_request_with_extractor() {
480        let attr = quote!();
481        let item = quote!(
482            async fn update_user(
483                _req: ::silent::Request,
484                body: Json<UserInput>,
485            ) -> ::silent::Result<::silent::Response> {
486                unimplemented!()
487            }
488        );
489        let out = super::endpoint_impl(attr, item);
490        let s = render(out);
491        assert!(s.contains("RequestMeta :: JsonBody"));
492        assert!(s.contains("register_request_by_ptr"));
493    }
494
495    #[test]
496    fn no_request_meta_for_plain_request() {
497        let attr = quote!();
498        let item = quote!(
499            async fn health(_req: ::silent::Request) -> ::silent::Result<::silent::Response> {
500                unimplemented!()
501            }
502        );
503        let out = super::endpoint_impl(attr, item);
504        let s = render(out);
505        assert!(!s.contains("register_request_by_ptr"));
506    }
507
508    #[test]
509    fn registers_schema_for_enum_return_type() {
510        let attr = quote!();
511        let item = quote!(
512            async fn get_status(_req: ::silent::Request) -> ::silent::Result<ApiResponse> {
513                unimplemented!()
514            }
515        );
516        let out = super::endpoint_impl(attr, item);
517        let s = render(out);
518        // 枚举返回类型应生成 Json 响应元信息和 schema 注册
519        assert!(s.contains("ResponseMeta :: Json"));
520        assert!(s.contains("register_schema_for"));
521        assert!(s.contains("ApiResponse"));
522    }
523
524    #[test]
525    fn registers_schema_for_enum_request_body() {
526        let attr = quote!();
527        let item = quote!(
528            async fn create_item(body: Json<CreateAction>) -> ::silent::Result<::silent::Response> {
529                unimplemented!()
530            }
531        );
532        let out = super::endpoint_impl(attr, item);
533        let s = render(out);
534        // 枚举请求体类型同样应注册 schema
535        assert!(s.contains("RequestMeta :: JsonBody"));
536        assert!(s.contains("register_schema_for"));
537        assert!(s.contains("CreateAction"));
538    }
539
540    #[test]
541    fn doc_comment_as_summary_and_description() {
542        let attr = quote!();
543        let item = quote!(
544            /// 获取用户信息
545            ///
546            /// 根据用户 ID 查询完整的用户资料
547            async fn get_user(_req: ::silent::Request) -> ::silent::Result<::silent::Response> {
548                unimplemented!()
549            }
550        );
551        let out = super::endpoint_impl(attr, item);
552        let s = render(out);
553        assert!(s.contains("获取用户信息"));
554        assert!(s.contains("根据用户 ID 查询完整的用户资料"));
555    }
556
557    #[test]
558    fn registers_response_meta_for_string() {
559        let attr = quote!();
560        let item = quote!(
561            async fn ping(_req: ::silent::Request) -> ::silent::Result<String> {
562                unimplemented!()
563            }
564        );
565        let out = super::endpoint_impl(attr, item);
566        let s = render(out);
567        // 生成文本响应的注册调用
568        assert!(s.contains("ResponseMeta :: TextPlain"));
569    }
570}