crosstalk_macros/
lib.rs

1#![doc(html_root_url = "https://docs.rs/crosstalk-macros/1.0")]
2//! Macros for [`crosstalk`](https://crates.io/crates/crosstalk)
3//! 
4//! ## License
5//! 
6//! Crosstalk is released under the MIT license [http://opensource.org/licenses/MIT](http://opensource.org/licenses/MIT)
7// --------------------------------------------------
8// external
9// --------------------------------------------------
10use quote::{quote, format_ident};
11use syn::{
12    Data,
13    Path,
14    Type,
15    Token,
16    parse::{
17        Parse,
18        ParseStream,
19    },
20    DeriveInput,
21    parse_macro_input,
22    punctuated::Punctuated,
23};
24use proc_macro::TokenStream;
25use std::collections::HashSet;
26use proc_macro2::TokenStream as TokenStream2;
27
28#[proc_macro]
29#[inline(always)]
30/// Macro used to initialize a [`crosstalk`](https://crates.io/crates/crosstalk) node
31/// 
32/// This correlates variants in the enum (topics) with datatypes used to 
33/// communicate on said topics. These channels broadcast messages from publishers
34/// to subscribers, without catastrophic consumption of the data
35/// 
36/// Any variants missing from the macro will automatically be added
37/// using the [`String`] datatype
38/// 
39/// The enum variants and the datatypes are formatted in the following fashion:
40/// 
41/// ```rust ignore
42/// crosstalk::init!{
43///     TopicEnum::Variant1 => bool,
44///     TopicEnum::Variant2 => String,
45///     TopicEnum::Variant3 => i32,
46/// }
47/// ```
48/// 
49/// Where `TopicEnum::<VariantName>` is the name of the enum, followed by 
50/// a `=>` and the datatype to be used on that topic of communication
51/// 
52/// # Examples
53/// 
54/// ```ignore
55/// use crosstalk::AsTopic;
56/// 
57/// #[derive(AsTopic)]
58/// enum ExampleTopics {
59///     BoolChannel,
60///     StringChannel,
61///     IntChannel,
62///     MissingChannel,
63/// }
64/// 
65/// crosstalk::init!{
66///     ExampleTopics::BoolChannel => bool,
67///     ExampleTopics::StringChannel => String,
68///     ExampleTopics::IntChannel => i32,
69/// }
70/// // `ExampleTopics::MissingChannel`` will be added automatically with datatype `String``
71/// ```
72pub fn init(input: TokenStream) -> TokenStream {
73    init_inner(input, quote!(::crosstalk::))
74}
75
76#[proc_macro]
77#[inline(always)]
78/// The [`init`] macro for testing. This is meant
79/// to be internal to `crosstalk` only.
80/// 
81/// See [`init`] for actual usage
82pub fn init_test(input: TokenStream) -> TokenStream {
83    init_inner(input, quote!(crate::))
84}
85
86/// Internal implementation
87/// 
88/// This is only used for testing, so that `mod tests;` can use `crosstalk::init!`
89/// expansion, replacing all instances of `crosstalk` with `crate`
90/// 
91/// See [`init`] for the proper usage
92fn init_inner(input: TokenStream, source: TokenStream2) -> TokenStream {
93    // --------------------------------------------------
94    // parse
95    // --------------------------------------------------
96    let NodeFields(fields) = parse_macro_input!(input as NodeFields);
97    
98    // --------------------------------------------------
99    // see if there are multiple enums for topics
100    // --------------------------------------------------
101    let unique_enum_names = fields
102        .iter()
103        .map(|nf| nf
104            .topic
105            .segments
106            .first()
107            .map(|s|
108                s
109                .ident
110                .to_string()
111            ).expect("Expected enum name to be path-like")
112        )
113        .collect::<HashSet<_>>()
114        .into_iter()
115        .collect::<Vec<_>>();
116    if unique_enum_names.len() > 1 {
117        let error_desc = "Multiple topic enum's found in crosstalk node initialization";
118        let error_msg = "Please use only one enum to represent topics, and use the absolute path to the enum (e.g. `TopicEnum::MyTopic` instead of `MyTopic`)";
119        panic!("\n{}\nFound:\n{:#?}\n{}", error_desc, unique_enum_names, error_msg);
120    }
121    let enum_master = format_ident!( "{}", &unique_enum_names[0]);
122
123    // --------------------------------------------------
124    // get topic names/types
125    // --------------------------------------------------
126    let nt = fields
127        .iter()
128        .map(|nf| (nf.topic.clone(), nf.dtype.clone()))
129        .collect::<Vec<_>>();
130
131    // --------------------------------------------------
132    // default type 
133    // --------------------------------------------------
134    let dt: Type = syn::parse_quote! { String };
135
136    // --------------------------------------------------
137    // publisher arms
138    // - add default case
139    // --------------------------------------------------
140    let mut pub_arms: Vec<TokenStream2> = nt
141        .iter()
142        .map(|(n, t)| get_publisher_arm(Some(n), t, &source))
143        .collect();
144    pub_arms.push(get_publisher_arm(None, &dt, &source));
145
146    // --------------------------------------------------
147    // subscriber arms
148    // - add default case
149    // --------------------------------------------------
150    let mut sub_arms: Vec<TokenStream2> = nt
151        .iter()
152        .map(|(n, t)| get_subscriber_arm(Some(n), t, &source))
153        .collect::<Vec<_>>();
154    sub_arms.push(get_subscriber_arm(None, &dt, &source));
155
156    // --------------------------------------------------
157    // output
158    // --------------------------------------------------
159    let output: TokenStream2 = quote! {
160        #[automatically_derived]
161        impl #source CrosstalkPubSub<#enum_master> for #source ImplementedBoundedNode<#enum_master> {
162            #[doc = " Get a [`crosstalk::Publisher`] for the given topic"]
163            #[doc = ""]
164            #[doc = " See [`crosstalk::BoundedNode::publisher`] for more information"]
165            fn publisher<D: #source CrosstalkData>(&mut self, topic: #enum_master) -> Result<#source Publisher<D, #enum_master>, #source Error> {
166                match topic {
167                    #(#pub_arms,)*
168                }
169            }
170            
171            #[doc = " Get a [`crosstalk::Subscriber`] for the given topic"]
172            #[doc = ""]
173            #[doc = " See [`crosstalk::BoundedNode::subscriber`] for more information"]
174            fn subscriber<D: #source CrosstalkData>(&mut self, topic: #enum_master) -> Result<#source Subscriber<D, #enum_master>, #source Error> {
175                match topic {
176                    #(#sub_arms,)*
177                }
178            }
179            
180            #[inline(always)]
181            #[doc = " Get a [`crosstalk::Publisher`] and [`crosstalk::Subscriber`] for the given topic"]
182            #[doc = ""]
183            #[doc = " See [`crosstalk::BoundedNode::pubsub`] for more information"]
184            fn pubsub<D: #source CrosstalkData>(&mut self, topic: #enum_master) -> Result<(#source Publisher<D, #enum_master>, #source Subscriber<D, #enum_master>), #source Error> {
185                match (self.publisher(topic), self.subscriber(topic)) {
186                    (Ok(publisher), Ok(subscriber)) => Ok((publisher, subscriber)),
187                    (Err(err), _) => Err(err),
188                    (_, Err(err)) => Err(err),
189                    _ => unreachable!(),
190                }
191            }
192        }
193    };
194
195    // --------------------------------------------------
196    // return
197    // --------------------------------------------------
198    TokenStream::from(output)
199}
200
201#[proc_macro_derive(AsTopic)]
202#[inline(always)]
203/// The [`AsTopic`] derive macro
204/// 
205/// This is used to distinguish enums as "crosstalk topics"
206/// by implementing the [`crosstalk::CrosstalkTopic`] trait
207/// amongst other traits
208/// 
209/// This will automatically implement the following traits:
210/// 
211/// * Clone
212/// * Copy
213/// * PartialEq
214/// * Eq
215/// * Hash
216///
217/// # Example
218/// 
219/// ```ignore
220/// use crosstalk::AsTopic;
221/// 
222/// #[derive(AsTopic)]
223/// enum ExampleTopics {
224///     BoolChannel,
225///     StringChannel,
226///     IntChannel,
227///     MissingChannel,
228/// }
229/// ```
230pub fn derive_enum_as_topic(input: TokenStream) -> TokenStream {
231    derive_enum_as_topic_inner(input, quote!(::crosstalk::))
232}
233
234#[proc_macro_derive(AsTopicTest)]
235#[inline(always)]
236/// The [`AsTopicTest`] derive macro for testing. This is meant
237/// to be internal to crosstalk only.
238/// 
239/// See [`derive_enum_as_topic`] for actual usage
240pub fn derive_enum_as_topic_test(input: TokenStream) -> TokenStream {
241    derive_enum_as_topic_inner(input, quote!(crate::))
242}
243
244/// Internal implementation
245/// 
246/// This is only used for testing, so that `mod tests;` can implement `crosstalk::CrosstalkTopic`
247/// as `crate::CrosstalkTopic`
248/// 
249/// See [`derive_enum_as_topic`] for the proper usage
250fn derive_enum_as_topic_inner(input: TokenStream, source: TokenStream2) -> TokenStream {
251    let input = parse_macro_input!(input as DeriveInput);
252    
253    let Data::Enum(_) = &input.data else {
254        panic!("CrosstalkTopic can only be derived on enums");
255    };
256    
257    let name = &input.ident;
258
259    let expanded = quote! {
260        #[automatically_derived]
261        impl ::core::clone::Clone for #name {
262            #[inline]
263            fn clone(&self) -> #name { *self }
264        }
265
266        #[automatically_derived]
267        impl ::core::marker::Copy for #name {}
268        
269        // TODO: impl this once stable
270        // #[automatically_derived]
271        // impl ::core::marker::StructuralPartialEq for #name {}
272        
273        #[automatically_derived]
274        impl ::core::cmp::PartialEq for #name {
275            #[inline]
276            fn eq(&self, other: &#name) -> bool {
277                let __self_tag = ::std::mem::discriminant(self);
278                let __arg1_tag = ::std::mem::discriminant(other);
279                __self_tag == __arg1_tag
280            }
281        }
282        
283        #[automatically_derived]
284        impl ::core::cmp::Eq for #name { }
285        
286        #[automatically_derived]
287        impl ::core::hash::Hash for #name {
288            #[inline]
289            fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) {
290                let __self_tag = ::std::mem::discriminant(self);
291                ::core::hash::Hash::hash(&__self_tag, state)
292            }
293        }
294        
295        #[automatically_derived]
296        impl #source CrosstalkTopic for #name {}
297    };
298
299    TokenStream::from(expanded)
300}
301
302/// Individual field for the [`crosstalk_macros::init!`] macro
303/// 
304/// # Format
305/// 
306/// ```text
307/// `<Enum>::<Variant> => <Type>`
308/// ```
309struct NodeField {
310    topic: Path,
311    _arrow: Token![=>],
312    dtype: Type,
313}
314/// [`NodeField`] implementation of [`syn::parse::Parse`]
315impl Parse for NodeField {
316    fn parse(input: ParseStream) -> syn::Result<Self> {
317        Ok(NodeField {
318            topic: input.parse()?,
319            _arrow: input.parse()?,
320            dtype: input.parse()?,
321        })
322    }
323}
324
325/// Fields for the [`crosstalk_macros::init!`] macro
326/// 
327/// # Format
328/// 
329/// ```text
330/// crosstalk_macros::init!{
331///     `<Enum>::<Variant> => <Type>`,
332///     `<Enum>::<Variant> => <Type>`,
333/// }
334/// ```
335struct NodeFields(Punctuated<NodeField, Token![,]>);
336
337/// [`NodeFields`] implementation of [`syn::parse::Parse`]
338impl Parse for NodeFields {
339    fn parse(input: ParseStream) -> syn::Result<Self> {
340        let content = Punctuated::<NodeField, Token![,]>::parse_terminated(input)?;
341        Ok(NodeFields(content))
342    }
343}
344
345/// Get publisher arm (used in type-matching within the [`crosstalk_macros::init!`] macro)
346/// 
347/// This helps fill in the `match` statement in the [`crosstalk_macros::init!`] macro
348/// with all the arms that are valid for a given topic and datatype
349fn get_publisher_arm(case: Option<&Path>, dtype: &Type, source: &TokenStream2) -> TokenStream2 {
350    let contents = quote! {
351        => {
352            let err = #source Error::PublisherMismatch(
353                ::std::any::type_name::<D>(),
354                ::std::any::type_name::<#dtype>(),
355            );
356            if ::std::any::TypeId::of::<D>()
357            == ::std::any::TypeId::of::<#dtype>() {
358                // --------------------------------------------------
359                // if the datatype matches, get the tokio brdcst sender
360                // to create the subscriber
361                // --------------------------------------------------
362                let tsen = match self.senders.contains_key(&topic) {
363                    true => {
364                        #[allow(clippy::unwrap_used)]
365                        // can't use .get here because it returns a reference,
366                        // and need to consume the value in order to downcast
367                        // it. as a result, .contains_key() is used in junction
368                        // with .remove.unwrap()
369                        let tsen_ = self.senders.remove(&topic).unwrap();
370                        let tsen_ = #source __macro_exports::downcast::<#source __macro_exports::broadcast::Sender<#dtype>>(tsen_, err)?;
371                        let tsen = tsen_.clone();
372                        self.senders.insert(topic, Box::new(tsen_));
373                        tsen
374                    },
375                    false => {
376                        // size is defined during crosstalk::BoundedNode::new(size)
377                        let (sender, _) =  #source __macro_exports::broadcast::channel::<#dtype>(self.size);
378                        self.senders.insert(topic, Box::new(sender.clone()));
379                        sender
380                    },
381                };
382                // --------------------------------------------------
383                // create and return publisher
384                // --------------------------------------------------
385                // this downcasts from #dtype -> D. These are the same type,
386                // due to check made above
387                // --------------------------------------------------
388                let sender = #source __macro_exports::downcast::<#source __macro_exports::broadcast::Sender<D>>(Box::new(tsen), err)?;
389                Ok(#source Publisher::new(topic, sender))
390            } else {
391                // --------------------------------------------------
392                // if the datatype does not match, return an error
393                // --------------------------------------------------
394                Err(err)
395            }
396        }
397    };
398    // --------------------------------------------------
399    // if arm not specified, use default (_)
400    // --------------------------------------------------
401    match case {
402        Some(case) => quote! { #case #contents },
403        None => quote! { _ #contents }
404    }
405}
406
407/// Get subscriber arm (used in type-matching within the [`crosstalk_macros::init!`] macro)
408/// 
409/// This helps fill in the `match` statement in the [`crosstalk_macros::init!`] macro
410/// with all the arms that are valid for a given topic and datatype
411fn get_subscriber_arm(case: Option<&Path>, dtype: &Type, source: &TokenStream2) -> TokenStream2 {
412    let contents = quote! {
413        => {
414            let err = #source Error::SubscriberMismatch(
415                ::std::any::type_name::<D>(),
416                ::std::any::type_name::<#dtype>(),
417            );
418            if ::std::any::TypeId::of::<D>()
419            == ::std::any::TypeId::of::<#dtype>() {
420                // --------------------------------------------------
421                // if the datatype matches, get the tokio brdcst sender
422                // to create the subscriber
423                // --------------------------------------------------
424                let tsen = match self.senders.contains_key(&topic) {
425                    true => {
426                        #[allow(clippy::unwrap_used)]
427                        // can't use .get here because it returns a reference,
428                        // and need to consume the value in order to downcast
429                        // it. as a result, .contains_key() is used in junction
430                        // with .remove.unwrap()
431                        let tsen_ = self.senders.remove(&topic).unwrap();
432                        let tsen_ = #source __macro_exports::downcast::<#source __macro_exports::broadcast::Sender<#dtype>>(tsen_, err)?;
433                        let tsen = tsen_.clone();
434                        self.senders.insert(topic, Box::new(tsen_));
435                        tsen
436                    },
437                    false => {
438                        // size is defined during crosstalk::BoundedNode::new(size)
439                        let (sender, _) =  #source __macro_exports::broadcast::channel::<#dtype>(self.size);
440                        self.senders.insert(topic, Box::new(sender.clone()));
441                        sender
442                    },
443                };
444                // --------------------------------------------------
445                // create and return subscriber
446                // --------------------------------------------------
447                // this downcasts from #dtype -> D. These are the same type,
448                // due to check made above
449                // --------------------------------------------------
450                let sender = #source __macro_exports::downcast::<#source __macro_exports::broadcast::Sender<D>>(Box::new(tsen), err)?;
451                Ok(#source Subscriber::new(topic, None, ::std::sync::Arc::new(sender)))
452            } else {
453                // --------------------------------------------------
454                // if the datatype does not match, return an error
455                // --------------------------------------------------
456                Err(err)
457            }
458        }
459    };
460    // --------------------------------------------------
461    // if arm not specified, use default (_)
462    // --------------------------------------------------
463    match case {
464        Some(case) => quote! { #case #contents },
465        None => quote! { _ #contents }
466    }
467}