ceres_macros/
lib.rs

1//ceres-dsp/ceres-macros/src/lib.rs
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{parse_macro_input, DeriveInput, Data, Fields};
5use syn::spanned::Spanned;
6
7#[proc_macro_attribute]
8pub fn parameters(_args: TokenStream, input: TokenStream) -> TokenStream {
9    let input = parse_macro_input!(input as DeriveInput);
10
11    let struct_name = &input.ident;
12    let runtime_name = syn::Ident::new(&format!("{}Runtime", struct_name), struct_name.span());
13    let accessor_name = syn::Ident::new(&format!("{}Accessor", struct_name), struct_name.span());
14    
15    let fields = match &input.data {
16        Data::Struct(data) => match &data.fields {
17            Fields::Named(fields) => &fields.named,
18            _ => return syn::Error::new(struct_name.span(), "Only named fields supported")
19                .to_compile_error().into(),
20        },
21        _ => return syn::Error::new(struct_name.span(), "Only structs supported")
22            .to_compile_error().into(),
23    };
24    
25    // Validate f32 fields
26    for field in fields.iter() {
27        let field_name = field.ident.as_ref().unwrap();
28        if let syn::Type::Path(type_path) = &field.ty {
29            if let Some(segment) = type_path.path.segments.last() {
30                if segment.ident != "f32" {
31                    return syn::Error::new(
32                        field.span(), 
33                        format!("Parameter field '{}' must be f32", field_name)
34                    ).to_compile_error().into();
35                }
36            }
37        } else {
38            return syn::Error::new(
39                field.span(), 
40                format!("Parameter field '{}' must be f32", field_name)
41            ).to_compile_error().into();
42        }
43    }
44    
45    let field_names: Vec<_> = fields.iter().map(|f| &f.ident).collect();
46    
47    // Generate modulation field names
48    let mod_field_names: Vec<_> = field_names.iter().map(|name| {
49        syn::Ident::new(&format!("{}_modulation", name.as_ref().unwrap()), name.span())
50    }).collect();
51    
52    let mod_fields = mod_field_names.iter().map(|mod_name| {
53        quote! { #mod_name: Option<::ceres::ModulationRouting> }
54    });
55    
56    // Generate route methods
57    let route_methods = field_names.iter().zip(mod_field_names.iter()).map(|(name, mod_name)| {
58        let method_name = syn::Ident::new(&format!("route_{}", name.as_ref().unwrap()), name.span());
59        quote! {
60            fn #method_name(&mut self, source_index: usize, amount: f32) {
61                self.#mod_name = Some(::ceres::ModulationRouting { source_index, amount });
62            }
63        }
64    });
65    
66    // Generate route_parameter match arms
67    let route_arms = field_names.iter().zip(mod_field_names.iter()).map(|(name, _)| {
68        let name_str = name.as_ref().unwrap().to_string();
69        let method_name = syn::Ident::new(&format!("route_{}", name.as_ref().unwrap()), name.span());
70        quote! { #name_str => self.#method_name(source_index, amount) }
71    });
72    
73    // Generate update logic
74    let update_fields = field_names.iter().zip(mod_field_names.iter()).map(|(name, mod_name)| {
75        quote! {
76            let #name = self.#mod_name
77                .as_ref()
78                .map(|routing| {
79                    let modulator_value = sources[routing.source_index].get_value(i);
80                    modulator_value * routing.amount
81                })
82                .unwrap_or(0.0);
83            let #name = (self.base.#name + #name).clamp(0.0, 1.0);
84        }
85    });
86    
87    let expanded = quote! {
88        #[derive(Clone, Copy, Default)]
89        #input
90        
91        struct #runtime_name<E> {
92            base: #struct_name,
93            #(#mod_fields,)*
94            computed_values: [#struct_name; ::ceres::BUFFER_SIZE],
95        }
96        
97        impl<E> #runtime_name<E> {
98            fn new() -> Self {
99                let base = #struct_name::default();
100                Self {
101                    base,
102                    #(#mod_field_names: None,)*
103                    computed_values: [base; ::ceres::BUFFER_SIZE],
104                }
105            }
106            
107            #(#route_methods)*
108        }
109        
110        impl<E: Send + 'static> ::ceres::ParameterRuntime<E> for #runtime_name<E> {
111            fn update(&mut self, sources: &[Box<dyn ::ceres::Modulator<E>>]) {
112                for i in 0..::ceres::BUFFER_SIZE {
113                    #(#update_fields)*
114                    self.computed_values[i] = #struct_name {
115                        #(#field_names: #field_names),*
116                    };
117                }
118            }
119            
120            fn route_parameter(&mut self, param_name: &str, source_index: usize, amount: f32) {
121                match param_name {
122                    #(#route_arms,)*
123                    _ => {}
124                }
125            }
126        }
127        
128        struct #accessor_name<'a> {
129            values: &'a [#struct_name; ::ceres::BUFFER_SIZE],
130        }
131        
132        impl<'a> #accessor_name<'a> {
133            fn new(values: &'a [#struct_name; ::ceres::BUFFER_SIZE]) -> Self {
134                Self { values }
135            }
136        }
137        
138        impl<'a> std::ops::Index<usize> for #accessor_name<'a> {
139            type Output = #struct_name;
140            fn index(&self, index: usize) -> &Self::Output {
141                &self.values[index % ::ceres::BUFFER_SIZE]
142            }
143        }
144        
145        impl ::ceres::Parameters for #struct_name {
146            type Runtime<E: Send + 'static> = #runtime_name<E>;
147            type Accessor<'a, E> = #accessor_name<'a> where E: 'a;
148            type Values = #struct_name;
149            
150            fn create_runtime<E: Send>() -> Self::Runtime<E> {
151                #runtime_name::new()
152            }
153            
154            fn create_accessor<E: Send>(runtime: &Self::Runtime<E>) -> Self::Accessor<'_, E> {
155                #accessor_name::new(&runtime.computed_values)
156            }
157        }
158    };
159    
160    TokenStream::from(expanded)
161}