nodo_derive/
lib.rs

1// Copyright 2023 David Weikersdorfer
2
3use proc_macro::{Span, TokenStream};
4use quote::quote;
5use syn::{parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Meta};
6
7/// Derive macro to implement the RxBundle trait for a custom struct with Rx fields
8#[proc_macro_derive(RxBundleDerive)]
9pub fn rx_bundle_derive(input: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(input as DeriveInput);
11    impl_rx_bundle_derive(&input)
12}
13
14fn impl_rx_bundle_derive(input: &syn::DeriveInput) -> TokenStream {
15    let name = &input.ident;
16    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
17    let name_str = name.to_string();
18
19    let fields = match &input.data {
20        Data::Struct(DataStruct {
21            fields: Fields::Named(fields),
22            ..
23        }) => &fields.named,
24        _ => panic!("expected a struct with named fields"),
25    };
26
27    let fields_count = fields.len();
28    let field_index = (0..fields.len()).collect::<Vec<_>>();
29    let field_name = fields.iter().map(|field| &field.ident).collect::<Vec<_>>();
30    let field_name_str = fields
31        .iter()
32        .map(|f| f.ident.as_ref().unwrap().to_string())
33        .collect::<Vec<_>>();
34
35    let gen = quote! {
36        impl #impl_generics nodo::channels::RxBundle for #name #type_generics #where_clause {
37            fn channel_count(&self) -> usize {
38                #fields_count
39            }
40
41            fn name(&self, index: usize) -> &str {
42                match index {
43                    #(
44                        #field_index => #field_name_str,
45                    )*
46                    _ => panic!("invalid rx bundle index {index} for `{}`", #name_str),
47                }
48            }
49
50            fn inbox_message_count(&self, index: usize) -> usize {
51                match index {
52                    #(#field_index => self.#field_name.len(),)*
53                    _ => panic!("invalid rx bundle index {index} for `{}`", #name_str),
54                }
55            }
56
57            fn sync_all(&mut self, results: &mut [nodo::channels::SyncResult]) {
58                use nodo::channels::Rx;
59
60                #(results[#field_index] = self.#field_name.sync();)*
61            }
62
63            fn check_connection(&self) -> nodo::channels::ConnectionCheck {
64                use nodo::channels::Rx;
65
66                let mut cc = nodo::channels::ConnectionCheck::new(#fields_count);
67                #(cc.mark(#field_index, self.#field_name.is_connected());)*
68                cc
69            }
70        }
71    };
72    gen.into()
73}
74
75/// Derive macro to implement the TxBundle trait for a custom struct with Tx fields
76#[proc_macro_derive(TxBundleDerive)]
77pub fn tx_bundle_derive(input: TokenStream) -> TokenStream {
78    let input = parse_macro_input!(input as DeriveInput);
79    impl_tx_bundle_derive(&input)
80}
81
82fn impl_tx_bundle_derive(input: &syn::DeriveInput) -> TokenStream {
83    let name = &input.ident;
84    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
85    let name_str = name.to_string();
86
87    let fields = match &input.data {
88        Data::Struct(DataStruct {
89            fields: Fields::Named(fields),
90            ..
91        }) => &fields.named,
92        _ => panic!("expected a struct with named fields"),
93    };
94
95    let fields_count = fields.len();
96    let field_index = (0..fields.len()).collect::<Vec<_>>();
97    let field_name = fields.iter().map(|field| &field.ident).collect::<Vec<_>>();
98    let field_name_str = fields
99        .iter()
100        .map(|f| f.ident.as_ref().unwrap().to_string())
101        .collect::<Vec<_>>();
102
103    let gen = quote! {
104        impl #impl_generics nodo::channels::TxBundle for #name #type_generics #where_clause {
105            fn channel_count(&self) -> usize {
106                #fields_count
107            }
108
109            fn name(&self, index: usize) -> &str {
110                match index {
111                    #(
112                        #field_index => #field_name_str,
113                    )*
114                    _ => panic!("invalid tx bundle index {index} for `{}`", #name_str),
115                }
116            }
117
118            fn outbox_message_count(&self, index: usize) -> usize {
119                match index {
120                    #(#field_index => self.#field_name.len(),)*
121                    _ => panic!("invalid tx bundle index {index} for `{}`", #name_str),
122                }
123            }
124
125            fn flush_all(&mut self, results: &mut [nodo::channels::FlushResult]) {
126                use nodo::channels::Tx;
127
128                #(results[#field_index] = self.#field_name.flush();)*
129            }
130
131            fn check_connection(&self) -> nodo::channels::ConnectionCheck {
132                use nodo::channels::Tx;
133
134                let mut cc = nodo::channels::ConnectionCheck::new(#fields_count);
135                #(cc.mark(#field_index, self.#field_name.is_connected());;)*
136                cc
137            }
138        }
139    };
140    gen.into()
141}
142
143#[proc_macro_derive(Status, attributes(label, default, skipped))]
144pub fn derive_status(input: TokenStream) -> TokenStream {
145    // Parse the input token stream (the enum)
146    let input = parse_macro_input!(input as DeriveInput);
147
148    // Get the enum name
149    let enum_name = input.ident.clone();
150
151    // Ensure we have an enum
152    let data = if let Data::Enum(DataEnum { variants, .. }) = input.data {
153        variants
154    } else {
155        return syn::Error::new_spanned(input, "Status can only be derived for enums")
156            .to_compile_error()
157            .into();
158    };
159
160    let mut default_variant = None;
161    let mut match_arms_status = Vec::new();
162    let mut match_arms_label = Vec::new();
163
164    // Iterate over each variant
165    for variant in data {
166        let variant_name = &variant.ident;
167        let mut label = None;
168        let mut is_default = false;
169        let mut is_skipped = false;
170
171        // Parse the attributes on each variant
172        for attr in variant.attrs {
173            if attr.path.is_ident("label") {
174                if let Ok(Meta::NameValue(meta_name_value)) = attr.parse_meta() {
175                    if let syn::Lit::Str(lit_str) = &meta_name_value.lit {
176                        label = Some(lit_str.value());
177                    }
178                }
179            } else if attr.path.is_ident("default") {
180                is_default = true;
181            } else if attr.path.is_ident("skipped") {
182                is_skipped = true;
183            }
184        }
185
186        // Handle different variant types (unit, tuple, and struct)
187        let pattern = match &variant.fields {
188            Fields::Unit => quote! { #enum_name::#variant_name },
189            Fields::Unnamed(_) => quote! { #enum_name::#variant_name(..) },
190            Fields::Named(_) => quote! { #enum_name::#variant_name { .. } },
191        };
192
193        // Generate match arms for as_default_status
194        let default_status = if is_skipped {
195            quote! { DefaultStatus::Skipped }
196        } else {
197            quote! { DefaultStatus::Running }
198        };
199        match_arms_status.push(quote! {
200            #pattern => #default_status,
201        });
202
203        // Generate match arms for label, defaulting to the variant's name if no label is provided
204        let label = label.unwrap_or_else(|| variant_name.to_string());
205        match_arms_label.push(quote! {
206            #pattern => #label,
207        });
208
209        // Set the default variant
210        if is_default {
211            default_variant = Some(quote! {
212                fn default_implementation_status() -> Self {
213                    #enum_name::#variant_name
214                }
215            });
216        }
217    }
218
219    // Generate the default implementation status function
220    let default_implementation_status = default_variant.unwrap_or_else(|| {
221        quote! {
222            fn default_implementation_status() -> Self {
223                compile_error!("No default status was specified. Use #[default] to choose one.");
224            }
225        }
226    });
227
228    // Generate the final implementation
229    let expanded = quote! {
230        impl CodeletStatus for #enum_name {
231            #default_implementation_status
232
233            fn is_default_status(&self) -> bool {
234                false
235            }
236
237            fn as_default_status(&self) -> DefaultStatus {
238                match self {
239                    #(#match_arms_status)*
240                }
241            }
242
243            fn label(&self) -> &'static str {
244                match self {
245                    #(#match_arms_label)*
246                }
247            }
248        }
249    };
250
251    // Convert the generated code into a token stream
252    TokenStream::from(expanded)
253}
254
255fn to_camel_case(snake: &str) -> String {
256    let mut result = String::new();
257    let mut capitalize_next = true;
258
259    for c in snake.chars() {
260        if c == '_' {
261            capitalize_next = true;
262        } else if capitalize_next {
263            result.push(c.to_ascii_uppercase());
264            capitalize_next = false;
265        } else {
266            result.push(c);
267        }
268    }
269    result
270}
271
272#[proc_macro_derive(Config, attributes(mutable, hidden))]
273pub fn derive_config(input: TokenStream) -> TokenStream {
274    let input = parse_macro_input!(input as DeriveInput);
275    let struct_name = input.ident;
276    let generics = input.generics;
277    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
278
279    let pk_enum_name = format!("{}ParameterKind", struct_name);
280    let pk_enum_ident = syn::Ident::new(&pk_enum_name, struct_name.span());
281
282    let aux_name = format!("{}Aux", struct_name);
283    let aux_ident = syn::Ident::new(&aux_name, struct_name.span());
284
285    let mut parameters = Vec::new();
286    let mut parameters_with_value = Vec::new();
287    let mut pk_variants = Vec::new();
288    let mut pk_variants_doc = Vec::new();
289    let mut match_arms_set = Vec::new();
290    let mut aux_match_arms = Vec::new();
291    let mut aux_fields_decl = Vec::new();
292    let mut aux_fields = Vec::new();
293    let mut pk_field_names = Vec::new();
294
295    if let Data::Struct(data_struct) = input.data {
296        if let Fields::Named(fields) = data_struct.fields {
297            for field in fields.named {
298                let field_name = field.ident.unwrap();
299                let field_name_str = field_name.to_string();
300                let field_type = field.ty;
301                let field_type_str = quote!(#field_type).to_string();
302
303                // Skip parameters with the #[hidden] attributes
304                let is_hidden = field.attrs.iter().any(|attr| attr.path.is_ident("hidden"));
305                if is_hidden {
306                    continue;
307                }
308
309                // Do not allow modification to #[mutable] attributes
310                let is_mutable = field.attrs.iter().any(|attr| attr.path.is_ident("mutable"));
311
312                // Determine if we need to wrap this type in Parameter<T>
313                let config_kind = match field_type_str.as_str() {
314                    "bool" => Some(quote!(Bool)),
315                    "i64" => Some(quote!(Int64)),
316                    "usize" => Some(quote!(Usize)),
317                    "f64" => Some(quote!(Float64)),
318                    "String" => Some(quote!(String)),
319                    "Vec < f64 >" => Some(quote!(VecFloat64)),
320                    s if s.starts_with("[f64;") => Some(quote!(VecFloat64)),
321                    _ => None,
322                };
323
324                let pk_name = to_camel_case(&field_name.to_string());
325                let pk_ident = syn::Ident::new(&pk_name, field_name.span());
326
327                // Add conversion for this field
328                if config_kind.is_some() {
329                    if is_mutable {
330                        aux_fields_decl.push(quote! {
331                            pub #field_name: ParameterAux
332                        });
333
334                        aux_fields.push(quote! {
335                            #field_name
336                        });
337                    }
338
339                    pk_variants.push(quote! {
340                        #pk_ident
341                    });
342
343                    let doc_string =
344                        format!("Parameter `{}` of type {}", field_name_str, field_type_str);
345                    pk_variants_doc.push(quote! {
346                        #doc_string
347                    });
348
349                    pk_field_names.push(quote!(
350                        #field_name_str
351                    ));
352                }
353
354                if let Some(kind) = config_kind {
355                    parameters.push(quote! {
356                        (
357                            #pk_enum_ident::#pk_ident,
358                            ParameterProperties {
359                                dtype: ParameterDataType::#kind,
360                                is_mutable: #is_mutable,
361                            }
362                        )
363                    });
364
365                    parameters_with_value.push(quote! {
366                        (
367                            #pk_enum_ident::#pk_ident,
368                            self.#field_name.clone().into(),
369                        )
370                    });
371
372                    if is_mutable {
373                        let match_arm_set = quote! {
374                            #pk_enum_ident::#pk_ident => {
375                                match value {
376                                    ParameterValue::#kind(val) => {
377                                        Ok((&mut self.#field_name, val).assign()?)
378                                    }
379                                    actual => Err(ConfigSetParameterError::InvalidType {
380                                        expected: ParameterDataType::#kind,
381                                        actual: actual.dtype(),
382                                    })
383                                }
384                            }
385                        };
386                        match_arms_set.push(match_arm_set);
387                    } else {
388                        let match_arm_set = quote! {
389                            #pk_enum_ident::#pk_ident => {
390                                Err(ConfigSetParameterError::Immutable)
391                            }
392                        };
393                        match_arms_set.push(match_arm_set);
394                    }
395
396                    if is_mutable {
397                        let aux_match_arm = quote! {
398                            #pk_enum_ident::#pk_ident => {
399                                self.#field_name.on_set_parameter(now);
400                            }
401                        };
402                        aux_match_arms.push(aux_match_arm);
403                    }
404                }
405            }
406        }
407    }
408
409    let expanded = quote! {
410        #[automatically_derived]
411        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
412        #[allow(missing_docs)]
413        pub enum #pk_enum_ident {
414            #(
415                # [doc = #pk_variants_doc]
416                #pk_variants,
417            )*
418        }
419
420        impl ConfigKind for #pk_enum_ident {
421            #[inline]
422            fn from_str(id: &str) -> Option<Self> {
423                match id {
424                    #(#pk_field_names => Some(#pk_enum_ident::#pk_variants),)*
425                    _ => None,
426                }
427            }
428
429            #[inline]
430            fn as_str(self) -> &'static str {
431                match self {
432                    #(#pk_enum_ident::#pk_variants => #pk_field_names,)*
433                }
434            }
435        }
436
437        impl #impl_generics Config for #struct_name #ty_generics #where_clause {
438            type Kind = #pk_enum_ident;
439
440            type Aux =  #aux_ident;
441
442            fn list_parameters() -> &'static [(Self::Kind, ParameterProperties)] {
443                &[#(#parameters),*]
444            }
445
446            fn set_parameter(&mut self, kind: Self::Kind, value: ParameterValue)
447                -> Result<(), ConfigSetParameterError>
448            {
449                match kind {
450                    #(#match_arms_set)*
451                }
452            }
453
454            fn get_parameters(&self) -> Vec<(Self::Kind, ParameterValue)>{
455                vec![#(#parameters_with_value),*]
456            }
457
458        }
459
460        #[automatically_derived]
461        #[derive(Default)]
462        #[allow(dead_code)]
463        #[allow(missing_docs)]
464        pub struct #aux_ident {
465            _dirty: Vec<#pk_enum_ident>,
466            #(#aux_fields_decl,)*
467        }
468
469        impl ConfigAux for #aux_ident {
470            type Kind = #pk_enum_ident;
471
472            #[inline]
473            fn dirty(&self) -> &[Self::Kind] {
474                &self._dirty
475            }
476
477            #[inline]
478            fn is_dirty(&self) -> bool {
479                !self._dirty.is_empty()
480            }
481
482            #[allow(unreachable_code)]
483            fn on_set_parameter(&mut self, kind: Self::Kind, now: Pubtime) {
484                match kind {
485                    #(#aux_match_arms)*
486                    _ => unreachable!()
487                }
488                self._dirty.push(kind);
489            }
490
491            fn on_post_step(&mut self) {
492                #(self.#aux_fields.on_post_step();)*
493                self._dirty.clear();
494            }
495        }
496    };
497
498    TokenStream::from(expanded)
499}
500
501#[proc_macro]
502pub fn signals(input: TokenStream) -> TokenStream {
503    let input_str = input.to_string();
504
505    // Basic parsing of the input: extract name and fields
506    let binding = input_str.trim();
507    let parts: Vec<_> = binding.split('{').collect();
508
509    if parts.len() != 2 {
510        return quote! {
511            compile_error!(concat!(
512                "Invalid signals! syntax. Expected: signals! { Name { field1: type1, field2: type2, ... } }"
513            ))
514        }
515        .into();
516    }
517
518    let name = parts[0].trim();
519    let mut fields_str = parts[1].trim();
520
521    // Remove the trailing }
522    assert!(fields_str.ends_with('}'));
523    fields_str = &fields_str[0..fields_str.len() - 1];
524
525    // Split by "," to extract sections for each field
526    let parts: Vec<_> = fields_str.split(',').collect();
527
528    let mut field_def = Vec::new();
529    for part in parts {
530        let mut doc_comment = String::new();
531        let mut found_field = false;
532
533        for line in part.lines() {
534            let line = line.trim();
535            if line.is_empty() {
536                continue;
537            }
538
539            if found_field {
540                eprintln!("{part:?}");
541                return quote! {
542                    compile_error!(concat!(
543                        "found line after field definition: '",
544                        #line,
545                        "'. Expected: field_name: field_type"
546                    ))
547                }
548                .into();
549            }
550
551            // Check if this is a doc comment
552            if line.starts_with("///") {
553                if !doc_comment.is_empty() {
554                    doc_comment.push('\n');
555                }
556                doc_comment.push_str(line);
557            }
558            // Check for regular comment
559            else if line.starts_with("//") {
560                // ignore
561            }
562            // Check for field definition
563            else if line.contains(':') {
564                let field_parts: Vec<&str> = line.split(':').collect();
565                if field_parts.len() != 2 {
566                    eprintln!("{part:?}");
567                    return quote! {
568                        compile_error!(concat!(
569                            "Invalid field syntax: '",
570                            #line,
571                            "'. Expected: field_name: field_type"
572                        ))
573                    }
574                    .into();
575                }
576
577                let field_name_str = field_parts[0].trim();
578                let field_type_str = field_parts[1].trim();
579
580                field_def.push((doc_comment.clone(), field_name_str, field_type_str));
581                found_field = true;
582            } else {
583                eprintln!("{part:?}");
584                return quote! {
585                    compile_error!(concat!(
586                        "Invalid field syntax: '",
587                        #line,
588                        "'. Expected: field_name: field_type"
589                    ))
590                }
591                .into();
592            }
593        }
594    }
595
596    let name_ident = syn::Ident::new(name, Span::call_site().into());
597    let pk_enum_name = format!("{}Kind", name);
598    let pk_enum_ident = syn::Ident::new(&pk_enum_name, Span::call_site().into());
599
600    // Process fields
601    let mut field_defs = Vec::new();
602    let mut signal_kinds = Vec::new();
603    let mut signal_kinds_doc = Vec::new();
604    let mut signal_name_str = Vec::new();
605    let mut signal_names = Vec::new();
606    let mut signal_kind_dtypes = Vec::new();
607
608    for (doc_comment_with_slashes, field_name_str, field_type_str) in field_def.iter() {
609        let doc_comment = if doc_comment_with_slashes.is_empty() {
610            String::new()
611        } else {
612            // Remove the "///" prefix from each line
613            doc_comment_with_slashes
614                .lines()
615                .map(|line| line.trim_start_matches("///").trim())
616                .collect::<Vec<_>>()
617                .join("\n")
618        };
619
620        let field_name = syn::Ident::new(field_name_str, Span::call_site().into());
621        let field_type = syn::parse_str::<syn::Type>(field_type_str).unwrap_or_else(|_| {
622            panic!("Could not parse type: {}", field_type_str);
623        });
624
625        field_defs.push(quote! {
626            #[doc = #doc_comment]
627            pub #field_name: SignalCell<#field_type>
628        });
629
630        // Determine signal data type
631        let signal_dtype = match *field_type_str {
632            "bool" => quote!(Bool),
633            "i64" => quote!(Int64),
634            "usize" => quote!(Usize),
635            "f64" => quote!(Float64),
636            "String" => quote!(String),
637            _ => {
638                return quote! {
639                    compile_error!(concat!(
640                        "unsupported nodo signal field type: '",
641                        #field_type_str,
642                        "'. Supported types are: bool, i64, usize, f64, String."
643                    ))
644                }
645                .into();
646            }
647        };
648
649        signal_kind_dtypes.push(signal_dtype);
650
651        let signal_kind_name = to_camel_case(field_name_str);
652        let signal_kind_ident = syn::Ident::new(&signal_kind_name, Span::call_site().into());
653        signal_kinds.push(quote! { #signal_kind_ident });
654
655        signal_kinds_doc.push(quote! { #doc_comment });
656
657        signal_name_str.push(quote! { #field_name_str });
658        signal_names.push(field_name);
659    }
660
661    // Generate the struct and implementations
662    let expanded = quote! {
663        #[automatically_derived]
664        #[allow(missing_docs)]
665        pub struct #name_ident {
666            #(#field_defs,)*
667        }
668
669        #[automatically_derived]
670        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
671        #[allow(missing_docs)]
672        pub enum #pk_enum_ident {
673            #(
674                #[doc = #signal_kinds_doc]
675                #signal_kinds,
676            )*
677        }
678
679        impl SignalKind for #pk_enum_ident {
680            #[inline]
681            fn list() -> &'static [Self] {
682                &[
683                    #(
684                        #pk_enum_ident::#signal_kinds,
685                    )*
686                ]
687            }
688
689            #[inline]
690            fn dtype(&self) -> SignalDataType {
691                match self {
692                    #(
693                        #pk_enum_ident::#signal_kinds => SignalDataType::#signal_kind_dtypes,
694                    )*
695                }
696            }
697
698            #[inline]
699            fn from_str(id: &str) -> Option<Self> {
700                match id {
701                    #(
702                        #signal_name_str => Some(#pk_enum_ident::#signal_kinds),
703                    )*
704                    _ => None,
705                }
706            }
707
708            #[inline]
709            fn as_str(&self) -> &'static str {
710                match self {
711                    #(
712                        #pk_enum_ident::#signal_kinds => #signal_name_str,
713                    )*
714                }
715            }
716        }
717
718        impl Signals for #name_ident {
719            type Kind = #pk_enum_ident;
720
721            #[inline]
722            fn as_time_value_iter(
723                    &self
724            ) -> impl Iterator<Item = Option<SignalTimeValue>> + ExactSizeIterator {
725                [
726                    #(
727                        self.#signal_names.anon_time_value(),
728                    )*
729                ].into_iter()
730            }
731
732            #[inline]
733            fn on_post_execute(&mut self, step_time: Pubtime) {
734                #(
735                    self.#signal_names.on_post_execute(step_time);
736                )*
737            }
738        }
739
740        impl Default for #name_ident {
741            fn default() -> Self {
742                Self {
743                    #(
744                        #signal_names: Default::default(),
745                    )*
746                }
747            }
748        }
749    };
750
751    TokenStream::from(expanded)
752}