Skip to main content

flowgentra_ai_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Data, DeriveInput, Fields, ItemFn};
4
5/// Derive the `State` trait for a struct.
6///
7/// Generates:
8/// - A `{Name}Update` struct with all fields wrapped in `Option<T>`
9/// - Builder methods on the update struct for ergonomic partial updates
10/// - `State` trait impl with `type Update` and `apply_update` using per-field reducers
11///
12/// # Attributes
13///
14/// Two equivalent syntaxes are supported for configuring per-field reducers:
15///
16/// **Ident syntax** — `#[reducer(Kind)]`:
17/// - `Overwrite` (default) — replaces the value
18/// - `Append` — extends `Vec<T>` fields
19/// - `Sum` — adds numeric fields
20/// - `MergeMap` — merges `HashMap` fields
21///
22/// **String syntax** — `#[state(reducer = "kind")]` (LangGraph-style):
23/// - `"overwrite"` / `"replace"` / `"last_value"` → `Overwrite`
24/// - `"append"` / `"topic"` → `Append`
25/// - `"sum"` → `Sum`
26/// - `"merge_map"` / `"merge"` → `MergeMap`
27///
28/// # Example
29///
30/// ```ignore
31/// use flowgentra_ai::prelude::*;
32/// use serde::{Serialize, Deserialize};
33///
34/// #[derive(State, Clone, Debug, Serialize, Deserialize)]
35/// struct AgentState {
36///     query: String,
37///
38///     // Ident syntax
39///     #[reducer(Append)]
40///     messages: Vec<Message>,
41///
42///     result: Option<String>,
43///
44///     // String syntax
45///     #[state(reducer = "sum")]
46///     retry_count: i32,
47/// }
48///
49/// // Nodes return partial updates:
50/// let update = AgentStateUpdate::new()
51///     .result(Some("done".into()));
52/// ```
53#[proc_macro_derive(State, attributes(reducer, state))]
54pub fn derive_state(input: TokenStream) -> TokenStream {
55    let input = parse_macro_input!(input as DeriveInput);
56    let struct_name = &input.ident;
57    let vis = &input.vis;
58
59    // Use the absolute crate path. `crate::` in generated code resolves to the
60    // *calling* crate (where the derive is used), not to `flowgentra_ai` itself,
61    // so we must use `::flowgentra_ai::` to reference library types correctly.
62    let flowgentra_ai_crate = quote! { ::flowgentra_ai };
63
64    let fields = match &input.data {
65        Data::Struct(data_struct) => match &data_struct.fields {
66            Fields::Named(fields_named) => &fields_named.named,
67            _ => panic!("#[derive(State)] requires a struct with named fields"),
68        },
69        _ => panic!("#[derive(State)] requires a struct"),
70    };
71
72    let update_name = format_ident!("{}Update", struct_name);
73
74    let mut update_field_defs = Vec::new();
75    let mut update_defaults = Vec::new();
76    let mut setters = Vec::new();
77    let mut apply_arms = Vec::new();
78
79    for field in fields {
80        let name = field.ident.as_ref().unwrap();
81        let ty = &field.ty;
82
83        // Parse reducer from either #[reducer(Kind)] or #[state(reducer = "kind")].
84        // Default to Overwrite.
85        let mut reducer_path = quote! { #flowgentra_ai_crate::core::reducer::Overwrite };
86        for attr in &field.attrs {
87            // ── #[reducer(Kind)] ─────────────────────────────────────────────
88            if attr.path().is_ident("reducer") {
89                if let Ok(ident) = attr.parse_args::<syn::Ident>() {
90                    reducer_path = quote! { #flowgentra_ai_crate::core::reducer::#ident };
91                }
92                break;
93            }
94            // ── #[state(reducer = "string")] ─────────────────────────────────
95            if attr.path().is_ident("state") {
96                // parse as a key = "value" meta
97                let _ = attr.parse_nested_meta(|meta| {
98                    if meta.path.is_ident("reducer") {
99                        let value: syn::LitStr = meta.value()?.parse()?;
100                        let kind = match value.value().to_lowercase().as_str() {
101                            "append" | "topic" => "Append",
102                            "sum" => "Sum",
103                            "merge_map" | "merge" => "MergeMap",
104                            // "overwrite" | "replace" | "last_value" | _ → Overwrite (default)
105                            _ => "Overwrite",
106                        };
107                        let kind_ident = format_ident!("{}", kind);
108                        reducer_path = quote! { #flowgentra_ai_crate::core::reducer::#kind_ident };
109                    }
110                    Ok(())
111                });
112                break;
113            }
114        }
115
116        // Update struct: all fields are Option<T>
117        update_field_defs.push(quote! {
118            #[serde(default, skip_serializing_if = "Option::is_none")]
119            pub #name: Option<#ty>
120        });
121        update_defaults.push(quote! { #name: None });
122
123        // Builder setter
124        setters.push(quote! {
125            pub fn #name(mut self, value: #ty) -> Self {
126                self.#name = Some(value);
127                self
128            }
129        });
130
131        // apply_update arm: apply reducer if field is Some
132        apply_arms.push(quote! {
133            if let Some(value) = update.#name {
134                <#reducer_path as #flowgentra_ai_crate::core::reducer::Reducer<#ty>>::merge(
135                    &mut self.#name,
136                    value,
137                );
138            }
139        });
140    }
141
142    let expanded = quote! {
143        /// Partial update struct — all fields are `Option<T>`.
144        ///
145        /// Nodes return this to indicate which fields changed.
146        /// Only `Some` fields are applied via their configured reducer.
147        #[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
148        #vis struct #update_name {
149            #(#update_field_defs,)*
150        }
151
152        impl #update_name {
153            /// Create an empty update (all fields `None`).
154            pub fn new() -> Self {
155                Self {
156                    #(#update_defaults,)*
157                }
158            }
159
160            #(#setters)*
161        }
162
163        impl #flowgentra_ai_crate::core::state::State for #struct_name {
164            type Update = #update_name;
165
166            fn apply_update(&mut self, update: Self::Update) {
167                #(#apply_arms)*
168            }
169        }
170    };
171
172    TokenStream::from(expanded)
173}
174
175/// Register a handler function for automatic discovery.
176///
177/// This attribute macro registers the decorated async function to a global inventory,
178/// enabling automatic discovery when creating agents with `from_config_path()`.
179///
180/// The handler name in the inventory is automatically derived from the function name.
181/// Make sure your `config.yaml` references the handler by its function name.
182///
183/// # Requirements
184///
185/// Decorated function must be:
186/// - `pub async fn`
187/// - Takes `(&S, &Context)` parameters
188/// - Returns `Result<S::Update>`
189///
190/// # Crate name requirement
191///
192/// The generated code references `::flowgentra_ai` as an absolute crate path.
193/// Your `Cargo.toml` dependency **must** be named `flowgentra-ai` (the default).
194/// If you alias it (e.g. `flowgentra_ai = { package = "flowgentra-ai" }`), the
195/// generated code will still work because the crate name in code is `flowgentra_ai`.
196///
197/// # Example
198///
199/// ```ignore
200/// #[register_handler]
201/// pub async fn validate_input(state: &MyState, ctx: &Context) -> Result<MyStateUpdate> {
202///     Ok(MyStateUpdate::new().valid(!state.input.is_empty()))
203/// }
204/// ```
205#[proc_macro_attribute]
206pub fn register_handler(_attr: TokenStream, item: TokenStream) -> TokenStream {
207    let input = parse_macro_input!(item as ItemFn);
208
209    let func_name = &input.sig.ident;
210    let handler_name = func_name.to_string();
211
212    let expanded = quote! {
213        #input
214
215        inventory::submit! {
216            // Use the absolute path (leading `::`) so this compiles correctly
217            // regardless of how the calling crate imports `flowgentra_ai`.
218            ::flowgentra_ai::core::agent::HandlerEntry::new(
219                #handler_name,
220                ::std::sync::Arc::new(|state, ctx| ::std::boxed::Box::pin(#func_name(state, ctx)))
221            )
222        }
223    };
224
225    TokenStream::from(expanded)
226}
227
228/// Attribute macro to turn an `async fn` into a `FunctionNode`.
229///
230/// Generates a factory function `{name}_node()` that returns
231/// `Arc<FunctionNode<S, _>>` for use with `StateGraphBuilder::add_node`.
232///
233/// If the function has **no return type**, the macro infers it from the first
234/// parameter (`state: &MyState` → `Result<MyStateUpdate>`) and wraps the last
235/// expression in `Ok(...)` automatically.
236///
237/// # Example — minimal form (no return type, no `Ok`)
238///
239/// ```ignore
240/// use flowgentra_ai_macros::node;
241///
242/// #[node]
243/// async fn summarize(state: &MyState, _ctx: &Context) {
244///     update! { summary: "done".into() }
245/// }
246///
247/// // Use: graph.add_node("summarize", summarize_node())
248/// ```
249///
250/// # Example — explicit form (full control)
251///
252/// ```ignore
253/// #[node]
254/// async fn summarize(state: &MyState, _ctx: &Context) -> Result<MyStateUpdate> {
255///     Ok(update! { summary: "done".into() })
256/// }
257/// ```
258#[proc_macro_attribute]
259pub fn node(_attr: TokenStream, item: TokenStream) -> TokenStream {
260    let mut input = parse_macro_input!(item as ItemFn);
261    let func_name = &input.sig.ident;
262    let node_factory_name = format_ident!("{}_node", func_name);
263    let node_name_str = func_name.to_string();
264
265    // If no return type is written, infer it from `state: &MyState` → `Result<MyStateUpdate>`
266    // and wrap the last expression in `Ok(...)`.
267    if matches!(input.sig.output, syn::ReturnType::Default) {
268        if let Some(update_ident) = extract_update_ident(&input.sig.inputs) {
269            input.sig.output = syn::parse_quote! {
270                -> flowgentra_ai::core::state_graph::error::Result<#update_ident>
271            };
272
273            // Wrap the tail expression (last stmt with no semicolon) in Ok(...)
274            if let Some(syn::Stmt::Expr(expr, None)) = input.block.stmts.last_mut() {
275                let inner = expr.clone();
276                *expr = syn::parse_quote! { Ok(#inner) };
277            }
278        }
279    }
280
281    let expanded = quote! {
282        #input
283
284        /// Auto-generated node factory for use with `StateGraphBuilder::add_node`.
285        pub fn #node_factory_name() -> ::std::sync::Arc<
286            ::flowgentra_ai::core::state_graph::node::FunctionNode<
287                _,
288                impl Fn(&_, &::flowgentra_ai::core::state::Context) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::flowgentra_ai::core::state_graph::error::Result<_>> + Send>> + Send + Sync,
289            >
290        > {
291            ::std::sync::Arc::new(::flowgentra_ai::core::state_graph::node::FunctionNode::new(
292                #node_name_str,
293                |state, ctx| ::std::boxed::Box::pin(#func_name(state, ctx)),
294            ))
295        }
296    };
297
298    TokenStream::from(expanded)
299}
300
301/// Extract `MyStateUpdate` ident from `state: &MyState` (first function parameter).
302fn extract_update_ident(
303    inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>,
304) -> Option<syn::Ident> {
305    let first = inputs.first()?;
306    let syn::FnArg::Typed(pat_type) = first else {
307        return None;
308    };
309    let syn::Type::Reference(type_ref) = pat_type.ty.as_ref() else {
310        return None;
311    };
312    let syn::Type::Path(type_path) = type_ref.elem.as_ref() else {
313        return None;
314    };
315    let state_ident = &type_path.path.segments.last()?.ident;
316    Some(format_ident!("{}Update", state_ident))
317}