1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::{ToTokens, quote};
6use syn::{
7 Attribute, Data, DeriveInput, Error, Fields, Ident, LitStr, Token, Type, Visibility,
8 parse::{Parse, ParseStream},
9 parse_macro_input,
10 punctuated::Punctuated,
11 spanned::Spanned,
12 token::Comma,
13};
14
15const RPC_ATTR_NAME: &str = "rpc";
17const WRAP_ATTR_NAME: &str = "wrap";
19const TX_ATTR: &str = "tx";
21const RX_ATTR: &str = "rx";
23const DEFAULT_RX_TYPE: &str = "::irpc::channel::none::NoReceiver";
25const DEFAULT_TX_TYPE: &str = "::irpc::channel::none::NoSender";
27
28#[proc_macro_attribute]
30pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream {
31 let mut input = parse_macro_input!(item as DeriveInput);
32 let args = parse_macro_input!(attr as MacroArgs);
33
34 let enum_name = &input.ident;
35 let vis = &input.vis;
36
37 let data_enum = match &mut input.data {
38 Data::Enum(data_enum) => data_enum,
39 _ => {
40 return error_tokens(
41 input.span(),
42 "The rpc_requests macro can only be applied to enums",
43 );
44 }
45 };
46
47 let cfg_feature_rpc = match args.rpc_feature.as_ref() {
48 None => quote!(),
49 Some(feature) => quote!(#[cfg(feature = #feature)]),
50 };
51
52 let mut channel_impls = TokenStream2::new();
54 let mut types = HashSet::new();
56 let mut all_variants = Vec::new();
58 let mut variants_with_attr = Vec::new();
60 let mut wrapper_types = TokenStream2::new();
62
63 for variant in &mut data_enum.variants {
64 let rpc_attr = match VariantRpcArgs::from_attrs(&mut variant.attrs) {
65 Ok(args) => args,
66 Err(err) => return err.into_compile_error().into(),
67 };
68
69 let request_type = match rpc_attr.wrap {
70 None => match &mut variant.fields {
71 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
72 fields.unnamed[0].ty.clone()
73 }
74 _ => {
75 return error_tokens(
76 variant.span(),
77 "Each variant must either have exactly one unnamed field, or use the `wrap` argument in the `rpc` attribute.",
78 );
79 }
80 },
81 Some(WrapArgs { ident, derive, vis }) => {
82 let vis = vis.as_ref().unwrap_or(&input.vis).clone();
83 let ty = type_from_ident(&ident);
84 let struc = struct_from_variant_fields(
85 ident,
86 variant.fields.clone(),
87 variant.attrs.clone(),
88 vis,
89 );
90 wrapper_types.extend(quote! {
91 #[derive(::std::fmt::Debug, ::serde::Serialize, ::serde::Deserialize, #(#derive),* )]
92 #struc
93 });
94 variant.fields = single_unnamed_field(ty.clone());
95 ty
96 }
97 };
98
99 all_variants.push((variant.ident.clone(), request_type.clone()));
100
101 if !types.insert(request_type.to_token_stream().to_string()) {
102 return error_tokens(
103 variant.span(),
104 "Each variant must have a unique request type",
105 );
106 }
107
108 if let Some(args) = rpc_attr.rpc {
109 variants_with_attr.push((variant.ident.clone(), request_type.clone()));
110 channel_impls.extend(generate_channels_impl(args, enum_name, &request_type))
111 }
112 }
113
114 let protocol_enum_from_impls =
116 generate_protocol_enum_from_impls(enum_name, &variants_with_attr);
117
118 let type_aliases = if let Some(suffix) = args.alias_suffix {
120 generate_type_aliases(&all_variants, enum_name, &suffix)
122 } else {
123 quote! {}
124 };
125
126 let extended_enum_code = if let Some(message_enum_name) = args.message_enum_name.as_ref() {
128 let message_variants = all_variants
129 .iter()
130 .map(|(variant_name, inner_type)| {
131 quote! {
132 #variant_name(::irpc::WithChannels<#inner_type, #enum_name>)
133 }
134 })
135 .collect::<Vec<_>>();
136
137 let variant_names: Vec<&Ident> = all_variants.iter().map(|(name, _)| name).collect();
139
140 let doc = format!("Message enum for [`{enum_name}`]");
142 let message_enum = quote! {
143 #[doc = #doc]
144 #[allow(missing_docs)]
145 #[derive(::std::fmt::Debug)]
146 #vis enum #message_enum_name {
147 #(#message_variants),*
148 }
149 };
150
151 let parent_span_impl = if !args.no_spans {
153 generate_parent_span_impl(message_enum_name, &variant_names)
154 } else {
155 quote! {}
156 };
157
158 let message_from_impls =
160 generate_message_enum_from_impls(message_enum_name, &variants_with_attr, enum_name);
161
162 let span_propagation = args.span_propagation;
163 let service_impl = quote! {
164 impl ::irpc::Service for #enum_name {
165 type Message = #message_enum_name;
166 const SPAN_PROPAGATION: bool = #span_propagation;
167 }
168 };
169
170 let remote_service_impl = if !args.no_rpc {
171 let block = generate_remote_service_impl(
172 message_enum_name,
173 enum_name,
174 &variants_with_attr,
175 args.span_propagation,
176 );
177 quote! {
178 #cfg_feature_rpc
179 #block
180 }
181 } else {
182 quote! {}
183 };
184
185 quote! {
186 #message_enum
187 #service_impl
188 #remote_service_impl
189 #parent_span_impl
190 #message_from_impls
191 }
192 } else {
193 quote! {}
194 };
195
196 let output = quote! {
198 #input
199
200 #wrapper_types
202
203 #channel_impls
205
206 #protocol_enum_from_impls
208
209 #type_aliases
211
212 #extended_enum_code
214 };
215
216 output.into()
217}
218
219fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> TokenStream2 {
221 quote! {
222 impl #enum_name {
223 pub fn parent_span(&self) -> ::tracing::Span {
225 let span = match self {
226 #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),*
227 };
228 span.cloned().unwrap_or_else(|| ::tracing::Span::current())
229 }
230 }
231 }
232}
233
234fn generate_channels_impl(
235 args: RpcArgs,
236 service_name: &Ident,
237 request_type: &Type,
238) -> TokenStream2 {
239 let rx = args.rx.unwrap_or_else(|| {
240 syn::parse_str::<Type>(DEFAULT_RX_TYPE).expect("Failed to parse default rx type")
242 });
243 let tx = args.tx.unwrap_or_else(|| {
244 syn::parse_str::<Type>(DEFAULT_TX_TYPE).expect("Failed to parse default tx type")
246 });
247
248 quote! {
249 impl ::irpc::Channels<#service_name> for #request_type {
250 type Tx = #tx;
251 type Rx = #rx;
252 }
253 }
254}
255
256fn generate_protocol_enum_from_impls(
258 enum_name: &Ident,
259 variants_with_attr: &[(Ident, Type)],
260) -> TokenStream2 {
261 variants_with_attr
262 .iter()
263 .map(|(variant_name, inner_type)| {
264 quote! {
265 impl From<#inner_type> for #enum_name {
266 fn from(value: #inner_type) -> Self {
267 #enum_name::#variant_name(value)
268 }
269 }
270 }
271 })
272 .collect()
273}
274
275fn generate_message_enum_from_impls(
277 message_enum_name: &Ident,
278 variants_with_attr: &[(Ident, Type)],
279 service_name: &Ident,
280) -> TokenStream2 {
281 variants_with_attr
282 .iter()
283 .map(|(variant_name, inner_type)| {
284 quote! {
285 impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name {
286 fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self {
287 #message_enum_name::#variant_name(value)
288 }
289 }
290 }
291 })
292 .collect()
293}
294
295fn generate_remote_service_impl(
300 message_enum_name: &Ident,
301 proto_enum_name: &Ident,
302 variants_with_attr: &[(Ident, Type)],
303 span_propagation: bool,
304) -> TokenStream2 {
305 let variants = variants_with_attr
306 .iter()
307 .map(|(variant_name, _inner_type)| {
308 let span_name = variant_name.to_string();
309
310 if span_propagation {
311 quote! {
313 #proto_enum_name::#variant_name(msg) => {
314 let span = ::tracing::info_span!(#span_name);
316 ::irpc::span_propagation::set_span_parent_from_remote(&span);
318 let _guard = span.enter();
319 #message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx)))
320 }
321 }
322 } else {
323 quote! {
324 #proto_enum_name::#variant_name(msg) => {
325 #message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx)))
326 }
327 }
328 }
329 });
330
331 quote! {
332 impl ::irpc::rpc::RemoteService for #proto_enum_name {
333 fn with_remote_channels(
334 self,
335 rx: ::irpc::rpc::noq::RecvStream,
336 tx: ::irpc::rpc::noq::SendStream
337 ) -> Self::Message {
338 match self {
339 #(#variants),*
340 }
341 }
342 }
343 }
344}
345
346fn generate_type_aliases(
348 variants: &[(Ident, Type)],
349 service_name: &Ident,
350 suffix: &str,
351) -> TokenStream2 {
352 variants
353 .iter()
354 .map(|(variant_name, inner_type)| {
355 let type_name = format!("{variant_name}{suffix}");
358 let type_ident = Ident::new(&type_name, variant_name.span());
359 quote! {
360 pub type #type_ident = ::irpc::WithChannels<#inner_type, #service_name>;
362 }
363 })
364 .collect()
365}
366
367#[derive(Default)]
369struct MacroArgs {
370 message_enum_name: Option<Ident>,
371 alias_suffix: Option<String>,
372 rpc_feature: Option<String>,
373 no_rpc: bool,
374 no_spans: bool,
375 span_propagation: bool,
377}
378
379impl Parse for MacroArgs {
380 fn parse(input: ParseStream) -> syn::Result<Self> {
381 let mut this = Self::default();
382 loop {
383 let arg: Ident = input.parse()?;
384 match arg.to_string().as_str() {
385 "message" => {
386 input.parse::<Token![=]>()?;
387 let value: Ident = input.parse()?;
388 this.message_enum_name = Some(value);
389 }
390 "alias" => {
391 input.parse::<Token![=]>()?;
392 let value: LitStr = input.parse()?;
393 this.alias_suffix = Some(value.value());
394 }
395 "rpc_feature" => {
396 input.parse::<Token![=]>()?;
397 if this.no_rpc {
398 return syn_err(arg.span(), "rpc_feature is incompatible with no_rpc");
399 }
400 let value: LitStr = input.parse()?;
401 this.rpc_feature = Some(value.value());
402 }
403 "no_rpc" => {
404 if this.rpc_feature.is_some() {
405 return syn_err(arg.span(), "rpc_feature is incompatible with no_rpc");
406 }
407 this.no_rpc = true;
408 }
409 "no_spans" => {
410 this.no_spans = true;
411 }
412 "span_propagation" => {
413 this.span_propagation = true;
414 }
415 _ => {
416 return syn_err(arg.span(), format!("Unknown parameter: {arg}"));
417 }
418 }
419
420 if input.peek(Token![,]) {
421 input.parse::<Token![,]>()?;
422 } else {
423 break;
424 }
425 }
426
427 Ok(this)
428 }
429}
430
431#[derive(Default)]
432struct VariantRpcArgs {
433 wrap: Option<WrapArgs>,
434 rpc: Option<RpcArgs>,
435}
436
437impl VariantRpcArgs {
438 fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
439 let mut this = Self::default();
440 let mut remaining_attrs = Vec::new();
441 for attr in attrs.drain(..) {
442 let ident = attr.path().get_ident().map(|ident| ident.to_string());
443 match ident.as_deref() {
444 Some(RPC_ATTR_NAME) => {
445 if this.rpc.is_some() {
446 syn_err(attr.span(), "Each variant can have only one rpc attribute")?;
447 }
448 this.rpc = Some(attr.parse_args()?);
449 }
450 Some(WRAP_ATTR_NAME) => {
451 if this.wrap.is_some() {
452 syn_err(attr.span(), "Each variant can have only one wrap attribute")?;
453 }
454 this.wrap = Some(attr.parse_args()?);
455 }
456 _ => remaining_attrs.push(attr),
457 }
458 }
459 *attrs = remaining_attrs;
460 Ok(this)
461 }
462}
463
464#[derive(Default)]
465struct RpcArgs {
466 rx: Option<Type>,
467 tx: Option<Type>,
468}
469
470impl Parse for RpcArgs {
472 fn parse(input: ParseStream) -> syn::Result<Self> {
473 let mut this = Self::default();
474 while !input.is_empty() {
475 let arg: Ident = input.parse()?;
476 let _: Token![=] = input.parse()?;
477 let value: Type = input.parse()?;
478 if arg == RX_ATTR {
479 this.rx = Some(value);
480 } else if arg == TX_ATTR {
481 this.tx = Some(value);
482 } else {
483 syn_err(arg.span(), "Unexpected argument in rpc attribute")?;
484 }
485 if !input.peek(Token![,]) {
486 break;
487 } else {
488 let _: Token![,] = input.parse()?;
489 }
490 }
491
492 Ok(this)
493 }
494}
495
496struct WrapArgs {
497 vis: Option<Visibility>,
498 ident: Ident,
499 derive: Vec<Type>,
500}
501
502impl Parse for WrapArgs {
503 fn parse(input: ParseStream) -> syn::Result<Self> {
504 let vis = match input.parse::<Visibility>()? {
505 Visibility::Inherited => None,
506 vis => Some(vis),
507 };
508 let ident: Ident = input.parse()?;
509 let mut this = Self {
510 ident,
511 derive: Default::default(),
512 vis,
513 };
514 while input.peek(Token![,]) {
515 let _: Token![,] = input.parse()?;
516 let arg: Ident = input.parse()?;
517 match arg.to_string().as_str() {
518 "derive" => {
519 let content;
520 syn::parenthesized!(content in input);
521 let types: Punctuated<Type, Comma> = Punctuated::parse_terminated(&content)?;
522 this.derive = types.into_iter().collect();
523 }
524 _ => syn_err(arg.span(), "Unexpected argument in wrap argument")?,
525 }
526 }
527 if !input.is_empty() {
528 syn_err(input.span(), "Unexpected tokens in wrap argument")?;
529 }
530 Ok(this)
531 }
532}
533
534fn type_from_ident(ident: &Ident) -> Type {
535 Type::Path(syn::TypePath {
536 qself: None,
537 path: syn::Path {
538 leading_colon: None,
539 segments: Punctuated::from_iter([syn::PathSegment::from(ident.clone())]),
540 },
541 })
542}
543
544fn struct_from_variant_fields(
545 ident: Ident,
546 mut fields: Fields,
547 attrs: Vec<Attribute>,
548 vis: Visibility,
549) -> syn::ItemStruct {
550 set_fields_vis(&mut fields, &vis);
551 let span = ident.span();
552 syn::ItemStruct {
553 attrs,
554 vis,
555 struct_token: Token,
556 ident,
557 generics: Default::default(),
558 semi_token: match &fields {
559 Fields::Unit => Some(Token),
560 Fields::Unnamed(_) => Some(Token),
561 Fields::Named(_) => None,
562 },
563 fields,
564 }
565}
566
567fn single_unnamed_field(ty: Type) -> Fields {
568 let field = syn::Field {
569 attrs: vec![],
570 vis: Visibility::Inherited,
571 ident: None,
572 colon_token: None,
573 mutability: syn::FieldMutability::None,
574 ty,
575 };
576 Fields::Unnamed(syn::FieldsUnnamed {
577 paren_token: syn::token::Paren(Span::call_site()),
578 unnamed: Punctuated::from_iter([field]),
579 })
580}
581
582fn set_fields_vis(fields: &mut Fields, vis: &Visibility) {
583 let inner = match fields {
584 Fields::Named(named) => named.named.iter_mut(),
585 Fields::Unnamed(unnamed) => unnamed.unnamed.iter_mut(),
586 Fields::Unit => return,
587 };
588 for field in inner {
589 field.vis = vis.clone();
590 }
591}
592
593fn error_tokens(span: Span, message: &str) -> TokenStream {
595 Error::new(span, message).to_compile_error().into()
596}
597
598fn syn_err<T>(span: Span, message: impl std::fmt::Display) -> syn::Result<T> {
599 Err(Error::new(span, message))
600}