hsnet_rpc_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Attribute, FnArg, ItemTrait, Meta, Pat, TraitItem};
4
5/// RPC trait 宏
6///
7/// 用法:
8/// ```ignore
9/// #[rpc(server)]
10/// pub trait ControlRpc {
11///     #[method(name = "wg_start")]
12///     async fn wg_start(&self, config: &WgConfig) -> Result<String, RpcError>;
13/// }
14/// ```
15#[proc_macro_attribute]
16pub fn rpc(_args: TokenStream, input: TokenStream) -> TokenStream {
17    let input_trait = parse_macro_input!(input as ItemTrait);
18
19    let trait_name = &input_trait.ident;
20    let trait_vis = &input_trait.vis;
21
22    // 提取所有 RPC 方法
23    let mut rpc_methods = Vec::new();
24
25    for item in &input_trait.items {
26        if let TraitItem::Fn(method) = item {
27            // 查找 #[method] 属性
28            let method_attr = method.attrs.iter().find(|attr| attr.path().is_ident("method"));
29
30            if let Some(attr) = method_attr {
31                // 解析 #[method(name = "xxx")]
32                let rpc_name = parse_method_name(attr);
33                if rpc_name.is_empty() {
34                    continue;
35                }
36
37                let method_name = &method.sig.ident;
38                let method_inputs = &method.sig.inputs;
39
40                // 提取参数(跳过 &self)
41                let params: Vec<_> = method_inputs
42                    .iter()
43                    .skip(1) // 跳过 &self
44                    .filter_map(|arg| {
45                        if let FnArg::Typed(pat_type) = arg {
46                            if let Pat::Ident(pat_ident) = &*pat_type.pat {
47                                let param_name = &pat_ident.ident;
48                                let param_type = &pat_type.ty;
49                                return Some((param_name.clone(), param_type.clone()));
50                            }
51                        }
52                        None
53                    })
54                    .collect();
55
56                // 提取 cfg 属性
57                let cfg_attrs: Vec<_> = method
58                    .attrs
59                    .iter()
60                    .filter(|attr| attr.path().is_ident("cfg"))
61                    .cloned()
62                    .collect();
63
64                rpc_methods.push(RpcMethod {
65                    rpc_name,
66                    method_name: method_name.clone(),
67                    params,
68                    cfg_attrs,
69                });
70            }
71        }
72    }
73
74    // 生成 handler 注册代码
75    let handlers = rpc_methods.iter().map(|method| {
76        let rpc_name = &method.rpc_name;
77        let method_name = &method.method_name;
78        let cfg_attrs = &method.cfg_attrs;
79
80        if method.params.is_empty() {
81            // 无参数方法
82            quote! {
83                #(#cfg_attrs)*
84                {
85                    server = server.handle(#rpc_name, {
86                        let service = service.clone();
87                        move |_: ()| {
88                            let service = service.clone();
89                            async move { service.#method_name().await }
90                        }
91                    });
92                }
93            }
94        } else if method.params.len() == 1 {
95            // 单参数方法 - 直接传递值,不加引用
96            let param_name = &method.params[0].0;
97            let param_type = &method.params[0].1;
98
99            quote! {
100                #(#cfg_attrs)*
101                {
102                    server = server.handle(#rpc_name, {
103                        let service = service.clone();
104                        move |#param_name: #param_type| {
105                            let service = service.clone();
106                            async move { service.#method_name(#param_name).await }
107                        }
108                    });
109                }
110            }
111        } else {
112            // 多参数 - 暂不支持
113            quote! {
114                compile_error!("Multiple parameters not yet supported");
115            }
116        }
117    });
118
119    // 过滤掉 #[rpc] 属性,保留其他属性
120    let trait_attrs: Vec<_> = input_trait
121        .attrs
122        .iter()
123        .filter(|attr| !attr.path().is_ident("rpc"))
124        .collect();
125
126    // 清理 trait items,去掉所有 #[method] 属性
127    let cleaned_items: Vec<_> = input_trait
128        .items
129        .iter()
130        .map(|item| {
131            if let TraitItem::Fn(method) = item {
132                let mut method = method.clone();
133                method.attrs.retain(|attr| !attr.path().is_ident("method"));
134                TraitItem::Fn(method)
135            } else {
136                item.clone()
137            }
138        })
139        .collect();
140
141    let trait_generics = &input_trait.generics;
142
143    // 生成最终代码
144    let expanded = quote! {
145        // 保留原始 trait 定义(去掉 #[rpc] 和 #[method] 属性)
146        #(#trait_attrs)*
147        #trait_vis trait #trait_name #trait_generics {
148            #(#cleaned_items)*
149        }
150
151        // 生成扩展 trait
152        #trait_vis trait IntoRpcServer {
153            fn into_rpc(self: std::sync::Arc<Self>) -> hsnet_rpc::RpcServer;
154        }
155
156        // 为所有实现了原 trait 的类型实现 IntoRpcServer
157        impl<T: #trait_name + Send + Sync + 'static> IntoRpcServer for T {
158            fn into_rpc(self: std::sync::Arc<Self>) -> hsnet_rpc::RpcServer {
159                let service = self;
160                let mut server = hsnet_rpc::RpcServer::new();
161
162                #(#handlers)*
163
164                server
165            }
166        }
167    };
168
169    TokenStream::from(expanded)
170}
171
172// 辅助函数:解析 #[method(name = "xxx")]
173fn parse_method_name(attr: &Attribute) -> String {
174    if let Meta::List(meta_list) = &attr.meta {
175        let tokens_str = meta_list.tokens.to_string();
176
177        // 简单解析 name = "xxx"
178        if let Some(start) = tokens_str.find('"') {
179            if let Some(end) = tokens_str[start + 1..].find('"') {
180                return tokens_str[start + 1..start + 1 + end].to_string();
181            }
182        }
183    }
184    String::new()
185}
186
187struct RpcMethod {
188    rpc_name: String,
189    method_name: syn::Ident,
190    params: Vec<(syn::Ident, Box<syn::Type>)>,
191    cfg_attrs: Vec<Attribute>,
192}