potatonet_codegen/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::Ident;
5use quote::quote;
6use syn::spanned::Spanned;
7use syn::{
8    parse_macro_input, AngleBracketedGenericArguments, Attribute, AttributeArgs, Block,
9    DeriveInput, Error, FnArg, GenericArgument, ImplItem, ImplItemMethod, ItemImpl, LitStr, Meta,
10    NestedMeta, Pat, PatIdent, PathArguments, Result, ReturnType, Type, TypePath,
11};
12
13#[derive(Copy, Clone, Eq, PartialEq, Debug)]
14enum MethodType {
15    Call,
16    Notify,
17}
18
19struct MethodInfo {
20    ty: MethodType,
21    name: Option<LitStr>,
22}
23
24enum MethodResult<'a> {
25    Default,
26    Value(&'a TypePath),
27    Result(&'a TypePath),
28}
29
30struct Method<'a> {
31    ty: MethodType,
32    name: Ident,
33    context: Option<&'a PatIdent>,
34    args: Vec<(&'a PatIdent, &'a TypePath)>,
35    result: MethodResult<'a>,
36    block: &'a Block,
37}
38
39/// 解析method定义
40fn parse_method_info(attrs: &[Attribute]) -> Option<MethodInfo> {
41    for attr in attrs {
42        match attr.parse_meta() {
43            Ok(Meta::Path(path)) => {
44                if path.is_ident("call") {
45                    return Some(MethodInfo {
46                        ty: MethodType::Call,
47                        name: None,
48                    });
49                } else if path.is_ident("notify") {
50                    return Some(MethodInfo {
51                        ty: MethodType::Notify,
52                        name: None,
53                    });
54                }
55            }
56            Ok(Meta::List(list)) => {
57                let ty = if list.path.is_ident("call") {
58                    Some(MethodType::Call)
59                } else if list.path.is_ident("notify") {
60                    Some(MethodType::Notify)
61                } else {
62                    None
63                };
64
65                if let Some(ty) = ty {
66                    let mut name = None;
67                    for arg in list.nested {
68                        if let NestedMeta::Meta(Meta::NameValue(nv)) = arg {
69                            if nv.path.is_ident("name") {
70                                if let syn::Lit::Str(lit) = nv.lit {
71                                    name = Some(lit);
72                                }
73                            }
74                        }
75                    }
76                    return Some(MethodInfo { ty, name });
77                }
78            }
79            _ => {}
80        }
81    }
82
83    None
84}
85
86/// 解析method
87fn parse_method(info: MethodInfo, method: &ImplItemMethod) -> Result<Method> {
88    let name = info
89        .name
90        .map(|lit| Ident::new(&lit.value(), lit.span()))
91        .unwrap_or(method.sig.ident.clone());
92
93    if method.sig.asyncness.is_none() {
94        return Err(Error::new(method.span(), "invalid method"));
95    }
96
97    // 解析参数
98    let mut args = Vec::new();
99    let mut context = None;
100    for (idx, arg) in method.sig.inputs.iter().enumerate() {
101        if let FnArg::Receiver(receiver) = arg {
102            if idx != 0 {
103                // self必须是第一个参数
104                return Err(Error::new(receiver.span(), "invalid method"));
105            }
106            if receiver.mutability.is_some() {
107                // 不能是可变借用
108                return Err(Error::new(receiver.mutability.span(), "invalid method"));
109            }
110        } else if let FnArg::Typed(pat) = arg {
111            if idx == 0 {
112                // 第一个参数必须是self
113                return Err(Error::new(pat.span(), "invalid method"));
114            }
115
116            match (&*pat.pat, &*pat.ty) {
117                // 参数
118                (Pat::Ident(id), Type::Path(ty)) => args.push((id, ty)),
119                // Context
120                (Pat::Ident(id), Type::Reference(ty)) => {
121                    if idx != 1 {
122                        // context必须是第二个参数
123                        return Err(Error::new(pat.span(), "invalid method"));
124                    }
125
126                    if ty.mutability.is_some() {
127                        // context必须是不可变借用
128                        return Err(Error::new(pat.span(), "invalid method"));
129                    }
130
131                    if let Type::Path(path) = ty.elem.as_ref() {
132                        if path.path.segments.last().unwrap().ident.to_string() == "NodeContext" {
133                            let seg = &path.path.segments.last().unwrap();
134                            if let PathArguments::AngleBracketed(angle_args) = &seg.arguments {
135                                if angle_args.args.len() != 1 {
136                                    // context的泛型参数错误
137                                    return Err(Error::new(pat.span(), "invalid method"));
138                                }
139                                if let GenericArgument::Lifetime(life) = &angle_args.args[0] {
140                                    if life.ident.to_string() != "_" {
141                                        // context的泛型参数错误
142                                        return Err(Error::new(pat.span(), "invalid method"));
143                                    }
144                                    context = Some(id);
145                                } else {
146                                    // context的泛型参数错误
147                                    return Err(Error::new(pat.span(), "invalid method"));
148                                }
149                            } else {
150                                // context的泛型参数错误
151                                return Err(Error::new(pat.span(), "invalid method"));
152                            }
153                        } else {
154                            // 不是context类型
155                            return Err(Error::new(pat.span(), "invalid method"));
156                        }
157                    } else {
158                        // 不是context类型
159                        return Err(Error::new(pat.span(), "invalid method"));
160                    }
161                }
162                _ => return Err(Error::new(pat.span(), "invalid method")),
163            }
164        }
165    }
166
167    // 解析返回值
168    let result = match info.ty {
169        MethodType::Call => {
170            match &method.sig.output {
171                ReturnType::Default => MethodResult::Default,
172                ReturnType::Type(_, ty) => {
173                    if let Type::Path(type_path) = ty.as_ref() {
174                        let is_result = if type_path.path.segments.len() == 1 {
175                            type_path.path.segments[0].ident.to_string() == "Result"
176                        } else {
177                            false
178                        };
179
180                        if is_result {
181                            if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
182                                args,
183                                ..
184                            }) = &type_path.path.segments[0].arguments
185                            {
186                                if args.len() != 1 {
187                                    // 错误的result类型
188                                    return Err(Error::new(
189                                        method.sig.output.span(),
190                                        "invalid method",
191                                    ));
192                                }
193                                let value = match &args[0] {
194                                    GenericArgument::Type(Type::Path(path)) => path,
195                                    _ => {
196                                        return Err(Error::new(
197                                            method.sig.output.span(),
198                                            "invalid method",
199                                        ))
200                                    }
201                                };
202                                MethodResult::Result(value)
203                            } else {
204                                // 错误的result类型
205                                return Err(Error::new(method.sig.output.span(), "invalid method"));
206                            }
207                        } else {
208                            MethodResult::Value(type_path)
209                        }
210                    } else {
211                        // 不支持的返回值类型
212                        return Err(Error::new(method.sig.output.span(), "invalid method"));
213                    }
214                }
215            }
216        }
217        MethodType::Notify => {
218            // notify不能有返回值
219            match method.sig.output {
220                ReturnType::Default => MethodResult::Default,
221                _ => return Err(Error::new(method.sig.output.span(), "invalid method")),
222            }
223        }
224    };
225
226    Ok(Method {
227        ty: info.ty,
228        name,
229        context,
230        args,
231        result,
232        block: &method.block,
233    })
234}
235
236#[proc_macro_attribute]
237pub fn service(_args: TokenStream, input: TokenStream) -> TokenStream {
238    let impl_item = parse_macro_input!(input as ItemImpl);
239    let (self_ty, self_name) = match impl_item.self_ty.as_ref() {
240        Type::Path(path) => (
241            path,
242            path.path
243                .segments
244                .last()
245                .map(|s| s.ident.to_string())
246                .unwrap(),
247        ),
248        _ => {
249            return Error::new(impl_item.span(), "invalid method")
250                .to_compile_error()
251                .into()
252        }
253    };
254    let client_ty = Ident::new(&format!("{}Client", self_name), self_ty.span());
255    let client_notifyto_ty = Ident::new(&format!("{}ClientNotifyTo", self_name), self_ty.span());
256    let req_type_name = Ident::new(&format!("__RequestType_{}", self_name), self_ty.span());
257    let rep_type_name = Ident::new(&format!("__ResponseType{}", self_name), self_ty.span());
258    let notify_type_name = Ident::new(&format!("__NotifyType{}", self_name), self_ty.span());
259    let mut methods = Vec::new();
260    let mut other_methods = Vec::new();
261    let mut internal_methods = Vec::new();
262
263    for item in &impl_item.items {
264        if let ImplItem::Method(method) = item {
265            let ident = method.sig.ident.to_string();
266            if let Some(method_info) = parse_method_info(&method.attrs) {
267                let method = match parse_method(method_info, method) {
268                    Ok(method) => method,
269                    Err(err) => return err.to_compile_error().into(),
270                };
271                methods.push(method);
272            } else if ident == "start" || ident == "stop" {
273                // 开始或者停止服务
274                other_methods.push(item);
275            } else {
276                // 内部函数
277                internal_methods.push(item);
278            }
279        }
280    }
281
282    let expanded = {
283        // 请求类型
284        let req_type = {
285            let mut reqs = Vec::new();
286            for method in methods
287                .iter()
288                .filter(|method| method.ty == MethodType::Call)
289            {
290                let name = Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
291                let types = method.args.iter().map(|(_, ty)| ty).collect::<Vec<_>>();
292                reqs.push(quote! { #name(#(#types),*) });
293            }
294            quote! {
295                #[derive(potatonet::serde_derive::Serialize, potatonet::serde_derive::Deserialize)]
296                pub enum #req_type_name { #(#reqs),* }
297            }
298        };
299
300        // 响应类型
301        let rep_type = {
302            let mut reps = Vec::new();
303            for method in methods
304                .iter()
305                .filter(|method| method.ty == MethodType::Call)
306            {
307                let name = Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
308                match &method.result {
309                    MethodResult::Value(ty) => reps.push(quote! { #name(#ty) }),
310                    MethodResult::Result(ty) => reps.push(quote! { #name(#ty) }),
311                    MethodResult::Default => {}
312                }
313            }
314            quote! {
315                #[derive(potatonet::serde_derive::Serialize, potatonet::serde_derive::Deserialize)]
316                pub enum #rep_type_name { #(#reps),* }
317            }
318        };
319
320        // 通知类型
321        let notify_type = {
322            let mut notify = Vec::new();
323            for method in methods
324                .iter()
325                .filter(|method| method.ty == MethodType::Notify)
326            {
327                let name = Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
328                let types = method.args.iter().map(|(_, ty)| ty).collect::<Vec<_>>();
329                notify.push(quote! { #name(#(#types),*) });
330            }
331            quote! {
332                #[derive(potatonet::serde_derive::Serialize, potatonet::serde_derive::Deserialize)]
333                pub enum #notify_type_name { #(#notify),* }
334            }
335        };
336
337        // 请求处理代码
338        let req_handler = {
339            let mut list = Vec::new();
340
341            for (method_id, method) in methods
342                .iter()
343                .enumerate()
344                .filter(|(_, method)| method.ty == MethodType::Call)
345            {
346                let method_id = method_id as u32;
347                let vars = method.args.iter().map(|(name, _)| name).collect::<Vec<_>>();
348                let name = Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
349                let block = method.block;
350                let ctx = match method.context {
351                    Some(id) => quote! { let #id = ctx; },
352                    None => quote! {},
353                };
354
355                match &method.result {
356                    MethodResult::Default => {
357                        list.push(quote! {
358                            if request.method == #method_id {
359                                if let #req_type_name::#name(#(#vars),*) = request.data {
360                                    #ctx
361                                    return Ok(potatonet::Response::new(#rep_type_name::#name(#block)));
362                                }
363                            }
364                        });
365                    }
366                    MethodResult::Value(_) => {
367                        list.push(quote! {
368                            if request.method == #method_id {
369                                if let #req_type_name::#name(#(#vars),*) = request.data {
370                                    #ctx
371                                    let res = #block;
372                                    return Ok(potatonet::Response::new(#rep_type_name::#name(res)));
373                                }
374                            }
375                        });
376                    }
377                    MethodResult::Result(_) => {
378                        list.push(quote! {
379                            if request.method == #method_id {
380                                if let #req_type_name::#name(#(#vars),*) = request.data {
381                                    #ctx
382                                    let res: potatonet::Result<potatonet::Response<Self::Rep>> = #block.map(|x| potatonet::Response::new(#rep_type_name::#name(x)));
383                                    return res;
384                                }
385                            }
386                        });
387                    }
388                }
389            }
390
391            quote! { #(#list)* }
392        };
393
394        // 通知处理代码
395        let notify_handler = {
396            let mut list = Vec::new();
397
398            for (method_id, method) in methods
399                .iter()
400                .enumerate()
401                .filter(|(_, method)| method.ty == MethodType::Notify)
402            {
403                let method_id = method_id as u32;
404                let vars = method.args.iter().map(|(name, _)| name).collect::<Vec<_>>();
405                let name = Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
406                let ctx = match method.context {
407                    Some(id) => quote! { let #id = ctx; },
408                    None => quote! {},
409                };
410                let block = method.block;
411
412                list.push(quote! {
413                    if request.method == #method_id {
414                        if let #notify_type_name::#name(#(#vars),*) = request.data {
415                            #ctx
416                            #block
417                        }
418                    }
419                });
420            }
421
422            quote! { #(#list)* }
423        };
424
425        // 客户端函数
426        let client_methods = {
427            let mut client_methods = Vec::new();
428            for (method_id, method) in methods.iter().enumerate() {
429                let method_id = method_id as u32;
430                let client_method = {
431                    let method_name = &method.name;
432                    let name =
433                        Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
434                    let params = method.args.iter().map(|(name, ty)| {
435                        quote! { #name: #ty }
436                    });
437                    let vars = method.args.iter().map(|(name, _)| name).collect::<Vec<_>>();
438                    match method.ty {
439                        MethodType::Call => {
440                            let res_type = match &method.result {
441                                MethodResult::Default => quote! { () },
442                                MethodResult::Value(value) => quote! { #value },
443                                MethodResult::Result(value) => quote! { #value },
444                            };
445                            quote! {
446                                pub async fn #method_name(&self, #(#params),*) -> potatonet::Result<#res_type> {
447                                    let res = self.ctx.call::<_, #rep_type_name>(&self.service_name, potatonet::Request::new(#method_id, #req_type_name::#name(#(#vars),*))).await?;
448                                    if let potatonet::Response{data: #rep_type_name::#name(value)} = res {
449                                        Ok(value)
450                                    } else {
451                                        unreachable!()
452                                    }
453                                }
454                            }
455                        }
456                        MethodType::Notify => {
457                            quote! {
458                                pub async fn #method_name(&self, #(#params),*) {
459                                    self.ctx.notify(&self.service_name, potatonet::Request::new(#method_id, #notify_type_name::#name(#(#vars),*))).await
460                                }
461                            }
462                        }
463                    }
464                };
465                client_methods.push(client_method);
466            }
467            client_methods
468        };
469
470        // 定向通知的客户端函数
471        let client_notifyto_methods = {
472            let mut client_methods = Vec::new();
473            for (method_id, method) in methods.iter().enumerate() {
474                let method_id = method_id as u32;
475                let method_name = &method.name;
476                let name = Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
477                let params = method.args.iter().map(|(name, ty)| {
478                    quote! { #name: #ty }
479                });
480                let vars = method.args.iter().map(|(name, _)| name).collect::<Vec<_>>();
481                match method.ty {
482                    MethodType::Notify => {
483                        client_methods.push(quote! {
484                                pub async fn #method_name(&self, #(#params),*) {
485                                    self.ctx.notify_to(self.to, potatonet::Request::new(#method_id, #notify_type_name::#name(#(#vars),*))).await
486                                }
487                            });
488                    }
489                    _ => {}
490                }
491            }
492            client_methods
493        };
494
495        quote! {
496            #[allow(non_camel_case_types)] #req_type
497            #[allow(non_camel_case_types)] #rep_type
498            #[allow(non_camel_case_types)] #notify_type
499
500            // 服务代码
501            #[potatonet::async_trait::async_trait]
502            impl potatonet::node::Service for #self_ty {
503                type Req = #req_type_name;
504                type Rep = #rep_type_name;
505                type Notify = #notify_type_name;
506
507                #(#other_methods)*
508
509                #[allow(unused_variables)]
510                async fn call(&self, ctx: &potatonet::node::NodeContext<'_>, request: potatonet::Request<Self::Req>) ->
511                    potatonet::Result<potatonet::Response<Self::Rep>> {
512                    #req_handler
513                    Err(potatonet::Error::MethodNotFound { method: request.method }.into())
514                }
515
516                #[allow(unused_variables)]
517                async fn notify(&self, ctx: &potatonet::node::NodeContext<'_>, request: potatonet::Request<Self::Notify>) {
518                    #notify_handler
519                }
520            }
521
522            impl potatonet::node::NamedService for #self_ty {
523                fn name(&self) -> &'static str {
524                    #self_name
525                }
526            }
527
528            impl #self_ty {
529                #(#internal_methods)*
530            }
531
532            // 客户端代码
533            pub struct #client_ty<'a, C> {
534                ctx: &'a C,
535                service_name: std::borrow::Cow<'a, str>,
536            }
537
538            impl<'a, C: potatonet::Context> #client_ty<'a, C> {
539                pub fn new(ctx: &'a C) -> Self {
540                    Self { ctx, service_name: std::borrow::Cow::Borrowed(#self_name) }
541                }
542
543                pub fn with_name<N>(ctx: &'a C, name: N) -> Self where N: Into<std::borrow::Cow<'a, str>> {
544                    Self { ctx, service_name: name.into() }
545                }
546
547                pub fn to(&self, to: potatonet::ServiceId) -> #client_notifyto_ty<'a, C> {
548                    #client_notifyto_ty { ctx: self.ctx, to }
549                }
550
551                #(#client_methods)*
552            }
553
554            // 定向通知客户端
555            pub struct #client_notifyto_ty<'a, C> {
556                ctx: &'a C,
557                to: potatonet::ServiceId,
558            }
559
560            impl<'a, C: potatonet::Context> #client_notifyto_ty<'a, C> {
561                #(#client_notifyto_methods)*
562            }
563        }
564    };
565
566    //        println!("{}", expanded.to_string());
567    expanded.into()
568}
569
570#[proc_macro_attribute]
571pub fn message(_args: TokenStream, input: TokenStream) -> TokenStream {
572    let input = parse_macro_input!(input as DeriveInput);
573    let expanded = quote! {
574        #[derive(potatonet::serde_derive::Serialize, potatonet::serde_derive::Deserialize)]
575        #input
576    };
577    expanded.into()
578}
579
580#[proc_macro_attribute]
581pub fn topic(args: TokenStream, input: TokenStream) -> TokenStream {
582    let args = parse_macro_input!(args as AttributeArgs);
583    let mut name = None;
584
585    for arg in args {
586        match arg {
587            NestedMeta::Meta(Meta::NameValue(nv)) => {
588                if nv.path.is_ident("name") {
589                    if let syn::Lit::Str(lit) = nv.lit {
590                        name = Some(lit.value());
591                    }
592                }
593            }
594            _ => {}
595        }
596    }
597
598    let input = parse_macro_input!(input as DeriveInput);
599    let name = name.unwrap_or_else(|| input.ident.to_string());
600    let ident = &input.ident;
601    let msg_type = quote! {
602        #[derive(potatonet::serde_derive::Serialize, potatonet::serde_derive::Deserialize)]
603        #input
604
605        impl Topic for #ident {
606            fn name() -> &'static str {
607                #name
608            }
609        }
610    };
611
612    let expanded = quote! {
613        #msg_type
614    };
615    expanded.into()
616}