1use proc_macro::{Span, TokenStream};
4use quote::quote;
5use syn::{parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Meta};
6
7#[proc_macro_derive(RxBundleDerive)]
9pub fn rx_bundle_derive(input: TokenStream) -> TokenStream {
10 let input = parse_macro_input!(input as DeriveInput);
11 impl_rx_bundle_derive(&input)
12}
13
14fn impl_rx_bundle_derive(input: &syn::DeriveInput) -> TokenStream {
15 let name = &input.ident;
16 let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
17 let name_str = name.to_string();
18
19 let fields = match &input.data {
20 Data::Struct(DataStruct {
21 fields: Fields::Named(fields),
22 ..
23 }) => &fields.named,
24 _ => panic!("expected a struct with named fields"),
25 };
26
27 let fields_count = fields.len();
28 let field_index = (0..fields.len()).collect::<Vec<_>>();
29 let field_name = fields.iter().map(|field| &field.ident).collect::<Vec<_>>();
30 let field_name_str = fields
31 .iter()
32 .map(|f| f.ident.as_ref().unwrap().to_string())
33 .collect::<Vec<_>>();
34
35 let gen = quote! {
36 impl #impl_generics nodo::channels::RxBundle for #name #type_generics #where_clause {
37 fn channel_count(&self) -> usize {
38 #fields_count
39 }
40
41 fn name(&self, index: usize) -> &str {
42 match index {
43 #(
44 #field_index => #field_name_str,
45 )*
46 _ => panic!("invalid rx bundle index {index} for `{}`", #name_str),
47 }
48 }
49
50 fn inbox_message_count(&self, index: usize) -> usize {
51 match index {
52 #(#field_index => self.#field_name.len(),)*
53 _ => panic!("invalid rx bundle index {index} for `{}`", #name_str),
54 }
55 }
56
57 fn sync_all(&mut self, results: &mut [nodo::channels::SyncResult]) {
58 use nodo::channels::Rx;
59
60 #(results[#field_index] = self.#field_name.sync();)*
61 }
62
63 fn check_connection(&self) -> nodo::channels::ConnectionCheck {
64 use nodo::channels::Rx;
65
66 let mut cc = nodo::channels::ConnectionCheck::new(#fields_count);
67 #(cc.mark(#field_index, self.#field_name.is_connected());)*
68 cc
69 }
70 }
71 };
72 gen.into()
73}
74
75#[proc_macro_derive(TxBundleDerive)]
77pub fn tx_bundle_derive(input: TokenStream) -> TokenStream {
78 let input = parse_macro_input!(input as DeriveInput);
79 impl_tx_bundle_derive(&input)
80}
81
82fn impl_tx_bundle_derive(input: &syn::DeriveInput) -> TokenStream {
83 let name = &input.ident;
84 let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
85 let name_str = name.to_string();
86
87 let fields = match &input.data {
88 Data::Struct(DataStruct {
89 fields: Fields::Named(fields),
90 ..
91 }) => &fields.named,
92 _ => panic!("expected a struct with named fields"),
93 };
94
95 let fields_count = fields.len();
96 let field_index = (0..fields.len()).collect::<Vec<_>>();
97 let field_name = fields.iter().map(|field| &field.ident).collect::<Vec<_>>();
98 let field_name_str = fields
99 .iter()
100 .map(|f| f.ident.as_ref().unwrap().to_string())
101 .collect::<Vec<_>>();
102
103 let gen = quote! {
104 impl #impl_generics nodo::channels::TxBundle for #name #type_generics #where_clause {
105 fn channel_count(&self) -> usize {
106 #fields_count
107 }
108
109 fn name(&self, index: usize) -> &str {
110 match index {
111 #(
112 #field_index => #field_name_str,
113 )*
114 _ => panic!("invalid tx bundle index {index} for `{}`", #name_str),
115 }
116 }
117
118 fn outbox_message_count(&self, index: usize) -> usize {
119 match index {
120 #(#field_index => self.#field_name.len(),)*
121 _ => panic!("invalid tx bundle index {index} for `{}`", #name_str),
122 }
123 }
124
125 fn flush_all(&mut self, results: &mut [nodo::channels::FlushResult]) {
126 use nodo::channels::Tx;
127
128 #(results[#field_index] = self.#field_name.flush();)*
129 }
130
131 fn check_connection(&self) -> nodo::channels::ConnectionCheck {
132 use nodo::channels::Tx;
133
134 let mut cc = nodo::channels::ConnectionCheck::new(#fields_count);
135 #(cc.mark(#field_index, self.#field_name.is_connected());;)*
136 cc
137 }
138 }
139 };
140 gen.into()
141}
142
143#[proc_macro_derive(Status, attributes(label, default, skipped))]
144pub fn derive_status(input: TokenStream) -> TokenStream {
145 let input = parse_macro_input!(input as DeriveInput);
147
148 let enum_name = input.ident.clone();
150
151 let data = if let Data::Enum(DataEnum { variants, .. }) = input.data {
153 variants
154 } else {
155 return syn::Error::new_spanned(input, "Status can only be derived for enums")
156 .to_compile_error()
157 .into();
158 };
159
160 let mut default_variant = None;
161 let mut match_arms_status = Vec::new();
162 let mut match_arms_label = Vec::new();
163
164 for variant in data {
166 let variant_name = &variant.ident;
167 let mut label = None;
168 let mut is_default = false;
169 let mut is_skipped = false;
170
171 for attr in variant.attrs {
173 if attr.path.is_ident("label") {
174 if let Ok(Meta::NameValue(meta_name_value)) = attr.parse_meta() {
175 if let syn::Lit::Str(lit_str) = &meta_name_value.lit {
176 label = Some(lit_str.value());
177 }
178 }
179 } else if attr.path.is_ident("default") {
180 is_default = true;
181 } else if attr.path.is_ident("skipped") {
182 is_skipped = true;
183 }
184 }
185
186 let pattern = match &variant.fields {
188 Fields::Unit => quote! { #enum_name::#variant_name },
189 Fields::Unnamed(_) => quote! { #enum_name::#variant_name(..) },
190 Fields::Named(_) => quote! { #enum_name::#variant_name { .. } },
191 };
192
193 let default_status = if is_skipped {
195 quote! { DefaultStatus::Skipped }
196 } else {
197 quote! { DefaultStatus::Running }
198 };
199 match_arms_status.push(quote! {
200 #pattern => #default_status,
201 });
202
203 let label = label.unwrap_or_else(|| variant_name.to_string());
205 match_arms_label.push(quote! {
206 #pattern => #label,
207 });
208
209 if is_default {
211 default_variant = Some(quote! {
212 fn default_implementation_status() -> Self {
213 #enum_name::#variant_name
214 }
215 });
216 }
217 }
218
219 let default_implementation_status = default_variant.unwrap_or_else(|| {
221 quote! {
222 fn default_implementation_status() -> Self {
223 compile_error!("No default status was specified. Use #[default] to choose one.");
224 }
225 }
226 });
227
228 let expanded = quote! {
230 impl CodeletStatus for #enum_name {
231 #default_implementation_status
232
233 fn is_default_status(&self) -> bool {
234 false
235 }
236
237 fn as_default_status(&self) -> DefaultStatus {
238 match self {
239 #(#match_arms_status)*
240 }
241 }
242
243 fn label(&self) -> &'static str {
244 match self {
245 #(#match_arms_label)*
246 }
247 }
248 }
249 };
250
251 TokenStream::from(expanded)
253}
254
255fn to_camel_case(snake: &str) -> String {
256 let mut result = String::new();
257 let mut capitalize_next = true;
258
259 for c in snake.chars() {
260 if c == '_' {
261 capitalize_next = true;
262 } else if capitalize_next {
263 result.push(c.to_ascii_uppercase());
264 capitalize_next = false;
265 } else {
266 result.push(c);
267 }
268 }
269 result
270}
271
272#[proc_macro_derive(Config, attributes(mutable, hidden))]
273pub fn derive_config(input: TokenStream) -> TokenStream {
274 let input = parse_macro_input!(input as DeriveInput);
275 let struct_name = input.ident;
276 let generics = input.generics;
277 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
278
279 let pk_enum_name = format!("{}ParameterKind", struct_name);
280 let pk_enum_ident = syn::Ident::new(&pk_enum_name, struct_name.span());
281
282 let aux_name = format!("{}Aux", struct_name);
283 let aux_ident = syn::Ident::new(&aux_name, struct_name.span());
284
285 let mut parameters = Vec::new();
286 let mut parameters_with_value = Vec::new();
287 let mut pk_variants = Vec::new();
288 let mut pk_variants_doc = Vec::new();
289 let mut match_arms_set = Vec::new();
290 let mut aux_match_arms = Vec::new();
291 let mut aux_fields_decl = Vec::new();
292 let mut aux_fields = Vec::new();
293 let mut pk_field_names = Vec::new();
294
295 if let Data::Struct(data_struct) = input.data {
296 if let Fields::Named(fields) = data_struct.fields {
297 for field in fields.named {
298 let field_name = field.ident.unwrap();
299 let field_name_str = field_name.to_string();
300 let field_type = field.ty;
301 let field_type_str = quote!(#field_type).to_string();
302
303 let is_hidden = field.attrs.iter().any(|attr| attr.path.is_ident("hidden"));
305 if is_hidden {
306 continue;
307 }
308
309 let is_mutable = field.attrs.iter().any(|attr| attr.path.is_ident("mutable"));
311
312 let config_kind = match field_type_str.as_str() {
314 "bool" => Some(quote!(Bool)),
315 "i64" => Some(quote!(Int64)),
316 "usize" => Some(quote!(Usize)),
317 "f64" => Some(quote!(Float64)),
318 "String" => Some(quote!(String)),
319 "Vec < f64 >" => Some(quote!(VecFloat64)),
320 s if s.starts_with("[f64;") => Some(quote!(VecFloat64)),
321 _ => None,
322 };
323
324 let pk_name = to_camel_case(&field_name.to_string());
325 let pk_ident = syn::Ident::new(&pk_name, field_name.span());
326
327 if config_kind.is_some() {
329 if is_mutable {
330 aux_fields_decl.push(quote! {
331 pub #field_name: ParameterAux
332 });
333
334 aux_fields.push(quote! {
335 #field_name
336 });
337 }
338
339 pk_variants.push(quote! {
340 #pk_ident
341 });
342
343 let doc_string =
344 format!("Parameter `{}` of type {}", field_name_str, field_type_str);
345 pk_variants_doc.push(quote! {
346 #doc_string
347 });
348
349 pk_field_names.push(quote!(
350 #field_name_str
351 ));
352 }
353
354 if let Some(kind) = config_kind {
355 parameters.push(quote! {
356 (
357 #pk_enum_ident::#pk_ident,
358 ParameterProperties {
359 dtype: ParameterDataType::#kind,
360 is_mutable: #is_mutable,
361 }
362 )
363 });
364
365 parameters_with_value.push(quote! {
366 (
367 #pk_enum_ident::#pk_ident,
368 self.#field_name.clone().into(),
369 )
370 });
371
372 if is_mutable {
373 let match_arm_set = quote! {
374 #pk_enum_ident::#pk_ident => {
375 match value {
376 ParameterValue::#kind(val) => {
377 Ok((&mut self.#field_name, val).assign()?)
378 }
379 actual => Err(ConfigSetParameterError::InvalidType {
380 expected: ParameterDataType::#kind,
381 actual: actual.dtype(),
382 })
383 }
384 }
385 };
386 match_arms_set.push(match_arm_set);
387 } else {
388 let match_arm_set = quote! {
389 #pk_enum_ident::#pk_ident => {
390 Err(ConfigSetParameterError::Immutable)
391 }
392 };
393 match_arms_set.push(match_arm_set);
394 }
395
396 if is_mutable {
397 let aux_match_arm = quote! {
398 #pk_enum_ident::#pk_ident => {
399 self.#field_name.on_set_parameter(now);
400 }
401 };
402 aux_match_arms.push(aux_match_arm);
403 }
404 }
405 }
406 }
407 }
408
409 let expanded = quote! {
410 #[automatically_derived]
411 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
412 #[allow(missing_docs)]
413 pub enum #pk_enum_ident {
414 #(
415 # [doc = #pk_variants_doc]
416 #pk_variants,
417 )*
418 }
419
420 impl ConfigKind for #pk_enum_ident {
421 #[inline]
422 fn from_str(id: &str) -> Option<Self> {
423 match id {
424 #(#pk_field_names => Some(#pk_enum_ident::#pk_variants),)*
425 _ => None,
426 }
427 }
428
429 #[inline]
430 fn as_str(self) -> &'static str {
431 match self {
432 #(#pk_enum_ident::#pk_variants => #pk_field_names,)*
433 }
434 }
435 }
436
437 impl #impl_generics Config for #struct_name #ty_generics #where_clause {
438 type Kind = #pk_enum_ident;
439
440 type Aux = #aux_ident;
441
442 fn list_parameters() -> &'static [(Self::Kind, ParameterProperties)] {
443 &[#(#parameters),*]
444 }
445
446 fn set_parameter(&mut self, kind: Self::Kind, value: ParameterValue)
447 -> Result<(), ConfigSetParameterError>
448 {
449 match kind {
450 #(#match_arms_set)*
451 }
452 }
453
454 fn get_parameters(&self) -> Vec<(Self::Kind, ParameterValue)>{
455 vec![#(#parameters_with_value),*]
456 }
457
458 }
459
460 #[automatically_derived]
461 #[derive(Default)]
462 #[allow(dead_code)]
463 #[allow(missing_docs)]
464 pub struct #aux_ident {
465 _dirty: Vec<#pk_enum_ident>,
466 #(#aux_fields_decl,)*
467 }
468
469 impl ConfigAux for #aux_ident {
470 type Kind = #pk_enum_ident;
471
472 #[inline]
473 fn dirty(&self) -> &[Self::Kind] {
474 &self._dirty
475 }
476
477 #[inline]
478 fn is_dirty(&self) -> bool {
479 !self._dirty.is_empty()
480 }
481
482 #[allow(unreachable_code)]
483 fn on_set_parameter(&mut self, kind: Self::Kind, now: Pubtime) {
484 match kind {
485 #(#aux_match_arms)*
486 _ => unreachable!()
487 }
488 self._dirty.push(kind);
489 }
490
491 fn on_post_step(&mut self) {
492 #(self.#aux_fields.on_post_step();)*
493 self._dirty.clear();
494 }
495 }
496 };
497
498 TokenStream::from(expanded)
499}
500
501#[proc_macro]
502pub fn signals(input: TokenStream) -> TokenStream {
503 let input_str = input.to_string();
504
505 let binding = input_str.trim();
507 let parts: Vec<_> = binding.split('{').collect();
508
509 if parts.len() != 2 {
510 return quote! {
511 compile_error!(concat!(
512 "Invalid signals! syntax. Expected: signals! { Name { field1: type1, field2: type2, ... } }"
513 ))
514 }
515 .into();
516 }
517
518 let name = parts[0].trim();
519 let mut fields_str = parts[1].trim();
520
521 assert!(fields_str.ends_with('}'));
523 fields_str = &fields_str[0..fields_str.len() - 1];
524
525 let parts: Vec<_> = fields_str.split(',').collect();
527
528 let mut field_def = Vec::new();
529 for part in parts {
530 let mut doc_comment = String::new();
531 let mut found_field = false;
532
533 for line in part.lines() {
534 let line = line.trim();
535 if line.is_empty() {
536 continue;
537 }
538
539 if found_field {
540 eprintln!("{part:?}");
541 return quote! {
542 compile_error!(concat!(
543 "found line after field definition: '",
544 #line,
545 "'. Expected: field_name: field_type"
546 ))
547 }
548 .into();
549 }
550
551 if line.starts_with("///") {
553 if !doc_comment.is_empty() {
554 doc_comment.push('\n');
555 }
556 doc_comment.push_str(line);
557 }
558 else if line.starts_with("//") {
560 }
562 else if line.contains(':') {
564 let field_parts: Vec<&str> = line.split(':').collect();
565 if field_parts.len() != 2 {
566 eprintln!("{part:?}");
567 return quote! {
568 compile_error!(concat!(
569 "Invalid field syntax: '",
570 #line,
571 "'. Expected: field_name: field_type"
572 ))
573 }
574 .into();
575 }
576
577 let field_name_str = field_parts[0].trim();
578 let field_type_str = field_parts[1].trim();
579
580 field_def.push((doc_comment.clone(), field_name_str, field_type_str));
581 found_field = true;
582 } else {
583 eprintln!("{part:?}");
584 return quote! {
585 compile_error!(concat!(
586 "Invalid field syntax: '",
587 #line,
588 "'. Expected: field_name: field_type"
589 ))
590 }
591 .into();
592 }
593 }
594 }
595
596 let name_ident = syn::Ident::new(name, Span::call_site().into());
597 let pk_enum_name = format!("{}Kind", name);
598 let pk_enum_ident = syn::Ident::new(&pk_enum_name, Span::call_site().into());
599
600 let mut field_defs = Vec::new();
602 let mut signal_kinds = Vec::new();
603 let mut signal_kinds_doc = Vec::new();
604 let mut signal_name_str = Vec::new();
605 let mut signal_names = Vec::new();
606 let mut signal_kind_dtypes = Vec::new();
607
608 for (doc_comment_with_slashes, field_name_str, field_type_str) in field_def.iter() {
609 let doc_comment = if doc_comment_with_slashes.is_empty() {
610 String::new()
611 } else {
612 doc_comment_with_slashes
614 .lines()
615 .map(|line| line.trim_start_matches("///").trim())
616 .collect::<Vec<_>>()
617 .join("\n")
618 };
619
620 let field_name = syn::Ident::new(field_name_str, Span::call_site().into());
621 let field_type = syn::parse_str::<syn::Type>(field_type_str).unwrap_or_else(|_| {
622 panic!("Could not parse type: {}", field_type_str);
623 });
624
625 field_defs.push(quote! {
626 #[doc = #doc_comment]
627 pub #field_name: SignalCell<#field_type>
628 });
629
630 let signal_dtype = match *field_type_str {
632 "bool" => quote!(Bool),
633 "i64" => quote!(Int64),
634 "usize" => quote!(Usize),
635 "f64" => quote!(Float64),
636 "String" => quote!(String),
637 _ => {
638 return quote! {
639 compile_error!(concat!(
640 "unsupported nodo signal field type: '",
641 #field_type_str,
642 "'. Supported types are: bool, i64, usize, f64, String."
643 ))
644 }
645 .into();
646 }
647 };
648
649 signal_kind_dtypes.push(signal_dtype);
650
651 let signal_kind_name = to_camel_case(field_name_str);
652 let signal_kind_ident = syn::Ident::new(&signal_kind_name, Span::call_site().into());
653 signal_kinds.push(quote! { #signal_kind_ident });
654
655 signal_kinds_doc.push(quote! { #doc_comment });
656
657 signal_name_str.push(quote! { #field_name_str });
658 signal_names.push(field_name);
659 }
660
661 let expanded = quote! {
663 #[automatically_derived]
664 #[allow(missing_docs)]
665 pub struct #name_ident {
666 #(#field_defs,)*
667 }
668
669 #[automatically_derived]
670 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
671 #[allow(missing_docs)]
672 pub enum #pk_enum_ident {
673 #(
674 #[doc = #signal_kinds_doc]
675 #signal_kinds,
676 )*
677 }
678
679 impl SignalKind for #pk_enum_ident {
680 #[inline]
681 fn list() -> &'static [Self] {
682 &[
683 #(
684 #pk_enum_ident::#signal_kinds,
685 )*
686 ]
687 }
688
689 #[inline]
690 fn dtype(&self) -> SignalDataType {
691 match self {
692 #(
693 #pk_enum_ident::#signal_kinds => SignalDataType::#signal_kind_dtypes,
694 )*
695 }
696 }
697
698 #[inline]
699 fn from_str(id: &str) -> Option<Self> {
700 match id {
701 #(
702 #signal_name_str => Some(#pk_enum_ident::#signal_kinds),
703 )*
704 _ => None,
705 }
706 }
707
708 #[inline]
709 fn as_str(&self) -> &'static str {
710 match self {
711 #(
712 #pk_enum_ident::#signal_kinds => #signal_name_str,
713 )*
714 }
715 }
716 }
717
718 impl Signals for #name_ident {
719 type Kind = #pk_enum_ident;
720
721 #[inline]
722 fn as_time_value_iter(
723 &self
724 ) -> impl Iterator<Item = Option<SignalTimeValue>> + ExactSizeIterator {
725 [
726 #(
727 self.#signal_names.anon_time_value(),
728 )*
729 ].into_iter()
730 }
731
732 #[inline]
733 fn on_post_execute(&mut self, step_time: Pubtime) {
734 #(
735 self.#signal_names.on_post_execute(step_time);
736 )*
737 }
738 }
739
740 impl Default for #name_ident {
741 fn default() -> Self {
742 Self {
743 #(
744 #signal_names: Default::default(),
745 )*
746 }
747 }
748 }
749 };
750
751 TokenStream::from(expanded)
752}