regula_macros/
lib.rs

1//! REGULA Macros - Procedural macros for the REGULA framework.
2//!
3//! This crate provides the `#[derive(GraphState)]` macro for automatically
4//! implementing the `GraphState` trait on structs.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! use regula_macros::GraphState;
10//! use serde::{Serialize, Deserialize};
11//!
12//! #[derive(Clone, GraphState, Serialize, Deserialize)]
13//! struct MyState {
14//!     messages: Vec<String>,
15//!     
16//!     #[reducer(append)]
17//!     history: Vec<String>,
18//! }
19//! ```
20
21use proc_macro::TokenStream;
22use proc_macro2::TokenStream as TokenStream2;
23use quote::quote;
24use syn::{
25    parse_macro_input, Attribute, Data, DeriveInput, Error, Fields, Ident,
26    Result as SynResult,
27};
28
29/// Channel type specification from attributes.
30#[derive(Debug, Clone, PartialEq)]
31enum ChannelType {
32    /// Default: last value semantics
33    LastValue,
34    /// Append reducer for Vec types
35    Append,
36    /// Add reducer for numeric types
37    Add,
38    /// Ephemeral: cleared after each step
39    Ephemeral,
40    /// Any value: last writer wins
41    AnyValue,
42    /// Custom reducer function
43    Custom(String),
44}
45
46impl Default for ChannelType {
47    fn default() -> Self {
48        Self::LastValue
49    }
50}
51
52/// Parse a single field and extract its channel configuration.
53struct FieldConfig {
54    name: Ident,
55    channel_type: ChannelType,
56}
57
58impl FieldConfig {
59    fn from_field(field: &syn::Field) -> SynResult<Option<Self>> {
60        let name = match &field.ident {
61            Some(ident) => ident.clone(),
62            None => return Ok(None), // Skip unnamed fields (tuple structs)
63        };
64
65        let mut channel_type = ChannelType::default();
66
67        // Process attributes
68        for attr in &field.attrs {
69            if attr.path().is_ident("reducer") {
70                channel_type = Self::parse_reducer_attr(attr)?;
71            } else if attr.path().is_ident("channel") {
72                channel_type = Self::parse_channel_attr(attr)?;
73            }
74        }
75
76        Ok(Some(Self { name, channel_type }))
77    }
78
79    fn parse_reducer_attr(attr: &Attribute) -> SynResult<ChannelType> {
80        let mut result = ChannelType::LastValue;
81        
82        attr.parse_nested_meta(|meta| {
83            if meta.path.is_ident("append") {
84                result = ChannelType::Append;
85            } else if meta.path.is_ident("add") {
86                result = ChannelType::Add;
87            } else {
88                // Custom reducer function name
89                let fn_name = meta
90                    .path
91                    .get_ident()
92                    .map(|i| i.to_string())
93                    .unwrap_or_else(|| "custom".to_string());
94                result = ChannelType::Custom(fn_name);
95            }
96            Ok(())
97        })?;
98
99        Ok(result)
100    }
101
102    fn parse_channel_attr(attr: &Attribute) -> SynResult<ChannelType> {
103        let mut result = ChannelType::LastValue;
104
105        attr.parse_nested_meta(|meta| {
106            if meta.path.is_ident("ephemeral") {
107                result = ChannelType::Ephemeral;
108            } else if meta.path.is_ident("last_value") {
109                result = ChannelType::LastValue;
110            } else if meta.path.is_ident("any_value") {
111                result = ChannelType::AnyValue;
112            } else if meta.path.is_ident("append") {
113                result = ChannelType::Append;
114            } else if meta.path.is_ident("add") {
115                result = ChannelType::Add;
116            } else {
117                return Err(meta.error(
118                    "unknown channel type. Use: ephemeral, last_value, any_value, append, or add",
119                ));
120            }
121            Ok(())
122        })?;
123
124        Ok(result)
125    }
126
127    fn to_channel_spec_tokens(&self) -> TokenStream2 {
128        match &self.channel_type {
129            ChannelType::LastValue => {
130                quote! { regula_core::ChannelSpec::LastValue }
131            }
132            ChannelType::Append => {
133                quote! { regula_core::ChannelSpec::Reducer(regula_core::channel::ReducerType::Append) }
134            }
135            ChannelType::Add => {
136                quote! { regula_core::ChannelSpec::Reducer(regula_core::channel::ReducerType::Add) }
137            }
138            ChannelType::Ephemeral => {
139                quote! { regula_core::ChannelSpec::Ephemeral }
140            }
141            ChannelType::AnyValue => {
142                quote! { regula_core::ChannelSpec::AnyValue }
143            }
144            ChannelType::Custom(name) => {
145                quote! { regula_core::ChannelSpec::Reducer(regula_core::channel::ReducerType::Custom(#name.to_string())) }
146            }
147        }
148    }
149}
150
151/// Derive macro for implementing `GraphState` on structs.
152///
153/// This macro generates the `channels()` method based on the struct fields
154/// and any `#[reducer(...)]` or `#[channel(...)]` attributes.
155///
156/// # Attributes
157///
158/// - `#[reducer(append)]`: Use append reducer for Vec fields
159/// - `#[reducer(add)]`: Use add reducer for numeric fields
160/// - `#[reducer(fn_name)]`: Use a custom reducer function
161/// - `#[channel(ephemeral)]`: Mark field as ephemeral (cleared each step)
162/// - `#[channel(last_value)]`: Use last value semantics (default)
163/// - `#[channel(any_value)]`: Allow multiple writes, last writer wins
164///
165/// # Example
166///
167/// ```ignore
168/// #[derive(Clone, GraphState, Serialize, Deserialize)]
169/// struct AgentState {
170///     /// Messages in the conversation (last value semantics)
171///     messages: Vec<Message>,
172///     
173///     /// Tool call history (appended across steps)
174///     #[reducer(append)]
175///     tool_calls: Vec<ToolCall>,
176///     
177///     /// Running total
178///     #[reducer(add)]
179///     total: i32,
180///     
181///     /// Temporary scratch space (cleared each step)
182///     #[channel(ephemeral)]
183///     scratch: Option<String>,
184/// }
185/// ```
186#[proc_macro_derive(GraphState, attributes(reducer, channel))]
187pub fn derive_graph_state(input: TokenStream) -> TokenStream {
188    let input = parse_macro_input!(input as DeriveInput);
189
190    match derive_graph_state_impl(input) {
191        Ok(tokens) => TokenStream::from(tokens),
192        Err(err) => TokenStream::from(err.to_compile_error()),
193    }
194}
195
196fn derive_graph_state_impl(input: DeriveInput) -> SynResult<TokenStream2> {
197    let name = &input.ident;
198    let generics = &input.generics;
199    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
200
201    // Ensure it's a struct with named fields
202    let fields = match &input.data {
203        Data::Struct(data_struct) => match &data_struct.fields {
204            Fields::Named(fields_named) => &fields_named.named,
205            Fields::Unnamed(_) => {
206                return Err(Error::new_spanned(
207                    &input.ident,
208                    "GraphState can only be derived for structs with named fields",
209                ));
210            }
211            Fields::Unit => {
212                return Err(Error::new_spanned(
213                    &input.ident,
214                    "GraphState cannot be derived for unit structs",
215                ));
216            }
217        },
218        Data::Enum(_) => {
219            return Err(Error::new_spanned(
220                &input.ident,
221                "GraphState can only be derived for structs, not enums",
222            ));
223        }
224        Data::Union(_) => {
225            return Err(Error::new_spanned(
226                &input.ident,
227                "GraphState can only be derived for structs, not unions",
228            ));
229        }
230    };
231
232    // Parse each field
233    let mut field_configs = Vec::new();
234    for field in fields {
235        if let Some(config) = FieldConfig::from_field(field)? {
236            field_configs.push(config);
237        }
238    }
239
240    // Generate channel insertions
241    let channel_insertions: Vec<TokenStream2> = field_configs
242        .iter()
243        .map(|config| {
244            let field_name = config.name.to_string();
245            let channel_spec = config.to_channel_spec_tokens();
246            quote! {
247                channels.insert(#field_name.to_string(), #channel_spec);
248            }
249        })
250        .collect();
251
252    // Generate field names for field_names() method
253    let field_name_literals: Vec<TokenStream2> = field_configs
254        .iter()
255        .map(|config| {
256            let field_name = config.name.to_string();
257            quote! { #field_name }
258        })
259        .collect();
260
261    let expanded = quote! {
262        impl #impl_generics regula_core::GraphState for #name #ty_generics #where_clause {
263            fn channels() -> std::collections::HashMap<String, regula_core::ChannelSpec> {
264                let mut channels = std::collections::HashMap::new();
265                #(#channel_insertions)*
266                channels
267            }
268
269            fn field_names() -> Vec<&'static str> {
270                vec![#(#field_name_literals),*]
271            }
272        }
273    };
274
275    Ok(expanded)
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn test_channel_type_default() {
284        assert_eq!(ChannelType::default(), ChannelType::LastValue);
285    }
286}