1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::Ident;
5use quote::quote;
6use syn::spanned::Spanned;
7use syn::{
8 parse_macro_input, AngleBracketedGenericArguments, Attribute, AttributeArgs, Block,
9 DeriveInput, Error, FnArg, GenericArgument, ImplItem, ImplItemMethod, ItemImpl, LitStr, Meta,
10 NestedMeta, Pat, PatIdent, PathArguments, Result, ReturnType, Type, TypePath,
11};
12
13#[derive(Copy, Clone, Eq, PartialEq, Debug)]
14enum MethodType {
15 Call,
16 Notify,
17}
18
19struct MethodInfo {
20 ty: MethodType,
21 name: Option<LitStr>,
22}
23
24enum MethodResult<'a> {
25 Default,
26 Value(&'a TypePath),
27 Result(&'a TypePath),
28}
29
30struct Method<'a> {
31 ty: MethodType,
32 name: Ident,
33 context: Option<&'a PatIdent>,
34 args: Vec<(&'a PatIdent, &'a TypePath)>,
35 result: MethodResult<'a>,
36 block: &'a Block,
37}
38
39fn parse_method_info(attrs: &[Attribute]) -> Option<MethodInfo> {
41 for attr in attrs {
42 match attr.parse_meta() {
43 Ok(Meta::Path(path)) => {
44 if path.is_ident("call") {
45 return Some(MethodInfo {
46 ty: MethodType::Call,
47 name: None,
48 });
49 } else if path.is_ident("notify") {
50 return Some(MethodInfo {
51 ty: MethodType::Notify,
52 name: None,
53 });
54 }
55 }
56 Ok(Meta::List(list)) => {
57 let ty = if list.path.is_ident("call") {
58 Some(MethodType::Call)
59 } else if list.path.is_ident("notify") {
60 Some(MethodType::Notify)
61 } else {
62 None
63 };
64
65 if let Some(ty) = ty {
66 let mut name = None;
67 for arg in list.nested {
68 if let NestedMeta::Meta(Meta::NameValue(nv)) = arg {
69 if nv.path.is_ident("name") {
70 if let syn::Lit::Str(lit) = nv.lit {
71 name = Some(lit);
72 }
73 }
74 }
75 }
76 return Some(MethodInfo { ty, name });
77 }
78 }
79 _ => {}
80 }
81 }
82
83 None
84}
85
86fn parse_method(info: MethodInfo, method: &ImplItemMethod) -> Result<Method> {
88 let name = info
89 .name
90 .map(|lit| Ident::new(&lit.value(), lit.span()))
91 .unwrap_or(method.sig.ident.clone());
92
93 if method.sig.asyncness.is_none() {
94 return Err(Error::new(method.span(), "invalid method"));
95 }
96
97 let mut args = Vec::new();
99 let mut context = None;
100 for (idx, arg) in method.sig.inputs.iter().enumerate() {
101 if let FnArg::Receiver(receiver) = arg {
102 if idx != 0 {
103 return Err(Error::new(receiver.span(), "invalid method"));
105 }
106 if receiver.mutability.is_some() {
107 return Err(Error::new(receiver.mutability.span(), "invalid method"));
109 }
110 } else if let FnArg::Typed(pat) = arg {
111 if idx == 0 {
112 return Err(Error::new(pat.span(), "invalid method"));
114 }
115
116 match (&*pat.pat, &*pat.ty) {
117 (Pat::Ident(id), Type::Path(ty)) => args.push((id, ty)),
119 (Pat::Ident(id), Type::Reference(ty)) => {
121 if idx != 1 {
122 return Err(Error::new(pat.span(), "invalid method"));
124 }
125
126 if ty.mutability.is_some() {
127 return Err(Error::new(pat.span(), "invalid method"));
129 }
130
131 if let Type::Path(path) = ty.elem.as_ref() {
132 if path.path.segments.last().unwrap().ident.to_string() == "NodeContext" {
133 let seg = &path.path.segments.last().unwrap();
134 if let PathArguments::AngleBracketed(angle_args) = &seg.arguments {
135 if angle_args.args.len() != 1 {
136 return Err(Error::new(pat.span(), "invalid method"));
138 }
139 if let GenericArgument::Lifetime(life) = &angle_args.args[0] {
140 if life.ident.to_string() != "_" {
141 return Err(Error::new(pat.span(), "invalid method"));
143 }
144 context = Some(id);
145 } else {
146 return Err(Error::new(pat.span(), "invalid method"));
148 }
149 } else {
150 return Err(Error::new(pat.span(), "invalid method"));
152 }
153 } else {
154 return Err(Error::new(pat.span(), "invalid method"));
156 }
157 } else {
158 return Err(Error::new(pat.span(), "invalid method"));
160 }
161 }
162 _ => return Err(Error::new(pat.span(), "invalid method")),
163 }
164 }
165 }
166
167 let result = match info.ty {
169 MethodType::Call => {
170 match &method.sig.output {
171 ReturnType::Default => MethodResult::Default,
172 ReturnType::Type(_, ty) => {
173 if let Type::Path(type_path) = ty.as_ref() {
174 let is_result = if type_path.path.segments.len() == 1 {
175 type_path.path.segments[0].ident.to_string() == "Result"
176 } else {
177 false
178 };
179
180 if is_result {
181 if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
182 args,
183 ..
184 }) = &type_path.path.segments[0].arguments
185 {
186 if args.len() != 1 {
187 return Err(Error::new(
189 method.sig.output.span(),
190 "invalid method",
191 ));
192 }
193 let value = match &args[0] {
194 GenericArgument::Type(Type::Path(path)) => path,
195 _ => {
196 return Err(Error::new(
197 method.sig.output.span(),
198 "invalid method",
199 ))
200 }
201 };
202 MethodResult::Result(value)
203 } else {
204 return Err(Error::new(method.sig.output.span(), "invalid method"));
206 }
207 } else {
208 MethodResult::Value(type_path)
209 }
210 } else {
211 return Err(Error::new(method.sig.output.span(), "invalid method"));
213 }
214 }
215 }
216 }
217 MethodType::Notify => {
218 match method.sig.output {
220 ReturnType::Default => MethodResult::Default,
221 _ => return Err(Error::new(method.sig.output.span(), "invalid method")),
222 }
223 }
224 };
225
226 Ok(Method {
227 ty: info.ty,
228 name,
229 context,
230 args,
231 result,
232 block: &method.block,
233 })
234}
235
236#[proc_macro_attribute]
237pub fn service(_args: TokenStream, input: TokenStream) -> TokenStream {
238 let impl_item = parse_macro_input!(input as ItemImpl);
239 let (self_ty, self_name) = match impl_item.self_ty.as_ref() {
240 Type::Path(path) => (
241 path,
242 path.path
243 .segments
244 .last()
245 .map(|s| s.ident.to_string())
246 .unwrap(),
247 ),
248 _ => {
249 return Error::new(impl_item.span(), "invalid method")
250 .to_compile_error()
251 .into()
252 }
253 };
254 let client_ty = Ident::new(&format!("{}Client", self_name), self_ty.span());
255 let client_notifyto_ty = Ident::new(&format!("{}ClientNotifyTo", self_name), self_ty.span());
256 let req_type_name = Ident::new(&format!("__RequestType_{}", self_name), self_ty.span());
257 let rep_type_name = Ident::new(&format!("__ResponseType{}", self_name), self_ty.span());
258 let notify_type_name = Ident::new(&format!("__NotifyType{}", self_name), self_ty.span());
259 let mut methods = Vec::new();
260 let mut other_methods = Vec::new();
261 let mut internal_methods = Vec::new();
262
263 for item in &impl_item.items {
264 if let ImplItem::Method(method) = item {
265 let ident = method.sig.ident.to_string();
266 if let Some(method_info) = parse_method_info(&method.attrs) {
267 let method = match parse_method(method_info, method) {
268 Ok(method) => method,
269 Err(err) => return err.to_compile_error().into(),
270 };
271 methods.push(method);
272 } else if ident == "start" || ident == "stop" {
273 other_methods.push(item);
275 } else {
276 internal_methods.push(item);
278 }
279 }
280 }
281
282 let expanded = {
283 let req_type = {
285 let mut reqs = Vec::new();
286 for method in methods
287 .iter()
288 .filter(|method| method.ty == MethodType::Call)
289 {
290 let name = Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
291 let types = method.args.iter().map(|(_, ty)| ty).collect::<Vec<_>>();
292 reqs.push(quote! { #name(#(#types),*) });
293 }
294 quote! {
295 #[derive(potatonet::serde_derive::Serialize, potatonet::serde_derive::Deserialize)]
296 pub enum #req_type_name { #(#reqs),* }
297 }
298 };
299
300 let rep_type = {
302 let mut reps = Vec::new();
303 for method in methods
304 .iter()
305 .filter(|method| method.ty == MethodType::Call)
306 {
307 let name = Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
308 match &method.result {
309 MethodResult::Value(ty) => reps.push(quote! { #name(#ty) }),
310 MethodResult::Result(ty) => reps.push(quote! { #name(#ty) }),
311 MethodResult::Default => {}
312 }
313 }
314 quote! {
315 #[derive(potatonet::serde_derive::Serialize, potatonet::serde_derive::Deserialize)]
316 pub enum #rep_type_name { #(#reps),* }
317 }
318 };
319
320 let notify_type = {
322 let mut notify = Vec::new();
323 for method in methods
324 .iter()
325 .filter(|method| method.ty == MethodType::Notify)
326 {
327 let name = Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
328 let types = method.args.iter().map(|(_, ty)| ty).collect::<Vec<_>>();
329 notify.push(quote! { #name(#(#types),*) });
330 }
331 quote! {
332 #[derive(potatonet::serde_derive::Serialize, potatonet::serde_derive::Deserialize)]
333 pub enum #notify_type_name { #(#notify),* }
334 }
335 };
336
337 let req_handler = {
339 let mut list = Vec::new();
340
341 for (method_id, method) in methods
342 .iter()
343 .enumerate()
344 .filter(|(_, method)| method.ty == MethodType::Call)
345 {
346 let method_id = method_id as u32;
347 let vars = method.args.iter().map(|(name, _)| name).collect::<Vec<_>>();
348 let name = Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
349 let block = method.block;
350 let ctx = match method.context {
351 Some(id) => quote! { let #id = ctx; },
352 None => quote! {},
353 };
354
355 match &method.result {
356 MethodResult::Default => {
357 list.push(quote! {
358 if request.method == #method_id {
359 if let #req_type_name::#name(#(#vars),*) = request.data {
360 #ctx
361 return Ok(potatonet::Response::new(#rep_type_name::#name(#block)));
362 }
363 }
364 });
365 }
366 MethodResult::Value(_) => {
367 list.push(quote! {
368 if request.method == #method_id {
369 if let #req_type_name::#name(#(#vars),*) = request.data {
370 #ctx
371 let res = #block;
372 return Ok(potatonet::Response::new(#rep_type_name::#name(res)));
373 }
374 }
375 });
376 }
377 MethodResult::Result(_) => {
378 list.push(quote! {
379 if request.method == #method_id {
380 if let #req_type_name::#name(#(#vars),*) = request.data {
381 #ctx
382 let res: potatonet::Result<potatonet::Response<Self::Rep>> = #block.map(|x| potatonet::Response::new(#rep_type_name::#name(x)));
383 return res;
384 }
385 }
386 });
387 }
388 }
389 }
390
391 quote! { #(#list)* }
392 };
393
394 let notify_handler = {
396 let mut list = Vec::new();
397
398 for (method_id, method) in methods
399 .iter()
400 .enumerate()
401 .filter(|(_, method)| method.ty == MethodType::Notify)
402 {
403 let method_id = method_id as u32;
404 let vars = method.args.iter().map(|(name, _)| name).collect::<Vec<_>>();
405 let name = Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
406 let ctx = match method.context {
407 Some(id) => quote! { let #id = ctx; },
408 None => quote! {},
409 };
410 let block = method.block;
411
412 list.push(quote! {
413 if request.method == #method_id {
414 if let #notify_type_name::#name(#(#vars),*) = request.data {
415 #ctx
416 #block
417 }
418 }
419 });
420 }
421
422 quote! { #(#list)* }
423 };
424
425 let client_methods = {
427 let mut client_methods = Vec::new();
428 for (method_id, method) in methods.iter().enumerate() {
429 let method_id = method_id as u32;
430 let client_method = {
431 let method_name = &method.name;
432 let name =
433 Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
434 let params = method.args.iter().map(|(name, ty)| {
435 quote! { #name: #ty }
436 });
437 let vars = method.args.iter().map(|(name, _)| name).collect::<Vec<_>>();
438 match method.ty {
439 MethodType::Call => {
440 let res_type = match &method.result {
441 MethodResult::Default => quote! { () },
442 MethodResult::Value(value) => quote! { #value },
443 MethodResult::Result(value) => quote! { #value },
444 };
445 quote! {
446 pub async fn #method_name(&self, #(#params),*) -> potatonet::Result<#res_type> {
447 let res = self.ctx.call::<_, #rep_type_name>(&self.service_name, potatonet::Request::new(#method_id, #req_type_name::#name(#(#vars),*))).await?;
448 if let potatonet::Response{data: #rep_type_name::#name(value)} = res {
449 Ok(value)
450 } else {
451 unreachable!()
452 }
453 }
454 }
455 }
456 MethodType::Notify => {
457 quote! {
458 pub async fn #method_name(&self, #(#params),*) {
459 self.ctx.notify(&self.service_name, potatonet::Request::new(#method_id, #notify_type_name::#name(#(#vars),*))).await
460 }
461 }
462 }
463 }
464 };
465 client_methods.push(client_method);
466 }
467 client_methods
468 };
469
470 let client_notifyto_methods = {
472 let mut client_methods = Vec::new();
473 for (method_id, method) in methods.iter().enumerate() {
474 let method_id = method_id as u32;
475 let method_name = &method.name;
476 let name = Ident::new(&method.name.to_string().to_uppercase(), method.name.span());
477 let params = method.args.iter().map(|(name, ty)| {
478 quote! { #name: #ty }
479 });
480 let vars = method.args.iter().map(|(name, _)| name).collect::<Vec<_>>();
481 match method.ty {
482 MethodType::Notify => {
483 client_methods.push(quote! {
484 pub async fn #method_name(&self, #(#params),*) {
485 self.ctx.notify_to(self.to, potatonet::Request::new(#method_id, #notify_type_name::#name(#(#vars),*))).await
486 }
487 });
488 }
489 _ => {}
490 }
491 }
492 client_methods
493 };
494
495 quote! {
496 #[allow(non_camel_case_types)] #req_type
497 #[allow(non_camel_case_types)] #rep_type
498 #[allow(non_camel_case_types)] #notify_type
499
500 #[potatonet::async_trait::async_trait]
502 impl potatonet::node::Service for #self_ty {
503 type Req = #req_type_name;
504 type Rep = #rep_type_name;
505 type Notify = #notify_type_name;
506
507 #(#other_methods)*
508
509 #[allow(unused_variables)]
510 async fn call(&self, ctx: &potatonet::node::NodeContext<'_>, request: potatonet::Request<Self::Req>) ->
511 potatonet::Result<potatonet::Response<Self::Rep>> {
512 #req_handler
513 Err(potatonet::Error::MethodNotFound { method: request.method }.into())
514 }
515
516 #[allow(unused_variables)]
517 async fn notify(&self, ctx: &potatonet::node::NodeContext<'_>, request: potatonet::Request<Self::Notify>) {
518 #notify_handler
519 }
520 }
521
522 impl potatonet::node::NamedService for #self_ty {
523 fn name(&self) -> &'static str {
524 #self_name
525 }
526 }
527
528 impl #self_ty {
529 #(#internal_methods)*
530 }
531
532 pub struct #client_ty<'a, C> {
534 ctx: &'a C,
535 service_name: std::borrow::Cow<'a, str>,
536 }
537
538 impl<'a, C: potatonet::Context> #client_ty<'a, C> {
539 pub fn new(ctx: &'a C) -> Self {
540 Self { ctx, service_name: std::borrow::Cow::Borrowed(#self_name) }
541 }
542
543 pub fn with_name<N>(ctx: &'a C, name: N) -> Self where N: Into<std::borrow::Cow<'a, str>> {
544 Self { ctx, service_name: name.into() }
545 }
546
547 pub fn to(&self, to: potatonet::ServiceId) -> #client_notifyto_ty<'a, C> {
548 #client_notifyto_ty { ctx: self.ctx, to }
549 }
550
551 #(#client_methods)*
552 }
553
554 pub struct #client_notifyto_ty<'a, C> {
556 ctx: &'a C,
557 to: potatonet::ServiceId,
558 }
559
560 impl<'a, C: potatonet::Context> #client_notifyto_ty<'a, C> {
561 #(#client_notifyto_methods)*
562 }
563 }
564 };
565
566 expanded.into()
568}
569
570#[proc_macro_attribute]
571pub fn message(_args: TokenStream, input: TokenStream) -> TokenStream {
572 let input = parse_macro_input!(input as DeriveInput);
573 let expanded = quote! {
574 #[derive(potatonet::serde_derive::Serialize, potatonet::serde_derive::Deserialize)]
575 #input
576 };
577 expanded.into()
578}
579
580#[proc_macro_attribute]
581pub fn topic(args: TokenStream, input: TokenStream) -> TokenStream {
582 let args = parse_macro_input!(args as AttributeArgs);
583 let mut name = None;
584
585 for arg in args {
586 match arg {
587 NestedMeta::Meta(Meta::NameValue(nv)) => {
588 if nv.path.is_ident("name") {
589 if let syn::Lit::Str(lit) = nv.lit {
590 name = Some(lit.value());
591 }
592 }
593 }
594 _ => {}
595 }
596 }
597
598 let input = parse_macro_input!(input as DeriveInput);
599 let name = name.unwrap_or_else(|| input.ident.to_string());
600 let ident = &input.ident;
601 let msg_type = quote! {
602 #[derive(potatonet::serde_derive::Serialize, potatonet::serde_derive::Deserialize)]
603 #input
604
605 impl Topic for #ident {
606 fn name() -> &'static str {
607 #name
608 }
609 }
610 };
611
612 let expanded = quote! {
613 #msg_type
614 };
615 expanded.into()
616}