use proc_macro::{Span, TokenStream};
use quote::quote;
use syn::{parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Meta};
#[proc_macro_derive(RxBundleDerive)]
pub fn rx_bundle_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
impl_rx_bundle_derive(&input)
}
fn impl_rx_bundle_derive(input: &syn::DeriveInput) -> TokenStream {
let name = &input.ident;
let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
let name_str = name.to_string();
let fields = match &input.data {
Data::Struct(DataStruct {
fields: Fields::Named(fields),
..
}) => &fields.named,
_ => panic!("expected a struct with named fields"),
};
let fields_count = fields.len();
let field_index = (0..fields.len()).collect::<Vec<_>>();
let field_name = fields.iter().map(|field| &field.ident).collect::<Vec<_>>();
let field_name_str = fields
.iter()
.map(|f| f.ident.as_ref().unwrap().to_string())
.collect::<Vec<_>>();
let gen = quote! {
impl #impl_generics nodo::channels::RxBundle for #name #type_generics #where_clause {
fn channel_count(&self) -> usize {
#fields_count
}
fn name(&self, index: usize) -> &str {
match index {
#(
#field_index => #field_name_str,
)*
_ => panic!("invalid rx bundle index {index} for `{}`", #name_str),
}
}
fn inbox_message_count(&self, index: usize) -> usize {
match index {
#(#field_index => self.#field_name.len(),)*
_ => panic!("invalid rx bundle index {index} for `{}`", #name_str),
}
}
fn sync_all(&mut self, results: &mut [nodo::channels::SyncResult]) {
use nodo::channels::Rx;
#(results[#field_index] = self.#field_name.sync();)*
}
fn check_connection(&self) -> nodo::channels::ConnectionCheck {
use nodo::channels::Rx;
let mut cc = nodo::channels::ConnectionCheck::new(#fields_count);
#(cc.mark(#field_index, self.#field_name.is_connected());)*
cc
}
}
};
gen.into()
}
#[proc_macro_derive(TxBundleDerive)]
pub fn tx_bundle_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
impl_tx_bundle_derive(&input)
}
fn impl_tx_bundle_derive(input: &syn::DeriveInput) -> TokenStream {
let name = &input.ident;
let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
let name_str = name.to_string();
let fields = match &input.data {
Data::Struct(DataStruct {
fields: Fields::Named(fields),
..
}) => &fields.named,
_ => panic!("expected a struct with named fields"),
};
let fields_count = fields.len();
let field_index = (0..fields.len()).collect::<Vec<_>>();
let field_name = fields.iter().map(|field| &field.ident).collect::<Vec<_>>();
let field_name_str = fields
.iter()
.map(|f| f.ident.as_ref().unwrap().to_string())
.collect::<Vec<_>>();
let gen = quote! {
impl #impl_generics nodo::channels::TxBundle for #name #type_generics #where_clause {
fn channel_count(&self) -> usize {
#fields_count
}
fn name(&self, index: usize) -> &str {
match index {
#(
#field_index => #field_name_str,
)*
_ => panic!("invalid tx bundle index {index} for `{}`", #name_str),
}
}
fn outbox_message_count(&self, index: usize) -> usize {
match index {
#(#field_index => self.#field_name.len(),)*
_ => panic!("invalid tx bundle index {index} for `{}`", #name_str),
}
}
fn flush_all(&mut self, results: &mut [nodo::channels::FlushResult]) {
use nodo::channels::Tx;
#(results[#field_index] = self.#field_name.flush();)*
}
fn check_connection(&self) -> nodo::channels::ConnectionCheck {
use nodo::channels::Tx;
let mut cc = nodo::channels::ConnectionCheck::new(#fields_count);
#(cc.mark(#field_index, self.#field_name.is_connected());;)*
cc
}
}
};
gen.into()
}
#[proc_macro_derive(Status, attributes(label, default, skipped))]
pub fn derive_status(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let enum_name = input.ident.clone();
let data = if let Data::Enum(DataEnum { variants, .. }) = input.data {
variants
} else {
return syn::Error::new_spanned(input, "Status can only be derived for enums")
.to_compile_error()
.into();
};
let mut default_variant = None;
let mut match_arms_status = Vec::new();
let mut match_arms_label = Vec::new();
for variant in data {
let variant_name = &variant.ident;
let mut label = None;
let mut is_default = false;
let mut is_skipped = false;
for attr in variant.attrs {
if attr.path.is_ident("label") {
if let Ok(Meta::NameValue(meta_name_value)) = attr.parse_meta() {
if let syn::Lit::Str(lit_str) = &meta_name_value.lit {
label = Some(lit_str.value());
}
}
} else if attr.path.is_ident("default") {
is_default = true;
} else if attr.path.is_ident("skipped") {
is_skipped = true;
}
}
let pattern = match &variant.fields {
Fields::Unit => quote! { #enum_name::#variant_name },
Fields::Unnamed(_) => quote! { #enum_name::#variant_name(..) },
Fields::Named(_) => quote! { #enum_name::#variant_name { .. } },
};
let default_status = if is_skipped {
quote! { DefaultStatus::Skipped }
} else {
quote! { DefaultStatus::Running }
};
match_arms_status.push(quote! {
#pattern => #default_status,
});
let label = label.unwrap_or_else(|| variant_name.to_string());
match_arms_label.push(quote! {
#pattern => #label,
});
if is_default {
default_variant = Some(quote! {
fn default_implementation_status() -> Self {
#enum_name::#variant_name
}
});
}
}
let default_implementation_status = default_variant.unwrap_or_else(|| {
quote! {
fn default_implementation_status() -> Self {
compile_error!("No default status was specified. Use #[default] to choose one.");
}
}
});
let expanded = quote! {
impl CodeletStatus for #enum_name {
#default_implementation_status
fn is_default_status(&self) -> bool {
false
}
fn as_default_status(&self) -> DefaultStatus {
match self {
#(#match_arms_status)*
}
}
fn label(&self) -> &'static str {
match self {
#(#match_arms_label)*
}
}
}
};
TokenStream::from(expanded)
}
fn to_camel_case(snake: &str) -> String {
let mut result = String::new();
let mut capitalize_next = true;
for c in snake.chars() {
if c == '_' {
capitalize_next = true;
} else if capitalize_next {
result.push(c.to_ascii_uppercase());
capitalize_next = false;
} else {
result.push(c);
}
}
result
}
#[proc_macro_derive(Config, attributes(mutable, hidden))]
pub fn derive_config(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let struct_name = input.ident;
let generics = input.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let pk_enum_name = format!("{}ParameterKind", struct_name);
let pk_enum_ident = syn::Ident::new(&pk_enum_name, struct_name.span());
let aux_name = format!("{}Aux", struct_name);
let aux_ident = syn::Ident::new(&aux_name, struct_name.span());
let mut parameters = Vec::new();
let mut parameters_with_value = Vec::new();
let mut pk_variants = Vec::new();
let mut pk_variants_doc = Vec::new();
let mut match_arms_set = Vec::new();
let mut aux_match_arms = Vec::new();
let mut aux_fields_decl = Vec::new();
let mut aux_fields = Vec::new();
let mut pk_field_names = Vec::new();
if let Data::Struct(data_struct) = input.data {
if let Fields::Named(fields) = data_struct.fields {
for field in fields.named {
let field_name = field.ident.unwrap();
let field_name_str = field_name.to_string();
let field_type = field.ty;
let field_type_str = quote!(#field_type).to_string();
let is_hidden = field.attrs.iter().any(|attr| attr.path.is_ident("hidden"));
if is_hidden {
continue;
}
let is_mutable = field.attrs.iter().any(|attr| attr.path.is_ident("mutable"));
let config_kind = match field_type_str.as_str() {
"bool" => Some(quote!(Bool)),
"i64" => Some(quote!(Int64)),
"usize" => Some(quote!(Usize)),
"f64" => Some(quote!(Float64)),
"String" => Some(quote!(String)),
"Vec < f64 >" => Some(quote!(VecFloat64)),
s if s.starts_with("[f64;") => Some(quote!(VecFloat64)),
_ => None,
};
let pk_name = to_camel_case(&field_name.to_string());
let pk_ident = syn::Ident::new(&pk_name, field_name.span());
if config_kind.is_some() {
if is_mutable {
aux_fields_decl.push(quote! {
pub #field_name: ParameterAux
});
aux_fields.push(quote! {
#field_name
});
}
pk_variants.push(quote! {
#pk_ident
});
let doc_string =
format!("Parameter `{}` of type {}", field_name_str, field_type_str);
pk_variants_doc.push(quote! {
#doc_string
});
pk_field_names.push(quote!(
#field_name_str
));
}
if let Some(kind) = config_kind {
parameters.push(quote! {
(
#pk_enum_ident::#pk_ident,
ParameterProperties {
dtype: ParameterDataType::#kind,
is_mutable: #is_mutable,
}
)
});
parameters_with_value.push(quote! {
(
#pk_enum_ident::#pk_ident,
self.#field_name.clone().into(),
)
});
if is_mutable {
let match_arm_set = quote! {
#pk_enum_ident::#pk_ident => {
match value {
ParameterValue::#kind(val) => {
Ok((&mut self.#field_name, val).assign()?)
}
actual => Err(ConfigSetParameterError::InvalidType {
expected: ParameterDataType::#kind,
actual: actual.dtype(),
})
}
}
};
match_arms_set.push(match_arm_set);
} else {
let match_arm_set = quote! {
#pk_enum_ident::#pk_ident => {
Err(ConfigSetParameterError::Immutable)
}
};
match_arms_set.push(match_arm_set);
}
if is_mutable {
let aux_match_arm = quote! {
#pk_enum_ident::#pk_ident => {
self.#field_name.on_set_parameter(now);
}
};
aux_match_arms.push(aux_match_arm);
}
}
}
}
}
let expanded = quote! {
#[automatically_derived]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[allow(missing_docs)]
pub enum #pk_enum_ident {
#(
# [doc = #pk_variants_doc]
#pk_variants,
)*
}
impl ConfigKind for #pk_enum_ident {
#[inline]
fn from_str(id: &str) -> Option<Self> {
match id {
#(#pk_field_names => Some(#pk_enum_ident::#pk_variants),)*
_ => None,
}
}
#[inline]
fn as_str(self) -> &'static str {
match self {
#(#pk_enum_ident::#pk_variants => #pk_field_names,)*
}
}
}
impl #impl_generics Config for #struct_name #ty_generics #where_clause {
type Kind = #pk_enum_ident;
type Aux = #aux_ident;
fn list_parameters() -> &'static [(Self::Kind, ParameterProperties)] {
&[#(#parameters),*]
}
fn set_parameter(&mut self, kind: Self::Kind, value: ParameterValue)
-> Result<(), ConfigSetParameterError>
{
match kind {
#(#match_arms_set)*
}
}
fn get_parameters(&self) -> Vec<(Self::Kind, ParameterValue)>{
vec![#(#parameters_with_value),*]
}
}
#[automatically_derived]
#[derive(Default)]
#[allow(dead_code)]
#[allow(missing_docs)]
pub struct #aux_ident {
_dirty: Vec<#pk_enum_ident>,
#(#aux_fields_decl,)*
}
impl ConfigAux for #aux_ident {
type Kind = #pk_enum_ident;
#[inline]
fn dirty(&self) -> &[Self::Kind] {
&self._dirty
}
#[inline]
fn is_dirty(&self) -> bool {
!self._dirty.is_empty()
}
#[allow(unreachable_code)]
fn on_set_parameter(&mut self, kind: Self::Kind, now: Pubtime) {
match kind {
#(#aux_match_arms)*
_ => unreachable!()
}
self._dirty.push(kind);
}
fn on_post_step(&mut self) {
#(self.#aux_fields.on_post_step();)*
self._dirty.clear();
}
}
};
TokenStream::from(expanded)
}
#[proc_macro]
pub fn signals(input: TokenStream) -> TokenStream {
let input_str = input.to_string();
let binding = input_str.trim();
let parts: Vec<_> = binding.split('{').collect();
if parts.len() != 2 {
return quote! {
compile_error!(concat!(
"Invalid signals! syntax. Expected: signals! { Name { field1: type1, field2: type2, ... } }"
))
}
.into();
}
let name = parts[0].trim();
let mut fields_str = parts[1].trim();
assert!(fields_str.ends_with('}'));
fields_str = &fields_str[0..fields_str.len() - 1];
let parts: Vec<_> = fields_str.split(',').collect();
let mut field_def = Vec::new();
for part in parts {
let mut doc_comment = String::new();
let mut found_field = false;
for line in part.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
if found_field {
eprintln!("{part:?}");
return quote! {
compile_error!(concat!(
"found line after field definition: '",
#line,
"'. Expected: field_name: field_type"
))
}
.into();
}
if line.starts_with("///") {
if !doc_comment.is_empty() {
doc_comment.push('\n');
}
doc_comment.push_str(line);
}
else if line.starts_with("//") {
}
else if line.contains(':') {
let field_parts: Vec<&str> = line.split(':').collect();
if field_parts.len() != 2 {
eprintln!("{part:?}");
return quote! {
compile_error!(concat!(
"Invalid field syntax: '",
#line,
"'. Expected: field_name: field_type"
))
}
.into();
}
let field_name_str = field_parts[0].trim();
let field_type_str = field_parts[1].trim();
field_def.push((doc_comment.clone(), field_name_str, field_type_str));
found_field = true;
} else {
eprintln!("{part:?}");
return quote! {
compile_error!(concat!(
"Invalid field syntax: '",
#line,
"'. Expected: field_name: field_type"
))
}
.into();
}
}
}
let name_ident = syn::Ident::new(name, Span::call_site().into());
let pk_enum_name = format!("{}Kind", name);
let pk_enum_ident = syn::Ident::new(&pk_enum_name, Span::call_site().into());
let mut field_defs = Vec::new();
let mut signal_kinds = Vec::new();
let mut signal_kinds_doc = Vec::new();
let mut signal_name_str = Vec::new();
let mut signal_names = Vec::new();
let mut signal_kind_dtypes = Vec::new();
for (doc_comment_with_slashes, field_name_str, field_type_str) in field_def.iter() {
let doc_comment = if doc_comment_with_slashes.is_empty() {
String::new()
} else {
doc_comment_with_slashes
.lines()
.map(|line| line.trim_start_matches("///").trim())
.collect::<Vec<_>>()
.join("\n")
};
let field_name = syn::Ident::new(field_name_str, Span::call_site().into());
let field_type = syn::parse_str::<syn::Type>(field_type_str).unwrap_or_else(|_| {
panic!("Could not parse type: {}", field_type_str);
});
field_defs.push(quote! {
#[doc = #doc_comment]
pub #field_name: SignalCell<#field_type>
});
let signal_dtype = match *field_type_str {
"bool" => quote!(Bool),
"i64" => quote!(Int64),
"usize" => quote!(Usize),
"f64" => quote!(Float64),
"String" => quote!(String),
_ => {
return quote! {
compile_error!(concat!(
"unsupported nodo signal field type: '",
#field_type_str,
"'. Supported types are: bool, i64, usize, f64, String."
))
}
.into();
}
};
signal_kind_dtypes.push(signal_dtype);
let signal_kind_name = to_camel_case(field_name_str);
let signal_kind_ident = syn::Ident::new(&signal_kind_name, Span::call_site().into());
signal_kinds.push(quote! { #signal_kind_ident });
signal_kinds_doc.push(quote! { #doc_comment });
signal_name_str.push(quote! { #field_name_str });
signal_names.push(field_name);
}
let expanded = quote! {
#[automatically_derived]
#[allow(missing_docs)]
pub struct #name_ident {
#(#field_defs,)*
}
#[automatically_derived]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[allow(missing_docs)]
pub enum #pk_enum_ident {
#(
#[doc = #signal_kinds_doc]
#signal_kinds,
)*
}
impl SignalKind for #pk_enum_ident {
#[inline]
fn list() -> &'static [Self] {
&[
#(
#pk_enum_ident::#signal_kinds,
)*
]
}
#[inline]
fn dtype(&self) -> SignalDataType {
match self {
#(
#pk_enum_ident::#signal_kinds => SignalDataType::#signal_kind_dtypes,
)*
}
}
#[inline]
fn from_str(id: &str) -> Option<Self> {
match id {
#(
#signal_name_str => Some(#pk_enum_ident::#signal_kinds),
)*
_ => None,
}
}
#[inline]
fn as_str(&self) -> &'static str {
match self {
#(
#pk_enum_ident::#signal_kinds => #signal_name_str,
)*
}
}
}
impl Signals for #name_ident {
type Kind = #pk_enum_ident;
#[inline]
fn as_time_value_iter(
&self
) -> impl Iterator<Item = Option<SignalTimeValue>> + ExactSizeIterator {
[
#(
self.#signal_names.anon_time_value(),
)*
].into_iter()
}
#[inline]
fn on_post_execute(&mut self, step_time: Pubtime) {
#(
self.#signal_names.on_post_execute(step_time);
)*
}
}
impl Default for #name_ident {
fn default() -> Self {
Self {
#(
#signal_names: Default::default(),
)*
}
}
}
};
TokenStream::from(expanded)
}