rtactor_macros/
lib.rs

1//! Proc macros for the rtactor library.
2//!
3//! # Generate a Response enum from a Request enum with `derive(ResponseEnum)`
4//! ```rs
5//! #[derive(ResponseEnum)]
6//! pub enum Request {
7//!     SetValue{val: i32},
8//!    
9//!     #[response_val(i32)]
10//!     GetValue{},
11//! }
12//! ```
13//!
14//! Will generate:
15//! ```
16//! pub enum Response
17//! {
18//!     SetValue(),
19//!     GetValue(i32)
20//! }
21//! ```
22//!
23//! # Generate a synchronous access trait from a Notification enum with `SyncNotifier`
24//!
25//! ```rs
26//! #[derive(SyncNotifier)]
27//! pub enum Notification {
28//!     TemperatureChanged{temp: float}
29//! }
30//! ```
31//!
32//! Will generate:
33//!
34//! ```rs
35//! pub trait SyncNotifier : ::rtactor::SyncAccessData
36//! {
37//!  temperature_changed(&mut self, temp: float) -> Result<(), ::rtactor::Error> {[...]}
38//! }
39//! ```
40//!
41//! A structure can add the generated methods by deriving `SyncNotifier` and
42//! implementing the methods of `SyncAccessData`. The macro `define_sync_accessor!()`
43//! found in create `rtactor` can be used to generate a struct that
44//! allows easy access with its internal ActiveMailbox:
45//! ```rs
46//! define_sync_accessor!(MyNotifSyncAccessor, SyncNotifier)
47//!
48//!
49//! fn test(addr: rtactor::Addr)
50//! {
51//!     let accessor = MyNotifSyncAccessor::new(&addr);
52//!     accessor.temperature_changed(13.2f32).unwrap();
53//! }
54//! ```
55//!
56//! # Generate a synchronous access trait from a Request enum with `SyncRequester`
57//!
58//! ```rs
59//! #[derive(ResponseEnum, SyncRequester)]
60//! pub enum Request {
61//!     SetValue{val: i32},
62//!    
63//!     #[response_val(i32)]
64//!     GetValue{},
65//! }
66//! ```
67//!
68//! Will generate for the `SyncRequester` part:
69//! ```rs
70//! pub trait SyncRequester: ::rtactor::SyncAccessor
71//! {
72//!     fn set_value(&mut self, val: i32) -> Result<(), ::rtactor::Error> {[...]}
73//!     fn get_value(&mut self) -> Result<i32, ::rtactor::Error> {[...]}
74//! }
75//! ```
76//!
77//! A structure can add the generated methods by deriving `SyncRequester` and
78//! implementing the methods of `SyncAccessor`. The macro `define_sync_notifier!()`
79//! found in create `rtactor` can be used to generate a struct that
80//! allows easy access with its internal ActiveMailbox:
81//! ```rs
82//! define_sync_accessor!(MySyncAccessor, SyncNotifier, SyncRequester)
83//!
84//! fn test(addr: rtactor::Addr)
85//! {
86//!     let accessor = MyNotifSyncAccessor::new(&addr);
87//!     accessor.temperature_changed(13.2f32).unwrap();
88//!     accessor.set_value(72).unwrap();
89//!     assert!(accessor.get_value().unwrap() == 72);
90//! }
91//! ```
92//!
93
94extern crate proc_macro;
95
96use convert_case::{Case, Casing};
97use proc_macro::TokenStream;
98use proc_macro2::TokenStream as TokenStream2;
99use quote::{format_ident, quote, ToTokens};
100use syn::{parse_macro_input, Data, DeriveInput, Fields};
101
102// Set to true to print to std::out the generated macro code.
103const PRINT_GENERATED_MACRO_CODE: bool = false;
104
105// see for attributes():
106// https://stackoverflow.com/questions/42484062/how-do-i-process-enum-struct-field-attributes-in-a-procedural-macro
107#[proc_macro_derive(ResponseEnum, attributes(response_val))]
108pub fn derive_response_enum(input: TokenStream) -> TokenStream {
109    // See https://doc.servo.org/syn/derive/struct.DeriveInput.html
110    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
111
112    // get enum name
113    let enum_name = &input.ident;
114    let data = &input.data;
115
116    let mut response_variants;
117
118    // data is of type syn::Data
119    // See https://doc.servo.org/syn/enum.Data.html
120    match data {
121        // Only if data is an enum, we do parsing
122        Data::Enum(data_enum) => {
123            // data_enum is of type syn::DataEnum
124            // https://doc.servo.org/syn/struct.DataEnum.html
125
126            response_variants = TokenStream2::new();
127
128            // Iterate over enum variants
129            // `variants` if of type `Punctuated` which implements IntoIterator
130            //
131            // https://doc.servo.org/syn/punctuated/struct.Punctuated.html
132            // https://doc.servo.org/syn/struct.Variant.html
133            for variant in &data_enum.variants {
134                // Variant's name
135                let variant_name = &variant.ident;
136
137                // construct an identifier named <variant_name> for function name
138                // We convert it to snake case using `to_case(Case::Snake)`
139                // For example, if variant is `HelloWorld`, it will generate `is_hello_world`
140                let mut request_func_name =
141                    format_ident!("{}", variant_name.to_string().to_case(Case::Snake));
142                request_func_name.set_span(variant_name.span());
143
144                if let Some(ref a) = variant.attrs.iter().find(|a| match a.path.get_ident() {
145                    Some(ident) => ident == "response_val",
146                    None => false,
147                }) {
148                    if let Ok(response_val_type) = a.parse_args::<syn::Type>() {
149                        let response_val_type = response_val_type.to_token_stream();
150
151                        response_variants.extend(quote!(
152                            #variant_name(#response_val_type),
153                        ));
154                    } else if a.parse_args::<syn::parse::Nothing>().is_ok() {
155                        response_variants.extend(quote!(
156                            #variant_name(),
157                        ));
158                    } else {
159                        panic!(
160                            "attribute '{}' parsing failed for variant '{}'",
161                            a.to_token_stream(),
162                            variant_name
163                        );
164                    }
165                } else {
166                    response_variants.extend(quote!(
167                        #variant_name(),
168                    ));
169                };
170            }
171        }
172        _ => panic!(
173            "ResponseEnum is only valid for enums and '{}' is not one.",
174            enum_name
175        ),
176    };
177
178    let response_enum_name = format_ident!("{}", "Response");
179
180    let expanded = quote! {
181        pub enum #response_enum_name {
182            #response_variants
183        }
184    };
185
186    if PRINT_GENERATED_MACRO_CODE {
187        println!("expanded='{}'", expanded);
188    }
189    TokenStream::from(expanded)
190}
191
192#[proc_macro_derive(SyncNotifier)]
193pub fn derive_sync_notifier(input: TokenStream) -> TokenStream {
194    // See https://doc.servo.org/syn/derive/struct.DeriveInput.html
195    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
196
197    // get enum name
198    let enum_name = &input.ident;
199    let data = &input.data;
200
201    let mut variant_notifier_functions;
202
203    // data is of type syn::Data
204    // See https://doc.servo.org/syn/enum.Data.html
205    match data {
206        // Only if data is an enum, we do parsing
207        Data::Enum(data_enum) => {
208            // data_enum is of type syn::DataEnum
209            // https://doc.servo.org/syn/struct.DataEnum.html
210
211            variant_notifier_functions = TokenStream2::new();
212
213            // Iterate over enum variants
214            // `variants` if of type `Punctuated` which implements IntoIterator
215            //
216            // https://doc.servo.org/syn/punctuated/struct.Punctuated.html
217            // https://doc.servo.org/syn/struct.Variant.html
218            for variant in &data_enum.variants {
219                // Variant's name
220                let variant_name = &variant.ident;
221
222                // construct an identifier named <variant_name> for function name
223                // We convert it to snake case using `to_case(Case::Snake)`
224                // For example, if variant is `HelloWorld`, it will generate `is_hello_world`
225                let mut notify_func_name =
226                    format_ident!("{}", variant_name.to_string().to_case(Case::Snake));
227                notify_func_name.set_span(variant_name.span());
228
229                // Variant can have unnamed fields like `Variant(i32, i64)`
230                // Variant can have named fields like `Variant {x: i32, y: i32}`
231                // Variant can be named Unit like `Variant`
232                match &variant.fields {
233                    Fields::Named(fields) => {
234                        let field_name: Vec<_> =
235                            fields.named.iter().map(|field| &field.ident).collect();
236                        let field_type: Vec<_> =
237                            fields.named.iter().map(|field| &field.ty).collect();
238
239                        variant_notifier_functions.extend(quote!(
240                            fn #notify_func_name( &mut self, #( #field_name : #field_type, )*) -> Result<(), ::rtactor::Error> {
241                                self.send_notification(#enum_name::#variant_name { #( #field_name : #field_name, )*})
242                            }
243                        ));
244                    }
245                    Fields::Unnamed(_) =>
246                    panic!("SyncNotifier is not valid for Unnamed variant like '{}', use Named (i.e. 'MyVariant{{arg1: i32}}') variant.", variant_name),
247                    Fields::Unit => {
248                        variant_notifier_functions.extend(quote!(
249                            fn #notify_func_name(&mut self) -> Result<(), ::rtactor::Error> {
250                                self.send_notification(#enum_name::#variant_name)
251                            }
252                        ));
253                    }
254                };
255            }
256        }
257        _ => panic!(
258            "SyncNotifier is only valid for enums and '{}' is not one.",
259            enum_name
260        ),
261    };
262
263    let trait_name = format_ident!("{}", "SyncNotifier");
264
265    let expanded = quote! {
266        pub trait #trait_name : ::rtactor::SyncAccessor {
267            #variant_notifier_functions
268        }
269    };
270
271    if PRINT_GENERATED_MACRO_CODE {
272        println!("expanded='{}'", expanded);
273    }
274    TokenStream::from(expanded)
275}
276
277// see for attributes():
278// https://stackoverflow.com/questions/42484062/how-do-i-process-enum-struct-field-attributes-in-a-procedural-macro
279#[proc_macro_derive(SyncRequester, attributes(response_val))]
280pub fn derive_sync_requester(input: TokenStream) -> TokenStream {
281    // See https://doc.servo.org/syn/derive/struct.DeriveInput.html
282    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
283
284    // get enum name
285    let enum_name = &input.ident;
286    let data = &input.data;
287
288    let mut variant_requester_functions;
289
290    // data is of type syn::Data
291    // See https://doc.servo.org/syn/enum.Data.html
292    match data {
293        // Only if data is an enum, we do parsing
294        Data::Enum(data_enum) => {
295            // data_enum is of type syn::DataEnum
296            // https://doc.servo.org/syn/struct.DataEnum.html
297
298            variant_requester_functions = TokenStream2::new();
299
300            // Iterate over enum variants
301            // `variants` if of type `Punctuated` which implements IntoIterator
302            //
303            // https://doc.servo.org/syn/punctuated/struct.Punctuated.html
304            // https://doc.servo.org/syn/struct.Variant.html
305            for variant in &data_enum.variants {
306                // Variant's name
307                let variant_name = &variant.ident;
308
309                // construct an identifier named <variant_name> for function name
310                // We convert it to snake case using `to_case(Case::Snake)`
311                // For example, if variant is `HelloWorld`, it will generate `is_hello_world`
312                let mut request_func_name =
313                    format_ident!("{}", variant_name.to_string().to_case(Case::Snake));
314                request_func_name.set_span(variant_name.span());
315
316                let return_type = if let Some(ref a) =
317                    variant.attrs.iter().find(|a| match a.path.get_ident() {
318                        Some(ident) => ident == "response_val",
319                        None => false,
320                    }) {
321                    if let Ok(types) = a.parse_args::<syn::Type>() {
322                        Some(types)
323                    } else if a.parse_args::<syn::parse::Nothing>().is_ok() {
324                        None
325                    } else {
326                        panic!(
327                            "attribute '{}' parsing failed for variant '{}'",
328                            a.to_token_stream(),
329                            variant_name
330                        );
331                    }
332                } else {
333                    None
334                };
335
336                let method_return_type = match return_type.clone() {
337                    Some(ret_type) => {
338                        let token_stream = ret_type.into_token_stream();
339                        quote!(#token_stream)
340                    }
341                    None => quote!(()),
342                };
343
344                let ok_var_name = match return_type.clone() {
345                    Some(_) => quote!(variant_data),
346                    None => quote!(),
347                };
348
349                let ok_ret_value = match return_type.clone() {
350                    Some(_) => quote!(variant_data),
351                    None => quote!(()),
352                };
353
354                // Variant can have unnamed fields like `Variant(i32, i64)`
355                // Variant can have named fields like `Variant {x: i32, y: i32}`
356                // Variant can be named Unit like `Variant`
357                match &variant.fields {
358                    Fields::Named(fields) => {
359                        let field_name: Vec<_> =
360                            fields.named.iter().map(|field| &field.ident).collect();
361                        let field_type: Vec<_> =
362                            fields.named.iter().map(|field| &field.ty).collect();
363
364                        variant_requester_functions.extend(quote!(
365                            fn #request_func_name( &mut self, #( #field_name : #field_type, )* duration: std::time::Duration) -> Result<#method_return_type, ::rtactor::Error> {
366                                match self.request_for::<#enum_name, Response>(#enum_name::#variant_name { #( #field_name : #field_name, )*}, duration)
367                                {
368                                    Ok(Response::#variant_name(#ok_var_name)) => Ok(#ok_ret_value),
369                                    Ok(_) => Err(::rtactor::Error::DowncastFailed),
370                                    Err(err) => Err(err),
371                                }
372                            }
373                        ));
374                    }
375                    Fields::Unnamed(_) => {
376                        panic!(
377                            "SyncRequester do not accept Unnamed variant and '{}' is one.",
378                            variant_name
379                        );
380                    }
381                    Fields::Unit => {
382                        variant_requester_functions.extend(quote!(
383                            fn #request_func_name( &mut self, duration: std::time::Duration) -> Result<#method_return_type, ::rtactor::Error> {
384                                match self.request_for::<#enum_name, Response>(Request::#variant_name, duration)
385                                {
386                                    Ok(Response::#variant_name(#ok_var_name)) => Ok(#ok_ret_value),
387                                    Ok(_) => Err(::rtactor::Error::DowncastFailed),
388                                    Err(err) => Err(err),
389                                }
390                            }
391                        ));
392                    }
393                };
394            }
395        }
396        _ => panic!(
397            "SyncRequester is only valid for enums and '{}' is not one.",
398            enum_name
399        ),
400    };
401
402    let trait_name = format_ident!("{}", "SyncRequester");
403
404    let expanded = quote! {
405        pub trait #trait_name : ::rtactor::SyncAccessor {
406            #variant_requester_functions
407        }
408    };
409
410    if PRINT_GENERATED_MACRO_CODE {
411        println!("expanded='{}'", expanded);
412    }
413    TokenStream::from(expanded)
414}
415
416// see for attributes():
417// https://stackoverflow.com/questions/42484062/how-do-i-process-enum-struct-field-attributes-in-a-procedural-macro
418#[proc_macro_derive(AsyncRequester, attributes(response_val))]
419pub fn derive_async_requester(input: TokenStream) -> TokenStream {
420    // See https://doc.servo.org/syn/derive/struct.DeriveInput.html
421    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
422
423    // get enum name
424    let enum_name = &input.ident;
425    let data = &input.data;
426
427    let mut variant_requester_functions;
428
429    // data is of type syn::Data
430    // See https://doc.servo.org/syn/enum.Data.html
431    match data {
432        // Only if data is an enum, we do parsing
433        Data::Enum(data_enum) => {
434            // data_enum is of type syn::DataEnum
435            // https://doc.servo.org/syn/struct.DataEnum.html
436
437            variant_requester_functions = TokenStream2::new();
438
439            // Iterate over enum variants
440            // `variants` if of type `Punctuated` which implements IntoIterator
441            //
442            // https://doc.servo.org/syn/punctuated/struct.Punctuated.html
443            // https://doc.servo.org/syn/struct.Variant.html
444            for variant in &data_enum.variants {
445                // Variant's name
446                let variant_name = &variant.ident;
447
448                // construct an identifier named <variant_name> for function name
449                // We convert it to snake case using `to_case(Case::Snake)`
450                // For example, if variant is `HelloWorld`, it will generate `is_hello_world`
451                let mut request_func_name =
452                    format_ident!("{}", variant_name.to_string().to_case(Case::Snake));
453                request_func_name.set_span(variant_name.span());
454
455                let return_type = if let Some(ref a) =
456                    variant.attrs.iter().find(|a| match a.path.get_ident() {
457                        Some(ident) => ident == "response_val",
458                        None => false,
459                    }) {
460                    if let Ok(types) = a.parse_args::<syn::Type>() {
461                        Some(types)
462                    } else if a.parse_args::<syn::parse::Nothing>().is_ok() {
463                        None
464                    } else {
465                        panic!(
466                            "attribute '{}' parsing failed for variant '{}'",
467                            a.to_token_stream(),
468                            variant_name
469                        );
470                    }
471                } else {
472                    None
473                };
474
475                let method_return_type = match return_type.clone() {
476                    Some(ret_type) => {
477                        let token_stream = ret_type.into_token_stream();
478                        quote!(#token_stream)
479                    }
480                    None => quote!(()),
481                };
482
483                let ok_var_name = match return_type.clone() {
484                    Some(_) => quote!(variant_data),
485                    None => quote!(),
486                };
487
488                let ok_ret_value = match return_type.clone() {
489                    Some(_) => quote!(variant_data),
490                    None => quote!(()),
491                };
492
493                // Variant can have unnamed fields like `Variant(i32, i64)`
494                // Variant can have named fields like `Variant {x: i32, y: i32}`
495                // Variant can be named Unit like `Variant`
496                match &variant.fields {
497                    Fields::Named(fields) => {
498                        let field_name: Vec<_> =
499                            fields.named.iter().map(|field| &field.ident).collect();
500                        let field_type: Vec<_> =
501                            fields.named.iter().map(|field| &field.ty).collect();
502
503                        variant_requester_functions.extend(quote!(
504                            async fn #request_func_name( &mut self, #( #field_name : #field_type, )* duration: std::time::Duration) -> Result<#method_return_type, ::rtactor::Error> {
505                                match self.request_for::<#enum_name, Response>(#enum_name::#variant_name { #( #field_name : #field_name, )*}, duration).await
506                                {
507                                    Ok(Response::#variant_name(#ok_var_name)) => Ok(#ok_ret_value),
508                                    Ok(_) => Err(::rtactor::Error::DowncastFailed),
509                                    Err(err) => Err(err),
510                                }
511                            }
512                        ));
513                    }
514                    Fields::Unnamed(_) => {
515                        panic!(
516                            "AsyncRequester do not accept Unnamed variant and '{}' is one.",
517                            variant_name
518                        );
519                    }
520                    Fields::Unit => {
521                        variant_requester_functions.extend(quote!(
522                            async fn #request_func_name( &mut self, duration: std::time::Duration) -> Result<#method_return_type, ::rtactor::Error> {
523                                match self.request_for::<#enum_name, Response>(Request::#variant_name, duration).await
524                                {
525                                    Ok(Response::#variant_name(#ok_var_name)) => Ok(#ok_ret_value),
526                                    Ok(_) => Err(::rtactor::Error::DowncastFailed),
527                                    Err(err) => Err(err),
528                                }
529                            }
530                        ));
531                    }
532                };
533            }
534        }
535        _ => panic!(
536            "AsyncRequester is only valid for enums and '{}' is not one.",
537            enum_name
538        ),
539    };
540
541    let trait_name = format_ident!("{}", "AsyncRequester");
542
543    let expanded = quote! {
544        pub trait #trait_name : ::rtactor::AsyncAccessor {
545            #variant_requester_functions
546        }
547    };
548
549    if PRINT_GENERATED_MACRO_CODE {
550        println!("expanded='{}'", expanded);
551    }
552    TokenStream::from(expanded)
553}