use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::parse::Parser;
use syn::punctuated::Punctuated;
use syn::{
parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Expr, Fields, FnArg, Ident,
Item, ItemFn, ItemMod, Lit, Meta, MetaNameValue, Pat, Signature, Token, Type,
};
#[proc_macro_attribute]
pub fn aggregator(attr: TokenStream, item: TokenStream) -> TokenStream {
let metas = parse_macro_input!(attr with Punctuated::<Meta, Token![,]>::parse_terminated);
let input_fn = parse_macro_input!(item as ItemFn);
match expand_aggregator(&metas, input_fn) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn aggregators(attr: TokenStream, item: TokenStream) -> TokenStream {
let module_metas =
parse_macro_input!(attr with Punctuated::<Meta, Token![,]>::parse_terminated);
let mut module = parse_macro_input!(item as ItemMod);
match expand_aggregators_module(&module_metas, &mut module) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn has_attr(attrs: &[Attribute], name: &str) -> bool {
attrs.iter().any(|attr| attr.path().is_ident(name))
}
fn has_attr_any(attrs: &[Attribute], names: &[&str]) -> bool {
names.iter().any(|name| has_attr(attrs, name))
}
enum IdAccess {
Field(Ident),
Method(Ident),
Singleton,
FactStreamId,
}
fn parse_aggregator_id_access(metas: &Punctuated<Meta, Token![,]>) -> syn::Result<IdAccess> {
for meta in metas {
if let Meta::Path(path) = meta {
if path.is_ident("singleton") {
return Ok(IdAccess::Singleton);
}
}
if let Meta::NameValue(nv) = meta {
if nv.path.is_ident("id") {
if let Expr::Lit(expr_lit) = &nv.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
return Ok(IdAccess::Field(Ident::new(
&lit_str.value(),
lit_str.span(),
)));
}
}
return Err(syn::Error::new_spanned(
&nv.value,
"expected string literal for `id`, e.g. id = \"order_id\"",
));
}
if nv.path.is_ident("id_fn") {
if let Expr::Lit(expr_lit) = &nv.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
return Ok(IdAccess::Method(Ident::new(
&lit_str.value(),
lit_str.span(),
)));
}
}
return Err(syn::Error::new_spanned(
&nv.value,
"expected string literal for `id_fn`, e.g. id_fn = \"run_id\"",
));
}
}
}
Ok(IdAccess::FactStreamId)
}
fn parse_aggregator_params(sig: &Signature) -> syn::Result<(Type, Ident, Type, Ident)> {
let params: Vec<_> = sig.inputs.iter().collect();
if params.len() != 2 {
return Err(syn::Error::new_spanned(
&sig.inputs,
"#[aggregator] function must have exactly 2 parameters: (agg: &mut Aggregate, event: Event)",
));
}
let (agg_ty, agg_ident) = match ¶ms[0] {
FnArg::Typed(pat_type) => {
let ident = match pat_type.pat.as_ref() {
Pat::Ident(pat_ident) => pat_ident.ident.clone(),
_ => {
return Err(syn::Error::new_spanned(
&pat_type.pat,
"expected a simple identifier for aggregate parameter",
))
}
};
match pat_type.ty.as_ref() {
Type::Reference(type_ref) if type_ref.mutability.is_some() => {
(type_ref.elem.as_ref().clone(), ident)
}
_ => {
return Err(syn::Error::new_spanned(
&pat_type.ty,
"first parameter must be `&mut AggregateType`",
))
}
}
}
_ => {
return Err(syn::Error::new_spanned(
params[0],
"first parameter must be a typed parameter",
))
}
};
let (event_ty, event_ident) = match ¶ms[1] {
FnArg::Typed(pat_type) => {
let ident = match pat_type.pat.as_ref() {
Pat::Ident(pat_ident) => pat_ident.ident.clone(),
_ => {
return Err(syn::Error::new_spanned(
&pat_type.pat,
"expected a simple identifier for event parameter",
))
}
};
(pat_type.ty.as_ref().clone(), ident)
}
_ => {
return Err(syn::Error::new_spanned(
params[1],
"second parameter must be a typed parameter",
))
}
};
Ok((agg_ty, agg_ident, event_ty, event_ident))
}
fn expand_aggregator(
metas: &Punctuated<Meta, Token![,]>,
input_fn: ItemFn,
) -> syn::Result<TokenStream2> {
let id_access = parse_aggregator_id_access(metas)?;
expand_aggregator_with_id(&id_access, &input_fn)
}
fn expand_aggregator_with_id(
id_access: &IdAccess,
input_fn: &ItemFn,
) -> syn::Result<TokenStream2> {
let (agg_ty, agg_ident, event_ty, event_ident) = parse_aggregator_params(&input_fn.sig)?;
let fn_name = &input_fn.sig.ident;
let body = &input_fn.block;
let factory_name = format_ident!("__causal_aggregator_{}", fn_name);
let factory_body = match id_access {
IdAccess::FactStreamId => quote! {
::causal::Aggregator::for_type::<#agg_ty, #event_ty>()
},
IdAccess::Field(field_ident) => quote! {
::causal::Aggregator::for_type_with_id_fn::<#agg_ty, #event_ty, _>(
|e: &#event_ty| {
use ::causal::aggregator::AggregatorIdValue;
e.#field_ident.into_aggregator_id()
}
)
},
IdAccess::Method(method_ident) => quote! {
::causal::Aggregator::for_type_with_id_fn::<#agg_ty, #event_ty, _>(
|e: &#event_ty| {
use ::causal::aggregator::AggregatorIdValue;
e.#method_ident().into_aggregator_id()
}
)
},
IdAccess::Singleton => quote! {
::causal::Aggregator::for_type_with_id_fn::<#agg_ty, #event_ty, _>(
|_: &#event_ty| Some(::uuid::Uuid::nil())
)
},
};
Ok(quote! {
impl ::causal::Apply<#event_ty> for #agg_ty {
fn apply(&mut self, #event_ident: &#event_ty) {
let #agg_ident = self;
let #event_ident: #event_ty = #event_ident.clone();
#body
}
}
fn #factory_name() -> ::causal::Aggregator {
#factory_body
}
})
}
fn expand_aggregators_module(
module_metas: &Punctuated<Meta, Token![,]>,
module: &mut ItemMod,
) -> syn::Result<TokenStream2> {
let Some((_, items)) = &mut module.content else {
return Err(syn::Error::new_spanned(
module,
"#[aggregators] requires an inline module",
));
};
let module_id_access = if module_metas.is_empty() {
Some(IdAccess::FactStreamId)
} else {
Some(parse_aggregator_id_access(module_metas)?)
};
let mut factory_names = Vec::new();
let mut expanded_fns = Vec::new();
let mut expanded_fn_names = Vec::new();
for item in items.iter() {
let Item::Fn(item_fn) = item else {
continue;
};
let has_aggregator_attr = has_attr_any(&item_fn.attrs, &["aggregator"]);
if has_aggregator_attr {
let factory_name = format_ident!("__causal_aggregator_{}", item_fn.sig.ident);
factory_names.push(factory_name);
} else if let Some(ref default_id) = module_id_access {
let factory_name = format_ident!("__causal_aggregator_{}", item_fn.sig.ident);
factory_names.push(factory_name.clone());
expanded_fn_names.push(item_fn.sig.ident.to_string());
expanded_fns.push(expand_aggregator_with_id(default_id, item_fn)?);
}
}
if factory_names.is_empty() {
let msg = if module_id_access.is_some() {
"#[aggregators] module must contain at least one function"
} else {
"#[aggregators] module must contain at least one #[aggregator] function"
};
return Err(syn::Error::new_spanned(module, msg));
}
items.retain(|item| {
if let Item::Fn(item_fn) = item {
!expanded_fn_names.contains(&item_fn.sig.ident.to_string())
} else {
true
}
});
let aggregators_fn: ItemFn = parse_quote! {
pub fn aggregators() -> ::std::vec::Vec<::causal::Aggregator> {
::std::vec![#(#factory_names()),*]
}
};
items.push(Item::Fn(aggregators_fn));
let expanded = quote! { #module };
Ok(quote! {
#expanded
#(#expanded_fns)*
})
}
#[proc_macro_attribute]
pub fn event(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let args = parse_event_args(attr.into());
match expand_event(args, input) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
struct EventArgs {
prefix: Option<String>,
ephemeral: bool,
stream_category: Option<String>,
stream_id: Option<String>,
occurred_at_field: Option<String>,
stream: Option<String>,
}
fn parse_event_args(tokens: TokenStream2) -> EventArgs {
let mut prefix = None;
let mut ephemeral = false;
let mut stream_category = None;
let mut stream_id = None;
let mut occurred_at_field = None;
let mut stream = None;
if tokens.is_empty() {
return EventArgs {
prefix,
ephemeral,
stream_category,
stream_id,
occurred_at_field,
stream,
};
}
let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
let metas = match parser.parse2(tokens) {
Ok(m) => m,
Err(_) => {
return EventArgs {
prefix,
ephemeral,
stream_category,
stream_id,
occurred_at_field,
stream,
};
}
};
for meta in &metas {
match meta {
Meta::NameValue(MetaNameValue { path, value, .. }) if path.is_ident("prefix") => {
if let Expr::Lit(expr_lit) = value {
if let Lit::Str(lit) = &expr_lit.lit {
prefix = Some(lit.value());
}
}
}
Meta::Path(path) if path.is_ident("ephemeral") => {
ephemeral = true;
}
Meta::NameValue(MetaNameValue { path, value, .. })
if path.is_ident("stream_category") =>
{
if let Expr::Lit(expr_lit) = value {
if let Lit::Str(lit) = &expr_lit.lit {
stream_category = Some(lit.value());
}
}
}
Meta::NameValue(MetaNameValue { path, value, .. })
if path.is_ident("stream_id") =>
{
if let Expr::Lit(expr_lit) = value {
if let Lit::Str(lit) = &expr_lit.lit {
stream_id = Some(lit.value());
}
}
}
Meta::NameValue(MetaNameValue { path, value, .. })
if path.is_ident("occurred_at_field") =>
{
if let Expr::Lit(expr_lit) = value {
if let Lit::Str(lit) = &expr_lit.lit {
occurred_at_field = Some(lit.value());
}
}
}
Meta::NameValue(MetaNameValue { path, value, .. }) if path.is_ident("stream") => {
if let Expr::Lit(expr_lit) = value {
if let Lit::Str(lit) = &expr_lit.lit {
stream = Some(lit.value());
}
}
}
_ => {}
}
}
EventArgs {
prefix,
ephemeral,
stream_category,
stream_id,
occurred_at_field,
stream,
}
}
fn expand_event(args: EventArgs, input: DeriveInput) -> Result<TokenStream2, syn::Error> {
let name = &input.ident;
match &input.data {
Data::Enum(data_enum) => expand_event_enum(args, &input, data_enum),
Data::Struct(_) => expand_event_struct(args, &input),
Data::Union(_) => Err(syn::Error::new_spanned(
name,
"#[event] cannot be applied to unions",
)),
}
}
fn expand_event_enum(
args: EventArgs,
input: &DeriveInput,
data_enum: &syn::DataEnum,
) -> Result<TokenStream2, syn::Error> {
let name = &input.ident;
let prefix = args.prefix.ok_or_else(|| {
syn::Error::new_spanned(
name,
"#[event] on enums requires a prefix: #[event(prefix = \"...\")]",
)
})?;
let stream_const = match &args.stream {
Some(s) => quote! { const STREAM_CATEGORY: &'static str = #s; },
None => quote! {},
};
let serde_info = parse_serde_attrs(&input.attrs)?;
if serde_info.tag.is_none() && !serde_info.untagged {
return Err(syn::Error::new_spanned(
name,
"#[event] on enums requires #[serde(tag = \"...\")] for variant discrimination",
));
}
if serde_info.untagged {
return Err(syn::Error::new_spanned(
name,
"#[event] cannot be applied to #[serde(untagged)] enums — untagged enums have no stable variant discriminator",
));
}
let ephemeral = args.ephemeral;
let mut match_arms = Vec::new();
let mut name_arms = Vec::new();
for variant in &data_enum.variants {
let variant_name = &variant.ident;
let renamed = get_serde_rename(&variant.attrs);
let variant_str = if let Some(rename) = renamed {
rename
} else {
apply_rename_rule(&variant_name.to_string(), serde_info.rename_all.as_deref())
};
let durable = format!("{}:{}", prefix, variant_str);
let bare = variant_str.clone();
let pattern = match &variant.fields {
Fields::Named(_) => quote! { #name::#variant_name { .. } },
Fields::Unnamed(_) => quote! { #name::#variant_name(..) },
Fields::Unit => quote! { #name::#variant_name },
};
match_arms.push(quote! { #pattern => #durable });
name_arms.push(quote! { #pattern => #bare });
}
let fact_impl = if args.stream_id.is_some() {
let category = args.stream_category.as_ref().unwrap_or(&prefix);
let id_field = args.stream_id.as_ref().unwrap();
let id_field_ident = format_ident!("{}", id_field);
let occurred_field = args
.occurred_at_field
.clone()
.unwrap_or_else(|| "occurred_at".to_string());
let occurred_field_ident = format_ident!("{}", occurred_field);
let mut stream_arms = Vec::new();
let mut occurred_arms = Vec::new();
for variant in &data_enum.variants {
let variant_name = &variant.ident;
match &variant.fields {
Fields::Named(fields) => {
let has_id = fields
.named
.iter()
.any(|f| f.ident.as_ref().map(|i| i == &id_field_ident).unwrap_or(false));
let has_occurred = fields.named.iter().any(|f| {
f.ident
.as_ref()
.map(|i| i == &occurred_field_ident)
.unwrap_or(false)
});
if !has_id {
return Err(syn::Error::new_spanned(
variant_name,
format!(
"#[event(stream_id = \"{}\")] requires every variant to have a `{}` field",
id_field, id_field
),
));
}
if !has_occurred {
return Err(syn::Error::new_spanned(
variant_name,
format!(
"#[event] Event generation requires every variant to have an `{}` field (override with `occurred_at_field = \"...\"`)",
occurred_field
),
));
}
stream_arms.push(quote! {
#name::#variant_name { #id_field_ident, .. } => *#id_field_ident
});
occurred_arms.push(quote! {
#name::#variant_name { #occurred_field_ident, .. } => *#occurred_field_ident
});
}
Fields::Unnamed(_) | Fields::Unit => {
return Err(syn::Error::new_spanned(
variant_name,
"#[event] Event generation requires named-fields variants when stream_id/occurred_at_field are used",
));
}
}
}
quote! {
impl ::causal::Event for #name {
const CATEGORY: &'static str = #category;
#stream_const
fn event_type(&self) -> &str {
match self {
#(#name_arms,)*
}
}
fn stream_id(&self) -> ::uuid::Uuid {
match self {
#(#stream_arms,)*
}
}
fn occurred_at(&self) -> ::core::option::Option<::chrono::DateTime<::chrono::Utc>> {
::core::option::Option::Some(match self {
#(#occurred_arms,)*
})
}
}
}
} else {
quote! {
impl ::causal::Event for #name {
const CATEGORY: &'static str = #prefix;
#stream_const
fn event_type(&self) -> &str {
match self {
#(#name_arms,)*
}
}
fn stream_id(&self) -> ::uuid::Uuid {
::uuid::Uuid::nil()
}
}
}
};
let _ = (match_arms, ephemeral);
Ok(quote! {
#input
#fact_impl
})
}
fn expand_event_struct(
args: EventArgs,
input: &DeriveInput,
) -> Result<TokenStream2, syn::Error> {
let name = &input.ident;
let ephemeral = args.ephemeral;
let durable = if let Some(ref prefix) = args.prefix {
prefix.clone()
} else {
to_snake_case(&name.to_string())
};
let prefix_str = durable.clone();
let bare_name = prefix_str.clone();
let stream_const = match &args.stream {
Some(s) => quote! { const STREAM_CATEGORY: &'static str = #s; },
None => quote! {},
};
let fact_impl = if let Some(id_field) = args.stream_id.as_ref() {
let category = args.stream_category.as_ref().unwrap_or(&prefix_str);
let id_field_ident = format_ident!("{}", id_field);
let occurred_field_ident = format_ident!(
"{}",
args.occurred_at_field
.clone()
.unwrap_or_else(|| "occurred_at".to_string())
);
quote! {
impl ::causal::Event for #name {
const CATEGORY: &'static str = #category;
#stream_const
fn event_type(&self) -> &str { #bare_name }
fn stream_id(&self) -> ::uuid::Uuid {
self.#id_field_ident
}
fn occurred_at(&self) -> ::core::option::Option<::chrono::DateTime<::chrono::Utc>> {
::core::option::Option::Some(self.#occurred_field_ident)
}
}
}
} else {
let category = args.stream_category.as_ref().unwrap_or(&prefix_str);
quote! {
impl ::causal::Event for #name {
const CATEGORY: &'static str = #category;
#stream_const
fn event_type(&self) -> &str { #bare_name }
fn stream_id(&self) -> ::uuid::Uuid {
::uuid::Uuid::nil()
}
}
}
};
let _ = (durable, prefix_str, ephemeral);
Ok(quote! {
#input
#fact_impl
})
}
struct SerdeInfo {
tag: Option<String>,
rename_all: Option<String>,
untagged: bool,
}
fn parse_serde_attrs(attrs: &[Attribute]) -> Result<SerdeInfo, syn::Error> {
let mut info = SerdeInfo {
tag: None,
rename_all: None,
untagged: false,
};
for attr in attrs {
if !attr.path().is_ident("serde") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("tag") {
let value = meta.value()?;
let lit: Lit = value.parse()?;
if let Lit::Str(s) = lit {
info.tag = Some(s.value());
}
} else if meta.path.is_ident("rename_all") {
let value = meta.value()?;
let lit: Lit = value.parse()?;
if let Lit::Str(s) = lit {
info.rename_all = Some(s.value());
}
} else if meta.path.is_ident("untagged") {
info.untagged = true;
} else if meta.input.peek(Token![=]) {
let _: Token![=] = meta.input.parse()?;
let _: Lit = meta.input.parse()?;
} else {
}
Ok(())
})?;
}
Ok(info)
}
fn get_serde_rename(attrs: &[Attribute]) -> Option<String> {
for attr in attrs {
if !attr.path().is_ident("serde") {
continue;
}
let mut rename = None;
let _ = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("rename") {
let value = meta.value()?;
let lit: Lit = value.parse()?;
if let Lit::Str(s) = lit {
rename = Some(s.value());
}
}
Ok(())
});
if rename.is_some() {
return rename;
}
}
None
}
fn apply_rename_rule(name: &str, rule: Option<&str>) -> String {
match rule {
Some("snake_case") => to_snake_case(name),
Some("camelCase") => to_camel_case(name),
Some("PascalCase") => name.to_string(),
Some("SCREAMING_SNAKE_CASE") => to_snake_case(name).to_uppercase(),
Some("kebab-case") => to_snake_case(name).replace('_', "-"),
_ => name.to_string(), }
}
fn to_snake_case(s: &str) -> String {
let mut result = String::new();
let mut prev_was_upper = false;
let mut prev_was_underscore = true;
for (i, ch) in s.chars().enumerate() {
if ch.is_uppercase() {
if i > 0 && !prev_was_underscore {
if !prev_was_upper {
result.push('_');
} else {
let next_is_lower = s.chars().nth(i + 1).map_or(false, |c| c.is_lowercase());
if next_is_lower {
result.push('_');
}
}
}
result.push(ch.to_lowercase().next().unwrap());
prev_was_upper = true;
prev_was_underscore = false;
} else if ch == '_' {
result.push('_');
prev_was_upper = false;
prev_was_underscore = true;
} else {
result.push(ch);
prev_was_upper = false;
prev_was_underscore = false;
}
}
result
}
fn to_camel_case(s: &str) -> String {
let mut result = String::new();
let mut first = true;
for ch in s.chars() {
if first && ch.is_uppercase() {
result.push(ch.to_lowercase().next().unwrap());
first = false;
} else {
result.push(ch);
first = false;
}
}
result
}