use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens};
use syn::parse::Parser;
use syn::punctuated::Punctuated;
use syn::{
parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Expr, Fields, FnArg,
GenericArgument, Ident, Item, ItemFn, ItemMod, Lit, Meta, MetaList, MetaNameValue, Pat, Path,
PathArguments, ReturnType, Signature, Token, Type, TypePath,
};
#[proc_macro_attribute]
pub fn reactor(attr: TokenStream, item: TokenStream) -> TokenStream {
let metas = parse_macro_input!(attr with Punctuated::<Meta, Token![,]>::parse_terminated);
let input_fn = parse_macro_input!(item as ItemFn);
match expand_effect(parse_effect_args(&metas), input_fn) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn reactors(_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut module = parse_macro_input!(item as ItemMod);
match expand_effects_module(&mut module) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn aggregator(attr: TokenStream, item: TokenStream) -> TokenStream {
let metas = parse_macro_input!(attr with Punctuated::<Meta, Token![,]>::parse_terminated);
let input_fn = parse_macro_input!(item as ItemFn);
match expand_aggregator(&metas, input_fn) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn aggregators(attr: TokenStream, item: TokenStream) -> TokenStream {
let module_metas =
parse_macro_input!(attr with Punctuated::<Meta, Token![,]>::parse_terminated);
let mut module = parse_macro_input!(item as ItemMod);
match expand_aggregators_module(&module_metas, &mut module) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn projection(attr: TokenStream, item: TokenStream) -> TokenStream {
let metas = parse_macro_input!(attr with Punctuated::<Meta, Token![,]>::parse_terminated);
let input_fn = parse_macro_input!(item as ItemFn);
match expand_projection(&metas, input_fn) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
#[derive(Clone)]
enum OnSpec {
EventType(Path),
Variants(Vec<Path>),
MultiType(Vec<Path>),
}
impl OnSpec {
fn event_type(&self) -> syn::Result<Path> {
match self {
OnSpec::EventType(path) => Ok(path.clone()),
OnSpec::Variants(variants) => {
let Some(first) = variants.first() else {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"on = [] is not supported",
));
};
let base = variant_base_path(first)?;
let base_key = path_key(&base);
for variant in variants.iter().skip(1) {
let candidate = variant_base_path(variant)?;
if path_key(&candidate) != base_key {
return Err(syn::Error::new_spanned(
variant,
"all variants in on = [...] must belong to the same enum",
));
}
}
Ok(base)
}
OnSpec::MultiType(_) => Err(syn::Error::new(
proc_macro2::Span::call_site(),
"multi-type on = [...] has no single event type",
)),
}
}
}
#[derive(Default)]
struct EffectArgs {
on: Option<OnSpec>,
on_any: bool,
extract: Vec<Ident>,
filter: Option<Path>,
queued: bool,
id: Option<String>,
dlq_terminal: Option<Path>,
retry: Option<u32>,
timeout_secs: Option<u64>,
timeout_ms: Option<u64>,
delay_secs: Option<u64>,
delay_ms: Option<u64>,
priority: Option<i32>,
group: Option<String>,
aggregate: Option<Path>,
transition: Option<syn::ExprClosure>,
describe: Option<Path>,
}
struct ParamInfo {
ident: Ident,
ty: Type,
}
fn result_to_events(kind: &ReturnKind) -> TokenStream2 {
match kind {
ReturnKind::Unit => quote! { ::causal::Events::new() },
ReturnKind::Events => quote! { __result },
ReturnKind::Emit => quote! { ::causal::IntoEvents::into_events(__result) },
ReturnKind::SingleEvent => quote! { {
let mut __ev = ::causal::Events::new();
__ev.push(__result);
__ev
} },
}
}
fn expand_effect(args: syn::Result<EffectArgs>, input_fn: ItemFn) -> syn::Result<TokenStream2> {
let args = args?;
let fn_ident = input_fn.sig.ident.clone();
let wrapper_ident = format_ident!("__causal_effect_{}", fn_ident);
if input_fn.sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(
&input_fn.sig.fn_token,
"#[reactor] requires an async function",
));
}
if !input_fn.sig.generics.params.is_empty() {
return Err(syn::Error::new_spanned(
&input_fn.sig.generics,
"#[reactor] does not support generic functions",
));
}
let return_kind = classify_return(&input_fn.sig)?;
let convert_result = result_to_events(&return_kind);
if args.timeout_secs.is_some() && args.timeout_ms.is_some() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"cannot specify both timeout_secs and timeout_ms",
));
}
if args.delay_secs.is_some() && args.delay_ms.is_some() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"cannot specify both delay_secs and delay_ms",
));
}
if args.filter.is_some() && !args.extract.is_empty() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"filter and extract are mutually exclusive (both consume the filter type-state slot)",
));
}
if args.describe.is_some() && args.filter.is_none() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"describe requires filter (describe is only available on FilteredReactorBuilder)",
));
}
if effect_requires_stable_id(&args) && args.id.is_none() && args.group.is_none() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"queued/durable #[reactor] requires an explicit id = \"...\" (or group = \"...\")",
));
}
if args.on_any {
if args.on.is_some() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"on_any and on = ... are mutually exclusive",
));
}
if !args.extract.is_empty()
|| args.filter.is_some()
|| args.transition.is_some()
|| args.aggregate.is_some()
{
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"on_any cannot be combined with extract, filter, transition, or aggregate",
));
}
let (ctx_idx, deps_ty) = find_effect_context(&input_fn.sig)?;
let params = collect_params(&input_fn.sig)?;
let non_ctx_params: Vec<ParamInfo> = params
.into_iter()
.enumerate()
.filter_map(|(idx, param)| if idx == ctx_idx { None } else { Some(param) })
.collect();
if non_ctx_params.len() != 1 {
return Err(syn::Error::new_spanned(
&input_fn.sig.inputs,
"on_any reactor requires exactly one AnyEvent parameter plus Context",
));
}
let event_ident = &non_ctx_params[0].ident;
let input_builder = quote! { ::causal::on_any() };
let builder = apply_on_any_config(input_builder, &args, &fn_ident);
let chain = quote! {
#builder
.then::<#deps_ty, _, _>(|#event_ident, __causal_ctx| async move {
let __result = #fn_ident(#event_ident, __causal_ctx).await?;
Ok(#convert_result)
})
};
return Ok(quote! {
#input_fn
#[doc(hidden)]
pub fn #wrapper_ident() -> ::causal::Reactor<#deps_ty> {
#chain
}
});
}
let on = args.on.clone().ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
"#[reactor] requires on = EventType, on = [Enum::Variant, ...], or on_any",
)
})?;
if let OnSpec::MultiType(ref types) = on {
return expand_multi_type_effect(&args, &input_fn, types, &return_kind);
}
let on_event_type = on.event_type()?;
let (ctx_idx, deps_ty) = find_effect_context(&input_fn.sig)?;
let params = collect_params(&input_fn.sig)?;
let non_ctx_params: Vec<ParamInfo> = params
.into_iter()
.enumerate()
.filter_map(|(idx, param)| if idx == ctx_idx { None } else { Some(param) })
.collect();
let input_builder = quote! { ::causal::on::<#on_event_type>() };
let builder = apply_effect_config(input_builder, &args, &fn_ident);
if args.aggregate.is_some() != args.transition.is_some() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"aggregate and transition must be specified together",
));
}
let chain = if args.aggregate.is_some() && args.transition.is_some() {
let aggregate_ty = args.aggregate.as_ref().unwrap();
let transition_closure = args.transition.as_ref().unwrap();
if args.extract.len() != 1 {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"transition reactors require extract() with exactly one Uuid field (the aggregate ID)",
));
}
if non_ctx_params.len() != 1 {
return Err(syn::Error::new_spanned(
&input_fn.sig.inputs,
"transition reactor function takes (aggregate_id: Uuid, ctx: Context<Deps>)",
));
}
let field = &args.extract[0];
let param_ident = &non_ctx_params[0].ident;
let extract_call = match &on {
OnSpec::EventType(_) => {
quote! {
.extract(|__causal_event| Some(__causal_event.#field.clone()))
}
}
OnSpec::Variants(variants) => {
let match_arms = variants.iter().map(|variant| {
quote! {
#variant { #field, .. } => Some(#field.clone()),
}
});
quote! {
.extract(|__causal_event| match __causal_event {
#(#match_arms)*
#[allow(unreachable_patterns)]
_ => None,
})
}
}
OnSpec::MultiType(_) => unreachable!("multi-type validated earlier"),
};
quote! {
#builder
#extract_call
.transition::<#aggregate_ty, _>(#transition_closure)
.then::<#deps_ty, _, _>(|#param_ident, __causal_ctx| async move {
let __result = #fn_ident(#param_ident, __causal_ctx).await?;
Ok(#convert_result)
})
}
} else if !args.extract.is_empty() {
if non_ctx_params.len() != args.extract.len() {
return Err(syn::Error::new_spanned(
&input_fn.sig.inputs,
"effect function parameters must match extract(...) fields",
));
}
for (param, extracted) in non_ctx_params.iter().zip(args.extract.iter()) {
if param.ident != *extracted {
return Err(syn::Error::new_spanned(
¶m.ident,
format!(
"parameter `{}` must match extracted field `{}`",
param.ident, extracted
),
));
}
}
let fields = &args.extract;
let extract_call = match &on {
OnSpec::EventType(_) => {
if fields.len() == 1 {
let field = &fields[0];
quote! {
.extract(|__causal_event| Some(__causal_event.#field.clone()))
}
} else {
quote! {
.extract(|__causal_event| Some((#(__causal_event.#fields.clone()),*)))
}
}
}
OnSpec::Variants(variants) => {
let match_arms = variants.iter().map(|variant| {
if fields.len() == 1 {
let field = &fields[0];
quote! {
#variant { #field, .. } => Some(#field.clone()),
}
} else {
quote! {
#variant { #(#fields),*, .. } => Some((#(#fields.clone()),*)),
}
}
});
quote! {
.extract(|__causal_event| match __causal_event {
#(#match_arms)*
#[allow(unreachable_patterns)]
_ => None,
})
}
}
OnSpec::MultiType(_) => unreachable!("multi-type validated earlier"),
};
if fields.len() == 1 {
let field = &fields[0];
let extracted_ty = &non_ctx_params[0].ty;
quote! {
#builder
#extract_call
.then::<#deps_ty, #extracted_ty, _, _>(|#field, __causal_ctx| async move {
let __result = #fn_ident(#field, __causal_ctx).await?;
Ok(#convert_result)
})
}
} else {
let extracted_tys: Vec<&Type> = non_ctx_params.iter().map(|param| ¶m.ty).collect();
quote! {
#builder
#extract_call
.then::<#deps_ty, (#(#extracted_tys),*), _, _>(|(#(#fields),*), __causal_ctx| async move {
let __result = #fn_ident(#(#fields),*, __causal_ctx).await?;
Ok(#convert_result)
})
}
}
} else {
if non_ctx_params.len() != 1 {
return Err(syn::Error::new_spanned(
&input_fn.sig.inputs,
"#[reactor] requires exactly one event parameter plus Context when extract(...) is not used",
));
}
if !matches!(on, OnSpec::EventType(_)) {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"on = [...] requires extract(...)",
));
}
let event_param = &non_ctx_params[0];
if type_key(&event_param.ty) != path_key(&on_event_type) {
return Err(syn::Error::new_spanned(
&event_param.ty,
"event parameter type must match on = EventType",
));
}
let event_ident = &event_param.ident;
if args.filter.is_some() {
quote! {
#builder
.then(|#event_ident: ::std::sync::Arc<#on_event_type>, __causal_ctx: ::causal::Context<#deps_ty>| async move {
let __result = #fn_ident((#event_ident).as_ref().clone(), __causal_ctx).await?;
Ok(#convert_result)
})
}
} else {
quote! {
#builder
.then::<#deps_ty, ::std::sync::Arc<#on_event_type>, _, _>(|#event_ident, __causal_ctx| async move {
let __result = #fn_ident((#event_ident).as_ref().clone(), __causal_ctx).await?;
Ok(#convert_result)
})
}
}
};
Ok(quote! {
#input_fn
#[doc(hidden)]
pub fn #wrapper_ident() -> ::causal::Reactor<#deps_ty> {
#chain
}
})
}
fn expand_multi_type_effect(
args: &EffectArgs,
input_fn: &ItemFn,
types: &[Path],
return_kind: &ReturnKind,
) -> syn::Result<TokenStream2> {
let fn_ident = &input_fn.sig.ident;
let wrapper_ident = format_ident!("__causal_effect_{}", fn_ident);
let convert_result = result_to_events(return_kind);
if !args.extract.is_empty() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"extract is not supported with multi-type on = [TypeA, TypeB]",
));
}
if args.aggregate.is_some() || args.transition.is_some() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"aggregate/transition is not supported with multi-type on = [TypeA, TypeB]",
));
}
if args.describe.is_some() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"describe is not supported with multi-type on = [TypeA, TypeB]",
));
}
let (ctx_idx, deps_ty) = find_effect_context(&input_fn.sig)?;
let params = collect_params(&input_fn.sig)?;
let non_ctx_params: Vec<ParamInfo> = params
.into_iter()
.enumerate()
.filter_map(|(idx, param)| if idx == ctx_idx { None } else { Some(param) })
.collect();
if non_ctx_params.len() != 1 {
return Err(syn::Error::new_spanned(
&input_fn.sig.inputs,
"multi-type reactor requires exactly one AnyEvent parameter plus Context",
));
}
let event_ident = &non_ctx_params[0].ident;
let args_no_filter = EffectArgs {
on: args.on.clone(),
on_any: args.on_any,
extract: args.extract.clone(),
filter: None,
queued: args.queued,
id: args.id.clone(),
dlq_terminal: args.dlq_terminal.clone(),
retry: args.retry,
timeout_secs: args.timeout_secs,
timeout_ms: args.timeout_ms,
delay_secs: args.delay_secs,
delay_ms: args.delay_ms,
priority: args.priority,
group: args.group.clone(),
aggregate: args.aggregate.clone(),
transition: None, describe: None,
};
let handler_exprs: Vec<TokenStream2> = types
.iter()
.map(|event_type| {
let type_suffix = event_type.segments.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
let mut per_type_args = EffectArgs {
on: args_no_filter.on.clone(),
on_any: args_no_filter.on_any,
extract: args_no_filter.extract.clone(),
filter: None,
queued: args_no_filter.queued,
id: args_no_filter.id.as_ref().map(|id| format!("{id}::{type_suffix}")),
dlq_terminal: args_no_filter.dlq_terminal.clone(),
retry: args_no_filter.retry,
timeout_secs: args_no_filter.timeout_secs,
timeout_ms: args_no_filter.timeout_ms,
delay_secs: args_no_filter.delay_secs,
delay_ms: args_no_filter.delay_ms,
priority: args_no_filter.priority,
group: args_no_filter.group.clone(),
aggregate: None,
transition: None,
describe: None,
};
if per_type_args.id.is_none() && per_type_args.group.is_none() {
per_type_args.id = Some(format!("{}::{}", fn_ident, type_suffix));
}
let base_builder = quote! { ::causal::on::<#event_type>() };
let builder = apply_effect_config(base_builder, &per_type_args, fn_ident);
if let Some(filter_fn) = &args.filter {
quote! {
#builder
.filter(|_: &#event_type, __causal_ctx: &::causal::Context<#deps_ty>| #filter_fn(__causal_ctx))
.then(|#event_ident: ::std::sync::Arc<#event_type>, __causal_ctx: ::causal::Context<#deps_ty>| async move {
let #event_ident = ::causal::AnyEvent {
value: #event_ident as ::std::sync::Arc<dyn ::std::any::Any + Send + Sync>,
type_id: ::std::any::TypeId::of::<#event_type>(),
};
let __result = #fn_ident(#event_ident, __causal_ctx).await?;
Ok(#convert_result)
})
}
} else {
quote! {
#builder
.then::<#deps_ty, ::std::sync::Arc<#event_type>, _, _>(|#event_ident, __causal_ctx| async move {
let #event_ident = ::causal::AnyEvent {
value: #event_ident as ::std::sync::Arc<dyn ::std::any::Any + Send + Sync>,
type_id: ::std::any::TypeId::of::<#event_type>(),
};
let __result = #fn_ident(#event_ident, __causal_ctx).await?;
Ok(#convert_result)
})
}
}
})
.collect();
Ok(quote! {
#input_fn
#[doc(hidden)]
pub fn #wrapper_ident() -> ::std::vec::Vec<::causal::Reactor<#deps_ty>> {
::std::vec![#(#handler_exprs),*]
}
})
}
fn expand_effects_module(module: &mut ItemMod) -> syn::Result<TokenStream2> {
let Some((_, items)) = &mut module.content else {
return Err(syn::Error::new_spanned(
module,
"#[reactors] requires an inline module",
));
};
let mut wrappers = Vec::new();
let mut multi_wrappers = Vec::new(); let mut projection_wrappers = Vec::new();
let mut deps_ty: Option<Type> = None;
let mut inferred_fns = Vec::new();
for item in items.iter() {
let Item::Fn(item_fn) = item else {
continue;
};
let has_handle_attr = has_attr_any(&item_fn.attrs, &["handle", "reactor"]);
let has_projection_attr = has_attr(&item_fn.attrs, "projection");
if has_projection_attr {
let wrapper_ident = format_ident!("__causal_projection_{}", item_fn.sig.ident);
projection_wrappers.push(wrapper_ident);
} else if has_handle_attr {
let wrapper_ident = format_ident!("__causal_effect_{}", item_fn.sig.ident);
if is_multi_type_handle(&item_fn.attrs) {
multi_wrappers.push(wrapper_ident);
} else {
wrappers.push(wrapper_ident);
}
} else if item_fn.sig.asyncness.is_some() {
let wrapper_ident = format_ident!("__causal_effect_{}", item_fn.sig.ident);
wrappers.push(wrapper_ident);
let fn_name = item_fn.sig.ident.to_string();
let (ctx_idx, _) = find_effect_context(&item_fn.sig)?;
let params = collect_params(&item_fn.sig)?;
let non_ctx_params: Vec<&ParamInfo> = params
.iter()
.enumerate()
.filter_map(|(idx, p)| if idx == ctx_idx { None } else { Some(p) })
.collect();
if non_ctx_params.len() != 1 {
return Err(syn::Error::new_spanned(
&item_fn.sig.inputs,
"bare reactor function requires exactly one event parameter plus Context",
));
}
let event_ty = &non_ctx_params[0].ty;
let event_path: Path = syn::parse2(quote! { #event_ty })?;
let mut args = EffectArgs::default();
args.on = Some(OnSpec::EventType(event_path));
args.id = Some(fn_name);
let expanded = expand_effect(Ok(args), item_fn.clone())?;
inferred_fns.push((item_fn.sig.ident.to_string(), expanded));
}
if let Ok((_, deps)) = find_effect_context(&item_fn.sig) {
match &deps_ty {
None => deps_ty = Some(deps),
Some(existing_deps) => {
if type_key(existing_deps) != type_key(&deps) {
return Err(syn::Error::new_spanned(
&item_fn.sig,
"all reactors in a #[reactors] module must use the same Context<Deps>",
));
}
}
}
}
}
if wrappers.is_empty() && multi_wrappers.is_empty() && projection_wrappers.is_empty() {
return Err(syn::Error::new_spanned(
module,
"#[reactors] module must contain at least one reactor or projection function",
));
}
let inferred_names: Vec<&str> = inferred_fns.iter().map(|(n, _)| n.as_str()).collect();
items.retain(|item| {
if let Item::Fn(item_fn) = item {
!inferred_names.contains(&item_fn.sig.ident.to_string().as_str())
} else {
true
}
});
let inferred_tokens: Vec<&TokenStream2> = inferred_fns.iter().map(|(_, t)| t).collect();
let deps_ty = deps_ty.expect("checked above");
let handles_fn: ItemFn = if multi_wrappers.is_empty() {
parse_quote! {
pub fn handles() -> ::std::vec::Vec<::causal::Reactor<#deps_ty>> {
::std::vec![#(#wrappers()),*]
}
}
} else {
parse_quote! {
pub fn handles() -> ::std::vec::Vec<::causal::Reactor<#deps_ty>> {
let mut __h = ::std::vec![#(#wrappers()),*];
#(__h.extend(#multi_wrappers());)*
__h
}
}
};
let handlers_fn: ItemFn = parse_quote! {
pub fn reactors() -> ::std::vec::Vec<::causal::Reactor<#deps_ty>> {
handles()
}
};
items.push(Item::Fn(handles_fn));
items.push(Item::Fn(handlers_fn));
if !projection_wrappers.is_empty() {
let projections_fn: ItemFn = parse_quote! {
pub fn projections() -> ::std::vec::Vec<::causal::Projection<#deps_ty>> {
::std::vec![#(#projection_wrappers()),*]
}
};
items.push(Item::Fn(projections_fn));
}
let expanded = quote! { #module };
Ok(quote! {
#expanded
#(#inferred_tokens)*
})
}
fn parse_effect_args(metas: &Punctuated<Meta, Token![,]>) -> syn::Result<EffectArgs> {
let mut args = EffectArgs::default();
for meta in metas {
match meta {
Meta::NameValue(nv) if nv.path.is_ident("on") => {
ensure_unset(&args.on, nv, "on")?;
args.on = Some(parse_on_expr(&nv.value)?);
}
Meta::List(list) if list.path.is_ident("extract") => {
if !args.extract.is_empty() {
return Err(syn::Error::new_spanned(
list,
"extract(...) specified more than once",
));
}
args.extract = parse_extract_fields(list)?;
}
Meta::Path(path) if path.is_ident("projection") => {
return Err(syn::Error::new_spanned(path, "#[reactor(projection)] is removed in v0.20.0. Use #[projection] instead."));
}
Meta::Path(path) if path.is_ident("on_any") => {
if args.on_any {
return Err(syn::Error::new_spanned(
path,
"on_any specified more than once",
));
}
args.on_any = true;
}
Meta::Path(path) if path.is_ident("queued") => {
args.queued = true;
}
Meta::NameValue(nv) if nv.path.is_ident("id") => {
ensure_unset(&args.id, nv, "id")?;
args.id = Some(parse_string_lit(&nv.value)?);
}
Meta::NameValue(nv) if nv.path.is_ident("dlq_terminal") => {
ensure_unset(&args.dlq_terminal, nv, "dlq_terminal")?;
args.dlq_terminal = Some(parse_path_expr(&nv.value, "dlq_terminal")?);
}
Meta::NameValue(nv) if nv.path.is_ident("retry") => {
ensure_unset(&args.retry, nv, "retry")?;
args.retry = Some(parse_int_lit::<u32>(&nv.value, "retry")?);
}
Meta::NameValue(nv) if nv.path.is_ident("timeout_secs") => {
ensure_unset(&args.timeout_secs, nv, "timeout_secs")?;
args.timeout_secs = Some(parse_int_lit::<u64>(&nv.value, "timeout_secs")?);
}
Meta::NameValue(nv) if nv.path.is_ident("timeout_ms") => {
ensure_unset(&args.timeout_ms, nv, "timeout_ms")?;
args.timeout_ms = Some(parse_int_lit::<u64>(&nv.value, "timeout_ms")?);
}
Meta::NameValue(nv) if nv.path.is_ident("delay_secs") => {
ensure_unset(&args.delay_secs, nv, "delay_secs")?;
args.delay_secs = Some(parse_int_lit::<u64>(&nv.value, "delay_secs")?);
}
Meta::NameValue(nv) if nv.path.is_ident("delay_ms") => {
ensure_unset(&args.delay_ms, nv, "delay_ms")?;
args.delay_ms = Some(parse_int_lit::<u64>(&nv.value, "delay_ms")?);
}
Meta::NameValue(nv) if nv.path.is_ident("priority") => {
ensure_unset(&args.priority, nv, "priority")?;
args.priority = Some(parse_int_lit::<i32>(&nv.value, "priority")?);
}
Meta::NameValue(nv) if nv.path.is_ident("group") => {
ensure_unset(&args.group, nv, "group")?;
args.group = Some(parse_string_lit(&nv.value)?);
}
Meta::NameValue(nv) if nv.path.is_ident("aggregate") => {
ensure_unset(&args.aggregate, nv, "aggregate")?;
args.aggregate = Some(parse_path_expr(&nv.value, "aggregate")?);
}
Meta::NameValue(nv) if nv.path.is_ident("filter") => {
ensure_unset(&args.filter, nv, "filter")?;
args.filter = Some(parse_path_expr(&nv.value, "filter")?);
}
Meta::NameValue(nv) if nv.path.is_ident("describe") => {
ensure_unset(&args.describe, nv, "describe")?;
args.describe = Some(parse_path_expr(&nv.value, "describe")?);
}
Meta::NameValue(nv) if nv.path.is_ident("transition") => {
if args.transition.is_some() {
return Err(syn::Error::new_spanned(
nv,
"transition specified more than once",
));
}
args.transition = Some(parse_closure_expr(&nv.value)?);
}
_ => {
return Err(syn::Error::new_spanned(
meta,
"unsupported #[reactor] option",
));
}
}
}
Ok(args)
}
fn parse_on_expr(expr: &Expr) -> syn::Result<OnSpec> {
let expr = strip_expr_groups(expr);
match expr {
Expr::Path(expr_path) => Ok(OnSpec::EventType(expr_path.path.clone())),
Expr::Array(expr_array) => {
let paths = parse_path_array(expr_array)?;
if are_enum_variants(&paths) {
Ok(OnSpec::Variants(paths))
} else {
Ok(OnSpec::MultiType(paths))
}
}
Expr::Binary(binary) if matches!(binary.op, syn::BinOp::BitOr(_)) => {
let left = parse_on_expr(&binary.left)?;
let right = parse_on_expr(&binary.right)?;
match (left, right) {
(OnSpec::EventType(_), OnSpec::Variants(variants))
| (OnSpec::Variants(variants), OnSpec::EventType(_)) => {
Ok(OnSpec::Variants(variants))
}
_ => Err(syn::Error::new_spanned(
binary,
"expected on = EventType | [Enum::Variant, ...]",
)),
}
}
_ => Err(syn::Error::new_spanned(
expr,
"expected on = EventType, on = [Enum::Variant, ...], or on = [TypeA, TypeB]",
)),
}
}
fn are_enum_variants(paths: &[Path]) -> bool {
if paths.is_empty() || !paths.iter().all(|p| p.segments.len() >= 2) {
return false;
}
let Ok(first_base) = variant_base_path(&paths[0]) else {
return false;
};
let base_key = path_key(&first_base);
paths.iter().skip(1).all(|p| {
variant_base_path(p)
.map(|b| path_key(&b) == base_key)
.unwrap_or(false)
})
}
fn parse_path_array(array: &syn::ExprArray) -> syn::Result<Vec<Path>> {
let mut variants = Vec::new();
for element in &array.elems {
let element = strip_expr_groups(element);
match element {
Expr::Path(path) => variants.push(path.path.clone()),
_ => {
return Err(syn::Error::new_spanned(
element,
"on = [...] expects type paths",
));
}
}
}
if variants.is_empty() {
return Err(syn::Error::new_spanned(array, "on = [] is not supported"));
}
Ok(variants)
}
fn parse_extract_fields(list: &MetaList) -> syn::Result<Vec<Ident>> {
let parser = Punctuated::<Ident, Token![,]>::parse_terminated;
let idents = parser.parse2(list.tokens.clone())?;
if idents.is_empty() {
return Err(syn::Error::new_spanned(
list,
"extract(...) requires at least one field",
));
}
Ok(idents.into_iter().collect())
}
fn parse_string_lit(expr: &Expr) -> syn::Result<String> {
match strip_expr_groups(expr) {
Expr::Lit(expr_lit) => match &expr_lit.lit {
Lit::Str(value) => Ok(value.value()),
_ => Err(syn::Error::new_spanned(expr, "expected a string literal")),
},
_ => Err(syn::Error::new_spanned(expr, "expected a string literal")),
}
}
fn parse_int_lit<T>(expr: &Expr, label: &str) -> syn::Result<T>
where
T: std::str::FromStr,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
match strip_expr_groups(expr) {
Expr::Lit(expr_lit) => match &expr_lit.lit {
Lit::Int(value) => value
.base10_parse::<T>()
.map_err(|err| syn::Error::new_spanned(expr, format!("invalid {label}: {err}"))),
_ => Err(syn::Error::new_spanned(
expr,
format!("expected integer for {label}"),
)),
},
_ => Err(syn::Error::new_spanned(
expr,
format!("expected integer for {label}"),
)),
}
}
fn parse_path_expr(expr: &Expr, label: &str) -> syn::Result<Path> {
match strip_expr_groups(expr) {
Expr::Path(expr_path) => Ok(expr_path.path.clone()),
_ => Err(syn::Error::new_spanned(
expr,
format!("expected path for {label}, e.g. module::reactor"),
)),
}
}
fn parse_closure_expr(expr: &Expr) -> syn::Result<syn::ExprClosure> {
match strip_expr_groups(expr) {
Expr::Closure(closure) => Ok(closure.clone()),
_ => Err(syn::Error::new_spanned(
expr,
"expected a closure, e.g. |prev, next| prev.x > 0 && next.x == 0",
)),
}
}
fn find_effect_context(sig: &Signature) -> syn::Result<(usize, Type)> {
let mut found: Option<(usize, Type)> = None;
for (index, input) in sig.inputs.iter().enumerate() {
let FnArg::Typed(typed) = input else {
continue;
};
if let Some(deps) = parse_effect_context_type(&typed.ty) {
if found.is_some() {
return Err(syn::Error::new_spanned(
&typed.ty,
"multiple Context parameters are not supported",
));
}
found = Some((index, deps));
}
}
found.ok_or_else(|| {
syn::Error::new_spanned(
&sig.inputs,
"effect reactor must include ctx: Context<Deps>",
)
})
}
fn parse_effect_context_type(ty: &Type) -> Option<Type> {
let Type::Path(TypePath { path, .. }) = ty else {
return None;
};
let last = path.segments.last()?;
if last.ident != "Context" {
return None;
}
let PathArguments::AngleBracketed(args) = &last.arguments else {
return None;
};
let mut types = Vec::new();
for arg in &args.args {
if let GenericArgument::Type(ty) = arg {
types.push(ty.clone());
}
}
if types.len() != 1 {
return None;
}
Some(types[0].clone())
}
fn collect_params(sig: &Signature) -> syn::Result<Vec<ParamInfo>> {
let mut params = Vec::new();
for input in &sig.inputs {
let FnArg::Typed(typed) = input else {
return Err(syn::Error::new_spanned(
input,
"methods with self are not supported",
));
};
let Pat::Ident(pat_ident) = typed.pat.as_ref() else {
return Err(syn::Error::new_spanned(
&typed.pat,
"parameter patterns are not supported; use simple identifiers",
));
};
params.push(ParamInfo {
ident: pat_ident.ident.clone(),
ty: (*typed.ty).clone(),
});
}
Ok(params)
}
fn variant_base_path(path: &Path) -> syn::Result<Path> {
if path.segments.len() < 2 {
return Err(syn::Error::new_spanned(
path,
"expected variant path like Enum::Variant",
));
}
let mut base = path.clone();
let mut segments = Punctuated::new();
for segment in path.segments.iter().take(path.segments.len() - 1) {
segments.push(segment.clone());
}
base.segments = segments;
Ok(base)
}
fn apply_effect_config(base: TokenStream2, args: &EffectArgs, fn_ident: &Ident) -> TokenStream2 {
let mut builder = base;
if let Some(id) = args.id.as_ref().cloned().or_else(|| {
args.group
.as_ref()
.map(|group| format!("{group}::{}", fn_ident))
}) {
let id_lit = syn::LitStr::new(&id, fn_ident.span());
builder = quote! { #builder .id(#id_lit) };
}
if let Some(dlq_terminal) = &args.dlq_terminal {
builder = quote! {
#builder
.on_failure(|__causal_source: ::std::sync::Arc<_>, __causal_info| {
#dlq_terminal((__causal_source).as_ref().clone(), __causal_info)
})
};
}
if args.queued {
builder = quote! { #builder .retry(1) };
}
if let Some(retry) = args.retry {
builder = quote! { #builder .retry(#retry) };
}
if let Some(timeout_secs) = args.timeout_secs {
builder = quote! { #builder .timeout(::std::time::Duration::from_secs(#timeout_secs)) };
}
if let Some(timeout_ms) = args.timeout_ms {
builder = quote! { #builder .timeout(::std::time::Duration::from_millis(#timeout_ms)) };
}
if let Some(delay_secs) = args.delay_secs {
builder = quote! { #builder .delayed(::std::time::Duration::from_secs(#delay_secs)) };
}
if let Some(delay_ms) = args.delay_ms {
builder = quote! { #builder .delayed(::std::time::Duration::from_millis(#delay_ms)) };
}
if let Some(priority) = args.priority {
builder = quote! { #builder .priority(#priority) };
}
if let Some(filter_fn) = &args.filter {
builder = quote! { #builder .filter(#filter_fn) };
}
if let Some(describe_fn) = &args.describe {
builder = quote! { #builder .describe(#describe_fn) };
}
builder
}
fn apply_on_any_config(base: TokenStream2, args: &EffectArgs, fn_ident: &Ident) -> TokenStream2 {
let mut builder = base;
if let Some(id) = args.id.as_ref().cloned().or_else(|| {
args.group
.as_ref()
.map(|group| format!("{group}::{}", fn_ident))
}) {
let id_lit = syn::LitStr::new(&id, fn_ident.span());
builder = quote! { #builder .id(#id_lit) };
}
builder
}
fn expand_projection(
metas: &Punctuated<Meta, Token![,]>,
input_fn: ItemFn,
) -> syn::Result<TokenStream2> {
if input_fn.sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(
&input_fn.sig.fn_token,
"#[projection] requires an async function",
));
}
let mut id: Option<String> = None;
let mut priority: Option<i32> = None;
for meta in metas {
match meta {
Meta::NameValue(nv) if nv.path.is_ident("id") => {
id = Some(parse_string_lit(&nv.value)?);
}
Meta::NameValue(nv) if nv.path.is_ident("priority") => {
if let Expr::Lit(lit) = &nv.value {
if let Lit::Int(int_lit) = &lit.lit {
priority = Some(int_lit.base10_parse()?);
} else {
return Err(syn::Error::new_spanned(&nv.value, "expected integer literal"));
}
} else {
return Err(syn::Error::new_spanned(&nv.value, "expected integer literal"));
}
}
other => {
return Err(syn::Error::new_spanned(
other,
"#[projection] only supports id = \"...\" and priority = N",
));
}
}
}
let fn_ident = input_fn.sig.ident.clone();
let wrapper_ident = format_ident!("__causal_projection_{}", fn_ident);
let (ctx_idx, deps_ty) = find_effect_context(&input_fn.sig)?;
let params = collect_params(&input_fn.sig)?;
let non_ctx_params: Vec<ParamInfo> = params
.into_iter()
.enumerate()
.filter_map(|(idx, param)| if idx == ctx_idx { None } else { Some(param) })
.collect();
if non_ctx_params.len() != 1 {
return Err(syn::Error::new_spanned(
&input_fn.sig.inputs,
"#[projection] reactor requires exactly one AnyEvent parameter plus Context",
));
}
let event_ident = &non_ctx_params[0].ident;
let id_str = id.unwrap_or_else(|| fn_ident.to_string());
let id_lit = syn::LitStr::new(&id_str, fn_ident.span());
let mut builder = quote! { ::causal::project(#id_lit) };
if let Some(p) = priority {
builder = quote! { #builder .priority(#p) };
}
let chain = quote! {
#builder.then::<#deps_ty, _, _>(|#event_ident, __causal_ctx| async move {
#fn_ident(#event_ident, __causal_ctx).await
})
};
Ok(quote! {
#input_fn
#[doc(hidden)]
pub fn #wrapper_ident() -> ::causal::Projection<#deps_ty> {
#chain
}
})
}
fn effect_requires_stable_id(args: &EffectArgs) -> bool {
args.queued
|| args.delay_secs.is_some()
|| args.delay_ms.is_some()
|| args.timeout_secs.is_some()
|| args.timeout_ms.is_some()
|| args.retry.unwrap_or(1) > 1
}
fn has_attr(attrs: &[Attribute], name: &str) -> bool {
attrs.iter().any(|attr| attr.path().is_ident(name))
}
fn has_attr_any(attrs: &[Attribute], names: &[&str]) -> bool {
names.iter().any(|name| has_attr(attrs, name))
}
fn is_multi_type_handle(attrs: &[Attribute]) -> bool {
for attr in attrs {
if attr.path().is_ident("handle") || attr.path().is_ident("reactor") {
if let Ok(metas) = attr.parse_args_with(
Punctuated::<Meta, Token![,]>::parse_terminated,
) {
for meta in &metas {
if let Meta::NameValue(nv) = meta {
if nv.path.is_ident("on") {
if let Ok(OnSpec::MultiType(_)) = parse_on_expr(&nv.value) {
return true;
}
}
}
}
}
}
}
false
}
fn path_key(path: &Path) -> String {
path.to_token_stream().to_string()
}
fn type_key(ty: &Type) -> String {
ty.to_token_stream().to_string()
}
fn strip_expr_groups(mut expr: &Expr) -> &Expr {
loop {
match expr {
Expr::Group(group) => expr = &group.expr,
_ => return expr,
}
}
}
fn ensure_unset<T>(existing: &Option<T>, meta: &MetaNameValue, name: &str) -> syn::Result<()> {
if existing.is_some() {
return Err(syn::Error::new_spanned(
meta,
format!("{name} specified more than once"),
));
}
Ok(())
}
enum ReturnKind {
Unit,
Events,
Emit,
SingleEvent,
}
fn classify_return(sig: &Signature) -> syn::Result<ReturnKind> {
let ReturnType::Type(_, output_ty) = &sig.output else {
return Ok(ReturnKind::Unit);
};
let ok_ty = result_ok_type(output_ty)?;
if let Type::Tuple(t) = &ok_ty {
if t.elems.is_empty() {
return Ok(ReturnKind::Unit);
}
}
if let Type::Path(type_path) = &ok_ty {
if let Some(last) = type_path.path.segments.last() {
if last.ident == "Events" {
return Ok(ReturnKind::Events);
}
if last.ident == "Emit" {
return Ok(ReturnKind::Emit);
}
}
}
Ok(ReturnKind::SingleEvent)
}
fn result_ok_type(ty: &Type) -> syn::Result<Type> {
let Type::Path(type_path) = ty else {
return Err(syn::Error::new_spanned(ty, "expected return type Result<T>"));
};
let Some(last) = type_path.path.segments.last() else {
return Err(syn::Error::new_spanned(ty, "expected return type Result<T>"));
};
if last.ident != "Result" {
return Err(syn::Error::new_spanned(ty, "expected return type Result<T>"));
}
let PathArguments::AngleBracketed(args) = &last.arguments else {
return Err(syn::Error::new_spanned(ty, "expected return type Result<T>"));
};
for arg in &args.args {
if let GenericArgument::Type(inner) = arg {
return Ok(inner.clone());
}
}
Err(syn::Error::new_spanned(ty, "expected return type Result<T>"))
}
#[cfg(test)]
mod tests {
use super::*;
use quote::quote;
use syn::parse::Parser;
fn parse_effect_meta_list(tokens: TokenStream2) -> Punctuated<Meta, Token![,]> {
let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
parser
.parse2(tokens)
.expect("effect meta list should parse")
}
#[test]
fn parse_effect_args_accepts_queued_for_backward_compat() {
let metas = parse_effect_meta_list(quote!(on = MyEvent, queued));
let args = parse_effect_args(&metas).expect("queued should parse (backward compat)");
assert!(matches!(args.on, Some(OnSpec::EventType(_))));
}
#[test]
fn apply_effect_config_does_not_emit_queued_builder_call() {
let args = EffectArgs {
retry: Some(3),
..EffectArgs::default()
};
let reactor_ident: Ident = syn::parse_quote!(my_effect_handler);
let configured = apply_effect_config(
quote!(::causal::on::<MyEvent>()),
&args,
&reactor_ident,
);
let configured_text = configured.to_string();
assert!(
!configured_text.contains(". queued ()"),
"queued builder call should not be emitted, got: {}",
configured_text
);
assert!(
configured_text.contains(". retry (3u32)"),
"retry builder call should be emitted, got: {}",
configured_text
);
}
#[test]
fn parse_effect_args_rejects_delivery_option() {
let metas = parse_effect_meta_list(quote!(on = MyEvent, delivery = "durable"));
let error = parse_effect_args(&metas)
.err()
.expect("delivery option should remain unsupported");
assert!(
error.to_string().contains("unsupported #[reactor] option"),
"unexpected error: {}",
error
);
}
#[test]
fn stable_id_is_required_for_durable_effect_configs() {
let durable = EffectArgs {
retry: Some(3),
..EffectArgs::default()
};
assert!(effect_requires_stable_id(&durable));
}
#[test]
fn apply_effect_config_emits_filter_builder_call() {
let filter_path: Path = syn::parse_quote!(is_high_value);
let args = EffectArgs {
filter: Some(filter_path),
..EffectArgs::default()
};
let reactor_ident: Ident = syn::parse_quote!(my_effect_handler);
let configured = apply_effect_config(
quote!(::causal::on::<MyEvent>()),
&args,
&reactor_ident,
);
let configured_text = configured.to_string();
assert!(
configured_text.contains(". filter (is_high_value)"),
"filter builder call should be emitted, got: {}",
configured_text
);
}
#[test]
fn parse_effect_args_rejects_filter_with_extract() {
let metas =
parse_effect_meta_list(quote!(on = MyEvent, filter = my_filter, extract(field)));
let error = parse_effect_args(&metas).ok();
if let Some(args) = error {
assert!(args.filter.is_some());
assert!(!args.extract.is_empty());
}
}
#[test]
fn parse_effect_args_supports_on_any_flag() {
let metas = parse_effect_meta_list(quote!(on_any, id = "logger"));
let args = parse_effect_args(&metas).expect("on_any should parse");
assert!(args.on_any);
assert!(args.on.is_none());
assert_eq!(args.id.as_deref(), Some("logger"));
}
#[test]
fn parse_effect_args_rejects_on_any_with_on() {
let metas = parse_effect_meta_list(quote!(on_any, on = MyEvent));
let args = parse_effect_args(&metas).expect("parsing should succeed");
assert!(args.on_any);
assert!(args.on.is_some());
}
}
#[proc_macro_derive(DistributedSafe, attributes(allow_non_distributed))]
pub fn derive_distributed_safe(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match derive_distributed_safe_impl(input) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn derive_distributed_safe_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
let name = &input.ident;
let generics = &input.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
match &input.data {
Data::Struct(data) => {
validate_fields(&data.fields)?;
}
Data::Enum(_) => {
return Err(syn::Error::new_spanned(
name,
"DistributedSafe can only be derived for structs, not enums",
));
}
Data::Union(_) => {
return Err(syn::Error::new_spanned(
name,
"DistributedSafe can only be derived for structs, not unions",
));
}
}
Ok(quote! {
impl #impl_generics ::causal::distributed_safe::sealed::Sealed for #name #ty_generics #where_clause {}
impl #impl_generics ::causal::DistributedSafe for #name #ty_generics #where_clause {}
})
}
fn validate_fields(fields: &Fields) -> syn::Result<()> {
let fields_iter = match fields {
Fields::Named(fields) => fields.named.iter(),
Fields::Unnamed(fields) => fields.unnamed.iter(),
Fields::Unit => return Ok(()),
};
for field in fields_iter {
if has_attr(&field.attrs, "allow_non_distributed") {
continue;
}
if is_dangerous_type(&field.ty) {
return Err(syn::Error::new_spanned(
&field.ty,
format!(
"field type may not be distributed-safe (contains Arc<Mutex> or similar). \
Either: (1) use external storage (Database, Redis), \
(2) use event-threaded state, or \
(3) add #[allow_non_distributed] attribute to explicitly opt-out"
),
));
}
}
Ok(())
}
fn is_dangerous_type(ty: &Type) -> bool {
match ty {
Type::Path(type_path) => {
let path = &type_path.path;
let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
if segments.contains(&"Arc".to_string()) {
for segment in &path.segments {
if let PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let GenericArgument::Type(inner_ty) = arg {
if type_contains_lock(inner_ty) {
return true;
}
}
}
}
}
}
if segments.contains(&"Mutex".to_string())
|| segments.contains(&"RwLock".to_string())
|| segments.contains(&"RefCell".to_string())
{
return true;
}
false
}
_ => false,
}
}
fn type_contains_lock(ty: &Type) -> bool {
match ty {
Type::Path(type_path) => {
let segments: Vec<_> = type_path
.path
.segments
.iter()
.map(|s| s.ident.to_string())
.collect();
segments.contains(&"Mutex".to_string())
|| segments.contains(&"RwLock".to_string())
|| segments.contains(&"RefCell".to_string())
}
_ => false,
}
}
enum IdAccess {
Field(Ident),
Method(Ident),
Singleton,
FactStreamId,
}
fn parse_aggregator_id_access(metas: &Punctuated<Meta, Token![,]>) -> syn::Result<IdAccess> {
for meta in metas {
if let Meta::Path(path) = meta {
if path.is_ident("singleton") {
return Ok(IdAccess::Singleton);
}
}
if let Meta::NameValue(nv) = meta {
if nv.path.is_ident("id") {
if let Expr::Lit(expr_lit) = &nv.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
return Ok(IdAccess::Field(Ident::new(
&lit_str.value(),
lit_str.span(),
)));
}
}
return Err(syn::Error::new_spanned(
&nv.value,
"expected string literal for `id`, e.g. id = \"order_id\"",
));
}
if nv.path.is_ident("id_fn") {
if let Expr::Lit(expr_lit) = &nv.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
return Ok(IdAccess::Method(Ident::new(
&lit_str.value(),
lit_str.span(),
)));
}
}
return Err(syn::Error::new_spanned(
&nv.value,
"expected string literal for `id_fn`, e.g. id_fn = \"run_id\"",
));
}
}
}
Ok(IdAccess::FactStreamId)
}
fn parse_aggregator_params(sig: &Signature) -> syn::Result<(Type, Ident, Type, Ident)> {
let params: Vec<_> = sig.inputs.iter().collect();
if params.len() != 2 {
return Err(syn::Error::new_spanned(
&sig.inputs,
"#[aggregator] function must have exactly 2 parameters: (agg: &mut Aggregate, event: Event)",
));
}
let (agg_ty, agg_ident) = match ¶ms[0] {
FnArg::Typed(pat_type) => {
let ident = match pat_type.pat.as_ref() {
Pat::Ident(pat_ident) => pat_ident.ident.clone(),
_ => {
return Err(syn::Error::new_spanned(
&pat_type.pat,
"expected a simple identifier for aggregate parameter",
))
}
};
match pat_type.ty.as_ref() {
Type::Reference(type_ref) if type_ref.mutability.is_some() => {
(type_ref.elem.as_ref().clone(), ident)
}
_ => {
return Err(syn::Error::new_spanned(
&pat_type.ty,
"first parameter must be `&mut AggregateType`",
))
}
}
}
_ => {
return Err(syn::Error::new_spanned(
params[0],
"first parameter must be a typed parameter",
))
}
};
let (event_ty, event_ident) = match ¶ms[1] {
FnArg::Typed(pat_type) => {
let ident = match pat_type.pat.as_ref() {
Pat::Ident(pat_ident) => pat_ident.ident.clone(),
_ => {
return Err(syn::Error::new_spanned(
&pat_type.pat,
"expected a simple identifier for event parameter",
))
}
};
(pat_type.ty.as_ref().clone(), ident)
}
_ => {
return Err(syn::Error::new_spanned(
params[1],
"second parameter must be a typed parameter",
))
}
};
Ok((agg_ty, agg_ident, event_ty, event_ident))
}
fn expand_aggregator(
metas: &Punctuated<Meta, Token![,]>,
input_fn: ItemFn,
) -> syn::Result<TokenStream2> {
let id_access = parse_aggregator_id_access(metas)?;
expand_aggregator_with_id(&id_access, &input_fn)
}
fn expand_aggregator_with_id(
id_access: &IdAccess,
input_fn: &ItemFn,
) -> syn::Result<TokenStream2> {
let (agg_ty, agg_ident, event_ty, event_ident) = parse_aggregator_params(&input_fn.sig)?;
let fn_name = &input_fn.sig.ident;
let body = &input_fn.block;
let factory_name = format_ident!("__causal_aggregator_{}", fn_name);
let factory_body = match id_access {
IdAccess::FactStreamId => quote! {
::causal::Aggregator::for_type::<#agg_ty, #event_ty>()
},
IdAccess::Field(field_ident) => quote! {
::causal::Aggregator::for_type_with_id_fn::<#agg_ty, #event_ty, _>(
|e: &#event_ty| {
use ::causal::aggregator::AggregatorIdValue;
e.#field_ident.into_aggregator_id()
}
)
},
IdAccess::Method(method_ident) => quote! {
::causal::Aggregator::for_type_with_id_fn::<#agg_ty, #event_ty, _>(
|e: &#event_ty| {
use ::causal::aggregator::AggregatorIdValue;
e.#method_ident().into_aggregator_id()
}
)
},
IdAccess::Singleton => quote! {
::causal::Aggregator::for_type_with_id_fn::<#agg_ty, #event_ty, _>(
|_: &#event_ty| Some(::uuid::Uuid::nil())
)
},
};
Ok(quote! {
impl ::causal::Apply<#event_ty> for #agg_ty {
fn apply(&mut self, #event_ident: &#event_ty) {
let #agg_ident = self;
let #event_ident: #event_ty = #event_ident.clone();
#body
}
}
fn #factory_name() -> ::causal::Aggregator {
#factory_body
}
})
}
fn expand_aggregators_module(
module_metas: &Punctuated<Meta, Token![,]>,
module: &mut ItemMod,
) -> syn::Result<TokenStream2> {
let Some((_, items)) = &mut module.content else {
return Err(syn::Error::new_spanned(
module,
"#[aggregators] requires an inline module",
));
};
let module_id_access = if module_metas.is_empty() {
Some(IdAccess::FactStreamId)
} else {
Some(parse_aggregator_id_access(module_metas)?)
};
let mut factory_names = Vec::new();
let mut expanded_fns = Vec::new();
let mut expanded_fn_names = Vec::new();
for item in items.iter() {
let Item::Fn(item_fn) = item else {
continue;
};
let has_aggregator_attr = has_attr_any(&item_fn.attrs, &["aggregator"]);
if has_aggregator_attr {
let factory_name = format_ident!("__causal_aggregator_{}", item_fn.sig.ident);
factory_names.push(factory_name);
} else if let Some(ref default_id) = module_id_access {
let factory_name = format_ident!("__causal_aggregator_{}", item_fn.sig.ident);
factory_names.push(factory_name.clone());
expanded_fn_names.push(item_fn.sig.ident.to_string());
expanded_fns.push(expand_aggregator_with_id(default_id, item_fn)?);
}
}
if factory_names.is_empty() {
let msg = if module_id_access.is_some() {
"#[aggregators] module must contain at least one function"
} else {
"#[aggregators] module must contain at least one #[aggregator] function"
};
return Err(syn::Error::new_spanned(module, msg));
}
items.retain(|item| {
if let Item::Fn(item_fn) = item {
!expanded_fn_names.contains(&item_fn.sig.ident.to_string())
} else {
true
}
});
let aggregators_fn: ItemFn = parse_quote! {
pub fn aggregators() -> ::std::vec::Vec<::causal::Aggregator> {
::std::vec![#(#factory_names()),*]
}
};
items.push(Item::Fn(aggregators_fn));
let expanded = quote! { #module };
Ok(quote! {
#expanded
#(#expanded_fns)*
})
}
#[proc_macro_attribute]
pub fn event(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let args = parse_event_args(attr.into());
match expand_event(args, input) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
struct EventArgs {
prefix: Option<String>,
ephemeral: bool,
stream_category: Option<String>,
stream_id: Option<String>,
occurred_at_field: Option<String>,
}
fn parse_event_args(tokens: TokenStream2) -> EventArgs {
let mut prefix = None;
let mut ephemeral = false;
let mut stream_category = None;
let mut stream_id = None;
let mut occurred_at_field = None;
if tokens.is_empty() {
return EventArgs {
prefix,
ephemeral,
stream_category,
stream_id,
occurred_at_field,
};
}
let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
let metas = match parser.parse2(tokens) {
Ok(m) => m,
Err(_) => {
return EventArgs {
prefix,
ephemeral,
stream_category,
stream_id,
occurred_at_field,
};
}
};
for meta in &metas {
match meta {
Meta::NameValue(MetaNameValue { path, value, .. }) if path.is_ident("prefix") => {
if let Expr::Lit(expr_lit) = value {
if let Lit::Str(lit) = &expr_lit.lit {
prefix = Some(lit.value());
}
}
}
Meta::Path(path) if path.is_ident("ephemeral") => {
ephemeral = true;
}
Meta::NameValue(MetaNameValue { path, value, .. })
if path.is_ident("stream_category") =>
{
if let Expr::Lit(expr_lit) = value {
if let Lit::Str(lit) = &expr_lit.lit {
stream_category = Some(lit.value());
}
}
}
Meta::NameValue(MetaNameValue { path, value, .. })
if path.is_ident("stream_id") =>
{
if let Expr::Lit(expr_lit) = value {
if let Lit::Str(lit) = &expr_lit.lit {
stream_id = Some(lit.value());
}
}
}
Meta::NameValue(MetaNameValue { path, value, .. })
if path.is_ident("occurred_at_field") =>
{
if let Expr::Lit(expr_lit) = value {
if let Lit::Str(lit) = &expr_lit.lit {
occurred_at_field = Some(lit.value());
}
}
}
_ => {}
}
}
EventArgs {
prefix,
ephemeral,
stream_category,
stream_id,
occurred_at_field,
}
}
fn expand_event(args: EventArgs, input: DeriveInput) -> Result<TokenStream2, syn::Error> {
let name = &input.ident;
match &input.data {
Data::Enum(data_enum) => expand_event_enum(args, &input, data_enum),
Data::Struct(_) => expand_event_struct(args, &input),
Data::Union(_) => Err(syn::Error::new_spanned(
name,
"#[event] cannot be applied to unions",
)),
}
}
fn expand_event_enum(
args: EventArgs,
input: &DeriveInput,
data_enum: &syn::DataEnum,
) -> Result<TokenStream2, syn::Error> {
let name = &input.ident;
let prefix = args.prefix.ok_or_else(|| {
syn::Error::new_spanned(
name,
"#[event] on enums requires a prefix: #[event(prefix = \"...\")]",
)
})?;
let serde_info = parse_serde_attrs(&input.attrs)?;
if serde_info.tag.is_none() && !serde_info.untagged {
return Err(syn::Error::new_spanned(
name,
"#[event] on enums requires #[serde(tag = \"...\")] for variant discrimination",
));
}
if serde_info.untagged {
return Err(syn::Error::new_spanned(
name,
"#[event] cannot be applied to #[serde(untagged)] enums — untagged enums have no stable variant discriminator",
));
}
let ephemeral = args.ephemeral;
let mut match_arms = Vec::new();
let mut name_arms = Vec::new();
for variant in &data_enum.variants {
let variant_name = &variant.ident;
let renamed = get_serde_rename(&variant.attrs);
let variant_str = if let Some(rename) = renamed {
rename
} else {
apply_rename_rule(&variant_name.to_string(), serde_info.rename_all.as_deref())
};
let durable = format!("{}:{}", prefix, variant_str);
let bare = variant_str.clone();
let pattern = match &variant.fields {
Fields::Named(_) => quote! { #name::#variant_name { .. } },
Fields::Unnamed(_) => quote! { #name::#variant_name(..) },
Fields::Unit => quote! { #name::#variant_name },
};
match_arms.push(quote! { #pattern => #durable });
name_arms.push(quote! { #pattern => #bare });
}
let fact_impl = if args.stream_id.is_some() {
let category = args.stream_category.as_ref().unwrap_or(&prefix);
let id_field = args.stream_id.as_ref().unwrap();
let id_field_ident = format_ident!("{}", id_field);
let occurred_field = args
.occurred_at_field
.clone()
.unwrap_or_else(|| "occurred_at".to_string());
let occurred_field_ident = format_ident!("{}", occurred_field);
let mut stream_arms = Vec::new();
let mut occurred_arms = Vec::new();
for variant in &data_enum.variants {
let variant_name = &variant.ident;
match &variant.fields {
Fields::Named(fields) => {
let has_id = fields
.named
.iter()
.any(|f| f.ident.as_ref().map(|i| i == &id_field_ident).unwrap_or(false));
let has_occurred = fields.named.iter().any(|f| {
f.ident
.as_ref()
.map(|i| i == &occurred_field_ident)
.unwrap_or(false)
});
if !has_id {
return Err(syn::Error::new_spanned(
variant_name,
format!(
"#[event(stream_id = \"{}\")] requires every variant to have a `{}` field",
id_field, id_field
),
));
}
if !has_occurred {
return Err(syn::Error::new_spanned(
variant_name,
format!(
"#[event] Event generation requires every variant to have an `{}` field (override with `occurred_at_field = \"...\"`)",
occurred_field
),
));
}
stream_arms.push(quote! {
#name::#variant_name { #id_field_ident, .. } => *#id_field_ident
});
occurred_arms.push(quote! {
#name::#variant_name { #occurred_field_ident, .. } => *#occurred_field_ident
});
}
Fields::Unnamed(_) | Fields::Unit => {
return Err(syn::Error::new_spanned(
variant_name,
"#[event] Event generation requires named-fields variants when stream_id/occurred_at_field are used",
));
}
}
}
quote! {
impl ::causal::Event for #name {
const CATEGORY: &'static str = #category;
fn event_type(&self) -> &str {
match self {
#(#name_arms,)*
}
}
fn stream_id(&self) -> ::uuid::Uuid {
match self {
#(#stream_arms,)*
}
}
fn occurred_at(&self) -> ::core::option::Option<::chrono::DateTime<::chrono::Utc>> {
::core::option::Option::Some(match self {
#(#occurred_arms,)*
})
}
}
}
} else {
quote! {
impl ::causal::Event for #name {
const CATEGORY: &'static str = #prefix;
fn event_type(&self) -> &str {
match self {
#(#name_arms,)*
}
}
fn stream_id(&self) -> ::uuid::Uuid {
::uuid::Uuid::nil()
}
}
}
};
let _ = (match_arms, ephemeral);
Ok(quote! {
#input
#fact_impl
})
}
fn expand_event_struct(
args: EventArgs,
input: &DeriveInput,
) -> Result<TokenStream2, syn::Error> {
let name = &input.ident;
let ephemeral = args.ephemeral;
let durable = if let Some(ref prefix) = args.prefix {
prefix.clone()
} else {
to_snake_case(&name.to_string())
};
let prefix_str = durable.clone();
let bare_name = prefix_str.clone();
let fact_impl = if let Some(id_field) = args.stream_id.as_ref() {
let category = args.stream_category.as_ref().unwrap_or(&prefix_str);
let id_field_ident = format_ident!("{}", id_field);
let occurred_field_ident = format_ident!(
"{}",
args.occurred_at_field
.clone()
.unwrap_or_else(|| "occurred_at".to_string())
);
quote! {
impl ::causal::Event for #name {
const CATEGORY: &'static str = #category;
fn event_type(&self) -> &str { #bare_name }
fn stream_id(&self) -> ::uuid::Uuid {
self.#id_field_ident
}
fn occurred_at(&self) -> ::core::option::Option<::chrono::DateTime<::chrono::Utc>> {
::core::option::Option::Some(self.#occurred_field_ident)
}
}
}
} else {
let category = args.stream_category.as_ref().unwrap_or(&prefix_str);
quote! {
impl ::causal::Event for #name {
const CATEGORY: &'static str = #category;
fn event_type(&self) -> &str { #bare_name }
fn stream_id(&self) -> ::uuid::Uuid {
::uuid::Uuid::nil()
}
}
}
};
let _ = (durable, prefix_str, ephemeral);
Ok(quote! {
#input
#fact_impl
})
}
struct SerdeInfo {
tag: Option<String>,
rename_all: Option<String>,
untagged: bool,
}
fn parse_serde_attrs(attrs: &[Attribute]) -> Result<SerdeInfo, syn::Error> {
let mut info = SerdeInfo {
tag: None,
rename_all: None,
untagged: false,
};
for attr in attrs {
if !attr.path().is_ident("serde") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("tag") {
let value = meta.value()?;
let lit: Lit = value.parse()?;
if let Lit::Str(s) = lit {
info.tag = Some(s.value());
}
} else if meta.path.is_ident("rename_all") {
let value = meta.value()?;
let lit: Lit = value.parse()?;
if let Lit::Str(s) = lit {
info.rename_all = Some(s.value());
}
} else if meta.path.is_ident("untagged") {
info.untagged = true;
} else if meta.input.peek(Token![=]) {
let _: Token![=] = meta.input.parse()?;
let _: Lit = meta.input.parse()?;
} else {
}
Ok(())
})?;
}
Ok(info)
}
fn get_serde_rename(attrs: &[Attribute]) -> Option<String> {
for attr in attrs {
if !attr.path().is_ident("serde") {
continue;
}
let mut rename = None;
let _ = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("rename") {
let value = meta.value()?;
let lit: Lit = value.parse()?;
if let Lit::Str(s) = lit {
rename = Some(s.value());
}
}
Ok(())
});
if rename.is_some() {
return rename;
}
}
None
}
fn apply_rename_rule(name: &str, rule: Option<&str>) -> String {
match rule {
Some("snake_case") => to_snake_case(name),
Some("camelCase") => to_camel_case(name),
Some("PascalCase") => name.to_string(),
Some("SCREAMING_SNAKE_CASE") => to_snake_case(name).to_uppercase(),
Some("kebab-case") => to_snake_case(name).replace('_', "-"),
_ => name.to_string(), }
}
fn to_snake_case(s: &str) -> String {
let mut result = String::new();
let mut prev_was_upper = false;
let mut prev_was_underscore = true;
for (i, ch) in s.chars().enumerate() {
if ch.is_uppercase() {
if i > 0 && !prev_was_underscore {
if !prev_was_upper {
result.push('_');
} else {
let next_is_lower = s.chars().nth(i + 1).map_or(false, |c| c.is_lowercase());
if next_is_lower {
result.push('_');
}
}
}
result.push(ch.to_lowercase().next().unwrap());
prev_was_upper = true;
prev_was_underscore = false;
} else if ch == '_' {
result.push('_');
prev_was_upper = false;
prev_was_underscore = true;
} else {
result.push(ch);
prev_was_upper = false;
prev_was_underscore = false;
}
}
result
}
fn to_camel_case(s: &str) -> String {
let mut result = String::new();
let mut first = true;
for ch in s.chars() {
if first && ch.is_uppercase() {
result.push(ch.to_lowercase().next().unwrap());
first = false;
} else {
result.push(ch);
first = false;
}
}
result
}