use proc_macro::TokenStream;
use quote::{format_ident, quote};
use std::collections::HashSet;
use syn::{parse_macro_input, spanned::Spanned};
#[proc_macro_attribute]
pub fn entity(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as EntityArgs);
let input = parse_macro_input!(item as syn::ItemStruct);
match entity_impl_inner(args, input) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn entity_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as ImplArgs);
let input = parse_macro_input!(item as syn::ItemImpl);
match entity_impl_block_inner(args, input) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn entity_trait(_attr: TokenStream, item: TokenStream) -> TokenStream {
let _ = item;
syn::Error::new(
proc_macro2::Span::call_site(),
"entity traits have been removed; use #[rpc_group] instead",
)
.to_compile_error()
.into()
}
#[proc_macro_attribute]
pub fn entity_trait_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
let _ = item;
syn::Error::new(
proc_macro2::Span::call_site(),
"entity trait impls have been removed; use #[rpc_group_impl] instead",
)
.to_compile_error()
.into()
}
#[proc_macro_attribute]
pub fn state(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn rpc(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn workflow(attr: TokenStream, item: TokenStream) -> TokenStream {
let item_clone = item.clone();
if syn::parse::<syn::ItemStruct>(item_clone).is_ok() {
let args = parse_macro_input!(attr as WorkflowStructArgs);
let input = parse_macro_input!(item as syn::ItemStruct);
match workflow_struct_inner(args, input) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
} else {
item
}
}
#[proc_macro_attribute]
pub fn workflow_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as WorkflowImplArgs);
let input = parse_macro_input!(item as syn::ItemImpl);
match workflow_impl_inner(args, input) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn activity(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn public(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn protected(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn private(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn method(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn rpc_group(attr: TokenStream, item: TokenStream) -> TokenStream {
let _args = parse_macro_input!(attr as TraitArgs);
let input = parse_macro_input!(item as syn::ItemStruct);
quote! { #input }.into()
}
#[proc_macro_attribute]
pub fn rpc_group_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as RpcGroupImplArgs);
let input = parse_macro_input!(item as syn::ItemImpl);
match rpc_group_impl_inner(args, input) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn activity_group(attr: TokenStream, item: TokenStream) -> TokenStream {
let _args = parse_macro_input!(attr as TraitArgs);
let input = parse_macro_input!(item as syn::ItemStruct);
quote! { #input }.into()
}
#[proc_macro_attribute]
pub fn activity_group_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as ActivityGroupImplArgs);
let input = parse_macro_input!(item as syn::ItemImpl);
match activity_group_impl_inner(args, input) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
struct EntityArgs {
name: Option<String>,
shard_group: Option<String>,
max_idle_time_secs: Option<u64>,
mailbox_capacity: Option<usize>,
concurrency: Option<usize>,
krate: Option<syn::Path>,
}
impl syn::parse::Parse for EntityArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut args = EntityArgs {
name: None,
shard_group: None,
max_idle_time_secs: None,
mailbox_capacity: None,
concurrency: None,
krate: None,
};
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
input.parse::<syn::Token![=]>()?;
match ident.to_string().as_str() {
"name" => {
let lit: syn::LitStr = input.parse()?;
args.name = Some(lit.value());
}
"shard_group" => {
let lit: syn::LitStr = input.parse()?;
args.shard_group = Some(lit.value());
}
"max_idle_time_secs" => {
let lit: syn::LitInt = input.parse()?;
args.max_idle_time_secs = Some(lit.base10_parse()?);
}
"mailbox_capacity" => {
let lit: syn::LitInt = input.parse()?;
args.mailbox_capacity = Some(lit.base10_parse()?);
}
"concurrency" => {
let lit: syn::LitInt = input.parse()?;
args.concurrency = Some(lit.base10_parse()?);
}
"krate" => {
let lit: syn::LitStr = input.parse()?;
args.krate = Some(lit.parse()?);
}
other => {
return Err(syn::Error::new(
ident.span(),
format!("unknown entity attribute: {other}"),
));
}
}
if !input.is_empty() {
input.parse::<syn::Token![,]>()?;
}
}
Ok(args)
}
}
struct ImplArgs {
krate: Option<syn::Path>,
traits: Vec<syn::Path>,
rpc_groups: Vec<syn::Path>,
deferred_keys: Vec<DeferredKeyDecl>,
}
struct DeferredKeyDecl {
ident: syn::Ident,
ty: syn::Type,
name: syn::LitStr,
}
impl syn::parse::Parse for DeferredKeyDecl {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let ident: syn::Ident = input.parse()?;
input.parse::<syn::Token![:]>()?;
let ty: syn::Type = input.parse()?;
if !input.peek(syn::Token![=]) {
return Err(syn::Error::new(
input.span(),
"expected `= \"name\"` for deferred key",
));
}
input.parse::<syn::Token![=]>()?;
let name: syn::LitStr = input.parse()?;
Ok(DeferredKeyDecl { ident, ty, name })
}
}
impl syn::parse::Parse for ImplArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut args = ImplArgs {
krate: None,
traits: Vec::new(),
rpc_groups: Vec::new(),
deferred_keys: Vec::new(),
};
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
match ident.to_string().as_str() {
"krate" => {
input.parse::<syn::Token![=]>()?;
let lit: syn::LitStr = input.parse()?;
args.krate = Some(lit.parse()?);
}
"traits" => {
let content;
syn::parenthesized!(content in input);
while !content.is_empty() {
let path: syn::Path = content.parse()?;
args.traits.push(path);
if !content.is_empty() {
content.parse::<syn::Token![,]>()?;
}
}
}
"rpc_groups" => {
let content;
syn::parenthesized!(content in input);
while !content.is_empty() {
let path: syn::Path = content.parse()?;
args.rpc_groups.push(path);
if !content.is_empty() {
content.parse::<syn::Token![,]>()?;
}
}
}
"deferred_keys" => {
let content;
syn::parenthesized!(content in input);
let decls = content.parse_terminated(DeferredKeyDecl::parse, syn::Token![,])?;
args.deferred_keys.extend(decls);
}
other => {
return Err(syn::Error::new(
ident.span(),
format!("unknown entity_impl attribute: {other}"),
));
}
}
if !input.is_empty() {
input.parse::<syn::Token![,]>()?;
}
}
Ok(args)
}
}
struct TraitArgs {
krate: Option<syn::Path>,
}
impl syn::parse::Parse for TraitArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut args = TraitArgs { krate: None };
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
input.parse::<syn::Token![=]>()?;
match ident.to_string().as_str() {
"krate" => {
let lit: syn::LitStr = input.parse()?;
args.krate = Some(lit.parse()?);
}
other => {
return Err(syn::Error::new(
ident.span(),
format!("unknown attribute: {other}"),
));
}
}
if !input.is_empty() {
input.parse::<syn::Token![,]>()?;
}
}
Ok(args)
}
}
fn default_crate_path() -> syn::Path {
syn::parse_str("cruster").unwrap()
}
fn replace_last_segment(path: &syn::Path, ident: syn::Ident) -> syn::Path {
let mut new_path = path.clone();
if let Some(last) = new_path.segments.last_mut() {
last.ident = ident;
last.arguments = syn::PathArguments::None;
}
new_path
}
#[allow(dead_code)]
struct RpcGroupInfo {
path: syn::Path,
ident: syn::Ident,
field: syn::Ident,
wrapper_path: syn::Path,
access_trait_path: syn::Path,
methods_trait_path: syn::Path,
}
fn rpc_group_infos_from_paths(paths: &[syn::Path]) -> Vec<RpcGroupInfo> {
paths
.iter()
.map(|path| {
let ident = path
.segments
.last()
.expect("rpc group path missing segment")
.ident
.clone();
let snake = to_snake(&ident.to_string());
let field = format_ident!("__rpc_group_{}", snake);
let wrapper_ident = format_ident!("__{}RpcGroupWrapper", ident);
let access_trait_ident = format_ident!("__{}RpcGroupAccess", ident);
let methods_trait_ident = format_ident!("__{}RpcGroupMethods", ident);
let wrapper_path = replace_last_segment(path, wrapper_ident);
let access_trait_path = replace_last_segment(path, access_trait_ident);
let methods_trait_path = replace_last_segment(path, methods_trait_ident);
RpcGroupInfo {
path: path.clone(),
ident,
field,
wrapper_path,
access_trait_path,
methods_trait_path,
}
})
.collect()
}
fn entity_impl_inner(
args: EntityArgs,
input: syn::ItemStruct,
) -> syn::Result<proc_macro2::TokenStream> {
let krate = args.krate.clone().unwrap_or_else(default_crate_path);
let struct_name = &input.ident;
let entity_name = args.name.unwrap_or_else(|| struct_name.to_string());
let shard_group_value = if let Some(sg) = &args.shard_group {
quote! { #sg }
} else {
quote! { "default" }
};
let max_idle_value = if let Some(secs) = args.max_idle_time_secs {
quote! { ::std::option::Option::Some(::std::time::Duration::from_secs(#secs)) }
} else {
quote! { ::std::option::Option::None }
};
let mailbox_value = if let Some(cap) = args.mailbox_capacity {
quote! { ::std::option::Option::Some(#cap) }
} else {
quote! { ::std::option::Option::None }
};
let concurrency_value = if let Some(c) = args.concurrency {
quote! { ::std::option::Option::Some(#c) }
} else {
quote! { ::std::option::Option::None }
};
Ok(quote! {
#input
#[allow(dead_code)]
impl #struct_name {
#[doc(hidden)]
fn __entity_type(&self) -> #krate::types::EntityType {
#krate::types::EntityType::new(#entity_name)
}
#[doc(hidden)]
fn __shard_group(&self) -> &str {
#shard_group_value
}
#[doc(hidden)]
fn __shard_group_for(&self, _entity_id: &#krate::types::EntityId) -> &str {
self.__shard_group()
}
#[doc(hidden)]
fn __max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
#max_idle_value
}
#[doc(hidden)]
fn __mailbox_capacity(&self) -> ::std::option::Option<usize> {
#mailbox_value
}
#[doc(hidden)]
fn __concurrency(&self) -> ::std::option::Option<usize> {
#concurrency_value
}
}
})
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum RpcKind {
Rpc,
Workflow,
Activity,
Method,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum RpcVisibility {
Public,
Protected,
Private,
}
impl RpcKind {
fn is_persisted(&self) -> bool {
matches!(self, RpcKind::Workflow | RpcKind::Activity)
}
}
impl RpcMethod {
fn uses_persisted_delivery(&self) -> bool {
self.kind.is_persisted() || self.rpc_persisted
}
}
impl RpcVisibility {
fn is_public(&self) -> bool {
matches!(self, RpcVisibility::Public)
}
fn is_private(&self) -> bool {
matches!(self, RpcVisibility::Private)
}
}
struct RpcMethod {
name: syn::Ident,
tag: String,
params: Vec<RpcParam>,
response_type: syn::Type,
is_mut: bool,
kind: RpcKind,
visibility: RpcVisibility,
persist_key: Option<syn::ExprClosure>,
#[allow(dead_code)]
has_durable_context: bool,
rpc_persisted: bool,
#[allow(dead_code)]
retries: Option<u32>,
#[allow(dead_code)]
backoff: Option<String>,
}
impl RpcMethod {
fn is_dispatchable(&self) -> bool {
self.visibility.is_public() && !matches!(self.kind, RpcKind::Activity | RpcKind::Method)
}
fn is_client_visible(&self) -> bool {
self.visibility.is_public() && !matches!(self.kind, RpcKind::Activity | RpcKind::Method)
}
fn is_trait_visible(&self) -> bool {
!self.visibility.is_private() && !matches!(self.kind, RpcKind::Method)
}
}
struct RpcParam {
name: syn::Ident,
ty: syn::Type,
}
fn entity_impl_block_inner(
args: ImplArgs,
input: syn::ItemImpl,
) -> syn::Result<proc_macro2::TokenStream> {
let krate = args.krate.unwrap_or_else(default_crate_path);
let traits = args.traits;
let rpc_groups = args.rpc_groups;
let deferred_keys = args.deferred_keys;
let mut input = input;
let self_ty = &input.self_ty;
let state_info = parse_state_attr(&mut input.attrs)?;
if let Some(ref info) = state_info {
return Err(syn::Error::new(
info.span,
"entities are stateless; use a database for state management",
));
}
if !traits.is_empty() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"entity traits have been replaced by #[rpc_group]; use #[entity_impl(rpc_groups(...))] instead",
));
}
let struct_name = match self_ty.as_ref() {
syn::Type::Path(tp) => tp
.path
.segments
.last()
.map(|s| s.ident.clone())
.ok_or_else(|| syn::Error::new(self_ty.span(), "expected struct name"))?,
_ => return Err(syn::Error::new(self_ty.span(), "expected struct name")),
};
let handler_name = format_ident!("{}Handler", struct_name);
let client_name = format_ident!("{}Client", struct_name);
let mut rpcs = Vec::new();
let mut original_methods = Vec::new();
for item in &input.items {
match item {
syn::ImplItem::Type(type_item) if type_item.ident == "State" => {
return Err(syn::Error::new(
type_item.span(),
"entities are stateless; use a database for state management",
));
}
syn::ImplItem::Fn(method) => {
let has_rpc_attrs = parse_kind_attr(&method.attrs)?.is_some()
|| parse_visibility_attr(&method.attrs)?.is_some();
if method.sig.ident == "init" && method.sig.asyncness.is_none() {
return Err(syn::Error::new(
method.sig.span(),
"entities are stateless; `fn init` is no longer needed",
));
}
if method.sig.asyncness.is_some() {
if let Some(rpc) = parse_rpc_method(method)? {
if matches!(rpc.kind, RpcKind::Workflow) {
return Err(syn::Error::new(
method.sig.span(),
"use standalone #[workflow] for durable orchestration; \
entities only support #[rpc] and #[rpc(persisted)]",
));
}
if matches!(rpc.kind, RpcKind::Activity) {
return Err(syn::Error::new(
method.sig.span(),
"activities belong on workflows, not entities; \
use standalone #[workflow] with #[activity] methods",
));
}
if rpc.is_mut {
return Err(syn::Error::new(
method.sig.span(),
"entity methods must use `&self`; \
entities are stateless and do not support `&mut self`",
));
}
rpcs.push(rpc);
}
} else if has_rpc_attrs {
return Err(syn::Error::new(
method.sig.span(),
"RPC annotations are only valid on async methods",
));
}
original_methods.push(method.clone());
}
_ => {}
}
}
let rpc_group_infos = rpc_group_infos_from_paths(&rpc_groups);
let entity_tokens = generate_pure_rpc_entity(
&krate,
&struct_name,
&handler_name,
&client_name,
&rpc_group_infos,
&rpcs,
&original_methods,
)?;
let deferred_consts = generate_deferred_key_consts(&krate, &deferred_keys)?;
Ok(quote! {
#entity_tokens
#deferred_consts
})
}
fn generate_deferred_key_consts(
krate: &syn::Path,
deferred_keys: &[DeferredKeyDecl],
) -> syn::Result<proc_macro2::TokenStream> {
if deferred_keys.is_empty() {
return Ok(quote! {});
}
let mut seen_idents = HashSet::new();
let mut seen_names = HashSet::new();
for decl in deferred_keys {
let ident = decl.ident.to_string();
if !seen_idents.insert(ident.clone()) {
return Err(syn::Error::new(
decl.ident.span(),
format!("duplicate deferred key constant: {ident}"),
));
}
let name = decl.name.value();
if !seen_names.insert(name.clone()) {
return Err(syn::Error::new(
decl.name.span(),
format!("duplicate deferred key name: {name}"),
));
}
}
let consts: Vec<_> = deferred_keys
.iter()
.map(|decl| {
let ident = &decl.ident;
let ty = &decl.ty;
let name = &decl.name;
quote! {
#[allow(dead_code)]
pub const #ident: #krate::__internal::DeferredKey<#ty> =
#krate::__internal::DeferredKey::new(#name);
}
})
.collect();
Ok(quote! {
#(#consts)*
})
}
#[allow(clippy::too_many_arguments)]
fn generate_pure_rpc_entity(
krate: &syn::Path,
struct_name: &syn::Ident,
handler_name: &syn::Ident,
client_name: &syn::Ident,
rpc_group_infos: &[RpcGroupInfo],
rpcs: &[RpcMethod],
original_methods: &[syn::ImplItemFn],
) -> syn::Result<proc_macro2::TokenStream> {
let has_rpc_groups = !rpc_group_infos.is_empty();
let struct_name_str = struct_name.to_string();
let rpc_view_name = format_ident!("__{}RpcView", struct_name);
let entity_impl = if has_rpc_groups {
quote! {}
} else {
quote! {
#[async_trait::async_trait]
impl #krate::entity::Entity for #struct_name {
fn entity_type(&self) -> #krate::types::EntityType {
self.__entity_type()
}
fn shard_group(&self) -> &str {
self.__shard_group()
}
fn shard_group_for(&self, entity_id: &#krate::types::EntityId) -> &str {
self.__shard_group_for(entity_id)
}
fn max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
self.__max_idle_time()
}
fn mailbox_capacity(&self) -> ::std::option::Option<usize> {
self.__mailbox_capacity()
}
fn concurrency(&self) -> ::std::option::Option<usize> {
self.__concurrency()
}
async fn spawn(
&self,
ctx: #krate::entity::EntityContext,
) -> ::std::result::Result<
::std::boxed::Box<dyn #krate::entity::EntityHandler>,
#krate::error::ClusterError,
> {
let handler = #handler_name::__new(self.clone(), ctx).await?;
::std::result::Result::Ok(::std::boxed::Box::new(handler))
}
}
}
};
let dispatch_arms: Vec<proc_macro2::TokenStream> = rpcs
.iter()
.filter(|rpc| rpc.is_dispatchable())
.map(|rpc| {
let tag = &rpc.tag;
let method_name = &rpc.name;
let param_count = rpc.params.len();
let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
let deserialize_request = match param_count {
0 => quote! {},
1 => {
let name = ¶m_names[0];
let ty = ¶m_types[0];
quote! {
let #name: #ty = rmp_serde::from_slice(payload)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("failed to deserialize request for '{}': {e}", #tag),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
}
}
_ => quote! {
let (#(#param_names),*): (#(#param_types),*) = rmp_serde::from_slice(payload)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("failed to deserialize request for '{}': {e}", #tag),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
},
};
let mut call_args: Vec<proc_macro2::TokenStream> = Vec::new();
for name in ¶m_names {
call_args.push(quote! { #name });
}
let call_args = quote! { #(#call_args),* };
let method_call = quote! { __view.#method_name(#call_args).await? };
quote! {
#tag => {
let __view = #rpc_view_name { __handler: self };
#deserialize_request
let response = { #method_call };
rmp_serde::to_vec(&response)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("failed to serialize response for '{}': {e}", #tag),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})
}
}
})
.collect();
let client_methods = generate_client_methods(krate, rpcs);
let rpc_group_field_defs: Vec<proc_macro2::TokenStream> = rpc_group_infos
.iter()
.map(|info| {
let field = &info.field;
let wrapper_path = &info.wrapper_path;
quote! { #field: #wrapper_path, }
})
.collect();
let rpc_group_params: Vec<proc_macro2::TokenStream> = rpc_group_infos
.iter()
.map(|info| {
let path = &info.path;
let field = &info.field;
quote! { #field: #path }
})
.collect();
let rpc_group_field_inits: Vec<proc_macro2::TokenStream> = rpc_group_infos
.iter()
.map(|info| {
let field = &info.field;
let wrapper_path = &info.wrapper_path;
quote! { #field: #wrapper_path::new(#field, __entity_address.clone()), }
})
.collect();
let rpc_group_dispatch_checks: Vec<proc_macro2::TokenStream> = rpc_group_infos
.iter()
.map(|info| {
let field = &info.field;
quote! {
if let ::std::option::Option::Some(response) = self.#field.__dispatch(tag, payload, headers).await? {
return ::std::result::Result::Ok(response);
}
}
})
.collect();
let rpc_group_handler_access_impls: Vec<proc_macro2::TokenStream> = rpc_group_infos
.iter()
.map(|info| {
let wrapper_path = &info.wrapper_path;
let access_trait_path = &info.access_trait_path;
let field = &info.field;
quote! {
impl #access_trait_path for #handler_name {
fn __rpc_group_wrapper(&self) -> &#wrapper_path {
&self.#field
}
}
}
})
.collect();
let rpc_group_use_tokens: Vec<proc_macro2::TokenStream> = rpc_group_infos
.iter()
.map(|info| {
let methods_trait_path = &info.methods_trait_path;
quote! {
#[allow(unused_imports)]
use #methods_trait_path as _;
}
})
.collect();
let dispatch_fallback = if has_rpc_groups {
quote! {{
#(#rpc_group_dispatch_checks)*
::std::result::Result::Err(
#krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("unknown RPC tag: {tag}"),
source: ::std::option::Option::None,
}
)
}}
} else {
quote! {{
::std::result::Result::Err(
#krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("unknown RPC tag: {tag}"),
source: ::std::option::Option::None,
}
)
}}
};
let method_impls: Vec<proc_macro2::TokenStream> = original_methods
.iter()
.map(|m| {
let attrs: Vec<_> = m
.attrs
.iter()
.filter(|a| {
!a.path().is_ident("rpc")
&& !a.path().is_ident("workflow")
&& !a.path().is_ident("activity")
&& !a.path().is_ident("method")
&& !a.path().is_ident("public")
&& !a.path().is_ident("protected")
&& !a.path().is_ident("private")
})
.collect();
let vis = &m.vis;
let sig = &m.sig;
let block = &m.block;
quote! {
#(#attrs)*
#vis #sig #block
}
})
.collect();
let new_fn = if has_rpc_groups {
quote! {}
} else {
quote! {
#[doc(hidden)]
pub async fn __new(entity: #struct_name, ctx: #krate::entity::EntityContext) -> ::std::result::Result<Self, #krate::error::ClusterError> {
let __sharding = ctx.sharding.clone();
let __entity_address = ctx.address.clone();
let __message_storage = ctx.message_storage.clone();
::std::result::Result::Ok(Self {
__entity: entity,
ctx,
__sharding,
__entity_address,
__message_storage,
})
}
}
};
let new_with_rpc_groups_fn = if has_rpc_groups {
quote! {
#[doc(hidden)]
pub async fn __new_with_rpc_groups(
entity: #struct_name,
#(#rpc_group_params,)*
ctx: #krate::entity::EntityContext,
) -> ::std::result::Result<Self, #krate::error::ClusterError> {
let __sharding = ctx.sharding.clone();
let __entity_address = ctx.address.clone();
let __message_storage = ctx.message_storage.clone();
::std::result::Result::Ok(Self {
__entity: entity,
ctx,
__sharding,
__message_storage,
#(#rpc_group_field_inits)*
__entity_address,
})
}
}
} else {
quote! {}
};
let sharding_builtin_impls = quote! {
pub fn sharding(&self) -> ::std::option::Option<&::std::sync::Arc<dyn #krate::sharding::Sharding>> {
self.__sharding.as_ref()
}
pub fn entity_address(&self) -> &#krate::types::EntityAddress {
&self.__entity_address
}
pub fn entity_id(&self) -> &#krate::types::EntityId {
&self.__entity_address.entity_id
}
pub fn self_client(&self) -> ::std::option::Option<#krate::entity_client::EntityClient> {
self.__sharding.as_ref().map(|s| {
::std::sync::Arc::clone(s).make_client(self.__entity_address.entity_type.clone())
})
}
};
let with_rpc_groups_impl = if has_rpc_groups {
let with_groups_name = format_ident!("{}WithRpcGroups", struct_name);
let rpc_group_option_fields: Vec<proc_macro2::TokenStream> = rpc_group_infos
.iter()
.map(|info| {
let field = &info.field;
let path = &info.path;
quote! { #field: #path, }
})
.collect();
let register_rpc_group_params: Vec<proc_macro2::TokenStream> = rpc_group_infos
.iter()
.map(|info| {
let field = &info.field;
let path = &info.path;
quote! { #field: #path }
})
.collect();
let register_rpc_group_fields: Vec<_> =
rpc_group_infos.iter().map(|info| &info.field).collect();
quote! {
#[doc(hidden)]
pub struct #with_groups_name {
pub entity: #struct_name,
#(pub #rpc_group_option_fields)*
}
#[async_trait::async_trait]
impl #krate::entity::Entity for #with_groups_name
where
#struct_name: ::std::clone::Clone,
{
fn entity_type(&self) -> #krate::types::EntityType {
self.entity.__entity_type()
}
fn shard_group(&self) -> &str {
self.entity.__shard_group()
}
fn shard_group_for(&self, entity_id: &#krate::types::EntityId) -> &str {
self.entity.__shard_group_for(entity_id)
}
fn max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
self.entity.__max_idle_time()
}
fn mailbox_capacity(&self) -> ::std::option::Option<usize> {
self.entity.__mailbox_capacity()
}
fn concurrency(&self) -> ::std::option::Option<usize> {
self.entity.__concurrency()
}
async fn spawn(
&self,
ctx: #krate::entity::EntityContext,
) -> ::std::result::Result<
::std::boxed::Box<dyn #krate::entity::EntityHandler>,
#krate::error::ClusterError,
> {
let handler = #handler_name::__new_with_rpc_groups(
self.entity.clone(),
#(self.#register_rpc_group_fields.clone(),)*
ctx,
)
.await?;
::std::result::Result::Ok(::std::boxed::Box::new(handler))
}
}
impl #struct_name {
pub async fn register(
self,
sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>,
#(#register_rpc_group_params),*
) -> ::std::result::Result<#client_name, #krate::error::ClusterError> {
let entity_with_groups = #with_groups_name {
entity: self,
#(#register_rpc_group_fields,)*
};
sharding.register_entity(::std::sync::Arc::new(entity_with_groups)).await?;
::std::result::Result::Ok(#client_name::new(sharding))
}
}
}
} else {
quote! {}
};
let register_impl = if has_rpc_groups {
quote! {} } else {
quote! {
impl #struct_name {
pub async fn register(
self,
sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>,
) -> ::std::result::Result<#client_name, #krate::error::ClusterError> {
sharding.register_entity(::std::sync::Arc::new(self)).await?;
::std::result::Result::Ok(#client_name::new(sharding))
}
}
}
};
Ok(quote! {
#(#rpc_group_use_tokens)*
#with_rpc_groups_impl
#entity_impl
#[doc(hidden)]
pub struct #handler_name {
#[allow(dead_code)]
__entity: #struct_name,
#[allow(dead_code)]
ctx: #krate::entity::EntityContext,
__sharding: ::std::option::Option<::std::sync::Arc<dyn #krate::sharding::Sharding>>,
__entity_address: #krate::types::EntityAddress,
#[allow(dead_code)]
__message_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::MessageStorage>>,
#(#rpc_group_field_defs)*
}
impl #handler_name {
#new_fn
#new_with_rpc_groups_fn
#sharding_builtin_impls
}
#[doc(hidden)]
#[allow(non_camel_case_types)]
struct #rpc_view_name<'a> {
__handler: &'a #handler_name,
}
impl ::std::ops::Deref for #rpc_view_name<'_> {
type Target = #struct_name;
fn deref(&self) -> &Self::Target {
&self.__handler.__entity
}
}
impl #rpc_view_name<'_> {
#[inline]
fn entity_id(&self) -> &str {
&self.__handler.__entity_address.entity_id.0
}
#[inline]
fn entity_address(&self) -> &#krate::types::EntityAddress {
&self.__handler.__entity_address
}
#[inline]
fn sharding(&self) -> ::std::option::Option<&::std::sync::Arc<dyn #krate::sharding::Sharding>> {
self.__handler.__sharding.as_ref()
}
#[inline]
fn self_client(&self) -> ::std::option::Option<#krate::entity_client::EntityClient> {
self.__handler.__sharding.as_ref().map(|s| {
::std::sync::Arc::clone(s).make_client(self.__handler.__entity_address.entity_type.clone())
})
}
#(#method_impls)*
}
#[async_trait::async_trait]
impl #krate::entity::EntityHandler for #handler_name {
async fn handle_request(
&self,
tag: &str,
payload: &[u8],
headers: &::std::collections::HashMap<::std::string::String, ::std::string::String>,
) -> ::std::result::Result<::std::vec::Vec<u8>, #krate::error::ClusterError> {
#[allow(unused_variables)]
let headers = headers;
match tag {
#(#dispatch_arms,)*
_ => #dispatch_fallback,
}
}
}
#register_impl
#[derive(Clone)]
pub struct #client_name {
inner: #krate::entity_client::EntityClient,
}
impl #client_name {
pub fn new(sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>) -> Self {
Self {
inner: #krate::entity_client::EntityClient::new(
sharding,
#krate::types::EntityType::new(#struct_name_str),
),
}
}
pub fn inner(&self) -> &#krate::entity_client::EntityClient {
&self.inner
}
#(#client_methods)*
}
impl #krate::entity_client::EntityClientAccessor for #client_name {
fn entity_client(&self) -> &#krate::entity_client::EntityClient {
&self.inner
}
}
#(#rpc_group_handler_access_impls)*
})
}
fn generate_client_methods(krate: &syn::Path, rpcs: &[RpcMethod]) -> Vec<proc_macro2::TokenStream> {
rpcs.iter()
.filter(|rpc| rpc.is_client_visible())
.map(|rpc| {
let method_name = &rpc.name;
let tag = &rpc.tag;
let resp_type = &rpc.response_type;
let persist_key = rpc.persist_key.as_ref();
let param_count = rpc.params.len();
let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
let param_defs: Vec<_> = param_names
.iter()
.zip(param_types.iter())
.map(|(name, ty)| quote! { #name: &#ty })
.collect();
if rpc.uses_persisted_delivery() {
match (persist_key, param_count) {
(Some(persist_key), 0) => quote! {
pub async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
let key = (#persist_key)();
let key_bytes = rmp_serde::to_vec(&key)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!(
"failed to serialize persist key for '{}': {e}",
#tag
),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
self.inner
.send_persisted_with_key(
entity_id,
#tag,
&(),
::std::option::Option::Some(key_bytes),
#krate::schema::Uninterruptible::No,
)
.await
}
},
(Some(persist_key), 1) => {
let name = ¶m_names[0];
let def = ¶m_defs[0];
quote! {
pub async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
#def,
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
let key = (#persist_key)(#name);
let key_bytes = rmp_serde::to_vec(&key)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!(
"failed to serialize persist key for '{}': {e}",
#tag
),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
self.inner
.send_persisted_with_key(
entity_id,
#tag,
#name,
::std::option::Option::Some(key_bytes),
#krate::schema::Uninterruptible::No,
)
.await
}
}
}
(Some(persist_key), _) => quote! {
pub async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
#(#param_defs),*
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
let key = (#persist_key)(#(#param_names),*);
let key_bytes = rmp_serde::to_vec(&key)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!(
"failed to serialize persist key for '{}': {e}",
#tag
),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
let request = (#(#param_names),*);
self.inner
.send_persisted_with_key(
entity_id,
#tag,
&request,
::std::option::Option::Some(key_bytes),
#krate::schema::Uninterruptible::No,
)
.await
}
},
(None, 0) => quote! {
pub async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
self.inner
.send_persisted(entity_id, #tag, &(), #krate::schema::Uninterruptible::No)
.await
}
},
(None, 1) => {
let name = ¶m_names[0];
let def = ¶m_defs[0];
quote! {
pub async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
#def,
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
self.inner
.send_persisted(entity_id, #tag, #name, #krate::schema::Uninterruptible::No)
.await
}
}
}
(None, _) => quote! {
pub async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
#(#param_defs),*
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
let request = (#(#param_names),*);
self.inner
.send_persisted(entity_id, #tag, &request, #krate::schema::Uninterruptible::No)
.await
}
},
}
} else {
match param_count {
0 => quote! {
pub async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
self.inner.send(entity_id, #tag, &()).await
}
},
1 => {
let def = ¶m_defs[0];
let name = ¶m_names[0];
quote! {
pub async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
#def,
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
self.inner.send(entity_id, #tag, #name).await
}
}
}
_ => quote! {
pub async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
#(#param_defs),*
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
let request = (#(#param_names),*);
self.inner.send(entity_id, #tag, &request).await
}
},
}
}
})
.collect()
}
fn is_durable_context_type(ty: &syn::Type) -> bool {
match ty {
syn::Type::Reference(r) => is_durable_context_type(&r.elem),
syn::Type::Path(tp) => tp
.path
.segments
.last()
.map(|s| s.ident == "DurableContext")
.unwrap_or(false),
_ => false,
}
}
struct StateArgs {
#[allow(dead_code)]
ty: syn::Type,
span: proc_macro2::Span,
}
impl syn::parse::Parse for StateArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let ty: syn::Type = input.parse()?;
if !input.is_empty() {
return Err(syn::Error::new(
input.span(),
"unexpected tokens in #[state(...)]; state is always persistent",
));
}
Ok(StateArgs {
ty,
span: proc_macro2::Span::call_site(),
})
}
}
fn parse_state_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<StateArgs>> {
let mut state_attr: Option<StateArgs> = None;
let mut i = 0;
while i < attrs.len() {
if attrs[i].path().is_ident("state") {
if state_attr.is_some() {
return Err(syn::Error::new(
attrs[i].span(),
"duplicate #[state(...)] attribute",
));
}
let attr_span = attrs[i].span();
let mut args = attrs[i].parse_args::<StateArgs>()?;
args.span = attr_span;
state_attr = Some(args);
attrs.remove(i);
continue;
}
i += 1;
}
Ok(state_attr)
}
struct RpcArgs {
persisted: bool,
}
impl syn::parse::Parse for RpcArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut args = RpcArgs { persisted: false };
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
match ident.to_string().as_str() {
"persisted" => {
args.persisted = true;
}
other => {
return Err(syn::Error::new(
ident.span(),
format!("unknown rpc attribute: {other}; expected `persisted`"),
));
}
}
if !input.is_empty() {
input.parse::<syn::Token![,]>()?;
}
}
Ok(args)
}
}
struct KeyArgs {
key: Option<syn::ExprClosure>,
}
impl syn::parse::Parse for KeyArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
if input.is_empty() {
return Ok(KeyArgs { key: None });
}
let ident: syn::Ident = input.parse()?;
if ident != "key" {
return Err(syn::Error::new(
ident.span(),
"expected `key` in #[workflow(key(...))] or #[activity(key(...))]",
));
}
if input.peek(syn::Token![=]) {
input.parse::<syn::Token![=]>()?;
}
let expr: syn::Expr = if input.peek(syn::token::Paren) {
let content;
syn::parenthesized!(content in input);
content.parse()?
} else {
input.parse()?
};
if !input.is_empty() {
return Err(syn::Error::new(
input.span(),
"unexpected tokens in #[workflow(...)] or #[activity(...)]",
));
}
match expr {
syn::Expr::Closure(closure) => Ok(KeyArgs { key: Some(closure) }),
_ => Err(syn::Error::new(
expr.span(),
"key must be a closure, e.g. #[workflow(key(|req| ...))]",
)),
}
}
}
struct ActivityAttrArgs {
key: Option<syn::ExprClosure>,
retries: Option<u32>,
backoff: Option<String>,
}
impl syn::parse::Parse for ActivityAttrArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut key = None;
let mut retries = None;
let mut backoff = None;
if input.is_empty() {
return Ok(ActivityAttrArgs {
key,
retries,
backoff,
});
}
loop {
if input.is_empty() {
break;
}
let ident: syn::Ident = input.parse()?;
if ident == "key" {
if key.is_some() {
return Err(syn::Error::new(ident.span(), "duplicate `key` argument"));
}
if input.peek(syn::Token![=]) {
input.parse::<syn::Token![=]>()?;
}
let expr: syn::Expr = if input.peek(syn::token::Paren) {
let content;
syn::parenthesized!(content in input);
content.parse()?
} else {
input.parse()?
};
match expr {
syn::Expr::Closure(closure) => key = Some(closure),
_ => {
return Err(syn::Error::new(
expr.span(),
"key must be a closure, e.g. #[activity(key = |req| ...)]",
))
}
}
} else if ident == "retries" {
if retries.is_some() {
return Err(syn::Error::new(
ident.span(),
"duplicate `retries` argument",
));
}
input.parse::<syn::Token![=]>()?;
let lit: syn::LitInt = input.parse()?;
retries = Some(lit.base10_parse::<u32>()?);
} else if ident == "backoff" {
if backoff.is_some() {
return Err(syn::Error::new(
ident.span(),
"duplicate `backoff` argument",
));
}
input.parse::<syn::Token![=]>()?;
let lit: syn::LitStr = input.parse()?;
let value = lit.value();
if value != "exponential" && value != "constant" {
return Err(syn::Error::new(
lit.span(),
"backoff must be \"exponential\" or \"constant\"",
));
}
backoff = Some(value);
} else {
return Err(syn::Error::new(
ident.span(),
"expected `key`, `retries`, or `backoff` in #[activity(...)]",
));
}
if input.peek(syn::Token![,]) {
input.parse::<syn::Token![,]>()?;
}
}
if backoff.is_some() && retries.is_none() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"`backoff` requires `retries` to be specified",
));
}
Ok(ActivityAttrArgs {
key,
retries,
backoff,
})
}
}
struct KindAttrInfo {
kind: RpcKind,
key: Option<syn::ExprClosure>,
rpc_persisted: bool,
retries: Option<u32>,
backoff: Option<String>,
}
fn parse_kind_attr(attrs: &[syn::Attribute]) -> syn::Result<Option<KindAttrInfo>> {
let mut kind: Option<RpcKind> = None;
let mut key: Option<syn::ExprClosure> = None;
let mut rpc_persisted = false;
let mut retries: Option<u32> = None;
let mut backoff: Option<String> = None;
for attr in attrs {
if attr.path().is_ident("rpc") {
if kind.is_some() {
return Err(syn::Error::new(attr.span(), "multiple RPC kind attributes"));
}
match &attr.meta {
syn::Meta::Path(_) => {
kind = Some(RpcKind::Rpc);
}
syn::Meta::List(_) => {
let args = attr.parse_args::<RpcArgs>()?;
kind = Some(RpcKind::Rpc);
rpc_persisted = args.persisted;
}
_ => {
return Err(syn::Error::new(
attr.span(),
"expected #[rpc] or #[rpc(persisted)]",
))
}
}
}
if attr.path().is_ident("workflow") {
if kind.is_some() {
return Err(syn::Error::new(attr.span(), "multiple RPC kind attributes"));
}
let args = match &attr.meta {
syn::Meta::Path(_) => KeyArgs { key: None },
syn::Meta::List(_) => attr.parse_args::<KeyArgs>()?,
syn::Meta::NameValue(_) => {
return Err(syn::Error::new(
attr.span(),
"expected #[workflow] or #[workflow(key(...))]",
))
}
};
kind = Some(RpcKind::Workflow);
if args.key.is_some() {
key = args.key;
}
}
if attr.path().is_ident("activity") {
if kind.is_some() {
return Err(syn::Error::new(attr.span(), "multiple RPC kind attributes"));
}
let args = match &attr.meta {
syn::Meta::Path(_) => ActivityAttrArgs {
key: None,
retries: None,
backoff: None,
},
syn::Meta::List(_) => attr.parse_args::<ActivityAttrArgs>()?,
syn::Meta::NameValue(_) => {
return Err(syn::Error::new(
attr.span(),
"expected #[activity] or #[activity(...)]",
))
}
};
kind = Some(RpcKind::Activity);
if args.key.is_some() {
key = args.key;
}
retries = args.retries;
backoff = args.backoff;
}
if attr.path().is_ident("method") {
if kind.is_some() {
return Err(syn::Error::new(attr.span(), "multiple RPC kind attributes"));
}
match &attr.meta {
syn::Meta::Path(_) => {
kind = Some(RpcKind::Method);
}
_ => {
return Err(syn::Error::new(
attr.span(),
"#[method] does not take arguments",
))
}
}
}
}
Ok(kind.map(|kind| KindAttrInfo {
kind,
key,
rpc_persisted,
retries,
backoff,
}))
}
fn parse_visibility_attr(attrs: &[syn::Attribute]) -> syn::Result<Option<RpcVisibility>> {
let mut visibility: Option<RpcVisibility> = None;
for attr in attrs {
let next = if attr.path().is_ident("public") {
Some(RpcVisibility::Public)
} else if attr.path().is_ident("protected") {
Some(RpcVisibility::Protected)
} else if attr.path().is_ident("private") {
Some(RpcVisibility::Private)
} else {
None
};
if let Some(next) = next {
match &attr.meta {
syn::Meta::Path(_) => {}
_ => {
return Err(syn::Error::new(
attr.span(),
"visibility attributes do not take arguments",
))
}
}
if visibility.is_some() {
return Err(syn::Error::new(
attr.span(),
"multiple visibility modifiers are not allowed",
));
}
visibility = Some(next);
}
}
Ok(visibility)
}
fn parse_rpc_method(method: &syn::ImplItemFn) -> syn::Result<Option<RpcMethod>> {
let name = method.sig.ident.clone();
let tag = name.to_string();
let kind_info = parse_kind_attr(&method.attrs)?;
let visibility_attr = parse_visibility_attr(&method.attrs)?;
let KindAttrInfo {
kind,
key: persist_key,
rpc_persisted,
retries,
backoff,
} = match kind_info {
Some(info) => info,
None => {
if visibility_attr.is_some() {
return Err(syn::Error::new(
method.sig.span(),
"visibility modifiers require #[rpc], #[workflow], #[activity], or #[method]",
));
}
return Ok(None);
}
};
if method.sig.asyncness.is_none() && !matches!(kind, RpcKind::Method) {
return Err(syn::Error::new(
method.sig.span(),
"#[rpc]/#[workflow]/#[activity] can only be applied to async methods",
));
}
if matches!(kind, RpcKind::Rpc | RpcKind::Method) && persist_key.is_some() {
return Err(syn::Error::new(
method.sig.span(),
"#[rpc] and #[method] do not support key(...) — use #[workflow(key(...))] or #[activity(key(...))]",
));
}
if rpc_persisted && !matches!(kind, RpcKind::Rpc) {
return Err(syn::Error::new(
method.sig.span(),
"persisted flag is only valid on #[rpc(persisted)]",
));
}
let visibility = match (kind, visibility_attr) {
(_, Some(RpcVisibility::Public)) if matches!(kind, RpcKind::Activity | RpcKind::Method) => {
return Err(syn::Error::new(
method.sig.span(),
"#[activity] and #[method] cannot be #[public]",
))
}
(RpcKind::Activity | RpcKind::Method, None) => RpcVisibility::Private,
(RpcKind::Rpc | RpcKind::Workflow, None) => RpcVisibility::Public,
(_, Some(vis)) => vis,
};
let is_mut = method
.sig
.inputs
.first()
.map(|arg| match arg {
syn::FnArg::Receiver(r) => r.mutability.is_some(),
_ => false,
})
.unwrap_or(false);
if is_mut && !matches!(kind, RpcKind::Activity) {
return Err(syn::Error::new(
method.sig.span(),
"only #[activity] methods can use `&mut self` for state mutation; use `&self` for read-only access",
));
}
let mut params = Vec::new();
let mut has_durable_context = false;
let mut saw_non_ctx_param = false;
let mut param_index = 0usize;
for arg in method.sig.inputs.iter().skip(1) {
match arg {
syn::FnArg::Typed(pat_type) => {
if is_durable_context_type(&pat_type.ty) {
if has_durable_context {
return Err(syn::Error::new(
arg.span(),
"duplicate DurableContext parameter",
));
}
if saw_non_ctx_param {
return Err(syn::Error::new(
arg.span(),
"DurableContext must be the first parameter after &self",
));
}
has_durable_context = true;
continue; }
saw_non_ctx_param = true;
let name = match &*pat_type.pat {
syn::Pat::Ident(ident) => ident.ident.clone(),
syn::Pat::Wild(_) => {
let ident = format_ident!("__arg{param_index}");
ident
}
_ => {
return Err(syn::Error::new(
pat_type.pat.span(),
"entity RPC parameters must be simple identifiers",
))
}
};
param_index += 1;
params.push(RpcParam {
name,
ty: (*pat_type.ty).clone(),
});
}
syn::FnArg::Receiver(_) => {}
}
}
if has_durable_context && matches!(kind, RpcKind::Rpc | RpcKind::Method) {
return Err(syn::Error::new(
method.sig.span(),
"methods with `&DurableContext` must be marked #[workflow] or #[activity]",
));
}
let response_type = match &method.sig.output {
syn::ReturnType::Type(_, ty) => {
if matches!(kind, RpcKind::Method) {
(**ty).clone()
} else {
extract_result_ok_type(ty)?
}
}
syn::ReturnType::Default => {
if matches!(kind, RpcKind::Method) {
syn::parse_quote!(())
} else {
return Err(syn::Error::new(
method.sig.span(),
"entity RPC methods must return Result<T, ClusterError>",
));
}
}
};
if retries.is_some() && !matches!(kind, RpcKind::Activity) {
return Err(syn::Error::new(
method.sig.span(),
"`retries` is only valid on #[activity(retries = N)]",
));
}
Ok(Some(RpcMethod {
name,
tag,
params,
response_type,
is_mut,
kind,
visibility,
persist_key,
has_durable_context,
rpc_persisted,
retries,
backoff,
}))
}
fn to_snake(input: &str) -> String {
let mut out = String::new();
let mut prev_is_upper = false;
let mut prev_is_lower = false;
let chars: Vec<char> = input.chars().collect();
for (i, ch) in chars.iter().enumerate() {
let is_upper = ch.is_uppercase();
let is_lower = ch.is_lowercase();
let next_is_lower = chars.get(i + 1).map(|c| c.is_lowercase()).unwrap_or(false);
if is_upper {
if prev_is_lower || (prev_is_upper && next_is_lower) {
out.push('_');
}
for lower in ch.to_lowercase() {
out.push(lower);
}
} else if ch.is_alphanumeric() || *ch == '_' {
out.push(*ch);
}
prev_is_upper = is_upper;
prev_is_lower = is_lower;
}
out
}
fn extract_result_ok_type(ty: &syn::Type) -> syn::Result<syn::Type> {
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Result" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(ok_type)) = args.args.first() {
return Ok(ok_type.clone());
}
}
}
}
}
Err(syn::Error::new(
ty.span(),
"expected Result<T, ClusterError> return type",
))
}
struct ActivityGroupImplArgs {
krate: Option<syn::Path>,
}
impl syn::parse::Parse for ActivityGroupImplArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut args = ActivityGroupImplArgs { krate: None };
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
match ident.to_string().as_str() {
"krate" => {
input.parse::<syn::Token![=]>()?;
let lit: syn::LitStr = input.parse()?;
args.krate = Some(lit.parse()?);
}
other => {
return Err(syn::Error::new(
ident.span(),
format!("unknown activity_group_impl attribute: {other}"),
));
}
}
if !input.is_empty() {
input.parse::<syn::Token![,]>()?;
}
}
Ok(args)
}
}
struct ActivityGroupActivityInfo {
name: syn::Ident,
params: Vec<RpcParam>,
#[allow(dead_code)]
response_type: syn::Type,
persist_key: Option<syn::ExprClosure>,
original_method: syn::ImplItemFn,
retries: Option<u32>,
backoff: Option<String>,
}
fn activity_group_impl_inner(
args: ActivityGroupImplArgs,
input: syn::ItemImpl,
) -> syn::Result<proc_macro2::TokenStream> {
let krate = args.krate.unwrap_or_else(default_crate_path);
let self_ty = &input.self_ty;
let struct_name = match self_ty.as_ref() {
syn::Type::Path(tp) => tp
.path
.segments
.last()
.map(|s| s.ident.clone())
.ok_or_else(|| syn::Error::new(self_ty.span(), "expected struct name"))?,
_ => return Err(syn::Error::new(self_ty.span(), "expected struct name")),
};
let wrapper_name = format_ident!("__{}ActivityGroupWrapper", struct_name);
let access_trait_name = format_ident!("__{}ActivityGroupAccess", struct_name);
let methods_trait_name = format_ident!("__{}ActivityGroupMethods", struct_name);
let activity_view_name = format_ident!("__{}ActivityGroupView", struct_name);
for attr in &input.attrs {
if attr.path().is_ident("state") {
return Err(syn::Error::new(
attr.span(),
"activity groups are stateless; remove #[state(...)]",
));
}
}
let mut activities: Vec<ActivityGroupActivityInfo> = Vec::new();
let mut all_methods: Vec<syn::ImplItemFn> = Vec::new();
for item in &input.items {
if let syn::ImplItem::Fn(method) = item {
for attr in &method.attrs {
if attr.path().is_ident("state") {
return Err(syn::Error::new(
attr.span(),
"activity groups are stateless; remove #[state(...)]",
));
}
if attr.path().is_ident("rpc") {
return Err(syn::Error::new(
attr.span(),
"activity groups use #[activity], not #[rpc]",
));
}
if attr.path().is_ident("workflow") {
return Err(syn::Error::new(
attr.span(),
"activity groups use #[activity], not #[workflow]",
));
}
}
let is_activity = method.attrs.iter().any(|a| a.path().is_ident("activity"));
if let Some(syn::FnArg::Receiver(r)) = method.sig.inputs.first() {
if r.mutability.is_some() {
return Err(syn::Error::new(
r.span(),
"activity group methods must use &self, not &mut self",
));
}
}
if is_activity {
if method.sig.asyncness.is_none() {
return Err(syn::Error::new(
method.sig.span(),
"#[activity] methods must be async",
));
}
let (persist_key, act_retries, act_backoff) = {
let mut key = None;
let mut retries = None;
let mut backoff = None;
for attr in &method.attrs {
if attr.path().is_ident("activity") {
let args = match &attr.meta {
syn::Meta::Path(_) => ActivityAttrArgs {
key: None,
retries: None,
backoff: None,
},
syn::Meta::List(_) => attr.parse_args::<ActivityAttrArgs>()?,
_ => {
return Err(syn::Error::new(
attr.span(),
"expected #[activity] or #[activity(...)]",
))
}
};
key = args.key;
retries = args.retries;
backoff = args.backoff;
}
}
(key, retries, backoff)
};
let mut params = Vec::new();
for arg in method.sig.inputs.iter().skip(1) {
if let syn::FnArg::Typed(pat_type) = arg {
let name = match &*pat_type.pat {
syn::Pat::Ident(ident) => ident.ident.clone(),
_ => {
return Err(syn::Error::new(
pat_type.pat.span(),
"activity parameters must be simple identifiers",
))
}
};
params.push(RpcParam {
name,
ty: (*pat_type.ty).clone(),
});
}
}
let response_type = extract_result_ok_type(match &method.sig.output {
syn::ReturnType::Type(_, ty) => ty,
syn::ReturnType::Default => {
return Err(syn::Error::new(
method.sig.span(),
"#[activity] must return Result<T, ClusterError>",
))
}
})?;
activities.push(ActivityGroupActivityInfo {
name: method.sig.ident.clone(),
params,
response_type,
persist_key,
original_method: method.clone(),
retries: act_retries,
backoff: act_backoff,
});
}
all_methods.push(method.clone());
}
}
let mut activity_view_methods = Vec::new();
let mut helper_view_methods = Vec::new();
for method in &all_methods {
let is_activity = method.attrs.iter().any(|a| a.path().is_ident("activity"));
let block = &method.block;
let output = &method.sig.output;
let name = &method.sig.ident;
let params: Vec<_> = method.sig.inputs.iter().skip(1).collect();
let attrs: Vec<_> = method
.attrs
.iter()
.filter(|a| {
!a.path().is_ident("activity")
&& !a.path().is_ident("public")
&& !a.path().is_ident("protected")
&& !a.path().is_ident("private")
})
.collect();
let vis = &method.vis;
if is_activity {
activity_view_methods.push(quote! {
#(#attrs)*
#vis async fn #name(&self, #(#params),*) #output
#block
});
} else {
let async_token = if method.sig.asyncness.is_some() {
quote! { async }
} else {
quote! {}
};
helper_view_methods.push(quote! {
#(#attrs)*
#vis #async_token fn #name(&self, #(#params),*) #output
#block
});
}
}
let view_struct = quote! {
#[doc(hidden)]
#[allow(non_camel_case_types)]
pub struct #activity_view_name<'a> {
__group: &'a #struct_name,
pub tx: #krate::__internal::ActivityTx,
pub pool: sqlx::PgPool,
}
impl ::std::ops::Deref for #activity_view_name<'_> {
type Target = #struct_name;
fn deref(&self) -> &Self::Target {
self.__group
}
}
impl #activity_view_name<'_> {
#(#activity_view_methods)*
#(#helper_view_methods)*
}
};
let wrapper_delegation_methods: Vec<proc_macro2::TokenStream> = activities
.iter()
.map(|act| {
let method_name = &act.name;
let method_name_str = method_name.to_string();
let method_info = &act.original_method;
let params: Vec<_> = method_info.sig.inputs.iter().skip(1).collect();
let param_names: Vec<_> = method_info
.sig
.inputs
.iter()
.skip(1)
.filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
return Some(&pat_ident.ident);
}
}
None
})
.collect();
let output = &method_info.sig.output;
let wire_param_names: Vec<_> = act.params.iter().map(|p| &p.name).collect();
let wire_param_count = wire_param_names.len();
let key_bytes_code = if let Some(persist_key) = &act.persist_key {
match wire_param_count {
0 => quote! {
let __journal_key = (#persist_key)();
let __journal_key_bytes = rmp_serde::to_vec(&__journal_key)
.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("failed to serialize journal key: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
},
1 => {
let name = &wire_param_names[0];
quote! {
let __journal_key = (#persist_key)(&#name);
let __journal_key_bytes = rmp_serde::to_vec(&__journal_key)
.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("failed to serialize journal key: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
}
}
_ => quote! {
let __journal_key = (#persist_key)(#(&#wire_param_names),*);
let __journal_key_bytes = rmp_serde::to_vec(&__journal_key)
.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("failed to serialize journal key: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
},
}
} else {
match wire_param_count {
0 => quote! {
let __journal_key_bytes = rmp_serde::to_vec(&()).unwrap_or_default();
},
1 => {
let name = &wire_param_names[0];
quote! {
let __journal_key_bytes = rmp_serde::to_vec(&#name)
.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("failed to serialize journal key: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
}
}
_ => quote! {
let __journal_key_bytes = rmp_serde::to_vec(&(#(&#wire_param_names),*))
.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("failed to serialize journal key: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
},
}
};
let gen_execute_and_journal = |key_bytes_var: &str, param_list: &[proc_macro2::TokenStream], in_retry: bool| -> proc_macro2::TokenStream {
let key_var = format_ident!("{}", key_bytes_var);
let error_handling = if in_retry {
quote! {
if __act_result.is_err() {
drop(__activity_view);
__act_result
}
}
} else {
quote! {
if __act_result.is_err() {
drop(__activity_view);
return __act_result;
}
}
};
quote! {
let __sql_pool = __wf_storage.sql_pool().cloned();
let __pool = __sql_pool.expect("SQL storage is required for workflow activities");
let __sql_tx = __pool.begin().await.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("failed to begin activity transaction: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
let __activity_view = #activity_view_name {
__group: &self.__group,
tx: #krate::__internal::ActivityTx::new(__sql_tx),
pool: __pool,
};
let __act_result = __activity_view.#method_name(#(#param_list),*).await;
#error_handling
else {
let __storage_key = #krate::__internal::DurableContext::journal_storage_key(
#method_name_str,
&#key_var,
__journal_ctx.entity_type(),
__journal_ctx.entity_id(),
);
let __journal_bytes = #krate::__internal::DurableContext::serialize_journal_result(&__act_result)?;
#krate::__internal::WorkflowScope::register_journal_key(__storage_key.clone());
let mut __tx_back = __activity_view.tx.into_inner().await;
#krate::__internal::save_journal_entry(&mut *__tx_back, &__storage_key, &__journal_bytes).await?;
__tx_back.commit().await.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("activity transaction commit failed: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
__act_result
}
}
};
let param_tokens: Vec<proc_macro2::TokenStream> = param_names
.iter()
.map(|name| quote! { #name.clone() })
.collect();
let max_retries = act.retries.unwrap_or(0);
let journal_body = if max_retries == 0 {
let exec_body = gen_execute_and_journal("__journal_key_bytes", ¶m_tokens, false);
quote! {
let __journal_ctx = #krate::__internal::DurableContext::with_journal_storage(
::std::sync::Arc::clone(__engine),
self.__entity_type.clone(),
self.__entity_id.clone(),
::std::sync::Arc::clone(__msg_storage),
::std::sync::Arc::clone(__wf_storage),
);
if let ::std::option::Option::Some(__cached) = __journal_ctx.check_journal(#method_name_str, &__journal_key_bytes).await? {
return ::std::result::Result::Ok(__cached);
}
#exec_body
}
} else {
let backoff_str = act.backoff.as_deref().unwrap_or("exponential");
let param_clones: Vec<proc_macro2::TokenStream> = param_names
.iter()
.map(|name| {
let clone_name = format_ident!("__{}_clone", name);
quote! { let #clone_name = #name.clone(); }
})
.collect();
let cloned_param_names: Vec<syn::Ident> = param_names
.iter()
.map(|name| format_ident!("__{}_clone", name))
.collect();
let exec_body = gen_execute_and_journal("__retry_key_bytes", &{
cloned_param_names.iter().map(|n| quote! { #n.clone() }).collect::<Vec<_>>()
}, true);
quote! {
let mut __attempt = 0u32;
loop {
#(#param_clones)*
let __retry_key_bytes = {
let mut __k = __journal_key_bytes.clone();
__k.extend_from_slice(&__attempt.to_le_bytes());
__k
};
let __journal_ctx = #krate::__internal::DurableContext::with_journal_storage(
::std::sync::Arc::clone(__engine),
self.__entity_type.clone(),
self.__entity_id.clone(),
::std::sync::Arc::clone(__msg_storage),
::std::sync::Arc::clone(__wf_storage),
);
if let ::std::option::Option::Some(__cached) = __journal_ctx.check_journal::<_>(#method_name_str, &__retry_key_bytes).await? {
break ::std::result::Result::Ok(__cached);
}
match { #exec_body } {
::std::result::Result::Ok(__val) => {
break ::std::result::Result::Ok(__val);
}
::std::result::Result::Err(__e) if __attempt < #max_retries => {
let __delay = #krate::__internal::compute_retry_backoff(
__attempt, #backoff_str, 1,
);
let __sleep_name = ::std::format!(
"{}/retry/{}", #method_name_str, __attempt
);
__engine.sleep(
&self.__entity_type,
&self.__entity_id,
&__sleep_name,
__delay,
).await?;
__attempt += 1;
}
::std::result::Result::Err(__e) => {
break ::std::result::Result::Err(__e);
}
}
}
}
};
quote! {
pub async fn #method_name(&self, #(#params),*) #output {
if let (
::std::option::Option::Some(__engine),
::std::option::Option::Some(__msg_storage),
::std::option::Option::Some(__wf_storage),
) = (
self.__workflow_engine.as_ref(),
self.__message_storage.as_ref(),
self.__workflow_storage.as_ref(),
) {
#key_bytes_code
let __journal_key_bytes = {
let mut __scoped = ::std::vec::Vec::new();
if let ::std::option::Option::Some(__wf_id) = #krate::__internal::WorkflowScope::current() {
__scoped.extend_from_slice(&__wf_id.to_le_bytes());
}
__scoped.extend_from_slice(&__journal_key_bytes);
__scoped
};
#journal_body
} else {
panic!("SQL storage is required for workflow activities; configure SqlWorkflowStorage")
}
}
}
})
.collect();
let wrapper_struct = quote! {
#[doc(hidden)]
pub struct #wrapper_name {
__group: #struct_name,
__workflow_engine: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowEngine>>,
__message_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::MessageStorage>>,
__workflow_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowStorage>>,
__entity_type: ::std::string::String,
__entity_id: ::std::string::String,
}
impl #wrapper_name {
pub fn new(
group: #struct_name,
workflow_engine: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowEngine>>,
message_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::MessageStorage>>,
workflow_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowStorage>>,
entity_type: ::std::string::String,
entity_id: ::std::string::String,
) -> Self {
Self {
__group: group,
__workflow_engine: workflow_engine,
__message_storage: message_storage,
__workflow_storage: workflow_storage,
__entity_type: entity_type,
__entity_id: entity_id,
}
}
pub fn group(&self) -> &#struct_name {
&self.__group
}
#(#wrapper_delegation_methods)*
}
};
let access_trait = quote! {
#[doc(hidden)]
pub trait #access_trait_name {
fn __activity_group_wrapper(&self) -> &#wrapper_name;
}
};
let blanket_methods: Vec<proc_macro2::TokenStream> = activities
.iter()
.map(|act| {
let method_name = &act.name;
let method_info = &act.original_method;
let params: Vec<_> = method_info.sig.inputs.iter().skip(1).collect();
let param_names: Vec<_> = method_info
.sig
.inputs
.iter()
.skip(1)
.filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
return Some(&pat_ident.ident);
}
}
None
})
.collect();
let output = &method_info.sig.output;
quote! {
async fn #method_name(&self, #(#params),*) #output {
self.__activity_group_wrapper().#method_name(#(#param_names),*).await
}
}
})
.collect();
let methods_trait = quote! {
#[doc(hidden)]
#[allow(async_fn_in_trait)]
pub trait #methods_trait_name: #access_trait_name {
#(#blanket_methods)*
}
impl<T: #access_trait_name> #methods_trait_name for T {}
};
Ok(quote! {
#view_struct
#wrapper_struct
#access_trait
#methods_trait
})
}
struct RpcGroupImplArgs {
krate: Option<syn::Path>,
}
impl syn::parse::Parse for RpcGroupImplArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut args = RpcGroupImplArgs { krate: None };
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
match ident.to_string().as_str() {
"krate" => {
input.parse::<syn::Token![=]>()?;
let lit: syn::LitStr = input.parse()?;
args.krate = Some(lit.parse()?);
}
other => {
return Err(syn::Error::new(
ident.span(),
format!("unknown rpc_group_impl attribute: {other}"),
));
}
}
if !input.is_empty() {
input.parse::<syn::Token![,]>()?;
}
}
Ok(args)
}
}
fn rpc_group_impl_inner(
args: RpcGroupImplArgs,
input: syn::ItemImpl,
) -> syn::Result<proc_macro2::TokenStream> {
let krate = args.krate.unwrap_or_else(default_crate_path);
let self_ty = &input.self_ty;
let struct_name = match self_ty.as_ref() {
syn::Type::Path(tp) => tp
.path
.segments
.last()
.map(|s| s.ident.clone())
.ok_or_else(|| syn::Error::new(self_ty.span(), "expected struct name"))?,
_ => return Err(syn::Error::new(self_ty.span(), "expected struct name")),
};
let wrapper_name = format_ident!("__{}RpcGroupWrapper", struct_name);
let access_trait_name = format_ident!("__{}RpcGroupAccess", struct_name);
let methods_trait_name = format_ident!("__{}RpcGroupMethods", struct_name);
let rpc_view_name = format_ident!("__{}RpcGroupView", struct_name);
let client_ext_name = format_ident!("{}ClientExt", struct_name);
for attr in &input.attrs {
if attr.path().is_ident("state") {
return Err(syn::Error::new(
attr.span(),
"RPC groups are stateless; remove #[state(...)]",
));
}
}
let mut rpcs: Vec<RpcMethod> = Vec::new();
let mut all_methods: Vec<syn::ImplItemFn> = Vec::new();
for item in &input.items {
if let syn::ImplItem::Fn(method) = item {
for attr in &method.attrs {
if attr.path().is_ident("state") {
return Err(syn::Error::new(
attr.span(),
"RPC groups are stateless; remove #[state(...)]",
));
}
if attr.path().is_ident("activity") {
return Err(syn::Error::new(
attr.span(),
"RPC groups use #[rpc], not #[activity]",
));
}
if attr.path().is_ident("workflow") {
return Err(syn::Error::new(
attr.span(),
"RPC groups use #[rpc], not #[workflow]",
));
}
}
if let Some(syn::FnArg::Receiver(r)) = method.sig.inputs.first() {
if r.mutability.is_some() {
return Err(syn::Error::new(
r.span(),
"RPC group methods must use &self, not &mut self",
));
}
}
let is_rpc = method.attrs.iter().any(|a| a.path().is_ident("rpc"));
if is_rpc {
if method.sig.asyncness.is_none() {
return Err(syn::Error::new(
method.sig.span(),
"#[rpc] methods must be async",
));
}
if let Some(rpc) = parse_rpc_method(method)? {
rpcs.push(rpc);
}
}
all_methods.push(method.clone());
}
}
let mut rpc_view_methods = Vec::new();
let mut helper_view_methods = Vec::new();
for method in &all_methods {
let is_rpc = method.attrs.iter().any(|a| a.path().is_ident("rpc"));
let block = &method.block;
let output = &method.sig.output;
let name = &method.sig.ident;
let params: Vec<_> = method.sig.inputs.iter().skip(1).collect();
let attrs: Vec<_> = method
.attrs
.iter()
.filter(|a| {
!a.path().is_ident("rpc")
&& !a.path().is_ident("public")
&& !a.path().is_ident("protected")
&& !a.path().is_ident("private")
})
.collect();
let vis = &method.vis;
if is_rpc {
rpc_view_methods.push(quote! {
#(#attrs)*
#vis async fn #name(&self, #(#params),*) #output
#block
});
} else {
let async_token = if method.sig.asyncness.is_some() {
quote! { async }
} else {
quote! {}
};
helper_view_methods.push(quote! {
#(#attrs)*
#vis #async_token fn #name(&self, #(#params),*) #output
#block
});
}
}
let view_struct = quote! {
#[doc(hidden)]
#[allow(non_camel_case_types)]
pub struct #rpc_view_name<'a> {
__group: &'a #struct_name,
__entity_address: &'a #krate::types::EntityAddress,
}
impl ::std::ops::Deref for #rpc_view_name<'_> {
type Target = #struct_name;
fn deref(&self) -> &Self::Target {
self.__group
}
}
impl #rpc_view_name<'_> {
#[inline]
fn entity_id(&self) -> &str {
&self.__entity_address.entity_id.0
}
#[inline]
fn entity_address(&self) -> &#krate::types::EntityAddress {
self.__entity_address
}
#(#rpc_view_methods)*
#(#helper_view_methods)*
}
};
let wrapper_delegation_methods: Vec<proc_macro2::TokenStream> = rpcs
.iter()
.filter(|rpc| rpc.is_trait_visible())
.map(|rpc| {
let method_name = &rpc.name;
let resp_type = &rpc.response_type;
let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
let param_defs: Vec<_> = param_names
.iter()
.zip(param_types.iter())
.map(|(name, ty)| quote! { #name: #ty })
.collect();
quote! {
pub async fn #method_name(
&self,
#(#param_defs),*
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
let __view = #rpc_view_name { __group: &self.__group, __entity_address: &self.__entity_address };
__view.#method_name(#(#param_names),*).await
}
}
})
.collect();
let dispatch_arms: Vec<proc_macro2::TokenStream> = rpcs
.iter()
.filter(|rpc| rpc.is_dispatchable())
.map(|rpc| {
let tag = &rpc.tag;
let method_name = &rpc.name;
let param_count = rpc.params.len();
let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
let deserialize_request = match param_count {
0 => quote! {},
1 => {
let name = ¶m_names[0];
let ty = ¶m_types[0];
quote! {
let #name: #ty = rmp_serde::from_slice(payload)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("failed to deserialize request for '{}': {e}", #tag),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
}
}
_ => quote! {
let (#(#param_names),*): (#(#param_types),*) = rmp_serde::from_slice(payload)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("failed to deserialize request for '{}': {e}", #tag),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
},
};
let mut call_args = Vec::new();
match param_count {
0 => {}
1 => {
let name = ¶m_names[0];
call_args.push(quote! { #name });
}
_ => {
for name in ¶m_names {
call_args.push(quote! { #name });
}
}
}
let call_args = quote! { #(#call_args),* };
quote! {
#tag => {
#deserialize_request
let response = self.#method_name(#call_args).await?;
let bytes = rmp_serde::to_vec(&response)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("failed to serialize response for '{}': {e}", #tag),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
::std::result::Result::Ok(::std::option::Option::Some(bytes))
}
}
})
.collect();
let wrapper_struct = quote! {
#[doc(hidden)]
pub struct #wrapper_name {
__group: #struct_name,
__entity_address: #krate::types::EntityAddress,
}
impl #wrapper_name {
pub fn new(group: #struct_name, entity_address: #krate::types::EntityAddress) -> Self {
Self { __group: group, __entity_address: entity_address }
}
pub fn group(&self) -> &#struct_name {
&self.__group
}
#[doc(hidden)]
pub async fn __dispatch(
&self,
tag: &str,
payload: &[u8],
headers: &::std::collections::HashMap<::std::string::String, ::std::string::String>,
) -> ::std::result::Result<::std::option::Option<::std::vec::Vec<u8>>, #krate::error::ClusterError> {
let _ = headers;
match tag {
#(#dispatch_arms,)*
_ => ::std::result::Result::Ok(::std::option::Option::None),
}
}
#(#wrapper_delegation_methods)*
}
};
let access_trait = quote! {
#[doc(hidden)]
pub trait #access_trait_name {
fn __rpc_group_wrapper(&self) -> &#wrapper_name;
}
};
let blanket_methods: Vec<proc_macro2::TokenStream> = rpcs
.iter()
.filter(|rpc| rpc.is_trait_visible())
.map(|rpc| {
let method_name = &rpc.name;
let resp_type = &rpc.response_type;
let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
let param_defs: Vec<_> = param_names
.iter()
.zip(param_types.iter())
.map(|(name, ty)| quote! { #name: #ty })
.collect();
quote! {
async fn #method_name(
&self,
#(#param_defs),*
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
self.__rpc_group_wrapper().#method_name(#(#param_names),*).await
}
}
})
.collect();
let methods_trait = quote! {
#[doc(hidden)]
#[allow(async_fn_in_trait)]
pub trait #methods_trait_name: #access_trait_name {
#(#blanket_methods)*
}
impl<T: #access_trait_name> #methods_trait_name for T {}
};
let client_ext_methods: Vec<proc_macro2::TokenStream> = rpcs
.iter()
.filter(|rpc| rpc.is_client_visible())
.map(|rpc| {
let method_name = &rpc.name;
let tag = &rpc.tag;
let resp_type = &rpc.response_type;
let param_count = rpc.params.len();
let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
let param_defs: Vec<_> = param_names
.iter()
.zip(param_types.iter())
.map(|(name, ty)| quote! { #name: &#ty })
.collect();
if rpc.uses_persisted_delivery() {
match param_count {
0 => quote! {
async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
self.entity_client()
.send_persisted(entity_id, #tag, &(), #krate::schema::Uninterruptible::No)
.await
}
},
1 => {
let def = ¶m_defs[0];
let name = ¶m_names[0];
quote! {
async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
#def,
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
self.entity_client()
.send_persisted(entity_id, #tag, #name, #krate::schema::Uninterruptible::No)
.await
}
}
}
_ => quote! {
async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
#(#param_defs),*
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
let request = (#(#param_names),*);
self.entity_client()
.send_persisted(entity_id, #tag, &request, #krate::schema::Uninterruptible::No)
.await
}
},
}
} else {
match param_count {
0 => quote! {
async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
self.entity_client().send(entity_id, #tag, &()).await
}
},
1 => {
let def = ¶m_defs[0];
let name = ¶m_names[0];
quote! {
async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
#def,
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
self.entity_client().send(entity_id, #tag, #name).await
}
}
}
_ => quote! {
async fn #method_name(
&self,
entity_id: &#krate::types::EntityId,
#(#param_defs),*
) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
let request = (#(#param_names),*);
self.entity_client().send(entity_id, #tag, &request).await
}
},
}
}
})
.collect();
let client_ext = quote! {
#[async_trait::async_trait]
pub trait #client_ext_name: #krate::entity_client::EntityClientAccessor {
#(#client_ext_methods)*
}
impl<T> #client_ext_name for T where T: #krate::entity_client::EntityClientAccessor {}
};
Ok(quote! {
#view_struct
#wrapper_struct
#access_trait
#methods_trait
#client_ext
})
}
struct WorkflowStructArgs {
key: Option<syn::ExprClosure>,
hash: bool,
krate: Option<syn::Path>,
}
impl syn::parse::Parse for WorkflowStructArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut args = WorkflowStructArgs {
key: None,
hash: true,
krate: None,
};
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
match ident.to_string().as_str() {
"key" => {
input.parse::<syn::Token![=]>()?;
let expr: syn::Expr = if input.peek(syn::token::Paren) {
let content;
syn::parenthesized!(content in input);
content.parse()?
} else {
input.parse()?
};
match expr {
syn::Expr::Closure(closure) => args.key = Some(closure),
_ => {
return Err(syn::Error::new(
expr.span(),
"key must be a closure, e.g. #[workflow(key = |req| ...)]",
))
}
}
}
"hash" => {
input.parse::<syn::Token![=]>()?;
let lit: syn::LitBool = input.parse()?;
args.hash = lit.value;
}
"krate" => {
input.parse::<syn::Token![=]>()?;
let lit: syn::LitStr = input.parse()?;
args.krate = Some(lit.parse()?);
}
other => {
return Err(syn::Error::new(
ident.span(),
format!("unknown workflow attribute: {other}"),
));
}
}
if !input.is_empty() {
input.parse::<syn::Token![,]>()?;
}
}
Ok(args)
}
}
struct WorkflowImplArgs {
krate: Option<syn::Path>,
activity_groups: Vec<syn::Path>,
key: Option<syn::ExprClosure>,
hash: bool,
}
impl syn::parse::Parse for WorkflowImplArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut args = WorkflowImplArgs {
krate: None,
activity_groups: Vec::new(),
key: None,
hash: true,
};
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
match ident.to_string().as_str() {
"krate" => {
input.parse::<syn::Token![=]>()?;
let lit: syn::LitStr = input.parse()?;
args.krate = Some(lit.parse()?);
}
"activity_groups" => {
let content;
syn::parenthesized!(content in input);
while !content.is_empty() {
let path: syn::Path = content.parse()?;
args.activity_groups.push(path);
if !content.is_empty() {
content.parse::<syn::Token![,]>()?;
}
}
}
"key" => {
input.parse::<syn::Token![=]>()?;
let expr: syn::Expr = if input.peek(syn::token::Paren) {
let content;
syn::parenthesized!(content in input);
content.parse()?
} else {
input.parse()?
};
match expr {
syn::Expr::Closure(closure) => args.key = Some(closure),
_ => {
return Err(syn::Error::new(
expr.span(),
"key must be a closure, e.g. #[workflow_impl(key = |req| ...)]",
))
}
}
}
"hash" => {
input.parse::<syn::Token![=]>()?;
let lit: syn::LitBool = input.parse()?;
args.hash = lit.value;
}
other => {
return Err(syn::Error::new(
ident.span(),
format!("unknown workflow_impl attribute: {other}"),
));
}
}
if !input.is_empty() {
input.parse::<syn::Token![,]>()?;
}
}
Ok(args)
}
}
fn workflow_struct_inner(
args: WorkflowStructArgs,
input: syn::ItemStruct,
) -> syn::Result<proc_macro2::TokenStream> {
let krate = args.krate.unwrap_or_else(default_crate_path);
let struct_name = &input.ident;
let entity_name = format!("Workflow/{}", struct_name);
let key_derivation_info = if let Some(_key_closure) = &args.key {
let hash_val = args.hash;
quote! {
#[doc(hidden)]
fn __workflow_key_closure() -> bool { true }
#[doc(hidden)]
fn __workflow_hash() -> bool { #hash_val }
#[doc(hidden)]
fn __extract_key<__Req>(req: &__Req) -> ::std::string::String
where __Req: serde::Serialize,
{
let _ = req;
unreachable!("key extraction is generated by workflow_impl")
}
}
} else {
quote! {
#[doc(hidden)]
fn __workflow_key_closure() -> bool { false }
#[doc(hidden)]
fn __workflow_hash() -> bool { true }
}
};
let _ = key_derivation_info;
let _ = args.key;
Ok(quote! {
#input
#[allow(dead_code)]
impl #struct_name {
#[doc(hidden)]
fn __entity_type(&self) -> #krate::types::EntityType {
#krate::types::EntityType::new(#entity_name)
}
#[doc(hidden)]
fn __shard_group(&self) -> &str {
"default"
}
#[doc(hidden)]
fn __shard_group_for(&self, _entity_id: &#krate::types::EntityId) -> &str {
self.__shard_group()
}
#[doc(hidden)]
fn __max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
::std::option::Option::None
}
#[doc(hidden)]
fn __mailbox_capacity(&self) -> ::std::option::Option<usize> {
::std::option::Option::None
}
#[doc(hidden)]
fn __concurrency(&self) -> ::std::option::Option<usize> {
::std::option::Option::None
}
}
})
}
struct WorkflowActivityInfo {
name: syn::Ident,
#[allow(dead_code)]
tag: String,
params: Vec<RpcParam>,
#[allow(dead_code)]
response_type: syn::Type,
persist_key: Option<syn::ExprClosure>,
original_method: syn::ImplItemFn,
retries: Option<u32>,
backoff: Option<String>,
}
struct WorkflowExecuteInfo {
params: Vec<RpcParam>,
request_type: syn::Type,
response_type: syn::Type,
original_method: syn::ImplItemFn,
}
fn workflow_impl_inner(
args: WorkflowImplArgs,
input: syn::ItemImpl,
) -> syn::Result<proc_macro2::TokenStream> {
let krate = args.krate.unwrap_or_else(default_crate_path);
let self_ty = &input.self_ty;
let struct_name = match self_ty.as_ref() {
syn::Type::Path(tp) => tp
.path
.segments
.last()
.map(|s| s.ident.clone())
.ok_or_else(|| syn::Error::new(self_ty.span(), "expected struct name"))?,
_ => return Err(syn::Error::new(self_ty.span(), "expected struct name")),
};
let handler_name = format_ident!("__{}WorkflowHandler", struct_name);
let client_name = format_ident!("{}Client", struct_name);
let execute_view_name = format_ident!("__{}ExecuteView", struct_name);
let activity_view_name = format_ident!("__{}ActivityView", struct_name);
let entity_name = format!("Workflow/{}", struct_name);
let has_activity_groups = !args.activity_groups.is_empty();
let with_groups_name = format_ident!("__{}WithGroups", struct_name);
for attr in &input.attrs {
if attr.path().is_ident("state") {
return Err(syn::Error::new(
attr.span(),
"workflows are stateless; remove #[state(...)]",
));
}
}
let mut execute_info: Option<WorkflowExecuteInfo> = None;
let mut activities: Vec<WorkflowActivityInfo> = Vec::new();
let mut original_methods: Vec<syn::ImplItemFn> = Vec::new();
for item in &input.items {
if let syn::ImplItem::Fn(method) = item {
for attr in &method.attrs {
if attr.path().is_ident("state") {
return Err(syn::Error::new(
attr.span(),
"workflows are stateless; remove #[state(...)]",
));
}
if attr.path().is_ident("rpc") {
return Err(syn::Error::new(
attr.span(),
"workflows use #[activity], not #[rpc]",
));
}
if attr.path().is_ident("workflow") {
return Err(syn::Error::new(
attr.span(),
"workflows have a single execute entry point; use client calls for cross-workflow interaction",
));
}
}
if let Some(syn::FnArg::Receiver(r)) = method.sig.inputs.first() {
if r.mutability.is_some() {
return Err(syn::Error::new(
r.span(),
"workflow methods must use &self, not &mut self",
));
}
}
if method.sig.ident == "execute" {
if execute_info.is_some() {
return Err(syn::Error::new(
method.sig.span(),
"workflow must have exactly one execute method",
));
}
if method.sig.asyncness.is_none() {
return Err(syn::Error::new(method.sig.span(), "execute must be async"));
}
let mut params = Vec::new();
for arg in method.sig.inputs.iter().skip(1) {
if let syn::FnArg::Typed(pat_type) = arg {
let name = match &*pat_type.pat {
syn::Pat::Ident(ident) => ident.ident.clone(),
_ => {
return Err(syn::Error::new(
pat_type.pat.span(),
"execute parameters must be simple identifiers",
))
}
};
params.push(RpcParam {
name,
ty: (*pat_type.ty).clone(),
});
}
}
if params.len() != 1 {
return Err(syn::Error::new(
method.sig.span(),
"execute must take exactly one request parameter (after &self)",
));
}
let request_type = params[0].ty.clone();
let response_type = extract_result_ok_type(match &method.sig.output {
syn::ReturnType::Type(_, ty) => ty,
syn::ReturnType::Default => {
return Err(syn::Error::new(
method.sig.span(),
"execute must return Result<T, ClusterError>",
))
}
})?;
execute_info = Some(WorkflowExecuteInfo {
params,
request_type,
response_type,
original_method: method.clone(),
});
} else {
let is_activity = method.attrs.iter().any(|a| a.path().is_ident("activity"));
if is_activity {
if method.sig.asyncness.is_none() {
return Err(syn::Error::new(
method.sig.span(),
"#[activity] methods must be async",
));
}
let (persist_key, act_retries, act_backoff) = {
let mut key = None;
let mut retries = None;
let mut backoff = None;
for attr in &method.attrs {
if attr.path().is_ident("activity") {
let args = match &attr.meta {
syn::Meta::Path(_) => ActivityAttrArgs {
key: None,
retries: None,
backoff: None,
},
syn::Meta::List(_) => attr.parse_args::<ActivityAttrArgs>()?,
_ => {
return Err(syn::Error::new(
attr.span(),
"expected #[activity] or #[activity(...)]",
))
}
};
key = args.key;
retries = args.retries;
backoff = args.backoff;
}
}
(key, retries, backoff)
};
let mut params = Vec::new();
for arg in method.sig.inputs.iter().skip(1) {
if let syn::FnArg::Typed(pat_type) = arg {
let name = match &*pat_type.pat {
syn::Pat::Ident(ident) => ident.ident.clone(),
_ => {
return Err(syn::Error::new(
pat_type.pat.span(),
"activity parameters must be simple identifiers",
))
}
};
params.push(RpcParam {
name,
ty: (*pat_type.ty).clone(),
});
}
}
let response_type = extract_result_ok_type(match &method.sig.output {
syn::ReturnType::Type(_, ty) => ty,
syn::ReturnType::Default => {
return Err(syn::Error::new(
method.sig.span(),
"#[activity] must return Result<T, ClusterError>",
))
}
})?;
activities.push(WorkflowActivityInfo {
name: method.sig.ident.clone(),
tag: method.sig.ident.to_string(),
params,
response_type,
persist_key,
original_method: method.clone(),
retries: act_retries,
backoff: act_backoff,
});
}
}
original_methods.push(method.clone());
}
}
let execute = execute_info.ok_or_else(|| {
syn::Error::new(
input.self_ty.span(),
"workflow must define an `async fn execute(&self, request: T) -> Result<R, ClusterError>` method",
)
})?;
let request_type = &execute.request_type;
let response_type = &execute.response_type;
#[allow(dead_code)]
struct ActivityGroupInfo {
path: syn::Path,
ident: syn::Ident,
field: syn::Ident,
wrapper_ident: syn::Ident,
wrapper_path: syn::Path,
access_trait_ident: syn::Ident,
access_trait_path: syn::Path,
methods_trait_ident: syn::Ident,
methods_trait_path: syn::Path,
}
let group_infos: Vec<ActivityGroupInfo> = args
.activity_groups
.iter()
.map(|path| {
let ident = path
.segments
.last()
.map(|s| s.ident.clone())
.expect("activity group path must have an ident");
let snake = to_snake(&ident.to_string());
let field = format_ident!("__group_{}", snake);
let wrapper_ident = format_ident!("__{}ActivityGroupWrapper", ident);
let wrapper_path = replace_last_segment(path, wrapper_ident.clone());
let access_trait_ident = format_ident!("__{}ActivityGroupAccess", ident);
let access_trait_path = replace_last_segment(path, access_trait_ident.clone());
let methods_trait_ident = format_ident!("__{}ActivityGroupMethods", ident);
let methods_trait_path = replace_last_segment(path, methods_trait_ident.clone());
ActivityGroupInfo {
path: path.clone(),
ident,
field,
wrapper_ident,
wrapper_path,
access_trait_ident,
access_trait_path,
methods_trait_ident,
methods_trait_path,
}
})
.collect();
let execute_method = &execute.original_method;
let execute_block = &execute_method.block;
let execute_output = &execute_method.sig.output;
let execute_param_name = &execute.params[0].name;
let execute_param_type = &execute.params[0].ty;
let execute_attrs: Vec<_> = execute_method
.attrs
.iter()
.filter(|a| {
!a.path().is_ident("rpc")
&& !a.path().is_ident("workflow")
&& !a.path().is_ident("activity")
})
.collect();
let mut activity_view_methods = Vec::new();
for act in &activities {
let method = &act.original_method;
let block = &method.block;
let output = &method.sig.output;
let name = &act.name;
let params: Vec<_> = method.sig.inputs.iter().skip(1).collect();
let attrs: Vec<_> = method
.attrs
.iter()
.filter(|a| {
!a.path().is_ident("activity")
&& !a.path().is_ident("public")
&& !a.path().is_ident("protected")
&& !a.path().is_ident("private")
})
.collect();
let vis = &method.vis;
activity_view_methods.push(quote! {
#(#attrs)*
#vis async fn #name(&self, #(#params),*) #output
#block
});
}
let mut helper_execute_methods = Vec::new();
let mut helper_activity_methods = Vec::new();
for method in &original_methods {
let name = &method.sig.ident;
if name == "execute" {
continue;
}
let is_activity = method.attrs.iter().any(|a| a.path().is_ident("activity"));
if is_activity {
continue;
}
let block = &method.block;
let output = &method.sig.output;
let params: Vec<_> = method.sig.inputs.iter().skip(1).collect();
let attrs: Vec<_> = method
.attrs
.iter()
.filter(|a| {
!a.path().is_ident("rpc")
&& !a.path().is_ident("workflow")
&& !a.path().is_ident("activity")
&& !a.path().is_ident("method")
&& !a.path().is_ident("public")
&& !a.path().is_ident("protected")
&& !a.path().is_ident("private")
})
.collect();
let vis = &method.vis;
let async_token = if method.sig.asyncness.is_some() {
quote! { async }
} else {
quote! {}
};
let method_tokens = quote! {
#(#attrs)*
#vis #async_token fn #name(&self, #(#params),*) #output
#block
};
helper_execute_methods.push(method_tokens.clone());
helper_activity_methods.push(method_tokens);
}
let activity_delegations: Vec<proc_macro2::TokenStream> = activities
.iter()
.map(|act| {
let method_name = &act.name;
let method_name_str = method_name.to_string();
let method_info = &act.original_method;
let params: Vec<_> = method_info.sig.inputs.iter().skip(1).collect();
let param_names: Vec<_> = method_info
.sig
.inputs
.iter()
.skip(1)
.filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
return Some(&pat_ident.ident);
}
}
None
})
.collect();
let output = &method_info.sig.output;
let wire_param_names: Vec<_> = act.params.iter().map(|p| &p.name).collect();
let wire_param_count = wire_param_names.len();
let key_bytes_code = if let Some(persist_key) = &act.persist_key {
match wire_param_count {
0 => quote! {
let __journal_key = (#persist_key)();
let __journal_key_bytes = rmp_serde::to_vec(&__journal_key)
.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("failed to serialize journal key: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
},
1 => {
let name = &wire_param_names[0];
quote! {
let __journal_key = (#persist_key)(&#name);
let __journal_key_bytes = rmp_serde::to_vec(&__journal_key)
.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("failed to serialize journal key: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
}
}
_ => quote! {
let __journal_key = (#persist_key)(#(&#wire_param_names),*);
let __journal_key_bytes = rmp_serde::to_vec(&__journal_key)
.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("failed to serialize journal key: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
},
}
} else {
match wire_param_count {
0 => quote! {
let __journal_key_bytes = rmp_serde::to_vec(&()).unwrap_or_default();
},
1 => {
let name = &wire_param_names[0];
quote! {
let __journal_key_bytes = rmp_serde::to_vec(&#name)
.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("failed to serialize journal key: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
}
}
_ => quote! {
let __journal_key_bytes = rmp_serde::to_vec(&(#(&#wire_param_names),*))
.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("failed to serialize journal key: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
},
}
};
let gen_execute_and_journal = |key_bytes_var: &str, param_list: &[proc_macro2::TokenStream], in_retry: bool| -> proc_macro2::TokenStream {
let key_var = format_ident!("{}", key_bytes_var);
let error_handling = if in_retry {
quote! {
if __act_result.is_err() {
drop(__activity_view);
__act_result
}
}
} else {
quote! {
if __act_result.is_err() {
drop(__activity_view);
return __act_result;
}
}
};
quote! {
let __sql_pool = __wf_storage.sql_pool().cloned();
let __pool = __sql_pool.expect("SQL storage is required for workflow activities");
let __sql_tx = __pool.begin().await.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("failed to begin activity transaction: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
let __activity_view = #activity_view_name {
__handler: self.__handler,
tx: #krate::__internal::ActivityTx::new(__sql_tx),
pool: __pool,
};
let __act_result = __activity_view.#method_name(#(#param_list),*).await;
#error_handling
else {
let __storage_key = #krate::__internal::DurableContext::journal_storage_key(
#method_name_str,
&#key_var,
__journal_ctx.entity_type(),
__journal_ctx.entity_id(),
);
let __journal_bytes = #krate::__internal::DurableContext::serialize_journal_result(&__act_result)?;
#krate::__internal::WorkflowScope::register_journal_key(__storage_key.clone());
let mut __tx_back = __activity_view.tx.into_inner().await;
#krate::__internal::save_journal_entry(&mut *__tx_back, &__storage_key, &__journal_bytes).await?;
__tx_back.commit().await.map_err(|e| #krate::error::ClusterError::PersistenceError {
reason: ::std::format!("activity transaction commit failed: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
__act_result
}
}
};
let param_tokens: Vec<proc_macro2::TokenStream> = param_names
.iter()
.map(|name| quote! { #name.clone() })
.collect();
let max_retries = act.retries.unwrap_or(0);
let journal_body = if max_retries == 0 {
let exec_body = gen_execute_and_journal("__journal_key_bytes", ¶m_tokens, false);
quote! {
let __journal_ctx = #krate::__internal::DurableContext::with_journal_storage(
::std::sync::Arc::clone(__engine),
self.__handler.ctx.address.entity_type.0.clone(),
self.__handler.ctx.address.entity_id.0.clone(),
::std::sync::Arc::clone(__msg_storage),
::std::sync::Arc::clone(__wf_storage),
);
if let ::std::option::Option::Some(__cached) = __journal_ctx.check_journal(#method_name_str, &__journal_key_bytes).await? {
return ::std::result::Result::Ok(__cached);
}
#exec_body
}
} else {
let backoff_str = act.backoff.as_deref().unwrap_or("exponential");
let param_clones: Vec<proc_macro2::TokenStream> = param_names
.iter()
.map(|name| {
let clone_name = format_ident!("__{}_clone", name);
quote! { let #clone_name = #name.clone(); }
})
.collect();
let cloned_param_names: Vec<syn::Ident> = param_names
.iter()
.map(|name| format_ident!("__{}_clone", name))
.collect();
let exec_body = gen_execute_and_journal("__retry_key_bytes", &{
cloned_param_names.iter().map(|n| quote! { #n.clone() }).collect::<Vec<_>>()
}, true);
quote! {
let mut __attempt = 0u32;
loop {
#(#param_clones)*
let __retry_key_bytes = {
let mut __k = __journal_key_bytes.clone();
__k.extend_from_slice(&__attempt.to_le_bytes());
__k
};
let __journal_ctx = #krate::__internal::DurableContext::with_journal_storage(
::std::sync::Arc::clone(__engine),
self.__handler.ctx.address.entity_type.0.clone(),
self.__handler.ctx.address.entity_id.0.clone(),
::std::sync::Arc::clone(__msg_storage),
::std::sync::Arc::clone(__wf_storage),
);
if let ::std::option::Option::Some(__cached) = __journal_ctx.check_journal::<_>(#method_name_str, &__retry_key_bytes).await? {
break ::std::result::Result::Ok(__cached);
}
match { #exec_body } {
::std::result::Result::Ok(__val) => {
break ::std::result::Result::Ok(__val);
}
::std::result::Result::Err(__e) if __attempt < #max_retries => {
let __delay = #krate::__internal::compute_retry_backoff(
__attempt, #backoff_str, 1,
);
let __sleep_name = ::std::format!(
"{}/retry/{}", #method_name_str, __attempt
);
__engine.sleep(
&self.__handler.ctx.address.entity_type.0,
&self.__handler.ctx.address.entity_id.0,
&__sleep_name,
__delay,
).await?;
__attempt += 1;
}
::std::result::Result::Err(__e) => {
break ::std::result::Result::Err(__e);
}
}
}
}
};
quote! {
#[inline]
async fn #method_name(&self, #(#params),*) #output {
if let (
::std::option::Option::Some(__engine),
::std::option::Option::Some(__msg_storage),
::std::option::Option::Some(__wf_storage),
) = (
self.__handler.__workflow_engine.as_ref(),
self.__handler.__message_storage.as_ref(),
self.__handler.__state_storage.as_ref(),
) {
#key_bytes_code
let __journal_key_bytes = {
let mut __scoped = ::std::vec::Vec::new();
if let ::std::option::Option::Some(__wf_id) = #krate::__internal::WorkflowScope::current() {
__scoped.extend_from_slice(&__wf_id.to_le_bytes());
}
__scoped.extend_from_slice(&__journal_key_bytes);
__scoped
};
#journal_body
} else {
panic!("SQL storage is required for workflow activities; configure SqlWorkflowStorage")
}
}
}
})
.collect();
let group_handler_fields: Vec<proc_macro2::TokenStream> = group_infos
.iter()
.map(|info| {
let field = &info.field;
let wrapper_path = &info.wrapper_path;
quote! {
#field: #wrapper_path,
}
})
.collect();
let group_new_params: Vec<proc_macro2::TokenStream> = group_infos
.iter()
.map(|info| {
let field = &info.field;
let path = &info.path;
quote! {
#field: #path,
}
})
.collect();
let group_field_inits: Vec<proc_macro2::TokenStream> = group_infos
.iter()
.map(|info| {
let field = &info.field;
let wrapper_path = &info.wrapper_path;
quote! {
#field: #wrapper_path::new(
#field,
ctx.workflow_engine.clone(),
ctx.message_storage.clone(),
ctx.state_storage.clone(),
ctx.address.entity_type.0.clone(),
ctx.address.entity_id.0.clone(),
),
}
})
.collect();
let handler_def = quote! {
#[doc(hidden)]
pub struct #handler_name {
__workflow: #struct_name,
#[allow(dead_code)]
ctx: #krate::entity::EntityContext,
__state_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowStorage>>,
__workflow_engine: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowEngine>>,
__message_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::MessageStorage>>,
__sharding: ::std::option::Option<::std::sync::Arc<dyn #krate::sharding::Sharding>>,
__entity_address: #krate::types::EntityAddress,
#(#group_handler_fields)*
}
impl #handler_name {
#[doc(hidden)]
pub async fn __new(
workflow: #struct_name,
#(#group_new_params)*
ctx: #krate::entity::EntityContext,
) -> ::std::result::Result<Self, #krate::error::ClusterError> {
let __state_storage = ctx.state_storage.clone();
let __sharding = ctx.sharding.clone();
let __entity_address = ctx.address.clone();
::std::result::Result::Ok(Self {
__workflow: workflow,
__workflow_engine: ctx.workflow_engine.clone(),
__message_storage: ctx.message_storage.clone(),
#(#group_field_inits)*
ctx,
__state_storage,
__sharding,
__entity_address,
})
}
pub async fn sleep(&self, name: &str, duration: ::std::time::Duration) -> ::std::result::Result<(), #krate::error::ClusterError> {
let engine = self.__workflow_engine.as_ref().ok_or_else(|| {
#krate::error::ClusterError::MalformedMessage {
reason: "sleep() requires a workflow engine".into(),
source: ::std::option::Option::None,
}
})?;
let ctx = #krate::__internal::DurableContext::new(
::std::sync::Arc::clone(engine),
self.ctx.address.entity_type.0.clone(),
self.ctx.address.entity_id.0.clone(),
);
ctx.sleep(name, duration).await
}
pub async fn await_deferred<T, K>(&self, key: K) -> ::std::result::Result<T, #krate::error::ClusterError>
where
T: serde::Serialize + serde::de::DeserializeOwned,
K: #krate::__internal::DeferredKeyLike<T>,
{
let engine = self.__workflow_engine.as_ref().ok_or_else(|| {
#krate::error::ClusterError::MalformedMessage {
reason: "await_deferred() requires a workflow engine".into(),
source: ::std::option::Option::None,
}
})?;
let ctx = #krate::__internal::DurableContext::new(
::std::sync::Arc::clone(engine),
self.ctx.address.entity_type.0.clone(),
self.ctx.address.entity_id.0.clone(),
);
ctx.await_deferred(key).await
}
pub async fn resolve_deferred<T, K>(&self, key: K, value: &T) -> ::std::result::Result<(), #krate::error::ClusterError>
where
T: serde::Serialize,
K: #krate::__internal::DeferredKeyLike<T>,
{
let engine = self.__workflow_engine.as_ref().ok_or_else(|| {
#krate::error::ClusterError::MalformedMessage {
reason: "resolve_deferred() requires a workflow engine".into(),
source: ::std::option::Option::None,
}
})?;
let ctx = #krate::__internal::DurableContext::new(
::std::sync::Arc::clone(engine),
self.ctx.address.entity_type.0.clone(),
self.ctx.address.entity_id.0.clone(),
);
ctx.resolve_deferred(key, value).await
}
pub async fn on_interrupt(&self) -> ::std::result::Result<(), #krate::error::ClusterError> {
let engine = self.__workflow_engine.as_ref().ok_or_else(|| {
#krate::error::ClusterError::MalformedMessage {
reason: "on_interrupt() requires a workflow engine".into(),
source: ::std::option::Option::None,
}
})?;
let ctx = #krate::__internal::DurableContext::new(
::std::sync::Arc::clone(engine),
self.ctx.address.entity_type.0.clone(),
self.ctx.address.entity_id.0.clone(),
);
ctx.on_interrupt().await
}
pub fn execution_id(&self) -> &str {
&self.__entity_address.entity_id.0
}
pub fn entity_id(&self) -> &#krate::types::EntityId {
&self.__entity_address.entity_id
}
pub fn sharding(&self) -> ::std::option::Option<&::std::sync::Arc<dyn #krate::sharding::Sharding>> {
self.__sharding.as_ref()
}
pub fn entity_address(&self) -> &#krate::types::EntityAddress {
&self.__entity_address
}
}
};
let view_structs = quote! {
#[doc(hidden)]
#[allow(non_camel_case_types)]
struct #execute_view_name<'a> {
__handler: &'a #handler_name,
}
#[doc(hidden)]
#[allow(non_camel_case_types)]
struct #activity_view_name<'a> {
__handler: &'a #handler_name,
pub tx: #krate::__internal::ActivityTx,
pub pool: sqlx::PgPool,
}
impl ::std::ops::Deref for #execute_view_name<'_> {
type Target = #struct_name;
fn deref(&self) -> &Self::Target {
&self.__handler.__workflow
}
}
impl ::std::ops::Deref for #activity_view_name<'_> {
type Target = #struct_name;
fn deref(&self) -> &Self::Target {
&self.__handler.__workflow
}
}
};
let group_access_impls: Vec<proc_macro2::TokenStream> = group_infos
.iter()
.map(|info| {
let access_trait_path = &info.access_trait_path;
let wrapper_path = &info.wrapper_path;
let field = &info.field;
quote! {
impl #access_trait_path for #execute_view_name<'_> {
fn __activity_group_wrapper(&self) -> &#wrapper_path {
&self.__handler.#field
}
}
}
})
.collect();
let group_use_methods: Vec<proc_macro2::TokenStream> = group_infos
.iter()
.map(|info| {
let methods_trait_path = &info.methods_trait_path;
quote! {
#[allow(unused_imports)]
use #methods_trait_path as _;
}
})
.collect();
let execute_view_impl = quote! {
#(#group_use_methods)*
impl #execute_view_name<'_> {
#[inline]
async fn sleep(&self, duration: ::std::time::Duration) -> ::std::result::Result<(), #krate::error::ClusterError> {
self.__handler.sleep("__wf_sleep", duration).await
}
#[inline]
async fn await_deferred<T, K>(&self, key: K) -> ::std::result::Result<T, #krate::error::ClusterError>
where
T: serde::Serialize + serde::de::DeserializeOwned,
K: #krate::__internal::DeferredKeyLike<T>,
{
self.__handler.await_deferred(key).await
}
#[inline]
async fn resolve_deferred<T, K>(&self, key: K, value: &T) -> ::std::result::Result<(), #krate::error::ClusterError>
where
T: serde::Serialize,
K: #krate::__internal::DeferredKeyLike<T>,
{
self.__handler.resolve_deferred(key, value).await
}
#[inline]
async fn on_interrupt(&self) -> ::std::result::Result<(), #krate::error::ClusterError> {
self.__handler.on_interrupt().await
}
#[inline]
fn execution_id(&self) -> &str {
self.__handler.execution_id()
}
#[inline]
fn sharding(&self) -> ::std::option::Option<&::std::sync::Arc<dyn #krate::sharding::Sharding>> {
self.__handler.sharding()
}
#[inline]
fn client<T: #krate::entity_client::WorkflowClientFactory>(&self) -> T::Client {
let sharding = self.__handler.__sharding.clone()
.expect("client() requires a sharding interface");
T::workflow_client(sharding)
}
#(#activity_delegations)*
#(#execute_attrs)*
async fn execute(&self, #execute_param_name: #execute_param_type) #execute_output
#execute_block
#(#helper_execute_methods)*
}
};
let activity_view_impl = quote! {
impl #activity_view_name<'_> {
#(#activity_view_methods)*
#(#helper_activity_methods)*
}
};
let dispatch_impl = quote! {
#[async_trait::async_trait]
impl #krate::entity::EntityHandler for #handler_name {
async fn handle_request(
&self,
tag: &str,
payload: &[u8],
headers: &::std::collections::HashMap<::std::string::String, ::std::string::String>,
) -> ::std::result::Result<::std::vec::Vec<u8>, #krate::error::ClusterError> {
#[allow(unused_variables)]
let headers = headers;
match tag {
"execute" => {
let __request: #request_type = rmp_serde::from_slice(payload)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("failed to deserialize workflow request: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
let __request_id = headers
.get(#krate::__internal::REQUEST_ID_HEADER_KEY)
.and_then(|v| v.parse::<i64>().ok())
.unwrap_or(0);
let (__wf_result, __journal_keys) = #krate::__internal::WorkflowScope::run(__request_id, || async {
let __view = #execute_view_name { __handler: self };
__view.execute(__request).await
}).await;
let response = __wf_result?;
if let ::std::option::Option::Some(ref __wf_storage) = self.__state_storage {
for __key in &__journal_keys {
let _ = __wf_storage.mark_completed(__key).await;
}
}
rmp_serde::to_vec(&response)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("failed to serialize workflow response: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})
}
_ => ::std::result::Result::Err(
#krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("unknown workflow tag: {tag}"),
source: ::std::option::Option::None,
}
),
}
}
}
};
let (entity_impl, register_impl) = if has_activity_groups {
let group_struct_fields: Vec<proc_macro2::TokenStream> = group_infos
.iter()
.map(|info| {
let field = &info.field;
let path = &info.path;
quote! { #field: #path, }
})
.collect();
let group_spawn_args: Vec<proc_macro2::TokenStream> = group_infos
.iter()
.map(|info| {
let field = &info.field;
quote! { self.#field.clone(), }
})
.collect();
let group_register_params: Vec<proc_macro2::TokenStream> = group_infos
.iter()
.map(|info| {
let field = &info.field;
let path = &info.path;
quote! { #field: #path, }
})
.collect();
let group_register_field_inits: Vec<proc_macro2::TokenStream> = group_infos
.iter()
.map(|info| {
let field = &info.field;
quote! { #field, }
})
.collect();
let entity_impl_tokens = quote! {
#[doc(hidden)]
#[derive(Clone)]
pub struct #with_groups_name {
__workflow: #struct_name,
#(#group_struct_fields)*
}
#[async_trait::async_trait]
impl #krate::entity::Entity for #with_groups_name {
fn entity_type(&self) -> #krate::types::EntityType {
self.__workflow.__entity_type()
}
fn shard_group(&self) -> &str {
self.__workflow.__shard_group()
}
fn shard_group_for(&self, entity_id: &#krate::types::EntityId) -> &str {
self.__workflow.__shard_group_for(entity_id)
}
fn max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
self.__workflow.__max_idle_time()
}
fn mailbox_capacity(&self) -> ::std::option::Option<usize> {
self.__workflow.__mailbox_capacity()
}
fn concurrency(&self) -> ::std::option::Option<usize> {
self.__workflow.__concurrency()
}
async fn spawn(
&self,
ctx: #krate::entity::EntityContext,
) -> ::std::result::Result<
::std::boxed::Box<dyn #krate::entity::EntityHandler>,
#krate::error::ClusterError,
> {
let handler = #handler_name::__new(
self.__workflow.clone(),
#(#group_spawn_args)*
ctx,
).await?;
::std::result::Result::Ok(::std::boxed::Box::new(handler))
}
}
};
let register_impl_tokens = quote! {
impl #struct_name {
pub async fn register(
self,
sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>,
#(#group_register_params)*
) -> ::std::result::Result<#client_name, #krate::error::ClusterError> {
let bundle = #with_groups_name {
__workflow: self,
#(#group_register_field_inits)*
};
sharding.register_entity(::std::sync::Arc::new(bundle)).await?;
::std::result::Result::Ok(#client_name::new(sharding))
}
}
};
(entity_impl_tokens, register_impl_tokens)
} else {
let entity_impl_tokens = quote! {
#[async_trait::async_trait]
impl #krate::entity::Entity for #struct_name {
fn entity_type(&self) -> #krate::types::EntityType {
self.__entity_type()
}
fn shard_group(&self) -> &str {
self.__shard_group()
}
fn shard_group_for(&self, entity_id: &#krate::types::EntityId) -> &str {
self.__shard_group_for(entity_id)
}
fn max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
self.__max_idle_time()
}
fn mailbox_capacity(&self) -> ::std::option::Option<usize> {
self.__mailbox_capacity()
}
fn concurrency(&self) -> ::std::option::Option<usize> {
self.__concurrency()
}
async fn spawn(
&self,
ctx: #krate::entity::EntityContext,
) -> ::std::result::Result<
::std::boxed::Box<dyn #krate::entity::EntityHandler>,
#krate::error::ClusterError,
> {
let handler = #handler_name::__new(self.clone(), ctx).await?;
::std::result::Result::Ok(::std::boxed::Box::new(handler))
}
}
};
let register_impl_tokens = quote! {
impl #struct_name {
pub async fn register(
self,
sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>,
) -> ::std::result::Result<#client_name, #krate::error::ClusterError> {
sharding.register_entity(::std::sync::Arc::new(self)).await?;
::std::result::Result::Ok(#client_name::new(sharding))
}
}
};
(entity_impl_tokens, register_impl_tokens)
};
let struct_name_str = entity_name;
let client_with_key_name = format_ident!("{}ClientWithKey", struct_name);
let derive_entity_id_fn = if let Some(ref key_closure) = args.key {
if args.hash {
quote! {
fn derive_entity_id(
request: &#request_type,
) -> ::std::result::Result<#krate::types::EntityId, #krate::error::ClusterError> {
let key_value = (#key_closure)(request);
let key_bytes = rmp_serde::to_vec(&key_value)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("failed to serialize workflow key: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
::std::result::Result::Ok(#krate::types::EntityId::new(
#krate::hash::sha256_hex(&key_bytes)
))
}
}
} else {
quote! {
fn derive_entity_id(
request: &#request_type,
) -> ::std::result::Result<#krate::types::EntityId, #krate::error::ClusterError> {
let key_value = (#key_closure)(request);
::std::result::Result::Ok(#krate::types::EntityId::new(
key_value.to_string()
))
}
}
}
} else {
quote! {
fn derive_entity_id(
request: &#request_type,
) -> ::std::result::Result<#krate::types::EntityId, #krate::error::ClusterError> {
let key_bytes = rmp_serde::to_vec(request)
.map_err(|e| #krate::error::ClusterError::MalformedMessage {
reason: ::std::format!("failed to serialize workflow request: {e}"),
source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
})?;
::std::result::Result::Ok(#krate::types::EntityId::new(
#krate::hash::sha256_hex(&key_bytes)
))
}
}
};
let client_impl = quote! {
#[derive(Clone)]
pub struct #client_name {
inner: #krate::entity_client::EntityClient,
}
impl #client_name {
pub fn new(sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>) -> Self {
Self {
inner: #krate::entity_client::EntityClient::new(
sharding,
#krate::types::EntityType::new(#struct_name_str),
),
}
}
pub fn inner(&self) -> &#krate::entity_client::EntityClient {
&self.inner
}
pub fn with_key(&self, key: impl ::std::fmt::Display) -> #client_with_key_name<'_> {
let key_str = key.to_string();
let entity_id = #krate::types::EntityId::new(
#krate::hash::sha256_hex(key_str.as_bytes())
);
#client_with_key_name {
inner: &self.inner,
entity_id,
}
}
pub fn with_key_raw(&self, key: impl ::std::string::ToString) -> #client_with_key_name<'_> {
#client_with_key_name {
inner: &self.inner,
entity_id: #krate::types::EntityId::new(key.to_string()),
}
}
#derive_entity_id_fn
pub async fn execute(
&self,
request: &#request_type,
) -> ::std::result::Result<#response_type, #krate::error::ClusterError> {
let entity_id = Self::derive_entity_id(request)?;
let key_bytes = entity_id.0.as_bytes().to_vec();
self.inner.send_persisted_with_key(
&entity_id,
"execute",
request,
::std::option::Option::Some(key_bytes),
#krate::schema::Uninterruptible::No,
).await
}
pub async fn start(
&self,
request: &#request_type,
) -> ::std::result::Result<::std::string::String, #krate::error::ClusterError> {
let entity_id = Self::derive_entity_id(request)?;
let key_bytes = entity_id.0.as_bytes().to_vec();
self.inner.notify_persisted_with_key(
&entity_id,
"execute",
request,
::std::option::Option::Some(key_bytes),
).await?;
::std::result::Result::Ok(entity_id.0)
}
pub async fn poll(
&self,
execution_id: &str,
) -> ::std::result::Result<::std::option::Option<#response_type>, #krate::error::ClusterError> {
let entity_id = #krate::types::EntityId::new(execution_id);
let key_bytes = entity_id.0.as_bytes();
self.inner.poll_reply::<#response_type>(
&entity_id,
"execute",
key_bytes,
).await
}
pub async fn join(
&self,
execution_id: &str,
) -> ::std::result::Result<#response_type, #krate::error::ClusterError> {
let entity_id = #krate::types::EntityId::new(execution_id);
let key_bytes = entity_id.0.as_bytes();
self.inner.join_reply::<#response_type>(
&entity_id,
"execute",
key_bytes,
).await
}
}
impl #krate::entity_client::EntityClientAccessor for #client_name {
fn entity_client(&self) -> &#krate::entity_client::EntityClient {
&self.inner
}
}
pub struct #client_with_key_name<'a> {
inner: &'a #krate::entity_client::EntityClient,
entity_id: #krate::types::EntityId,
}
impl #client_with_key_name<'_> {
pub async fn execute(
&self,
request: &#request_type,
) -> ::std::result::Result<#response_type, #krate::error::ClusterError> {
let key_bytes = self.entity_id.0.as_bytes().to_vec();
self.inner.send_persisted_with_key(
&self.entity_id,
"execute",
request,
::std::option::Option::Some(key_bytes),
#krate::schema::Uninterruptible::No,
).await
}
pub async fn start(
&self,
request: &#request_type,
) -> ::std::result::Result<::std::string::String, #krate::error::ClusterError> {
let key_bytes = self.entity_id.0.as_bytes().to_vec();
self.inner.notify_persisted_with_key(
&self.entity_id,
"execute",
request,
::std::option::Option::Some(key_bytes),
).await?;
::std::result::Result::Ok(self.entity_id.0.clone())
}
pub async fn poll(
&self,
) -> ::std::result::Result<::std::option::Option<#response_type>, #krate::error::ClusterError> {
let key_bytes = self.entity_id.0.as_bytes();
self.inner.poll_reply::<#response_type>(
&self.entity_id,
"execute",
key_bytes,
).await
}
pub async fn join(
&self,
) -> ::std::result::Result<#response_type, #krate::error::ClusterError> {
let key_bytes = self.entity_id.0.as_bytes();
self.inner.join_reply::<#response_type>(
&self.entity_id,
"execute",
key_bytes,
).await
}
}
};
let client_factory_impl = quote! {
impl #krate::entity_client::WorkflowClientFactory for #struct_name {
type Client = #client_name;
fn workflow_client(sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>) -> #client_name {
#client_name::new(sharding)
}
}
};
Ok(quote! {
#handler_def
#view_structs
#(#group_access_impls)*
#execute_view_impl
#activity_view_impl
#dispatch_impl
#entity_impl
#register_impl
#client_impl
#client_factory_impl
})
}