1use proc_macro::TokenStream;
22use proc_macro2::TokenStream as TokenStream2;
23use quote::quote;
24use syn::{
25 parse_macro_input, Attribute, Data, DeriveInput, Error, Fields, Ident,
26 Result as SynResult,
27};
28
29#[derive(Debug, Clone, PartialEq)]
31enum ChannelType {
32 LastValue,
34 Append,
36 Add,
38 Ephemeral,
40 AnyValue,
42 Custom(String),
44}
45
46impl Default for ChannelType {
47 fn default() -> Self {
48 Self::LastValue
49 }
50}
51
52struct FieldConfig {
54 name: Ident,
55 channel_type: ChannelType,
56}
57
58impl FieldConfig {
59 fn from_field(field: &syn::Field) -> SynResult<Option<Self>> {
60 let name = match &field.ident {
61 Some(ident) => ident.clone(),
62 None => return Ok(None), };
64
65 let mut channel_type = ChannelType::default();
66
67 for attr in &field.attrs {
69 if attr.path().is_ident("reducer") {
70 channel_type = Self::parse_reducer_attr(attr)?;
71 } else if attr.path().is_ident("channel") {
72 channel_type = Self::parse_channel_attr(attr)?;
73 }
74 }
75
76 Ok(Some(Self { name, channel_type }))
77 }
78
79 fn parse_reducer_attr(attr: &Attribute) -> SynResult<ChannelType> {
80 let mut result = ChannelType::LastValue;
81
82 attr.parse_nested_meta(|meta| {
83 if meta.path.is_ident("append") {
84 result = ChannelType::Append;
85 } else if meta.path.is_ident("add") {
86 result = ChannelType::Add;
87 } else {
88 let fn_name = meta
90 .path
91 .get_ident()
92 .map(|i| i.to_string())
93 .unwrap_or_else(|| "custom".to_string());
94 result = ChannelType::Custom(fn_name);
95 }
96 Ok(())
97 })?;
98
99 Ok(result)
100 }
101
102 fn parse_channel_attr(attr: &Attribute) -> SynResult<ChannelType> {
103 let mut result = ChannelType::LastValue;
104
105 attr.parse_nested_meta(|meta| {
106 if meta.path.is_ident("ephemeral") {
107 result = ChannelType::Ephemeral;
108 } else if meta.path.is_ident("last_value") {
109 result = ChannelType::LastValue;
110 } else if meta.path.is_ident("any_value") {
111 result = ChannelType::AnyValue;
112 } else if meta.path.is_ident("append") {
113 result = ChannelType::Append;
114 } else if meta.path.is_ident("add") {
115 result = ChannelType::Add;
116 } else {
117 return Err(meta.error(
118 "unknown channel type. Use: ephemeral, last_value, any_value, append, or add",
119 ));
120 }
121 Ok(())
122 })?;
123
124 Ok(result)
125 }
126
127 fn to_channel_spec_tokens(&self) -> TokenStream2 {
128 match &self.channel_type {
129 ChannelType::LastValue => {
130 quote! { regula_core::ChannelSpec::LastValue }
131 }
132 ChannelType::Append => {
133 quote! { regula_core::ChannelSpec::Reducer(regula_core::channel::ReducerType::Append) }
134 }
135 ChannelType::Add => {
136 quote! { regula_core::ChannelSpec::Reducer(regula_core::channel::ReducerType::Add) }
137 }
138 ChannelType::Ephemeral => {
139 quote! { regula_core::ChannelSpec::Ephemeral }
140 }
141 ChannelType::AnyValue => {
142 quote! { regula_core::ChannelSpec::AnyValue }
143 }
144 ChannelType::Custom(name) => {
145 quote! { regula_core::ChannelSpec::Reducer(regula_core::channel::ReducerType::Custom(#name.to_string())) }
146 }
147 }
148 }
149}
150
151#[proc_macro_derive(GraphState, attributes(reducer, channel))]
187pub fn derive_graph_state(input: TokenStream) -> TokenStream {
188 let input = parse_macro_input!(input as DeriveInput);
189
190 match derive_graph_state_impl(input) {
191 Ok(tokens) => TokenStream::from(tokens),
192 Err(err) => TokenStream::from(err.to_compile_error()),
193 }
194}
195
196fn derive_graph_state_impl(input: DeriveInput) -> SynResult<TokenStream2> {
197 let name = &input.ident;
198 let generics = &input.generics;
199 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
200
201 let fields = match &input.data {
203 Data::Struct(data_struct) => match &data_struct.fields {
204 Fields::Named(fields_named) => &fields_named.named,
205 Fields::Unnamed(_) => {
206 return Err(Error::new_spanned(
207 &input.ident,
208 "GraphState can only be derived for structs with named fields",
209 ));
210 }
211 Fields::Unit => {
212 return Err(Error::new_spanned(
213 &input.ident,
214 "GraphState cannot be derived for unit structs",
215 ));
216 }
217 },
218 Data::Enum(_) => {
219 return Err(Error::new_spanned(
220 &input.ident,
221 "GraphState can only be derived for structs, not enums",
222 ));
223 }
224 Data::Union(_) => {
225 return Err(Error::new_spanned(
226 &input.ident,
227 "GraphState can only be derived for structs, not unions",
228 ));
229 }
230 };
231
232 let mut field_configs = Vec::new();
234 for field in fields {
235 if let Some(config) = FieldConfig::from_field(field)? {
236 field_configs.push(config);
237 }
238 }
239
240 let channel_insertions: Vec<TokenStream2> = field_configs
242 .iter()
243 .map(|config| {
244 let field_name = config.name.to_string();
245 let channel_spec = config.to_channel_spec_tokens();
246 quote! {
247 channels.insert(#field_name.to_string(), #channel_spec);
248 }
249 })
250 .collect();
251
252 let field_name_literals: Vec<TokenStream2> = field_configs
254 .iter()
255 .map(|config| {
256 let field_name = config.name.to_string();
257 quote! { #field_name }
258 })
259 .collect();
260
261 let expanded = quote! {
262 impl #impl_generics regula_core::GraphState for #name #ty_generics #where_clause {
263 fn channels() -> std::collections::HashMap<String, regula_core::ChannelSpec> {
264 let mut channels = std::collections::HashMap::new();
265 #(#channel_insertions)*
266 channels
267 }
268
269 fn field_names() -> Vec<&'static str> {
270 vec![#(#field_name_literals),*]
271 }
272 }
273 };
274
275 Ok(expanded)
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn test_channel_type_default() {
284 assert_eq!(ChannelType::default(), ChannelType::LastValue);
285 }
286}