1use proc_macro::TokenStream;
59use quote::{format_ident, quote};
60use std::collections::HashSet;
61use syn::{parse_macro_input, spanned::Spanned};
62
63#[proc_macro_attribute]
77pub fn entity(attr: TokenStream, item: TokenStream) -> TokenStream {
78 let args = parse_macro_input!(attr as EntityArgs);
79 let input = parse_macro_input!(item as syn::ItemStruct);
80 match entity_impl_inner(args, input) {
81 Ok(tokens) => tokens.into(),
82 Err(e) => e.to_compile_error().into(),
83 }
84}
85
86#[proc_macro_attribute]
117pub fn entity_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
118 let args = parse_macro_input!(attr as ImplArgs);
119 let input = parse_macro_input!(item as syn::ItemImpl);
120 match entity_impl_block_inner(args, input) {
121 Ok(tokens) => tokens.into(),
122 Err(e) => e.to_compile_error().into(),
123 }
124}
125
126#[proc_macro_attribute]
130pub fn entity_trait(_attr: TokenStream, item: TokenStream) -> TokenStream {
131 let _ = item;
132 syn::Error::new(
133 proc_macro2::Span::call_site(),
134 "entity traits have been removed; use #[rpc_group] instead",
135 )
136 .to_compile_error()
137 .into()
138}
139
140#[proc_macro_attribute]
144pub fn entity_trait_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
145 let _ = item;
146 syn::Error::new(
147 proc_macro2::Span::call_site(),
148 "entity trait impls have been removed; use #[rpc_group_impl] instead",
149 )
150 .to_compile_error()
151 .into()
152}
153
154#[proc_macro_attribute]
212pub fn state(_attr: TokenStream, item: TokenStream) -> TokenStream {
213 item
216}
217
218#[proc_macro_attribute]
260pub fn rpc(_attr: TokenStream, item: TokenStream) -> TokenStream {
261 item
262}
263
264#[proc_macro_attribute]
320pub fn workflow(attr: TokenStream, item: TokenStream) -> TokenStream {
321 let item_clone = item.clone();
324 if syn::parse::<syn::ItemStruct>(item_clone).is_ok() {
325 let args = parse_macro_input!(attr as WorkflowStructArgs);
327 let input = parse_macro_input!(item as syn::ItemStruct);
328 match workflow_struct_inner(args, input) {
329 Ok(tokens) => tokens.into(),
330 Err(e) => e.to_compile_error().into(),
331 }
332 } else {
333 item
335 }
336}
337
338#[proc_macro_attribute]
365pub fn workflow_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
366 let args = parse_macro_input!(attr as WorkflowImplArgs);
367 let input = parse_macro_input!(item as syn::ItemImpl);
368 match workflow_impl_inner(args, input) {
369 Ok(tokens) => tokens.into(),
370 Err(e) => e.to_compile_error().into(),
371 }
372}
373
374#[proc_macro_attribute]
433pub fn activity(_attr: TokenStream, item: TokenStream) -> TokenStream {
434 item
435}
436
437#[proc_macro_attribute]
463pub fn public(_attr: TokenStream, item: TokenStream) -> TokenStream {
464 item
465}
466
467#[proc_macro_attribute]
495pub fn protected(_attr: TokenStream, item: TokenStream) -> TokenStream {
496 item
497}
498
499#[proc_macro_attribute]
528pub fn private(_attr: TokenStream, item: TokenStream) -> TokenStream {
529 item
530}
531
532#[proc_macro_attribute]
568pub fn method(_attr: TokenStream, item: TokenStream) -> TokenStream {
569 item
570}
571
572#[proc_macro_attribute]
607pub fn rpc_group(attr: TokenStream, item: TokenStream) -> TokenStream {
608 let _args = parse_macro_input!(attr as TraitArgs);
609 let input = parse_macro_input!(item as syn::ItemStruct);
610 quote! { #input }.into()
612}
613
614#[proc_macro_attribute]
644pub fn rpc_group_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
645 let args = parse_macro_input!(attr as RpcGroupImplArgs);
646 let input = parse_macro_input!(item as syn::ItemImpl);
647 match rpc_group_impl_inner(args, input) {
648 Ok(tokens) => tokens.into(),
649 Err(e) => e.to_compile_error().into(),
650 }
651}
652
653#[proc_macro_attribute]
697pub fn activity_group(attr: TokenStream, item: TokenStream) -> TokenStream {
698 let _args = parse_macro_input!(attr as TraitArgs);
699 let input = parse_macro_input!(item as syn::ItemStruct);
700 quote! { #input }.into()
702}
703
704#[proc_macro_attribute]
733pub fn activity_group_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
734 let args = parse_macro_input!(attr as ActivityGroupImplArgs);
735 let input = parse_macro_input!(item as syn::ItemImpl);
736 match activity_group_impl_inner(args, input) {
737 Ok(tokens) => tokens.into(),
738 Err(e) => e.to_compile_error().into(),
739 }
740}
741
742struct EntityArgs {
745 name: Option<String>,
746 shard_group: Option<String>,
747 max_idle_time_secs: Option<u64>,
748 mailbox_capacity: Option<usize>,
749 concurrency: Option<usize>,
750 krate: Option<syn::Path>,
751}
752
753impl syn::parse::Parse for EntityArgs {
754 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
755 let mut args = EntityArgs {
756 name: None,
757 shard_group: None,
758 max_idle_time_secs: None,
759 mailbox_capacity: None,
760 concurrency: None,
761 krate: None,
762 };
763
764 while !input.is_empty() {
765 let ident: syn::Ident = input.parse()?;
766 input.parse::<syn::Token![=]>()?;
767
768 match ident.to_string().as_str() {
769 "name" => {
770 let lit: syn::LitStr = input.parse()?;
771 args.name = Some(lit.value());
772 }
773 "shard_group" => {
774 let lit: syn::LitStr = input.parse()?;
775 args.shard_group = Some(lit.value());
776 }
777 "max_idle_time_secs" => {
778 let lit: syn::LitInt = input.parse()?;
779 args.max_idle_time_secs = Some(lit.base10_parse()?);
780 }
781 "mailbox_capacity" => {
782 let lit: syn::LitInt = input.parse()?;
783 args.mailbox_capacity = Some(lit.base10_parse()?);
784 }
785 "concurrency" => {
786 let lit: syn::LitInt = input.parse()?;
787 args.concurrency = Some(lit.base10_parse()?);
788 }
789 "krate" => {
790 let lit: syn::LitStr = input.parse()?;
791 args.krate = Some(lit.parse()?);
792 }
793 other => {
794 return Err(syn::Error::new(
795 ident.span(),
796 format!("unknown entity attribute: {other}"),
797 ));
798 }
799 }
800
801 if !input.is_empty() {
802 input.parse::<syn::Token![,]>()?;
803 }
804 }
805
806 Ok(args)
807 }
808}
809
810struct ImplArgs {
811 krate: Option<syn::Path>,
812 traits: Vec<syn::Path>,
813 rpc_groups: Vec<syn::Path>,
814 deferred_keys: Vec<DeferredKeyDecl>,
815}
816
817struct DeferredKeyDecl {
818 ident: syn::Ident,
819 ty: syn::Type,
820 name: syn::LitStr,
821}
822
823impl syn::parse::Parse for DeferredKeyDecl {
824 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
825 let ident: syn::Ident = input.parse()?;
826 input.parse::<syn::Token![:]>()?;
827 let ty: syn::Type = input.parse()?;
828 if !input.peek(syn::Token![=]) {
829 return Err(syn::Error::new(
830 input.span(),
831 "expected `= \"name\"` for deferred key",
832 ));
833 }
834 input.parse::<syn::Token![=]>()?;
835 let name: syn::LitStr = input.parse()?;
836 Ok(DeferredKeyDecl { ident, ty, name })
837 }
838}
839
840impl syn::parse::Parse for ImplArgs {
841 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
842 let mut args = ImplArgs {
843 krate: None,
844 traits: Vec::new(),
845 rpc_groups: Vec::new(),
846 deferred_keys: Vec::new(),
847 };
848 while !input.is_empty() {
849 let ident: syn::Ident = input.parse()?;
850 match ident.to_string().as_str() {
851 "krate" => {
852 input.parse::<syn::Token![=]>()?;
853 let lit: syn::LitStr = input.parse()?;
854 args.krate = Some(lit.parse()?);
855 }
856 "traits" => {
857 let content;
858 syn::parenthesized!(content in input);
859 while !content.is_empty() {
860 let path: syn::Path = content.parse()?;
861 args.traits.push(path);
862 if !content.is_empty() {
863 content.parse::<syn::Token![,]>()?;
864 }
865 }
866 }
867 "rpc_groups" => {
868 let content;
869 syn::parenthesized!(content in input);
870 while !content.is_empty() {
871 let path: syn::Path = content.parse()?;
872 args.rpc_groups.push(path);
873 if !content.is_empty() {
874 content.parse::<syn::Token![,]>()?;
875 }
876 }
877 }
878 "deferred_keys" => {
879 let content;
880 syn::parenthesized!(content in input);
881 let decls = content.parse_terminated(DeferredKeyDecl::parse, syn::Token![,])?;
882 args.deferred_keys.extend(decls);
883 }
884 other => {
885 return Err(syn::Error::new(
886 ident.span(),
887 format!("unknown entity_impl attribute: {other}"),
888 ));
889 }
890 }
891 if !input.is_empty() {
892 input.parse::<syn::Token![,]>()?;
893 }
894 }
895 Ok(args)
896 }
897}
898
899struct TraitArgs {
900 krate: Option<syn::Path>,
901}
902
903impl syn::parse::Parse for TraitArgs {
904 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
905 let mut args = TraitArgs { krate: None };
906 while !input.is_empty() {
907 let ident: syn::Ident = input.parse()?;
908 input.parse::<syn::Token![=]>()?;
909 match ident.to_string().as_str() {
910 "krate" => {
911 let lit: syn::LitStr = input.parse()?;
912 args.krate = Some(lit.parse()?);
913 }
914 other => {
915 return Err(syn::Error::new(
916 ident.span(),
917 format!("unknown attribute: {other}"),
918 ));
919 }
920 }
921 if !input.is_empty() {
922 input.parse::<syn::Token![,]>()?;
923 }
924 }
925 Ok(args)
926 }
927}
928
929fn default_crate_path() -> syn::Path {
930 syn::parse_str("cruster").unwrap()
931}
932
933fn replace_last_segment(path: &syn::Path, ident: syn::Ident) -> syn::Path {
934 let mut new_path = path.clone();
935 if let Some(last) = new_path.segments.last_mut() {
936 last.ident = ident;
937 last.arguments = syn::PathArguments::None;
938 }
939 new_path
940}
941
942#[allow(dead_code)]
943struct RpcGroupInfo {
944 path: syn::Path,
945 ident: syn::Ident,
946 field: syn::Ident,
947 wrapper_path: syn::Path,
948 access_trait_path: syn::Path,
949 methods_trait_path: syn::Path,
950}
951
952fn rpc_group_infos_from_paths(paths: &[syn::Path]) -> Vec<RpcGroupInfo> {
953 paths
954 .iter()
955 .map(|path| {
956 let ident = path
957 .segments
958 .last()
959 .expect("rpc group path missing segment")
960 .ident
961 .clone();
962 let snake = to_snake(&ident.to_string());
963 let field = format_ident!("__rpc_group_{}", snake);
964 let wrapper_ident = format_ident!("__{}RpcGroupWrapper", ident);
965 let access_trait_ident = format_ident!("__{}RpcGroupAccess", ident);
966 let methods_trait_ident = format_ident!("__{}RpcGroupMethods", ident);
967 let wrapper_path = replace_last_segment(path, wrapper_ident);
968 let access_trait_path = replace_last_segment(path, access_trait_ident);
969 let methods_trait_path = replace_last_segment(path, methods_trait_ident);
970 RpcGroupInfo {
971 path: path.clone(),
972 ident,
973 field,
974 wrapper_path,
975 access_trait_path,
976 methods_trait_path,
977 }
978 })
979 .collect()
980}
981
982fn entity_impl_inner(
985 args: EntityArgs,
986 input: syn::ItemStruct,
987) -> syn::Result<proc_macro2::TokenStream> {
988 let krate = args.krate.clone().unwrap_or_else(default_crate_path);
989 let struct_name = &input.ident;
990 let entity_name = args.name.unwrap_or_else(|| struct_name.to_string());
991 let shard_group_value = if let Some(sg) = &args.shard_group {
992 quote! { #sg }
993 } else {
994 quote! { "default" }
995 };
996 let max_idle_value = if let Some(secs) = args.max_idle_time_secs {
997 quote! { ::std::option::Option::Some(::std::time::Duration::from_secs(#secs)) }
998 } else {
999 quote! { ::std::option::Option::None }
1000 };
1001 let mailbox_value = if let Some(cap) = args.mailbox_capacity {
1002 quote! { ::std::option::Option::Some(#cap) }
1003 } else {
1004 quote! { ::std::option::Option::None }
1005 };
1006 let concurrency_value = if let Some(c) = args.concurrency {
1007 quote! { ::std::option::Option::Some(#c) }
1008 } else {
1009 quote! { ::std::option::Option::None }
1010 };
1011
1012 Ok(quote! {
1013 #input
1014
1015 #[allow(dead_code)]
1016 impl #struct_name {
1017 #[doc(hidden)]
1018 fn __entity_type(&self) -> #krate::types::EntityType {
1019 #krate::types::EntityType::new(#entity_name)
1020 }
1021
1022 #[doc(hidden)]
1023 fn __shard_group(&self) -> &str {
1024 #shard_group_value
1025 }
1026
1027 #[doc(hidden)]
1028 fn __shard_group_for(&self, _entity_id: &#krate::types::EntityId) -> &str {
1029 self.__shard_group()
1030 }
1031
1032 #[doc(hidden)]
1033 fn __max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
1034 #max_idle_value
1035 }
1036
1037 #[doc(hidden)]
1038 fn __mailbox_capacity(&self) -> ::std::option::Option<usize> {
1039 #mailbox_value
1040 }
1041
1042 #[doc(hidden)]
1043 fn __concurrency(&self) -> ::std::option::Option<usize> {
1044 #concurrency_value
1045 }
1046 }
1047 })
1048}
1049
1050#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1053enum RpcKind {
1054 Rpc,
1055 Workflow,
1056 Activity,
1057 Method,
1058}
1059
1060#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1061enum RpcVisibility {
1062 Public,
1063 Protected,
1064 Private,
1065}
1066
1067impl RpcKind {
1068 fn is_persisted(&self) -> bool {
1069 matches!(self, RpcKind::Workflow | RpcKind::Activity)
1070 }
1071}
1072
1073impl RpcMethod {
1074 fn uses_persisted_delivery(&self) -> bool {
1076 self.kind.is_persisted() || self.rpc_persisted
1077 }
1078}
1079
1080impl RpcVisibility {
1081 fn is_public(&self) -> bool {
1082 matches!(self, RpcVisibility::Public)
1083 }
1084
1085 fn is_private(&self) -> bool {
1086 matches!(self, RpcVisibility::Private)
1087 }
1088}
1089
1090struct RpcMethod {
1091 name: syn::Ident,
1092 tag: String,
1093 params: Vec<RpcParam>,
1094 response_type: syn::Type,
1095 is_mut: bool,
1096 kind: RpcKind,
1097 visibility: RpcVisibility,
1098 persist_key: Option<syn::ExprClosure>,
1100 #[allow(dead_code)]
1102 has_durable_context: bool,
1103 rpc_persisted: bool,
1105 #[allow(dead_code)]
1108 retries: Option<u32>,
1109 #[allow(dead_code)]
1112 backoff: Option<String>,
1113}
1114
1115impl RpcMethod {
1116 fn is_dispatchable(&self) -> bool {
1117 self.visibility.is_public() && !matches!(self.kind, RpcKind::Activity | RpcKind::Method)
1118 }
1119
1120 fn is_client_visible(&self) -> bool {
1121 self.visibility.is_public() && !matches!(self.kind, RpcKind::Activity | RpcKind::Method)
1122 }
1123
1124 fn is_trait_visible(&self) -> bool {
1125 !self.visibility.is_private() && !matches!(self.kind, RpcKind::Method)
1126 }
1127}
1128
1129struct RpcParam {
1130 name: syn::Ident,
1131 ty: syn::Type,
1132}
1133
1134fn entity_impl_block_inner(
1135 args: ImplArgs,
1136 input: syn::ItemImpl,
1137) -> syn::Result<proc_macro2::TokenStream> {
1138 let krate = args.krate.unwrap_or_else(default_crate_path);
1139 let traits = args.traits;
1140 let rpc_groups = args.rpc_groups;
1141 let deferred_keys = args.deferred_keys;
1142 let mut input = input;
1143 let self_ty = &input.self_ty;
1144
1145 let state_info = parse_state_attr(&mut input.attrs)?;
1147 if let Some(ref info) = state_info {
1148 return Err(syn::Error::new(
1149 info.span,
1150 "entities are stateless; use a database for state management",
1151 ));
1152 }
1153
1154 if !traits.is_empty() {
1156 return Err(syn::Error::new(
1157 proc_macro2::Span::call_site(),
1158 "entity traits have been replaced by #[rpc_group]; use #[entity_impl(rpc_groups(...))] instead",
1159 ));
1160 }
1161
1162 let struct_name = match self_ty.as_ref() {
1163 syn::Type::Path(tp) => tp
1164 .path
1165 .segments
1166 .last()
1167 .map(|s| s.ident.clone())
1168 .ok_or_else(|| syn::Error::new(self_ty.span(), "expected struct name"))?,
1169 _ => return Err(syn::Error::new(self_ty.span(), "expected struct name")),
1170 };
1171
1172 let handler_name = format_ident!("{}Handler", struct_name);
1173 let client_name = format_ident!("{}Client", struct_name);
1174
1175 let mut rpcs = Vec::new();
1176 let mut original_methods = Vec::new();
1177
1178 for item in &input.items {
1179 match item {
1180 syn::ImplItem::Type(type_item) if type_item.ident == "State" => {
1181 return Err(syn::Error::new(
1182 type_item.span(),
1183 "entities are stateless; use a database for state management",
1184 ));
1185 }
1186 syn::ImplItem::Fn(method) => {
1187 let has_rpc_attrs = parse_kind_attr(&method.attrs)?.is_some()
1188 || parse_visibility_attr(&method.attrs)?.is_some();
1189
1190 if method.sig.ident == "init" && method.sig.asyncness.is_none() {
1192 return Err(syn::Error::new(
1193 method.sig.span(),
1194 "entities are stateless; `fn init` is no longer needed",
1195 ));
1196 }
1197
1198 if method.sig.asyncness.is_some() {
1199 if let Some(rpc) = parse_rpc_method(method)? {
1200 if matches!(rpc.kind, RpcKind::Workflow) {
1202 return Err(syn::Error::new(
1203 method.sig.span(),
1204 "use standalone #[workflow] for durable orchestration; \
1205 entities only support #[rpc] and #[rpc(persisted)]",
1206 ));
1207 }
1208 if matches!(rpc.kind, RpcKind::Activity) {
1210 return Err(syn::Error::new(
1211 method.sig.span(),
1212 "activities belong on workflows, not entities; \
1213 use standalone #[workflow] with #[activity] methods",
1214 ));
1215 }
1216 if rpc.is_mut {
1218 return Err(syn::Error::new(
1219 method.sig.span(),
1220 "entity methods must use `&self`; \
1221 entities are stateless and do not support `&mut self`",
1222 ));
1223 }
1224 rpcs.push(rpc);
1225 }
1226 } else if has_rpc_attrs {
1227 return Err(syn::Error::new(
1228 method.sig.span(),
1229 "RPC annotations are only valid on async methods",
1230 ));
1231 }
1232 original_methods.push(method.clone());
1233 }
1234 _ => {}
1235 }
1236 }
1237
1238 let rpc_group_infos = rpc_group_infos_from_paths(&rpc_groups);
1239
1240 let entity_tokens = generate_pure_rpc_entity(
1241 &krate,
1242 &struct_name,
1243 &handler_name,
1244 &client_name,
1245 &rpc_group_infos,
1246 &rpcs,
1247 &original_methods,
1248 )?;
1249
1250 let deferred_consts = generate_deferred_key_consts(&krate, &deferred_keys)?;
1251
1252 Ok(quote! {
1253 #entity_tokens
1254 #deferred_consts
1255 })
1256}
1257
1258fn generate_deferred_key_consts(
1259 krate: &syn::Path,
1260 deferred_keys: &[DeferredKeyDecl],
1261) -> syn::Result<proc_macro2::TokenStream> {
1262 if deferred_keys.is_empty() {
1263 return Ok(quote! {});
1264 }
1265
1266 let mut seen_idents = HashSet::new();
1267 let mut seen_names = HashSet::new();
1268 for decl in deferred_keys {
1269 let ident = decl.ident.to_string();
1270 if !seen_idents.insert(ident.clone()) {
1271 return Err(syn::Error::new(
1272 decl.ident.span(),
1273 format!("duplicate deferred key constant: {ident}"),
1274 ));
1275 }
1276 let name = decl.name.value();
1277 if !seen_names.insert(name.clone()) {
1278 return Err(syn::Error::new(
1279 decl.name.span(),
1280 format!("duplicate deferred key name: {name}"),
1281 ));
1282 }
1283 }
1284
1285 let consts: Vec<_> = deferred_keys
1286 .iter()
1287 .map(|decl| {
1288 let ident = &decl.ident;
1289 let ty = &decl.ty;
1290 let name = &decl.name;
1291 quote! {
1292 #[allow(dead_code)]
1293 pub const #ident: #krate::__internal::DeferredKey<#ty> =
1294 #krate::__internal::DeferredKey::new(#name);
1295 }
1296 })
1297 .collect();
1298
1299 Ok(quote! {
1300 #(#consts)*
1301 })
1302}
1303
1304#[allow(clippy::too_many_arguments)]
1316fn generate_pure_rpc_entity(
1317 krate: &syn::Path,
1318 struct_name: &syn::Ident,
1319 handler_name: &syn::Ident,
1320 client_name: &syn::Ident,
1321 rpc_group_infos: &[RpcGroupInfo],
1322 rpcs: &[RpcMethod],
1323 original_methods: &[syn::ImplItemFn],
1324) -> syn::Result<proc_macro2::TokenStream> {
1325 let has_rpc_groups = !rpc_group_infos.is_empty();
1326 let struct_name_str = struct_name.to_string();
1327 let rpc_view_name = format_ident!("__{}RpcView", struct_name);
1328
1329 let entity_impl = if has_rpc_groups {
1331 quote! {}
1332 } else {
1333 quote! {
1334 #[async_trait::async_trait]
1335 impl #krate::entity::Entity for #struct_name {
1336 fn entity_type(&self) -> #krate::types::EntityType {
1337 self.__entity_type()
1338 }
1339
1340 fn shard_group(&self) -> &str {
1341 self.__shard_group()
1342 }
1343
1344 fn shard_group_for(&self, entity_id: &#krate::types::EntityId) -> &str {
1345 self.__shard_group_for(entity_id)
1346 }
1347
1348 fn max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
1349 self.__max_idle_time()
1350 }
1351
1352 fn mailbox_capacity(&self) -> ::std::option::Option<usize> {
1353 self.__mailbox_capacity()
1354 }
1355
1356 fn concurrency(&self) -> ::std::option::Option<usize> {
1357 self.__concurrency()
1358 }
1359
1360 async fn spawn(
1361 &self,
1362 ctx: #krate::entity::EntityContext,
1363 ) -> ::std::result::Result<
1364 ::std::boxed::Box<dyn #krate::entity::EntityHandler>,
1365 #krate::error::ClusterError,
1366 > {
1367 let handler = #handler_name::__new(self.clone(), ctx).await?;
1368 ::std::result::Result::Ok(::std::boxed::Box::new(handler))
1369 }
1370 }
1371 }
1372 };
1373
1374 let dispatch_arms: Vec<proc_macro2::TokenStream> = rpcs
1376 .iter()
1377 .filter(|rpc| rpc.is_dispatchable())
1378 .map(|rpc| {
1379 let tag = &rpc.tag;
1380 let method_name = &rpc.name;
1381 let param_count = rpc.params.len();
1382 let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
1383 let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
1384
1385 let deserialize_request = match param_count {
1386 0 => quote! {},
1387 1 => {
1388 let name = ¶m_names[0];
1389 let ty = ¶m_types[0];
1390 quote! {
1391 let #name: #ty = rmp_serde::from_slice(payload)
1392 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
1393 reason: ::std::format!("failed to deserialize request for '{}': {e}", #tag),
1394 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
1395 })?;
1396 }
1397 }
1398 _ => quote! {
1399 let (#(#param_names),*): (#(#param_types),*) = rmp_serde::from_slice(payload)
1400 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
1401 reason: ::std::format!("failed to deserialize request for '{}': {e}", #tag),
1402 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
1403 })?;
1404 },
1405 };
1406
1407 let mut call_args: Vec<proc_macro2::TokenStream> = Vec::new();
1408 for name in ¶m_names {
1409 call_args.push(quote! { #name });
1410 }
1411 let call_args = quote! { #(#call_args),* };
1412 let method_call = quote! { __view.#method_name(#call_args).await? };
1413
1414 quote! {
1415 #tag => {
1416 let __view = #rpc_view_name { __handler: self };
1417 #deserialize_request
1418 let response = { #method_call };
1419 rmp_serde::to_vec(&response)
1420 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
1421 reason: ::std::format!("failed to serialize response for '{}': {e}", #tag),
1422 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
1423 })
1424 }
1425 }
1426 })
1427 .collect();
1428
1429 let client_methods = generate_client_methods(krate, rpcs);
1431
1432 let rpc_group_field_defs: Vec<proc_macro2::TokenStream> = rpc_group_infos
1434 .iter()
1435 .map(|info| {
1436 let field = &info.field;
1437 let wrapper_path = &info.wrapper_path;
1438 quote! { #field: #wrapper_path, }
1439 })
1440 .collect();
1441
1442 let rpc_group_params: Vec<proc_macro2::TokenStream> = rpc_group_infos
1443 .iter()
1444 .map(|info| {
1445 let path = &info.path;
1446 let field = &info.field;
1447 quote! { #field: #path }
1448 })
1449 .collect();
1450
1451 let rpc_group_field_inits: Vec<proc_macro2::TokenStream> = rpc_group_infos
1452 .iter()
1453 .map(|info| {
1454 let field = &info.field;
1455 let wrapper_path = &info.wrapper_path;
1456 quote! { #field: #wrapper_path::new(#field, __entity_address.clone()), }
1457 })
1458 .collect();
1459
1460 let rpc_group_dispatch_checks: Vec<proc_macro2::TokenStream> = rpc_group_infos
1461 .iter()
1462 .map(|info| {
1463 let field = &info.field;
1464 quote! {
1465 if let ::std::option::Option::Some(response) = self.#field.__dispatch(tag, payload, headers).await? {
1466 return ::std::result::Result::Ok(response);
1467 }
1468 }
1469 })
1470 .collect();
1471
1472 let rpc_group_handler_access_impls: Vec<proc_macro2::TokenStream> = rpc_group_infos
1473 .iter()
1474 .map(|info| {
1475 let wrapper_path = &info.wrapper_path;
1476 let access_trait_path = &info.access_trait_path;
1477 let field = &info.field;
1478 quote! {
1479 impl #access_trait_path for #handler_name {
1480 fn __rpc_group_wrapper(&self) -> &#wrapper_path {
1481 &self.#field
1482 }
1483 }
1484 }
1485 })
1486 .collect();
1487
1488 let rpc_group_use_tokens: Vec<proc_macro2::TokenStream> = rpc_group_infos
1489 .iter()
1490 .map(|info| {
1491 let methods_trait_path = &info.methods_trait_path;
1492 quote! {
1493 #[allow(unused_imports)]
1494 use #methods_trait_path as _;
1495 }
1496 })
1497 .collect();
1498
1499 let dispatch_fallback = if has_rpc_groups {
1500 quote! {{
1501 #(#rpc_group_dispatch_checks)*
1502 ::std::result::Result::Err(
1503 #krate::error::ClusterError::MalformedMessage {
1504 reason: ::std::format!("unknown RPC tag: {tag}"),
1505 source: ::std::option::Option::None,
1506 }
1507 )
1508 }}
1509 } else {
1510 quote! {{
1511 ::std::result::Result::Err(
1512 #krate::error::ClusterError::MalformedMessage {
1513 reason: ::std::format!("unknown RPC tag: {tag}"),
1514 source: ::std::option::Option::None,
1515 }
1516 )
1517 }}
1518 };
1519
1520 let method_impls: Vec<proc_macro2::TokenStream> = original_methods
1522 .iter()
1523 .map(|m| {
1524 let attrs: Vec<_> = m
1525 .attrs
1526 .iter()
1527 .filter(|a| {
1528 !a.path().is_ident("rpc")
1529 && !a.path().is_ident("workflow")
1530 && !a.path().is_ident("activity")
1531 && !a.path().is_ident("method")
1532 && !a.path().is_ident("public")
1533 && !a.path().is_ident("protected")
1534 && !a.path().is_ident("private")
1535 })
1536 .collect();
1537 let vis = &m.vis;
1538 let sig = &m.sig;
1539 let block = &m.block;
1540 quote! {
1541 #(#attrs)*
1542 #vis #sig #block
1543 }
1544 })
1545 .collect();
1546
1547 let new_fn = if has_rpc_groups {
1549 quote! {}
1550 } else {
1551 quote! {
1552 #[doc(hidden)]
1553 pub async fn __new(entity: #struct_name, ctx: #krate::entity::EntityContext) -> ::std::result::Result<Self, #krate::error::ClusterError> {
1554 let __sharding = ctx.sharding.clone();
1555 let __entity_address = ctx.address.clone();
1556 let __message_storage = ctx.message_storage.clone();
1557 ::std::result::Result::Ok(Self {
1558 __entity: entity,
1559 ctx,
1560 __sharding,
1561 __entity_address,
1562 __message_storage,
1563 })
1564 }
1565 }
1566 };
1567
1568 let new_with_rpc_groups_fn = if has_rpc_groups {
1569 quote! {
1570 #[doc(hidden)]
1571 pub async fn __new_with_rpc_groups(
1572 entity: #struct_name,
1573 #(#rpc_group_params,)*
1574 ctx: #krate::entity::EntityContext,
1575 ) -> ::std::result::Result<Self, #krate::error::ClusterError> {
1576 let __sharding = ctx.sharding.clone();
1577 let __entity_address = ctx.address.clone();
1578 let __message_storage = ctx.message_storage.clone();
1579 ::std::result::Result::Ok(Self {
1580 __entity: entity,
1581 ctx,
1582 __sharding,
1583 __message_storage,
1584 #(#rpc_group_field_inits)*
1585 __entity_address,
1586 })
1587 }
1588 }
1589 } else {
1590 quote! {}
1591 };
1592
1593 let sharding_builtin_impls = quote! {
1595 pub fn sharding(&self) -> ::std::option::Option<&::std::sync::Arc<dyn #krate::sharding::Sharding>> {
1597 self.__sharding.as_ref()
1598 }
1599
1600 pub fn entity_address(&self) -> &#krate::types::EntityAddress {
1602 &self.__entity_address
1603 }
1604
1605 pub fn entity_id(&self) -> &#krate::types::EntityId {
1607 &self.__entity_address.entity_id
1608 }
1609
1610 pub fn self_client(&self) -> ::std::option::Option<#krate::entity_client::EntityClient> {
1612 self.__sharding.as_ref().map(|s| {
1613 ::std::sync::Arc::clone(s).make_client(self.__entity_address.entity_type.clone())
1614 })
1615 }
1616 };
1617
1618 let with_rpc_groups_impl = if has_rpc_groups {
1620 let with_groups_name = format_ident!("{}WithRpcGroups", struct_name);
1621 let rpc_group_option_fields: Vec<proc_macro2::TokenStream> = rpc_group_infos
1622 .iter()
1623 .map(|info| {
1624 let field = &info.field;
1625 let path = &info.path;
1626 quote! { #field: #path, }
1627 })
1628 .collect();
1629 let register_rpc_group_params: Vec<proc_macro2::TokenStream> = rpc_group_infos
1630 .iter()
1631 .map(|info| {
1632 let field = &info.field;
1633 let path = &info.path;
1634 quote! { #field: #path }
1635 })
1636 .collect();
1637 let register_rpc_group_fields: Vec<_> =
1638 rpc_group_infos.iter().map(|info| &info.field).collect();
1639 quote! {
1640 #[doc(hidden)]
1642 pub struct #with_groups_name {
1643 pub entity: #struct_name,
1644 #(pub #rpc_group_option_fields)*
1645 }
1646
1647 #[async_trait::async_trait]
1648 impl #krate::entity::Entity for #with_groups_name
1649 where
1650 #struct_name: ::std::clone::Clone,
1651 {
1652 fn entity_type(&self) -> #krate::types::EntityType {
1653 self.entity.__entity_type()
1654 }
1655
1656 fn shard_group(&self) -> &str {
1657 self.entity.__shard_group()
1658 }
1659
1660 fn shard_group_for(&self, entity_id: &#krate::types::EntityId) -> &str {
1661 self.entity.__shard_group_for(entity_id)
1662 }
1663
1664 fn max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
1665 self.entity.__max_idle_time()
1666 }
1667
1668 fn mailbox_capacity(&self) -> ::std::option::Option<usize> {
1669 self.entity.__mailbox_capacity()
1670 }
1671
1672 fn concurrency(&self) -> ::std::option::Option<usize> {
1673 self.entity.__concurrency()
1674 }
1675
1676 async fn spawn(
1677 &self,
1678 ctx: #krate::entity::EntityContext,
1679 ) -> ::std::result::Result<
1680 ::std::boxed::Box<dyn #krate::entity::EntityHandler>,
1681 #krate::error::ClusterError,
1682 > {
1683 let handler = #handler_name::__new_with_rpc_groups(
1684 self.entity.clone(),
1685 #(self.#register_rpc_group_fields.clone(),)*
1686 ctx,
1687 )
1688 .await?;
1689 ::std::result::Result::Ok(::std::boxed::Box::new(handler))
1690 }
1691 }
1692
1693 impl #struct_name {
1694 pub async fn register(
1696 self,
1697 sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>,
1698 #(#register_rpc_group_params),*
1699 ) -> ::std::result::Result<#client_name, #krate::error::ClusterError> {
1700 let entity_with_groups = #with_groups_name {
1701 entity: self,
1702 #(#register_rpc_group_fields,)*
1703 };
1704 sharding.register_entity(::std::sync::Arc::new(entity_with_groups)).await?;
1705 ::std::result::Result::Ok(#client_name::new(sharding))
1706 }
1707 }
1708 }
1709 } else {
1710 quote! {}
1711 };
1712
1713 let register_impl = if has_rpc_groups {
1715 quote! {} } else {
1717 quote! {
1718 impl #struct_name {
1719 pub async fn register(
1721 self,
1722 sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>,
1723 ) -> ::std::result::Result<#client_name, #krate::error::ClusterError> {
1724 sharding.register_entity(::std::sync::Arc::new(self)).await?;
1725 ::std::result::Result::Ok(#client_name::new(sharding))
1726 }
1727 }
1728 }
1729 };
1730
1731 Ok(quote! {
1732 #(#rpc_group_use_tokens)*
1733
1734 #with_rpc_groups_impl
1735 #entity_impl
1736
1737 #[doc(hidden)]
1739 pub struct #handler_name {
1740 #[allow(dead_code)]
1742 __entity: #struct_name,
1743 #[allow(dead_code)]
1745 ctx: #krate::entity::EntityContext,
1746 __sharding: ::std::option::Option<::std::sync::Arc<dyn #krate::sharding::Sharding>>,
1748 __entity_address: #krate::types::EntityAddress,
1750 #[allow(dead_code)]
1752 __message_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::MessageStorage>>,
1753 #(#rpc_group_field_defs)*
1754 }
1755
1756 impl #handler_name {
1757 #new_fn
1758 #new_with_rpc_groups_fn
1759
1760 #sharding_builtin_impls
1761 }
1762
1763 #[doc(hidden)]
1768 #[allow(non_camel_case_types)]
1769 struct #rpc_view_name<'a> {
1770 __handler: &'a #handler_name,
1771 }
1772
1773 impl ::std::ops::Deref for #rpc_view_name<'_> {
1774 type Target = #struct_name;
1775 fn deref(&self) -> &Self::Target {
1776 &self.__handler.__entity
1777 }
1778 }
1779
1780 impl #rpc_view_name<'_> {
1781 #[inline]
1783 fn entity_id(&self) -> &str {
1784 &self.__handler.__entity_address.entity_id.0
1785 }
1786
1787 #[inline]
1789 fn entity_address(&self) -> &#krate::types::EntityAddress {
1790 &self.__handler.__entity_address
1791 }
1792
1793 #[inline]
1795 fn sharding(&self) -> ::std::option::Option<&::std::sync::Arc<dyn #krate::sharding::Sharding>> {
1796 self.__handler.__sharding.as_ref()
1797 }
1798
1799 #[inline]
1801 fn self_client(&self) -> ::std::option::Option<#krate::entity_client::EntityClient> {
1802 self.__handler.__sharding.as_ref().map(|s| {
1803 ::std::sync::Arc::clone(s).make_client(self.__handler.__entity_address.entity_type.clone())
1804 })
1805 }
1806
1807 #(#method_impls)*
1809 }
1810
1811 #[async_trait::async_trait]
1812 impl #krate::entity::EntityHandler for #handler_name {
1813 async fn handle_request(
1814 &self,
1815 tag: &str,
1816 payload: &[u8],
1817 headers: &::std::collections::HashMap<::std::string::String, ::std::string::String>,
1818 ) -> ::std::result::Result<::std::vec::Vec<u8>, #krate::error::ClusterError> {
1819 #[allow(unused_variables)]
1820 let headers = headers;
1821 match tag {
1822 #(#dispatch_arms,)*
1823 _ => #dispatch_fallback,
1824 }
1825 }
1826 }
1827
1828 #register_impl
1829
1830 #[derive(Clone)]
1834 pub struct #client_name {
1835 inner: #krate::entity_client::EntityClient,
1836 }
1837
1838 impl #client_name {
1839 pub fn new(sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>) -> Self {
1841 Self {
1842 inner: #krate::entity_client::EntityClient::new(
1843 sharding,
1844 #krate::types::EntityType::new(#struct_name_str),
1845 ),
1846 }
1847 }
1848
1849 pub fn inner(&self) -> &#krate::entity_client::EntityClient {
1851 &self.inner
1852 }
1853
1854 #(#client_methods)*
1855 }
1856
1857 impl #krate::entity_client::EntityClientAccessor for #client_name {
1858 fn entity_client(&self) -> &#krate::entity_client::EntityClient {
1859 &self.inner
1860 }
1861 }
1862
1863 #(#rpc_group_handler_access_impls)*
1864 })
1865}
1866
1867fn generate_client_methods(krate: &syn::Path, rpcs: &[RpcMethod]) -> Vec<proc_macro2::TokenStream> {
1868 rpcs.iter()
1869 .filter(|rpc| rpc.is_client_visible())
1870 .map(|rpc| {
1871 let method_name = &rpc.name;
1872 let tag = &rpc.tag;
1873 let resp_type = &rpc.response_type;
1874 let persist_key = rpc.persist_key.as_ref();
1875 let param_count = rpc.params.len();
1876 let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
1877 let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
1878 let param_defs: Vec<_> = param_names
1879 .iter()
1880 .zip(param_types.iter())
1881 .map(|(name, ty)| quote! { #name: &#ty })
1882 .collect();
1883 if rpc.uses_persisted_delivery() {
1884 match (persist_key, param_count) {
1885 (Some(persist_key), 0) => quote! {
1886 pub async fn #method_name(
1887 &self,
1888 entity_id: &#krate::types::EntityId,
1889 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
1890 let key = (#persist_key)();
1891 let key_bytes = rmp_serde::to_vec(&key)
1892 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
1893 reason: ::std::format!(
1894 "failed to serialize persist key for '{}': {e}",
1895 #tag
1896 ),
1897 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
1898 })?;
1899 self.inner
1900 .send_persisted_with_key(
1901 entity_id,
1902 #tag,
1903 &(),
1904 ::std::option::Option::Some(key_bytes),
1905 #krate::schema::Uninterruptible::No,
1906 )
1907 .await
1908 }
1909 },
1910 (Some(persist_key), 1) => {
1911 let name = ¶m_names[0];
1912 let def = ¶m_defs[0];
1913 quote! {
1914 pub async fn #method_name(
1915 &self,
1916 entity_id: &#krate::types::EntityId,
1917 #def,
1918 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
1919 let key = (#persist_key)(#name);
1920 let key_bytes = rmp_serde::to_vec(&key)
1921 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
1922 reason: ::std::format!(
1923 "failed to serialize persist key for '{}': {e}",
1924 #tag
1925 ),
1926 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
1927 })?;
1928 self.inner
1929 .send_persisted_with_key(
1930 entity_id,
1931 #tag,
1932 #name,
1933 ::std::option::Option::Some(key_bytes),
1934 #krate::schema::Uninterruptible::No,
1935 )
1936 .await
1937 }
1938 }
1939 }
1940 (Some(persist_key), _) => quote! {
1941 pub async fn #method_name(
1942 &self,
1943 entity_id: &#krate::types::EntityId,
1944 #(#param_defs),*
1945 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
1946 let key = (#persist_key)(#(#param_names),*);
1947 let key_bytes = rmp_serde::to_vec(&key)
1948 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
1949 reason: ::std::format!(
1950 "failed to serialize persist key for '{}': {e}",
1951 #tag
1952 ),
1953 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
1954 })?;
1955 let request = (#(#param_names),*);
1956 self.inner
1957 .send_persisted_with_key(
1958 entity_id,
1959 #tag,
1960 &request,
1961 ::std::option::Option::Some(key_bytes),
1962 #krate::schema::Uninterruptible::No,
1963 )
1964 .await
1965 }
1966 },
1967 (None, 0) => quote! {
1968 pub async fn #method_name(
1969 &self,
1970 entity_id: &#krate::types::EntityId,
1971 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
1972 self.inner
1973 .send_persisted(entity_id, #tag, &(), #krate::schema::Uninterruptible::No)
1974 .await
1975 }
1976 },
1977 (None, 1) => {
1978 let name = ¶m_names[0];
1979 let def = ¶m_defs[0];
1980 quote! {
1981 pub async fn #method_name(
1982 &self,
1983 entity_id: &#krate::types::EntityId,
1984 #def,
1985 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
1986 self.inner
1987 .send_persisted(entity_id, #tag, #name, #krate::schema::Uninterruptible::No)
1988 .await
1989 }
1990 }
1991 }
1992 (None, _) => quote! {
1993 pub async fn #method_name(
1994 &self,
1995 entity_id: &#krate::types::EntityId,
1996 #(#param_defs),*
1997 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
1998 let request = (#(#param_names),*);
1999 self.inner
2000 .send_persisted(entity_id, #tag, &request, #krate::schema::Uninterruptible::No)
2001 .await
2002 }
2003 },
2004 }
2005 } else {
2006 match param_count {
2007 0 => quote! {
2008 pub async fn #method_name(
2009 &self,
2010 entity_id: &#krate::types::EntityId,
2011 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
2012 self.inner.send(entity_id, #tag, &()).await
2013 }
2014 },
2015 1 => {
2016 let def = ¶m_defs[0];
2017 let name = ¶m_names[0];
2018 quote! {
2019 pub async fn #method_name(
2020 &self,
2021 entity_id: &#krate::types::EntityId,
2022 #def,
2023 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
2024 self.inner.send(entity_id, #tag, #name).await
2025 }
2026 }
2027 }
2028 _ => quote! {
2029 pub async fn #method_name(
2030 &self,
2031 entity_id: &#krate::types::EntityId,
2032 #(#param_defs),*
2033 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
2034 let request = (#(#param_names),*);
2035 self.inner.send(entity_id, #tag, &request).await
2036 }
2037 },
2038 }
2039 }
2040 })
2041 .collect()
2042}
2043
2044fn is_durable_context_type(ty: &syn::Type) -> bool {
2046 match ty {
2047 syn::Type::Reference(r) => is_durable_context_type(&r.elem),
2048 syn::Type::Path(tp) => tp
2049 .path
2050 .segments
2051 .last()
2052 .map(|s| s.ident == "DurableContext")
2053 .unwrap_or(false),
2054 _ => false,
2055 }
2056}
2057
2058struct StateArgs {
2059 #[allow(dead_code)]
2060 ty: syn::Type,
2061 span: proc_macro2::Span,
2063}
2064
2065impl syn::parse::Parse for StateArgs {
2066 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
2067 let ty: syn::Type = input.parse()?;
2068
2069 if !input.is_empty() {
2074 return Err(syn::Error::new(
2075 input.span(),
2076 "unexpected tokens in #[state(...)]; state is always persistent",
2077 ));
2078 }
2079
2080 Ok(StateArgs {
2081 ty,
2082 span: proc_macro2::Span::call_site(),
2083 })
2084 }
2085}
2086
2087fn parse_state_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<StateArgs>> {
2088 let mut state_attr: Option<StateArgs> = None;
2089 let mut i = 0;
2090 while i < attrs.len() {
2091 if attrs[i].path().is_ident("state") {
2092 if state_attr.is_some() {
2093 return Err(syn::Error::new(
2094 attrs[i].span(),
2095 "duplicate #[state(...)] attribute",
2096 ));
2097 }
2098 let attr_span = attrs[i].span();
2099 let mut args = attrs[i].parse_args::<StateArgs>()?;
2100 args.span = attr_span;
2101 state_attr = Some(args);
2102 attrs.remove(i);
2103 continue;
2104 }
2105 i += 1;
2106 }
2107 Ok(state_attr)
2108}
2109
2110struct RpcArgs {
2111 persisted: bool,
2112}
2113
2114impl syn::parse::Parse for RpcArgs {
2115 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
2116 let mut args = RpcArgs { persisted: false };
2117 while !input.is_empty() {
2118 let ident: syn::Ident = input.parse()?;
2119 match ident.to_string().as_str() {
2120 "persisted" => {
2121 args.persisted = true;
2122 }
2123 other => {
2124 return Err(syn::Error::new(
2125 ident.span(),
2126 format!("unknown rpc attribute: {other}; expected `persisted`"),
2127 ));
2128 }
2129 }
2130 if !input.is_empty() {
2131 input.parse::<syn::Token![,]>()?;
2132 }
2133 }
2134 Ok(args)
2135 }
2136}
2137
2138struct KeyArgs {
2139 key: Option<syn::ExprClosure>,
2140}
2141
2142impl syn::parse::Parse for KeyArgs {
2143 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
2144 if input.is_empty() {
2145 return Ok(KeyArgs { key: None });
2146 }
2147
2148 let ident: syn::Ident = input.parse()?;
2149 if ident != "key" {
2150 return Err(syn::Error::new(
2151 ident.span(),
2152 "expected `key` in #[workflow(key(...))] or #[activity(key(...))]",
2153 ));
2154 }
2155
2156 if input.peek(syn::Token![=]) {
2157 input.parse::<syn::Token![=]>()?;
2158 }
2159
2160 let expr: syn::Expr = if input.peek(syn::token::Paren) {
2161 let content;
2162 syn::parenthesized!(content in input);
2163 content.parse()?
2164 } else {
2165 input.parse()?
2166 };
2167
2168 if !input.is_empty() {
2169 return Err(syn::Error::new(
2170 input.span(),
2171 "unexpected tokens in #[workflow(...)] or #[activity(...)]",
2172 ));
2173 }
2174
2175 match expr {
2176 syn::Expr::Closure(closure) => Ok(KeyArgs { key: Some(closure) }),
2177 _ => Err(syn::Error::new(
2178 expr.span(),
2179 "key must be a closure, e.g. #[workflow(key(|req| ...))]",
2180 )),
2181 }
2182 }
2183}
2184
2185struct ActivityAttrArgs {
2195 key: Option<syn::ExprClosure>,
2196 retries: Option<u32>,
2197 backoff: Option<String>,
2198}
2199
2200impl syn::parse::Parse for ActivityAttrArgs {
2201 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
2202 let mut key = None;
2203 let mut retries = None;
2204 let mut backoff = None;
2205
2206 if input.is_empty() {
2207 return Ok(ActivityAttrArgs {
2208 key,
2209 retries,
2210 backoff,
2211 });
2212 }
2213
2214 loop {
2215 if input.is_empty() {
2216 break;
2217 }
2218
2219 let ident: syn::Ident = input.parse()?;
2220 if ident == "key" {
2221 if key.is_some() {
2222 return Err(syn::Error::new(ident.span(), "duplicate `key` argument"));
2223 }
2224 if input.peek(syn::Token![=]) {
2225 input.parse::<syn::Token![=]>()?;
2226 }
2227 let expr: syn::Expr = if input.peek(syn::token::Paren) {
2228 let content;
2229 syn::parenthesized!(content in input);
2230 content.parse()?
2231 } else {
2232 input.parse()?
2233 };
2234 match expr {
2235 syn::Expr::Closure(closure) => key = Some(closure),
2236 _ => {
2237 return Err(syn::Error::new(
2238 expr.span(),
2239 "key must be a closure, e.g. #[activity(key = |req| ...)]",
2240 ))
2241 }
2242 }
2243 } else if ident == "retries" {
2244 if retries.is_some() {
2245 return Err(syn::Error::new(
2246 ident.span(),
2247 "duplicate `retries` argument",
2248 ));
2249 }
2250 input.parse::<syn::Token![=]>()?;
2251 let lit: syn::LitInt = input.parse()?;
2252 retries = Some(lit.base10_parse::<u32>()?);
2253 } else if ident == "backoff" {
2254 if backoff.is_some() {
2255 return Err(syn::Error::new(
2256 ident.span(),
2257 "duplicate `backoff` argument",
2258 ));
2259 }
2260 input.parse::<syn::Token![=]>()?;
2261 let lit: syn::LitStr = input.parse()?;
2262 let value = lit.value();
2263 if value != "exponential" && value != "constant" {
2264 return Err(syn::Error::new(
2265 lit.span(),
2266 "backoff must be \"exponential\" or \"constant\"",
2267 ));
2268 }
2269 backoff = Some(value);
2270 } else {
2271 return Err(syn::Error::new(
2272 ident.span(),
2273 "expected `key`, `retries`, or `backoff` in #[activity(...)]",
2274 ));
2275 }
2276
2277 if input.peek(syn::Token![,]) {
2279 input.parse::<syn::Token![,]>()?;
2280 }
2281 }
2282
2283 if backoff.is_some() && retries.is_none() {
2285 return Err(syn::Error::new(
2286 proc_macro2::Span::call_site(),
2287 "`backoff` requires `retries` to be specified",
2288 ));
2289 }
2290
2291 Ok(ActivityAttrArgs {
2292 key,
2293 retries,
2294 backoff,
2295 })
2296 }
2297}
2298
2299struct KindAttrInfo {
2301 kind: RpcKind,
2302 key: Option<syn::ExprClosure>,
2303 rpc_persisted: bool,
2304 retries: Option<u32>,
2306 backoff: Option<String>,
2308}
2309
2310fn parse_kind_attr(attrs: &[syn::Attribute]) -> syn::Result<Option<KindAttrInfo>> {
2311 let mut kind: Option<RpcKind> = None;
2312 let mut key: Option<syn::ExprClosure> = None;
2313 let mut rpc_persisted = false;
2314 let mut retries: Option<u32> = None;
2315 let mut backoff: Option<String> = None;
2316
2317 for attr in attrs {
2318 if attr.path().is_ident("rpc") {
2319 if kind.is_some() {
2320 return Err(syn::Error::new(attr.span(), "multiple RPC kind attributes"));
2321 }
2322 match &attr.meta {
2323 syn::Meta::Path(_) => {
2324 kind = Some(RpcKind::Rpc);
2325 }
2326 syn::Meta::List(_) => {
2327 let args = attr.parse_args::<RpcArgs>()?;
2328 kind = Some(RpcKind::Rpc);
2329 rpc_persisted = args.persisted;
2330 }
2331 _ => {
2332 return Err(syn::Error::new(
2333 attr.span(),
2334 "expected #[rpc] or #[rpc(persisted)]",
2335 ))
2336 }
2337 }
2338 }
2339
2340 if attr.path().is_ident("workflow") {
2341 if kind.is_some() {
2342 return Err(syn::Error::new(attr.span(), "multiple RPC kind attributes"));
2343 }
2344 let args = match &attr.meta {
2345 syn::Meta::Path(_) => KeyArgs { key: None },
2346 syn::Meta::List(_) => attr.parse_args::<KeyArgs>()?,
2347 syn::Meta::NameValue(_) => {
2348 return Err(syn::Error::new(
2349 attr.span(),
2350 "expected #[workflow] or #[workflow(key(...))]",
2351 ))
2352 }
2353 };
2354 kind = Some(RpcKind::Workflow);
2355 if args.key.is_some() {
2356 key = args.key;
2357 }
2358 }
2359
2360 if attr.path().is_ident("activity") {
2361 if kind.is_some() {
2362 return Err(syn::Error::new(attr.span(), "multiple RPC kind attributes"));
2363 }
2364 let args = match &attr.meta {
2365 syn::Meta::Path(_) => ActivityAttrArgs {
2366 key: None,
2367 retries: None,
2368 backoff: None,
2369 },
2370 syn::Meta::List(_) => attr.parse_args::<ActivityAttrArgs>()?,
2371 syn::Meta::NameValue(_) => {
2372 return Err(syn::Error::new(
2373 attr.span(),
2374 "expected #[activity] or #[activity(...)]",
2375 ))
2376 }
2377 };
2378 kind = Some(RpcKind::Activity);
2379 if args.key.is_some() {
2380 key = args.key;
2381 }
2382 retries = args.retries;
2383 backoff = args.backoff;
2384 }
2385
2386 if attr.path().is_ident("method") {
2387 if kind.is_some() {
2388 return Err(syn::Error::new(attr.span(), "multiple RPC kind attributes"));
2389 }
2390 match &attr.meta {
2391 syn::Meta::Path(_) => {
2392 kind = Some(RpcKind::Method);
2393 }
2394 _ => {
2395 return Err(syn::Error::new(
2396 attr.span(),
2397 "#[method] does not take arguments",
2398 ))
2399 }
2400 }
2401 }
2402 }
2403
2404 Ok(kind.map(|kind| KindAttrInfo {
2405 kind,
2406 key,
2407 rpc_persisted,
2408 retries,
2409 backoff,
2410 }))
2411}
2412
2413fn parse_visibility_attr(attrs: &[syn::Attribute]) -> syn::Result<Option<RpcVisibility>> {
2414 let mut visibility: Option<RpcVisibility> = None;
2415
2416 for attr in attrs {
2417 let next = if attr.path().is_ident("public") {
2418 Some(RpcVisibility::Public)
2419 } else if attr.path().is_ident("protected") {
2420 Some(RpcVisibility::Protected)
2421 } else if attr.path().is_ident("private") {
2422 Some(RpcVisibility::Private)
2423 } else {
2424 None
2425 };
2426
2427 if let Some(next) = next {
2428 match &attr.meta {
2429 syn::Meta::Path(_) => {}
2430 _ => {
2431 return Err(syn::Error::new(
2432 attr.span(),
2433 "visibility attributes do not take arguments",
2434 ))
2435 }
2436 }
2437 if visibility.is_some() {
2438 return Err(syn::Error::new(
2439 attr.span(),
2440 "multiple visibility modifiers are not allowed",
2441 ));
2442 }
2443 visibility = Some(next);
2444 }
2445 }
2446
2447 Ok(visibility)
2448}
2449
2450fn parse_rpc_method(method: &syn::ImplItemFn) -> syn::Result<Option<RpcMethod>> {
2451 let name = method.sig.ident.clone();
2452 let tag = name.to_string();
2453
2454 let kind_info = parse_kind_attr(&method.attrs)?;
2455 let visibility_attr = parse_visibility_attr(&method.attrs)?;
2456
2457 let KindAttrInfo {
2458 kind,
2459 key: persist_key,
2460 rpc_persisted,
2461 retries,
2462 backoff,
2463 } = match kind_info {
2464 Some(info) => info,
2465 None => {
2466 if visibility_attr.is_some() {
2467 return Err(syn::Error::new(
2468 method.sig.span(),
2469 "visibility modifiers require #[rpc], #[workflow], #[activity], or #[method]",
2470 ));
2471 }
2472 return Ok(None);
2473 }
2474 };
2475
2476 if method.sig.asyncness.is_none() && !matches!(kind, RpcKind::Method) {
2478 return Err(syn::Error::new(
2479 method.sig.span(),
2480 "#[rpc]/#[workflow]/#[activity] can only be applied to async methods",
2481 ));
2482 }
2483
2484 if matches!(kind, RpcKind::Rpc | RpcKind::Method) && persist_key.is_some() {
2485 return Err(syn::Error::new(
2486 method.sig.span(),
2487 "#[rpc] and #[method] do not support key(...) — use #[workflow(key(...))] or #[activity(key(...))]",
2488 ));
2489 }
2490
2491 if rpc_persisted && !matches!(kind, RpcKind::Rpc) {
2493 return Err(syn::Error::new(
2494 method.sig.span(),
2495 "persisted flag is only valid on #[rpc(persisted)]",
2496 ));
2497 }
2498
2499 let visibility = match (kind, visibility_attr) {
2500 (_, Some(RpcVisibility::Public)) if matches!(kind, RpcKind::Activity | RpcKind::Method) => {
2502 return Err(syn::Error::new(
2503 method.sig.span(),
2504 "#[activity] and #[method] cannot be #[public]",
2505 ))
2506 }
2507 (RpcKind::Activity | RpcKind::Method, None) => RpcVisibility::Private,
2508 (RpcKind::Rpc | RpcKind::Workflow, None) => RpcVisibility::Public,
2509 (_, Some(vis)) => vis,
2510 };
2511
2512 let is_mut = method
2514 .sig
2515 .inputs
2516 .first()
2517 .map(|arg| match arg {
2518 syn::FnArg::Receiver(r) => r.mutability.is_some(),
2519 _ => false,
2520 })
2521 .unwrap_or(false);
2522
2523 if is_mut && !matches!(kind, RpcKind::Activity) {
2525 return Err(syn::Error::new(
2526 method.sig.span(),
2527 "only #[activity] methods can use `&mut self` for state mutation; use `&self` for read-only access",
2528 ));
2529 }
2530
2531 let mut params = Vec::new();
2532 let mut has_durable_context = false;
2533 let mut saw_non_ctx_param = false;
2534 let mut param_index = 0usize;
2535 for arg in method.sig.inputs.iter().skip(1) {
2536 match arg {
2537 syn::FnArg::Typed(pat_type) => {
2538 if is_durable_context_type(&pat_type.ty) {
2539 if has_durable_context {
2540 return Err(syn::Error::new(
2541 arg.span(),
2542 "duplicate DurableContext parameter",
2543 ));
2544 }
2545 if saw_non_ctx_param {
2546 return Err(syn::Error::new(
2547 arg.span(),
2548 "DurableContext must be the first parameter after &self",
2549 ));
2550 }
2551 has_durable_context = true;
2552 continue; }
2554 saw_non_ctx_param = true;
2555 let name = match &*pat_type.pat {
2556 syn::Pat::Ident(ident) => ident.ident.clone(),
2557 syn::Pat::Wild(_) => {
2558 let ident = format_ident!("__arg{param_index}");
2559 ident
2560 }
2561 _ => {
2562 return Err(syn::Error::new(
2563 pat_type.pat.span(),
2564 "entity RPC parameters must be simple identifiers",
2565 ))
2566 }
2567 };
2568 param_index += 1;
2569 params.push(RpcParam {
2570 name,
2571 ty: (*pat_type.ty).clone(),
2572 });
2573 }
2574 syn::FnArg::Receiver(_) => {}
2575 }
2576 }
2577
2578 if has_durable_context && matches!(kind, RpcKind::Rpc | RpcKind::Method) {
2580 return Err(syn::Error::new(
2581 method.sig.span(),
2582 "methods with `&DurableContext` must be marked #[workflow] or #[activity]",
2583 ));
2584 }
2585
2586 let response_type = match &method.sig.output {
2588 syn::ReturnType::Type(_, ty) => {
2589 if matches!(kind, RpcKind::Method) {
2590 (**ty).clone()
2592 } else {
2593 extract_result_ok_type(ty)?
2594 }
2595 }
2596 syn::ReturnType::Default => {
2597 if matches!(kind, RpcKind::Method) {
2598 syn::parse_quote!(())
2600 } else {
2601 return Err(syn::Error::new(
2602 method.sig.span(),
2603 "entity RPC methods must return Result<T, ClusterError>",
2604 ));
2605 }
2606 }
2607 };
2608
2609 if retries.is_some() && !matches!(kind, RpcKind::Activity) {
2611 return Err(syn::Error::new(
2612 method.sig.span(),
2613 "`retries` is only valid on #[activity(retries = N)]",
2614 ));
2615 }
2616
2617 Ok(Some(RpcMethod {
2618 name,
2619 tag,
2620 params,
2621 response_type,
2622 is_mut,
2623 kind,
2624 visibility,
2625 persist_key,
2626 has_durable_context,
2627 rpc_persisted,
2628 retries,
2629 backoff,
2630 }))
2631}
2632
2633fn to_snake(input: &str) -> String {
2634 let mut out = String::new();
2635 let mut prev_is_upper = false;
2636 let mut prev_is_lower = false;
2637 let chars: Vec<char> = input.chars().collect();
2638 for (i, ch) in chars.iter().enumerate() {
2639 let is_upper = ch.is_uppercase();
2640 let is_lower = ch.is_lowercase();
2641 let next_is_lower = chars.get(i + 1).map(|c| c.is_lowercase()).unwrap_or(false);
2642
2643 if is_upper {
2644 if prev_is_lower || (prev_is_upper && next_is_lower) {
2645 out.push('_');
2646 }
2647 for lower in ch.to_lowercase() {
2648 out.push(lower);
2649 }
2650 } else if ch.is_alphanumeric() || *ch == '_' {
2651 out.push(*ch);
2652 }
2653
2654 prev_is_upper = is_upper;
2655 prev_is_lower = is_lower;
2656 }
2657 out
2658}
2659
2660fn extract_result_ok_type(ty: &syn::Type) -> syn::Result<syn::Type> {
2661 if let syn::Type::Path(type_path) = ty {
2662 if let Some(segment) = type_path.path.segments.last() {
2663 if segment.ident == "Result" {
2664 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
2665 if let Some(syn::GenericArgument::Type(ok_type)) = args.args.first() {
2666 return Ok(ok_type.clone());
2667 }
2668 }
2669 }
2670 }
2671 }
2672 Err(syn::Error::new(
2673 ty.span(),
2674 "expected Result<T, ClusterError> return type",
2675 ))
2676}
2677
2678struct ActivityGroupImplArgs {
2683 krate: Option<syn::Path>,
2684}
2685
2686impl syn::parse::Parse for ActivityGroupImplArgs {
2687 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
2688 let mut args = ActivityGroupImplArgs { krate: None };
2689 while !input.is_empty() {
2690 let ident: syn::Ident = input.parse()?;
2691 match ident.to_string().as_str() {
2692 "krate" => {
2693 input.parse::<syn::Token![=]>()?;
2694 let lit: syn::LitStr = input.parse()?;
2695 args.krate = Some(lit.parse()?);
2696 }
2697 other => {
2698 return Err(syn::Error::new(
2699 ident.span(),
2700 format!("unknown activity_group_impl attribute: {other}"),
2701 ));
2702 }
2703 }
2704 if !input.is_empty() {
2705 input.parse::<syn::Token![,]>()?;
2706 }
2707 }
2708 Ok(args)
2709 }
2710}
2711
2712struct ActivityGroupActivityInfo {
2714 name: syn::Ident,
2715 params: Vec<RpcParam>,
2716 #[allow(dead_code)]
2717 response_type: syn::Type,
2718 persist_key: Option<syn::ExprClosure>,
2719 original_method: syn::ImplItemFn,
2720 retries: Option<u32>,
2722 backoff: Option<String>,
2724}
2725
2726fn activity_group_impl_inner(
2734 args: ActivityGroupImplArgs,
2735 input: syn::ItemImpl,
2736) -> syn::Result<proc_macro2::TokenStream> {
2737 let krate = args.krate.unwrap_or_else(default_crate_path);
2738 let self_ty = &input.self_ty;
2739
2740 let struct_name = match self_ty.as_ref() {
2741 syn::Type::Path(tp) => tp
2742 .path
2743 .segments
2744 .last()
2745 .map(|s| s.ident.clone())
2746 .ok_or_else(|| syn::Error::new(self_ty.span(), "expected struct name"))?,
2747 _ => return Err(syn::Error::new(self_ty.span(), "expected struct name")),
2748 };
2749
2750 let wrapper_name = format_ident!("__{}ActivityGroupWrapper", struct_name);
2751 let access_trait_name = format_ident!("__{}ActivityGroupAccess", struct_name);
2752 let methods_trait_name = format_ident!("__{}ActivityGroupMethods", struct_name);
2753 let activity_view_name = format_ident!("__{}ActivityGroupView", struct_name);
2754
2755 for attr in &input.attrs {
2757 if attr.path().is_ident("state") {
2758 return Err(syn::Error::new(
2759 attr.span(),
2760 "activity groups are stateless; remove #[state(...)]",
2761 ));
2762 }
2763 }
2764
2765 let mut activities: Vec<ActivityGroupActivityInfo> = Vec::new();
2767 let mut all_methods: Vec<syn::ImplItemFn> = Vec::new();
2768
2769 for item in &input.items {
2770 if let syn::ImplItem::Fn(method) = item {
2771 for attr in &method.attrs {
2773 if attr.path().is_ident("state") {
2774 return Err(syn::Error::new(
2775 attr.span(),
2776 "activity groups are stateless; remove #[state(...)]",
2777 ));
2778 }
2779 if attr.path().is_ident("rpc") {
2780 return Err(syn::Error::new(
2781 attr.span(),
2782 "activity groups use #[activity], not #[rpc]",
2783 ));
2784 }
2785 if attr.path().is_ident("workflow") {
2786 return Err(syn::Error::new(
2787 attr.span(),
2788 "activity groups use #[activity], not #[workflow]",
2789 ));
2790 }
2791 }
2792
2793 let is_activity = method.attrs.iter().any(|a| a.path().is_ident("activity"));
2794
2795 if let Some(syn::FnArg::Receiver(r)) = method.sig.inputs.first() {
2797 if r.mutability.is_some() {
2798 return Err(syn::Error::new(
2799 r.span(),
2800 "activity group methods must use &self, not &mut self",
2801 ));
2802 }
2803 }
2804
2805 if is_activity {
2806 if method.sig.asyncness.is_none() {
2807 return Err(syn::Error::new(
2808 method.sig.span(),
2809 "#[activity] methods must be async",
2810 ));
2811 }
2812
2813 let (persist_key, act_retries, act_backoff) = {
2815 let mut key = None;
2816 let mut retries = None;
2817 let mut backoff = None;
2818 for attr in &method.attrs {
2819 if attr.path().is_ident("activity") {
2820 let args = match &attr.meta {
2821 syn::Meta::Path(_) => ActivityAttrArgs {
2822 key: None,
2823 retries: None,
2824 backoff: None,
2825 },
2826 syn::Meta::List(_) => attr.parse_args::<ActivityAttrArgs>()?,
2827 _ => {
2828 return Err(syn::Error::new(
2829 attr.span(),
2830 "expected #[activity] or #[activity(...)]",
2831 ))
2832 }
2833 };
2834 key = args.key;
2835 retries = args.retries;
2836 backoff = args.backoff;
2837 }
2838 }
2839 (key, retries, backoff)
2840 };
2841
2842 let mut params = Vec::new();
2844 for arg in method.sig.inputs.iter().skip(1) {
2845 if let syn::FnArg::Typed(pat_type) = arg {
2846 let name = match &*pat_type.pat {
2847 syn::Pat::Ident(ident) => ident.ident.clone(),
2848 _ => {
2849 return Err(syn::Error::new(
2850 pat_type.pat.span(),
2851 "activity parameters must be simple identifiers",
2852 ))
2853 }
2854 };
2855 params.push(RpcParam {
2856 name,
2857 ty: (*pat_type.ty).clone(),
2858 });
2859 }
2860 }
2861
2862 let response_type = extract_result_ok_type(match &method.sig.output {
2863 syn::ReturnType::Type(_, ty) => ty,
2864 syn::ReturnType::Default => {
2865 return Err(syn::Error::new(
2866 method.sig.span(),
2867 "#[activity] must return Result<T, ClusterError>",
2868 ))
2869 }
2870 })?;
2871
2872 activities.push(ActivityGroupActivityInfo {
2873 name: method.sig.ident.clone(),
2874 params,
2875 response_type,
2876 persist_key,
2877 original_method: method.clone(),
2878 retries: act_retries,
2879 backoff: act_backoff,
2880 });
2881 }
2882
2883 all_methods.push(method.clone());
2884 }
2885 }
2886
2887 let mut activity_view_methods = Vec::new();
2891 let mut helper_view_methods = Vec::new();
2892
2893 for method in &all_methods {
2894 let is_activity = method.attrs.iter().any(|a| a.path().is_ident("activity"));
2895 let block = &method.block;
2896 let output = &method.sig.output;
2897 let name = &method.sig.ident;
2898 let params: Vec<_> = method.sig.inputs.iter().skip(1).collect();
2899 let attrs: Vec<_> = method
2900 .attrs
2901 .iter()
2902 .filter(|a| {
2903 !a.path().is_ident("activity")
2904 && !a.path().is_ident("public")
2905 && !a.path().is_ident("protected")
2906 && !a.path().is_ident("private")
2907 })
2908 .collect();
2909 let vis = &method.vis;
2910
2911 if is_activity {
2912 activity_view_methods.push(quote! {
2913 #(#attrs)*
2914 #vis async fn #name(&self, #(#params),*) #output
2915 #block
2916 });
2917 } else {
2918 let async_token = if method.sig.asyncness.is_some() {
2919 quote! { async }
2920 } else {
2921 quote! {}
2922 };
2923 helper_view_methods.push(quote! {
2924 #(#attrs)*
2925 #vis #async_token fn #name(&self, #(#params),*) #output
2926 #block
2927 });
2928 }
2929 }
2930
2931 let view_struct = quote! {
2932 #[doc(hidden)]
2933 #[allow(non_camel_case_types)]
2934 pub struct #activity_view_name<'a> {
2935 __group: &'a #struct_name,
2936 pub tx: #krate::__internal::ActivityTx,
2939 pub pool: sqlx::PgPool,
2942 }
2943
2944 impl ::std::ops::Deref for #activity_view_name<'_> {
2945 type Target = #struct_name;
2946 fn deref(&self) -> &Self::Target {
2947 self.__group
2948 }
2949 }
2950
2951 impl #activity_view_name<'_> {
2952 #(#activity_view_methods)*
2953 #(#helper_view_methods)*
2954 }
2955 };
2956
2957 let wrapper_delegation_methods: Vec<proc_macro2::TokenStream> = activities
2962 .iter()
2963 .map(|act| {
2964 let method_name = &act.name;
2965 let method_name_str = method_name.to_string();
2966 let method_info = &act.original_method;
2967 let params: Vec<_> = method_info.sig.inputs.iter().skip(1).collect();
2968 let param_names: Vec<_> = method_info
2969 .sig
2970 .inputs
2971 .iter()
2972 .skip(1)
2973 .filter_map(|arg| {
2974 if let syn::FnArg::Typed(pat_type) = arg {
2975 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
2976 return Some(&pat_ident.ident);
2977 }
2978 }
2979 None
2980 })
2981 .collect();
2982 let output = &method_info.sig.output;
2983
2984 let wire_param_names: Vec<_> = act.params.iter().map(|p| &p.name).collect();
2985 let wire_param_count = wire_param_names.len();
2986
2987 let key_bytes_code = if let Some(persist_key) = &act.persist_key {
2989 match wire_param_count {
2990 0 => quote! {
2991 let __journal_key = (#persist_key)();
2992 let __journal_key_bytes = rmp_serde::to_vec(&__journal_key)
2993 .map_err(|e| #krate::error::ClusterError::PersistenceError {
2994 reason: ::std::format!("failed to serialize journal key: {e}"),
2995 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
2996 })?;
2997 },
2998 1 => {
2999 let name = &wire_param_names[0];
3000 quote! {
3001 let __journal_key = (#persist_key)(&#name);
3002 let __journal_key_bytes = rmp_serde::to_vec(&__journal_key)
3003 .map_err(|e| #krate::error::ClusterError::PersistenceError {
3004 reason: ::std::format!("failed to serialize journal key: {e}"),
3005 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3006 })?;
3007 }
3008 }
3009 _ => quote! {
3010 let __journal_key = (#persist_key)(#(&#wire_param_names),*);
3011 let __journal_key_bytes = rmp_serde::to_vec(&__journal_key)
3012 .map_err(|e| #krate::error::ClusterError::PersistenceError {
3013 reason: ::std::format!("failed to serialize journal key: {e}"),
3014 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3015 })?;
3016 },
3017 }
3018 } else {
3019 match wire_param_count {
3020 0 => quote! {
3021 let __journal_key_bytes = rmp_serde::to_vec(&()).unwrap_or_default();
3022 },
3023 1 => {
3024 let name = &wire_param_names[0];
3025 quote! {
3026 let __journal_key_bytes = rmp_serde::to_vec(&#name)
3027 .map_err(|e| #krate::error::ClusterError::PersistenceError {
3028 reason: ::std::format!("failed to serialize journal key: {e}"),
3029 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3030 })?;
3031 }
3032 }
3033 _ => quote! {
3034 let __journal_key_bytes = rmp_serde::to_vec(&(#(&#wire_param_names),*))
3035 .map_err(|e| #krate::error::ClusterError::PersistenceError {
3036 reason: ::std::format!("failed to serialize journal key: {e}"),
3037 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3038 })?;
3039 },
3040 }
3041 };
3042
3043 let gen_execute_and_journal = |key_bytes_var: &str, param_list: &[proc_macro2::TokenStream], in_retry: bool| -> proc_macro2::TokenStream {
3048 let key_var = format_ident!("{}", key_bytes_var);
3049 let error_handling = if in_retry {
3050 quote! {
3052 if __act_result.is_err() {
3053 drop(__activity_view);
3054 __act_result
3055 }
3056 }
3057 } else {
3058 quote! {
3060 if __act_result.is_err() {
3061 drop(__activity_view);
3062 return __act_result;
3063 }
3064 }
3065 };
3066 quote! {
3067 let __sql_pool = __wf_storage.sql_pool().cloned();
3069
3070 let __pool = __sql_pool.expect("SQL storage is required for workflow activities");
3071 let __sql_tx = __pool.begin().await.map_err(|e| #krate::error::ClusterError::PersistenceError {
3073 reason: ::std::format!("failed to begin activity transaction: {e}"),
3074 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3075 })?;
3076 let __activity_view = #activity_view_name {
3077 __group: &self.__group,
3078 tx: #krate::__internal::ActivityTx::new(__sql_tx),
3079 pool: __pool,
3080 };
3081 let __act_result = __activity_view.#method_name(#(#param_list),*).await;
3082
3083 #error_handling
3086 else {
3087 let __storage_key = #krate::__internal::DurableContext::journal_storage_key(
3089 #method_name_str,
3090 &#key_var,
3091 __journal_ctx.entity_type(),
3092 __journal_ctx.entity_id(),
3093 );
3094 let __journal_bytes = #krate::__internal::DurableContext::serialize_journal_result(&__act_result)?;
3095 #krate::__internal::WorkflowScope::register_journal_key(__storage_key.clone());
3096
3097 let mut __tx_back = __activity_view.tx.into_inner().await;
3099 #krate::__internal::save_journal_entry(&mut *__tx_back, &__storage_key, &__journal_bytes).await?;
3100 __tx_back.commit().await.map_err(|e| #krate::error::ClusterError::PersistenceError {
3101 reason: ::std::format!("activity transaction commit failed: {e}"),
3102 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3103 })?;
3104
3105 __act_result
3106 }
3107 }
3108 };
3109
3110 let param_tokens: Vec<proc_macro2::TokenStream> = param_names
3112 .iter()
3113 .map(|name| quote! { #name.clone() })
3114 .collect();
3115
3116 let max_retries = act.retries.unwrap_or(0);
3118 let journal_body = if max_retries == 0 {
3119 let exec_body = gen_execute_and_journal("__journal_key_bytes", ¶m_tokens, false);
3121 quote! {
3122 let __journal_ctx = #krate::__internal::DurableContext::with_journal_storage(
3124 ::std::sync::Arc::clone(__engine),
3125 self.__entity_type.clone(),
3126 self.__entity_id.clone(),
3127 ::std::sync::Arc::clone(__msg_storage),
3128 ::std::sync::Arc::clone(__wf_storage),
3129 );
3130 if let ::std::option::Option::Some(__cached) = __journal_ctx.check_journal(#method_name_str, &__journal_key_bytes).await? {
3131 return ::std::result::Result::Ok(__cached);
3132 }
3133
3134 #exec_body
3136 }
3137 } else {
3138 let backoff_str = act.backoff.as_deref().unwrap_or("exponential");
3139 let param_clones: Vec<proc_macro2::TokenStream> = param_names
3141 .iter()
3142 .map(|name| {
3143 let clone_name = format_ident!("__{}_clone", name);
3144 quote! { let #clone_name = #name.clone(); }
3145 })
3146 .collect();
3147 let cloned_param_names: Vec<syn::Ident> = param_names
3148 .iter()
3149 .map(|name| format_ident!("__{}_clone", name))
3150 .collect();
3151
3152 let exec_body = gen_execute_and_journal("__retry_key_bytes", &{
3153 cloned_param_names.iter().map(|n| quote! { #n.clone() }).collect::<Vec<_>>()
3154 }, true);
3155
3156 quote! {
3157 let mut __attempt = 0u32;
3158 loop {
3159 #(#param_clones)*
3161 let __retry_key_bytes = {
3163 let mut __k = __journal_key_bytes.clone();
3164 __k.extend_from_slice(&__attempt.to_le_bytes());
3165 __k
3166 };
3167 let __journal_ctx = #krate::__internal::DurableContext::with_journal_storage(
3169 ::std::sync::Arc::clone(__engine),
3170 self.__entity_type.clone(),
3171 self.__entity_id.clone(),
3172 ::std::sync::Arc::clone(__msg_storage),
3173 ::std::sync::Arc::clone(__wf_storage),
3174 );
3175 if let ::std::option::Option::Some(__cached) = __journal_ctx.check_journal::<_>(#method_name_str, &__retry_key_bytes).await? {
3176 break ::std::result::Result::Ok(__cached);
3177 }
3178 match { #exec_body } {
3180 ::std::result::Result::Ok(__val) => {
3181 break ::std::result::Result::Ok(__val);
3182 }
3183 ::std::result::Result::Err(__e) if __attempt < #max_retries => {
3184 let __delay = #krate::__internal::compute_retry_backoff(
3185 __attempt, #backoff_str, 1,
3186 );
3187 let __sleep_name = ::std::format!(
3188 "{}/retry/{}", #method_name_str, __attempt
3189 );
3190 __engine.sleep(
3191 &self.__entity_type,
3192 &self.__entity_id,
3193 &__sleep_name,
3194 __delay,
3195 ).await?;
3196 __attempt += 1;
3197 }
3198 ::std::result::Result::Err(__e) => {
3199 break ::std::result::Result::Err(__e);
3200 }
3201 }
3202 }
3203 }
3204 };
3205
3206 quote! {
3207 pub async fn #method_name(&self, #(#params),*) #output {
3208 if let (
3209 ::std::option::Option::Some(__engine),
3210 ::std::option::Option::Some(__msg_storage),
3211 ::std::option::Option::Some(__wf_storage),
3212 ) = (
3213 self.__workflow_engine.as_ref(),
3214 self.__message_storage.as_ref(),
3215 self.__workflow_storage.as_ref(),
3216 ) {
3217 #key_bytes_code
3218 let __journal_key_bytes = {
3220 let mut __scoped = ::std::vec::Vec::new();
3221 if let ::std::option::Option::Some(__wf_id) = #krate::__internal::WorkflowScope::current() {
3222 __scoped.extend_from_slice(&__wf_id.to_le_bytes());
3223 }
3224 __scoped.extend_from_slice(&__journal_key_bytes);
3225 __scoped
3226 };
3227 #journal_body
3228 } else {
3229 panic!("SQL storage is required for workflow activities; configure SqlWorkflowStorage")
3230 }
3231 }
3232 }
3233 })
3234 .collect();
3235
3236 let wrapper_struct = quote! {
3237 #[doc(hidden)]
3242 pub struct #wrapper_name {
3243 __group: #struct_name,
3244 __workflow_engine: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowEngine>>,
3245 __message_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::MessageStorage>>,
3246 __workflow_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowStorage>>,
3247 __entity_type: ::std::string::String,
3248 __entity_id: ::std::string::String,
3249 }
3250
3251 impl #wrapper_name {
3252 pub fn new(
3254 group: #struct_name,
3255 workflow_engine: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowEngine>>,
3256 message_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::MessageStorage>>,
3257 workflow_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowStorage>>,
3258 entity_type: ::std::string::String,
3259 entity_id: ::std::string::String,
3260 ) -> Self {
3261 Self {
3262 __group: group,
3263 __workflow_engine: workflow_engine,
3264 __message_storage: message_storage,
3265 __workflow_storage: workflow_storage,
3266 __entity_type: entity_type,
3267 __entity_id: entity_id,
3268 }
3269 }
3270
3271 pub fn group(&self) -> &#struct_name {
3273 &self.__group
3274 }
3275
3276 #(#wrapper_delegation_methods)*
3278 }
3279 };
3280
3281 let access_trait = quote! {
3283 #[doc(hidden)]
3284 pub trait #access_trait_name {
3285 fn __activity_group_wrapper(&self) -> &#wrapper_name;
3286 }
3287 };
3288
3289 let blanket_methods: Vec<proc_macro2::TokenStream> = activities
3293 .iter()
3294 .map(|act| {
3295 let method_name = &act.name;
3296 let method_info = &act.original_method;
3297 let params: Vec<_> = method_info.sig.inputs.iter().skip(1).collect();
3298 let param_names: Vec<_> = method_info
3299 .sig
3300 .inputs
3301 .iter()
3302 .skip(1)
3303 .filter_map(|arg| {
3304 if let syn::FnArg::Typed(pat_type) = arg {
3305 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
3306 return Some(&pat_ident.ident);
3307 }
3308 }
3309 None
3310 })
3311 .collect();
3312 let output = &method_info.sig.output;
3313
3314 quote! {
3315 async fn #method_name(&self, #(#params),*) #output {
3316 self.__activity_group_wrapper().#method_name(#(#param_names),*).await
3317 }
3318 }
3319 })
3320 .collect();
3321
3322 let methods_trait = quote! {
3323 #[doc(hidden)]
3327 #[allow(async_fn_in_trait)]
3328 pub trait #methods_trait_name: #access_trait_name {
3329 #(#blanket_methods)*
3330 }
3331
3332 impl<T: #access_trait_name> #methods_trait_name for T {}
3334 };
3335
3336 Ok(quote! {
3337 #view_struct
3338 #wrapper_struct
3339 #access_trait
3340 #methods_trait
3341 })
3342}
3343
3344struct RpcGroupImplArgs {
3349 krate: Option<syn::Path>,
3350}
3351
3352impl syn::parse::Parse for RpcGroupImplArgs {
3353 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
3354 let mut args = RpcGroupImplArgs { krate: None };
3355 while !input.is_empty() {
3356 let ident: syn::Ident = input.parse()?;
3357 match ident.to_string().as_str() {
3358 "krate" => {
3359 input.parse::<syn::Token![=]>()?;
3360 let lit: syn::LitStr = input.parse()?;
3361 args.krate = Some(lit.parse()?);
3362 }
3363 other => {
3364 return Err(syn::Error::new(
3365 ident.span(),
3366 format!("unknown rpc_group_impl attribute: {other}"),
3367 ));
3368 }
3369 }
3370 if !input.is_empty() {
3371 input.parse::<syn::Token![,]>()?;
3372 }
3373 }
3374 Ok(args)
3375 }
3376}
3377
3378fn rpc_group_impl_inner(
3388 args: RpcGroupImplArgs,
3389 input: syn::ItemImpl,
3390) -> syn::Result<proc_macro2::TokenStream> {
3391 let krate = args.krate.unwrap_or_else(default_crate_path);
3392 let self_ty = &input.self_ty;
3393
3394 let struct_name = match self_ty.as_ref() {
3395 syn::Type::Path(tp) => tp
3396 .path
3397 .segments
3398 .last()
3399 .map(|s| s.ident.clone())
3400 .ok_or_else(|| syn::Error::new(self_ty.span(), "expected struct name"))?,
3401 _ => return Err(syn::Error::new(self_ty.span(), "expected struct name")),
3402 };
3403
3404 let wrapper_name = format_ident!("__{}RpcGroupWrapper", struct_name);
3405 let access_trait_name = format_ident!("__{}RpcGroupAccess", struct_name);
3406 let methods_trait_name = format_ident!("__{}RpcGroupMethods", struct_name);
3407 let rpc_view_name = format_ident!("__{}RpcGroupView", struct_name);
3408 let client_ext_name = format_ident!("{}ClientExt", struct_name);
3409
3410 for attr in &input.attrs {
3412 if attr.path().is_ident("state") {
3413 return Err(syn::Error::new(
3414 attr.span(),
3415 "RPC groups are stateless; remove #[state(...)]",
3416 ));
3417 }
3418 }
3419
3420 let mut rpcs: Vec<RpcMethod> = Vec::new();
3422 let mut all_methods: Vec<syn::ImplItemFn> = Vec::new();
3423
3424 for item in &input.items {
3425 if let syn::ImplItem::Fn(method) = item {
3426 for attr in &method.attrs {
3428 if attr.path().is_ident("state") {
3429 return Err(syn::Error::new(
3430 attr.span(),
3431 "RPC groups are stateless; remove #[state(...)]",
3432 ));
3433 }
3434 if attr.path().is_ident("activity") {
3435 return Err(syn::Error::new(
3436 attr.span(),
3437 "RPC groups use #[rpc], not #[activity]",
3438 ));
3439 }
3440 if attr.path().is_ident("workflow") {
3441 return Err(syn::Error::new(
3442 attr.span(),
3443 "RPC groups use #[rpc], not #[workflow]",
3444 ));
3445 }
3446 }
3447
3448 if let Some(syn::FnArg::Receiver(r)) = method.sig.inputs.first() {
3450 if r.mutability.is_some() {
3451 return Err(syn::Error::new(
3452 r.span(),
3453 "RPC group methods must use &self, not &mut self",
3454 ));
3455 }
3456 }
3457
3458 let is_rpc = method.attrs.iter().any(|a| a.path().is_ident("rpc"));
3459
3460 if is_rpc {
3461 if method.sig.asyncness.is_none() {
3462 return Err(syn::Error::new(
3463 method.sig.span(),
3464 "#[rpc] methods must be async",
3465 ));
3466 }
3467
3468 if let Some(rpc) = parse_rpc_method(method)? {
3470 rpcs.push(rpc);
3471 }
3472 }
3473
3474 all_methods.push(method.clone());
3475 }
3476 }
3477
3478 let mut rpc_view_methods = Vec::new();
3482 let mut helper_view_methods = Vec::new();
3483
3484 for method in &all_methods {
3485 let is_rpc = method.attrs.iter().any(|a| a.path().is_ident("rpc"));
3486 let block = &method.block;
3487 let output = &method.sig.output;
3488 let name = &method.sig.ident;
3489 let params: Vec<_> = method.sig.inputs.iter().skip(1).collect();
3490 let attrs: Vec<_> = method
3491 .attrs
3492 .iter()
3493 .filter(|a| {
3494 !a.path().is_ident("rpc")
3495 && !a.path().is_ident("public")
3496 && !a.path().is_ident("protected")
3497 && !a.path().is_ident("private")
3498 })
3499 .collect();
3500 let vis = &method.vis;
3501
3502 if is_rpc {
3503 rpc_view_methods.push(quote! {
3504 #(#attrs)*
3505 #vis async fn #name(&self, #(#params),*) #output
3506 #block
3507 });
3508 } else {
3509 let async_token = if method.sig.asyncness.is_some() {
3510 quote! { async }
3511 } else {
3512 quote! {}
3513 };
3514 helper_view_methods.push(quote! {
3515 #(#attrs)*
3516 #vis #async_token fn #name(&self, #(#params),*) #output
3517 #block
3518 });
3519 }
3520 }
3521
3522 let view_struct = quote! {
3523 #[doc(hidden)]
3524 #[allow(non_camel_case_types)]
3525 pub struct #rpc_view_name<'a> {
3526 __group: &'a #struct_name,
3527 __entity_address: &'a #krate::types::EntityAddress,
3528 }
3529
3530 impl ::std::ops::Deref for #rpc_view_name<'_> {
3531 type Target = #struct_name;
3532 fn deref(&self) -> &Self::Target {
3533 self.__group
3534 }
3535 }
3536
3537 impl #rpc_view_name<'_> {
3538 #[inline]
3540 fn entity_id(&self) -> &str {
3541 &self.__entity_address.entity_id.0
3542 }
3543
3544 #[inline]
3546 fn entity_address(&self) -> &#krate::types::EntityAddress {
3547 self.__entity_address
3548 }
3549
3550 #(#rpc_view_methods)*
3551 #(#helper_view_methods)*
3552 }
3553 };
3554
3555 let wrapper_delegation_methods: Vec<proc_macro2::TokenStream> = rpcs
3558 .iter()
3559 .filter(|rpc| rpc.is_trait_visible())
3560 .map(|rpc| {
3561 let method_name = &rpc.name;
3562 let resp_type = &rpc.response_type;
3563 let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
3564 let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
3565 let param_defs: Vec<_> = param_names
3566 .iter()
3567 .zip(param_types.iter())
3568 .map(|(name, ty)| quote! { #name: #ty })
3569 .collect();
3570 quote! {
3571 pub async fn #method_name(
3572 &self,
3573 #(#param_defs),*
3574 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3575 let __view = #rpc_view_name { __group: &self.__group, __entity_address: &self.__entity_address };
3576 __view.#method_name(#(#param_names),*).await
3577 }
3578 }
3579 })
3580 .collect();
3581
3582 let dispatch_arms: Vec<proc_macro2::TokenStream> = rpcs
3584 .iter()
3585 .filter(|rpc| rpc.is_dispatchable())
3586 .map(|rpc| {
3587 let tag = &rpc.tag;
3588 let method_name = &rpc.name;
3589 let param_count = rpc.params.len();
3590 let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
3591 let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
3592
3593 let deserialize_request = match param_count {
3594 0 => quote! {},
3595 1 => {
3596 let name = ¶m_names[0];
3597 let ty = ¶m_types[0];
3598 quote! {
3599 let #name: #ty = rmp_serde::from_slice(payload)
3600 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3601 reason: ::std::format!("failed to deserialize request for '{}': {e}", #tag),
3602 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3603 })?;
3604 }
3605 }
3606 _ => quote! {
3607 let (#(#param_names),*): (#(#param_types),*) = rmp_serde::from_slice(payload)
3608 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3609 reason: ::std::format!("failed to deserialize request for '{}': {e}", #tag),
3610 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3611 })?;
3612 },
3613 };
3614
3615 let mut call_args = Vec::new();
3616 match param_count {
3617 0 => {}
3618 1 => {
3619 let name = ¶m_names[0];
3620 call_args.push(quote! { #name });
3621 }
3622 _ => {
3623 for name in ¶m_names {
3624 call_args.push(quote! { #name });
3625 }
3626 }
3627 }
3628 let call_args = quote! { #(#call_args),* };
3629
3630 quote! {
3631 #tag => {
3632 #deserialize_request
3633 let response = self.#method_name(#call_args).await?;
3634 let bytes = rmp_serde::to_vec(&response)
3635 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3636 reason: ::std::format!("failed to serialize response for '{}': {e}", #tag),
3637 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3638 })?;
3639 ::std::result::Result::Ok(::std::option::Option::Some(bytes))
3640 }
3641 }
3642 })
3643 .collect();
3644
3645 let wrapper_struct = quote! {
3646 #[doc(hidden)]
3650 pub struct #wrapper_name {
3651 __group: #struct_name,
3652 __entity_address: #krate::types::EntityAddress,
3653 }
3654
3655 impl #wrapper_name {
3656 pub fn new(group: #struct_name, entity_address: #krate::types::EntityAddress) -> Self {
3658 Self { __group: group, __entity_address: entity_address }
3659 }
3660
3661 pub fn group(&self) -> &#struct_name {
3663 &self.__group
3664 }
3665
3666 #[doc(hidden)]
3668 pub async fn __dispatch(
3669 &self,
3670 tag: &str,
3671 payload: &[u8],
3672 headers: &::std::collections::HashMap<::std::string::String, ::std::string::String>,
3673 ) -> ::std::result::Result<::std::option::Option<::std::vec::Vec<u8>>, #krate::error::ClusterError> {
3674 let _ = headers;
3675 match tag {
3676 #(#dispatch_arms,)*
3677 _ => ::std::result::Result::Ok(::std::option::Option::None),
3678 }
3679 }
3680
3681 #(#wrapper_delegation_methods)*
3683 }
3684 };
3685
3686 let access_trait = quote! {
3688 #[doc(hidden)]
3689 pub trait #access_trait_name {
3690 fn __rpc_group_wrapper(&self) -> &#wrapper_name;
3691 }
3692 };
3693
3694 let blanket_methods: Vec<proc_macro2::TokenStream> = rpcs
3696 .iter()
3697 .filter(|rpc| rpc.is_trait_visible())
3698 .map(|rpc| {
3699 let method_name = &rpc.name;
3700 let resp_type = &rpc.response_type;
3701 let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
3702 let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
3703 let param_defs: Vec<_> = param_names
3704 .iter()
3705 .zip(param_types.iter())
3706 .map(|(name, ty)| quote! { #name: #ty })
3707 .collect();
3708 quote! {
3709 async fn #method_name(
3710 &self,
3711 #(#param_defs),*
3712 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3713 self.__rpc_group_wrapper().#method_name(#(#param_names),*).await
3714 }
3715 }
3716 })
3717 .collect();
3718
3719 let methods_trait = quote! {
3720 #[doc(hidden)]
3724 #[allow(async_fn_in_trait)]
3725 pub trait #methods_trait_name: #access_trait_name {
3726 #(#blanket_methods)*
3727 }
3728
3729 impl<T: #access_trait_name> #methods_trait_name for T {}
3731 };
3732
3733 let client_ext_methods: Vec<proc_macro2::TokenStream> = rpcs
3736 .iter()
3737 .filter(|rpc| rpc.is_client_visible())
3738 .map(|rpc| {
3739 let method_name = &rpc.name;
3740 let tag = &rpc.tag;
3741 let resp_type = &rpc.response_type;
3742 let param_count = rpc.params.len();
3743 let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
3744 let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
3745 let param_defs: Vec<_> = param_names
3746 .iter()
3747 .zip(param_types.iter())
3748 .map(|(name, ty)| quote! { #name: &#ty })
3749 .collect();
3750
3751 if rpc.uses_persisted_delivery() {
3752 match param_count {
3753 0 => quote! {
3754 async fn #method_name(
3755 &self,
3756 entity_id: &#krate::types::EntityId,
3757 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3758 self.entity_client()
3759 .send_persisted(entity_id, #tag, &(), #krate::schema::Uninterruptible::No)
3760 .await
3761 }
3762 },
3763 1 => {
3764 let def = ¶m_defs[0];
3765 let name = ¶m_names[0];
3766 quote! {
3767 async fn #method_name(
3768 &self,
3769 entity_id: &#krate::types::EntityId,
3770 #def,
3771 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3772 self.entity_client()
3773 .send_persisted(entity_id, #tag, #name, #krate::schema::Uninterruptible::No)
3774 .await
3775 }
3776 }
3777 }
3778 _ => quote! {
3779 async fn #method_name(
3780 &self,
3781 entity_id: &#krate::types::EntityId,
3782 #(#param_defs),*
3783 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3784 let request = (#(#param_names),*);
3785 self.entity_client()
3786 .send_persisted(entity_id, #tag, &request, #krate::schema::Uninterruptible::No)
3787 .await
3788 }
3789 },
3790 }
3791 } else {
3792 match param_count {
3793 0 => quote! {
3794 async fn #method_name(
3795 &self,
3796 entity_id: &#krate::types::EntityId,
3797 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3798 self.entity_client().send(entity_id, #tag, &()).await
3799 }
3800 },
3801 1 => {
3802 let def = ¶m_defs[0];
3803 let name = ¶m_names[0];
3804 quote! {
3805 async fn #method_name(
3806 &self,
3807 entity_id: &#krate::types::EntityId,
3808 #def,
3809 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3810 self.entity_client().send(entity_id, #tag, #name).await
3811 }
3812 }
3813 }
3814 _ => quote! {
3815 async fn #method_name(
3816 &self,
3817 entity_id: &#krate::types::EntityId,
3818 #(#param_defs),*
3819 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3820 let request = (#(#param_names),*);
3821 self.entity_client().send(entity_id, #tag, &request).await
3822 }
3823 },
3824 }
3825 }
3826 })
3827 .collect();
3828
3829 let client_ext = quote! {
3830 #[async_trait::async_trait]
3831 pub trait #client_ext_name: #krate::entity_client::EntityClientAccessor {
3832 #(#client_ext_methods)*
3833 }
3834
3835 impl<T> #client_ext_name for T where T: #krate::entity_client::EntityClientAccessor {}
3836 };
3837
3838 Ok(quote! {
3839 #view_struct
3840 #wrapper_struct
3841 #access_trait
3842 #methods_trait
3843 #client_ext
3844 })
3845}
3846
3847struct WorkflowStructArgs {
3852 key: Option<syn::ExprClosure>,
3853 hash: bool,
3854 krate: Option<syn::Path>,
3855}
3856
3857impl syn::parse::Parse for WorkflowStructArgs {
3858 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
3859 let mut args = WorkflowStructArgs {
3860 key: None,
3861 hash: true,
3862 krate: None,
3863 };
3864
3865 while !input.is_empty() {
3866 let ident: syn::Ident = input.parse()?;
3867
3868 match ident.to_string().as_str() {
3869 "key" => {
3870 input.parse::<syn::Token![=]>()?;
3871 let expr: syn::Expr = if input.peek(syn::token::Paren) {
3872 let content;
3873 syn::parenthesized!(content in input);
3874 content.parse()?
3875 } else {
3876 input.parse()?
3877 };
3878 match expr {
3879 syn::Expr::Closure(closure) => args.key = Some(closure),
3880 _ => {
3881 return Err(syn::Error::new(
3882 expr.span(),
3883 "key must be a closure, e.g. #[workflow(key = |req| ...)]",
3884 ))
3885 }
3886 }
3887 }
3888 "hash" => {
3889 input.parse::<syn::Token![=]>()?;
3890 let lit: syn::LitBool = input.parse()?;
3891 args.hash = lit.value;
3892 }
3893 "krate" => {
3894 input.parse::<syn::Token![=]>()?;
3895 let lit: syn::LitStr = input.parse()?;
3896 args.krate = Some(lit.parse()?);
3897 }
3898 other => {
3899 return Err(syn::Error::new(
3900 ident.span(),
3901 format!("unknown workflow attribute: {other}"),
3902 ));
3903 }
3904 }
3905
3906 if !input.is_empty() {
3907 input.parse::<syn::Token![,]>()?;
3908 }
3909 }
3910
3911 Ok(args)
3912 }
3913}
3914
3915struct WorkflowImplArgs {
3916 krate: Option<syn::Path>,
3917 activity_groups: Vec<syn::Path>,
3918 key: Option<syn::ExprClosure>,
3920 hash: bool,
3922}
3923
3924impl syn::parse::Parse for WorkflowImplArgs {
3925 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
3926 let mut args = WorkflowImplArgs {
3927 krate: None,
3928 activity_groups: Vec::new(),
3929 key: None,
3930 hash: true,
3931 };
3932 while !input.is_empty() {
3933 let ident: syn::Ident = input.parse()?;
3934 match ident.to_string().as_str() {
3935 "krate" => {
3936 input.parse::<syn::Token![=]>()?;
3937 let lit: syn::LitStr = input.parse()?;
3938 args.krate = Some(lit.parse()?);
3939 }
3940 "activity_groups" => {
3941 let content;
3942 syn::parenthesized!(content in input);
3943 while !content.is_empty() {
3944 let path: syn::Path = content.parse()?;
3945 args.activity_groups.push(path);
3946 if !content.is_empty() {
3947 content.parse::<syn::Token![,]>()?;
3948 }
3949 }
3950 }
3951 "key" => {
3952 input.parse::<syn::Token![=]>()?;
3953 let expr: syn::Expr = if input.peek(syn::token::Paren) {
3954 let content;
3955 syn::parenthesized!(content in input);
3956 content.parse()?
3957 } else {
3958 input.parse()?
3959 };
3960 match expr {
3961 syn::Expr::Closure(closure) => args.key = Some(closure),
3962 _ => {
3963 return Err(syn::Error::new(
3964 expr.span(),
3965 "key must be a closure, e.g. #[workflow_impl(key = |req| ...)]",
3966 ))
3967 }
3968 }
3969 }
3970 "hash" => {
3971 input.parse::<syn::Token![=]>()?;
3972 let lit: syn::LitBool = input.parse()?;
3973 args.hash = lit.value;
3974 }
3975 other => {
3976 return Err(syn::Error::new(
3977 ident.span(),
3978 format!("unknown workflow_impl attribute: {other}"),
3979 ));
3980 }
3981 }
3982 if !input.is_empty() {
3983 input.parse::<syn::Token![,]>()?;
3984 }
3985 }
3986 Ok(args)
3987 }
3988}
3989
3990fn workflow_struct_inner(
3993 args: WorkflowStructArgs,
3994 input: syn::ItemStruct,
3995) -> syn::Result<proc_macro2::TokenStream> {
3996 let krate = args.krate.unwrap_or_else(default_crate_path);
3997 let struct_name = &input.ident;
3998 let entity_name = format!("Workflow/{}", struct_name);
3999
4000 let key_derivation_info = if let Some(_key_closure) = &args.key {
4003 let hash_val = args.hash;
4004 quote! {
4005 #[doc(hidden)]
4006 fn __workflow_key_closure() -> bool { true }
4007 #[doc(hidden)]
4008 fn __workflow_hash() -> bool { #hash_val }
4009 #[doc(hidden)]
4010 fn __extract_key<__Req>(req: &__Req) -> ::std::string::String
4011 where __Req: serde::Serialize,
4012 {
4013 let _ = req;
4016 unreachable!("key extraction is generated by workflow_impl")
4017 }
4018 }
4019 } else {
4020 quote! {
4021 #[doc(hidden)]
4022 fn __workflow_key_closure() -> bool { false }
4023 #[doc(hidden)]
4024 fn __workflow_hash() -> bool { true }
4025 }
4026 };
4027 let _ = key_derivation_info;
4030 let _ = args.key;
4031
4032 Ok(quote! {
4033 #input
4034
4035 #[allow(dead_code)]
4036 impl #struct_name {
4037 #[doc(hidden)]
4038 fn __entity_type(&self) -> #krate::types::EntityType {
4039 #krate::types::EntityType::new(#entity_name)
4040 }
4041
4042 #[doc(hidden)]
4043 fn __shard_group(&self) -> &str {
4044 "default"
4045 }
4046
4047 #[doc(hidden)]
4048 fn __shard_group_for(&self, _entity_id: &#krate::types::EntityId) -> &str {
4049 self.__shard_group()
4050 }
4051
4052 #[doc(hidden)]
4053 fn __max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
4054 ::std::option::Option::None
4055 }
4056
4057 #[doc(hidden)]
4058 fn __mailbox_capacity(&self) -> ::std::option::Option<usize> {
4059 ::std::option::Option::None
4060 }
4061
4062 #[doc(hidden)]
4063 fn __concurrency(&self) -> ::std::option::Option<usize> {
4064 ::std::option::Option::None
4065 }
4066 }
4067 })
4068}
4069
4070struct WorkflowActivityInfo {
4073 name: syn::Ident,
4074 #[allow(dead_code)]
4075 tag: String,
4076 params: Vec<RpcParam>,
4077 #[allow(dead_code)]
4078 response_type: syn::Type,
4079 persist_key: Option<syn::ExprClosure>,
4080 original_method: syn::ImplItemFn,
4081 retries: Option<u32>,
4083 backoff: Option<String>,
4085}
4086
4087struct WorkflowExecuteInfo {
4088 params: Vec<RpcParam>,
4089 request_type: syn::Type,
4090 response_type: syn::Type,
4091 original_method: syn::ImplItemFn,
4092}
4093
4094fn workflow_impl_inner(
4097 args: WorkflowImplArgs,
4098 input: syn::ItemImpl,
4099) -> syn::Result<proc_macro2::TokenStream> {
4100 let krate = args.krate.unwrap_or_else(default_crate_path);
4101 let self_ty = &input.self_ty;
4102
4103 let struct_name = match self_ty.as_ref() {
4104 syn::Type::Path(tp) => tp
4105 .path
4106 .segments
4107 .last()
4108 .map(|s| s.ident.clone())
4109 .ok_or_else(|| syn::Error::new(self_ty.span(), "expected struct name"))?,
4110 _ => return Err(syn::Error::new(self_ty.span(), "expected struct name")),
4111 };
4112
4113 let handler_name = format_ident!("__{}WorkflowHandler", struct_name);
4114 let client_name = format_ident!("{}Client", struct_name);
4115 let execute_view_name = format_ident!("__{}ExecuteView", struct_name);
4116 let activity_view_name = format_ident!("__{}ActivityView", struct_name);
4117 let entity_name = format!("Workflow/{}", struct_name);
4118 let has_activity_groups = !args.activity_groups.is_empty();
4119 let with_groups_name = format_ident!("__{}WithGroups", struct_name);
4120
4121 for attr in &input.attrs {
4123 if attr.path().is_ident("state") {
4124 return Err(syn::Error::new(
4125 attr.span(),
4126 "workflows are stateless; remove #[state(...)]",
4127 ));
4128 }
4129 }
4130
4131 let mut execute_info: Option<WorkflowExecuteInfo> = None;
4133 let mut activities: Vec<WorkflowActivityInfo> = Vec::new();
4134 let mut original_methods: Vec<syn::ImplItemFn> = Vec::new();
4135
4136 for item in &input.items {
4137 if let syn::ImplItem::Fn(method) = item {
4138 for attr in &method.attrs {
4140 if attr.path().is_ident("state") {
4141 return Err(syn::Error::new(
4142 attr.span(),
4143 "workflows are stateless; remove #[state(...)]",
4144 ));
4145 }
4146 if attr.path().is_ident("rpc") {
4147 return Err(syn::Error::new(
4148 attr.span(),
4149 "workflows use #[activity], not #[rpc]",
4150 ));
4151 }
4152 if attr.path().is_ident("workflow") {
4153 return Err(syn::Error::new(
4154 attr.span(),
4155 "workflows have a single execute entry point; use client calls for cross-workflow interaction",
4156 ));
4157 }
4158 }
4159
4160 if let Some(syn::FnArg::Receiver(r)) = method.sig.inputs.first() {
4162 if r.mutability.is_some() {
4163 return Err(syn::Error::new(
4164 r.span(),
4165 "workflow methods must use &self, not &mut self",
4166 ));
4167 }
4168 }
4169
4170 if method.sig.ident == "execute" {
4171 if execute_info.is_some() {
4173 return Err(syn::Error::new(
4174 method.sig.span(),
4175 "workflow must have exactly one execute method",
4176 ));
4177 }
4178
4179 if method.sig.asyncness.is_none() {
4180 return Err(syn::Error::new(method.sig.span(), "execute must be async"));
4181 }
4182
4183 let mut params = Vec::new();
4185 for arg in method.sig.inputs.iter().skip(1) {
4186 if let syn::FnArg::Typed(pat_type) = arg {
4187 let name = match &*pat_type.pat {
4188 syn::Pat::Ident(ident) => ident.ident.clone(),
4189 _ => {
4190 return Err(syn::Error::new(
4191 pat_type.pat.span(),
4192 "execute parameters must be simple identifiers",
4193 ))
4194 }
4195 };
4196 params.push(RpcParam {
4197 name,
4198 ty: (*pat_type.ty).clone(),
4199 });
4200 }
4201 }
4202
4203 if params.len() != 1 {
4204 return Err(syn::Error::new(
4205 method.sig.span(),
4206 "execute must take exactly one request parameter (after &self)",
4207 ));
4208 }
4209
4210 let request_type = params[0].ty.clone();
4211 let response_type = extract_result_ok_type(match &method.sig.output {
4212 syn::ReturnType::Type(_, ty) => ty,
4213 syn::ReturnType::Default => {
4214 return Err(syn::Error::new(
4215 method.sig.span(),
4216 "execute must return Result<T, ClusterError>",
4217 ))
4218 }
4219 })?;
4220
4221 execute_info = Some(WorkflowExecuteInfo {
4222 params,
4223 request_type,
4224 response_type,
4225 original_method: method.clone(),
4226 });
4227 } else {
4228 let is_activity = method.attrs.iter().any(|a| a.path().is_ident("activity"));
4230
4231 if is_activity {
4232 if method.sig.asyncness.is_none() {
4233 return Err(syn::Error::new(
4234 method.sig.span(),
4235 "#[activity] methods must be async",
4236 ));
4237 }
4238
4239 let (persist_key, act_retries, act_backoff) = {
4241 let mut key = None;
4242 let mut retries = None;
4243 let mut backoff = None;
4244 for attr in &method.attrs {
4245 if attr.path().is_ident("activity") {
4246 let args = match &attr.meta {
4247 syn::Meta::Path(_) => ActivityAttrArgs {
4248 key: None,
4249 retries: None,
4250 backoff: None,
4251 },
4252 syn::Meta::List(_) => attr.parse_args::<ActivityAttrArgs>()?,
4253 _ => {
4254 return Err(syn::Error::new(
4255 attr.span(),
4256 "expected #[activity] or #[activity(...)]",
4257 ))
4258 }
4259 };
4260 key = args.key;
4261 retries = args.retries;
4262 backoff = args.backoff;
4263 }
4264 }
4265 (key, retries, backoff)
4266 };
4267
4268 let mut params = Vec::new();
4270 for arg in method.sig.inputs.iter().skip(1) {
4271 if let syn::FnArg::Typed(pat_type) = arg {
4272 let name = match &*pat_type.pat {
4273 syn::Pat::Ident(ident) => ident.ident.clone(),
4274 _ => {
4275 return Err(syn::Error::new(
4276 pat_type.pat.span(),
4277 "activity parameters must be simple identifiers",
4278 ))
4279 }
4280 };
4281 params.push(RpcParam {
4282 name,
4283 ty: (*pat_type.ty).clone(),
4284 });
4285 }
4286 }
4287
4288 let response_type = extract_result_ok_type(match &method.sig.output {
4289 syn::ReturnType::Type(_, ty) => ty,
4290 syn::ReturnType::Default => {
4291 return Err(syn::Error::new(
4292 method.sig.span(),
4293 "#[activity] must return Result<T, ClusterError>",
4294 ))
4295 }
4296 })?;
4297
4298 activities.push(WorkflowActivityInfo {
4299 name: method.sig.ident.clone(),
4300 tag: method.sig.ident.to_string(),
4301 params,
4302 response_type,
4303 persist_key,
4304 original_method: method.clone(),
4305 retries: act_retries,
4306 backoff: act_backoff,
4307 });
4308 }
4309 }
4310 original_methods.push(method.clone());
4311 }
4312 }
4313
4314 let execute = execute_info.ok_or_else(|| {
4315 syn::Error::new(
4316 input.self_ty.span(),
4317 "workflow must define an `async fn execute(&self, request: T) -> Result<R, ClusterError>` method",
4318 )
4319 })?;
4320
4321 let request_type = &execute.request_type;
4322 let response_type = &execute.response_type;
4323
4324 #[allow(dead_code)]
4326 struct ActivityGroupInfo {
4327 path: syn::Path,
4328 ident: syn::Ident,
4329 field: syn::Ident,
4330 wrapper_ident: syn::Ident,
4331 wrapper_path: syn::Path,
4332 access_trait_ident: syn::Ident,
4333 access_trait_path: syn::Path,
4334 methods_trait_ident: syn::Ident,
4335 methods_trait_path: syn::Path,
4336 }
4337
4338 let group_infos: Vec<ActivityGroupInfo> = args
4339 .activity_groups
4340 .iter()
4341 .map(|path| {
4342 let ident = path
4343 .segments
4344 .last()
4345 .map(|s| s.ident.clone())
4346 .expect("activity group path must have an ident");
4347 let snake = to_snake(&ident.to_string());
4348 let field = format_ident!("__group_{}", snake);
4349 let wrapper_ident = format_ident!("__{}ActivityGroupWrapper", ident);
4350 let wrapper_path = replace_last_segment(path, wrapper_ident.clone());
4351 let access_trait_ident = format_ident!("__{}ActivityGroupAccess", ident);
4352 let access_trait_path = replace_last_segment(path, access_trait_ident.clone());
4353 let methods_trait_ident = format_ident!("__{}ActivityGroupMethods", ident);
4354 let methods_trait_path = replace_last_segment(path, methods_trait_ident.clone());
4355 ActivityGroupInfo {
4356 path: path.clone(),
4357 ident,
4358 field,
4359 wrapper_ident,
4360 wrapper_path,
4361 access_trait_ident,
4362 access_trait_path,
4363 methods_trait_ident,
4364 methods_trait_path,
4365 }
4366 })
4367 .collect();
4368
4369 let execute_method = &execute.original_method;
4372 let execute_block = &execute_method.block;
4373 let execute_output = &execute_method.sig.output;
4374 let execute_param_name = &execute.params[0].name;
4375 let execute_param_type = &execute.params[0].ty;
4376 let execute_attrs: Vec<_> = execute_method
4377 .attrs
4378 .iter()
4379 .filter(|a| {
4380 !a.path().is_ident("rpc")
4381 && !a.path().is_ident("workflow")
4382 && !a.path().is_ident("activity")
4383 })
4384 .collect();
4385
4386 let mut activity_view_methods = Vec::new();
4388 for act in &activities {
4389 let method = &act.original_method;
4390 let block = &method.block;
4391 let output = &method.sig.output;
4392 let name = &act.name;
4393 let params: Vec<_> = method.sig.inputs.iter().skip(1).collect();
4394 let attrs: Vec<_> = method
4395 .attrs
4396 .iter()
4397 .filter(|a| {
4398 !a.path().is_ident("activity")
4399 && !a.path().is_ident("public")
4400 && !a.path().is_ident("protected")
4401 && !a.path().is_ident("private")
4402 })
4403 .collect();
4404 let vis = &method.vis;
4405
4406 activity_view_methods.push(quote! {
4407 #(#attrs)*
4408 #vis async fn #name(&self, #(#params),*) #output
4409 #block
4410 });
4411 }
4412
4413 let mut helper_execute_methods = Vec::new();
4415 let mut helper_activity_methods = Vec::new();
4416
4417 for method in &original_methods {
4418 let name = &method.sig.ident;
4419 if name == "execute" {
4420 continue;
4421 }
4422 let is_activity = method.attrs.iter().any(|a| a.path().is_ident("activity"));
4423 if is_activity {
4424 continue;
4425 }
4426
4427 let block = &method.block;
4429 let output = &method.sig.output;
4430 let params: Vec<_> = method.sig.inputs.iter().skip(1).collect();
4431 let attrs: Vec<_> = method
4432 .attrs
4433 .iter()
4434 .filter(|a| {
4435 !a.path().is_ident("rpc")
4436 && !a.path().is_ident("workflow")
4437 && !a.path().is_ident("activity")
4438 && !a.path().is_ident("method")
4439 && !a.path().is_ident("public")
4440 && !a.path().is_ident("protected")
4441 && !a.path().is_ident("private")
4442 })
4443 .collect();
4444 let vis = &method.vis;
4445 let async_token = if method.sig.asyncness.is_some() {
4446 quote! { async }
4447 } else {
4448 quote! {}
4449 };
4450
4451 let method_tokens = quote! {
4452 #(#attrs)*
4453 #vis #async_token fn #name(&self, #(#params),*) #output
4454 #block
4455 };
4456 helper_execute_methods.push(method_tokens.clone());
4457 helper_activity_methods.push(method_tokens);
4458 }
4459
4460 let activity_delegations: Vec<proc_macro2::TokenStream> = activities
4463 .iter()
4464 .map(|act| {
4465 let method_name = &act.name;
4466 let method_name_str = method_name.to_string();
4467 let method_info = &act.original_method;
4468 let params: Vec<_> = method_info.sig.inputs.iter().skip(1).collect();
4469 let param_names: Vec<_> = method_info
4470 .sig
4471 .inputs
4472 .iter()
4473 .skip(1)
4474 .filter_map(|arg| {
4475 if let syn::FnArg::Typed(pat_type) = arg {
4476 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
4477 return Some(&pat_ident.ident);
4478 }
4479 }
4480 None
4481 })
4482 .collect();
4483 let output = &method_info.sig.output;
4484
4485 let wire_param_names: Vec<_> = act.params.iter().map(|p| &p.name).collect();
4486 let wire_param_count = wire_param_names.len();
4487
4488 let key_bytes_code = if let Some(persist_key) = &act.persist_key {
4490 match wire_param_count {
4491 0 => quote! {
4492 let __journal_key = (#persist_key)();
4493 let __journal_key_bytes = rmp_serde::to_vec(&__journal_key)
4494 .map_err(|e| #krate::error::ClusterError::PersistenceError {
4495 reason: ::std::format!("failed to serialize journal key: {e}"),
4496 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
4497 })?;
4498 },
4499 1 => {
4500 let name = &wire_param_names[0];
4501 quote! {
4502 let __journal_key = (#persist_key)(&#name);
4503 let __journal_key_bytes = rmp_serde::to_vec(&__journal_key)
4504 .map_err(|e| #krate::error::ClusterError::PersistenceError {
4505 reason: ::std::format!("failed to serialize journal key: {e}"),
4506 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
4507 })?;
4508 }
4509 }
4510 _ => quote! {
4511 let __journal_key = (#persist_key)(#(&#wire_param_names),*);
4512 let __journal_key_bytes = rmp_serde::to_vec(&__journal_key)
4513 .map_err(|e| #krate::error::ClusterError::PersistenceError {
4514 reason: ::std::format!("failed to serialize journal key: {e}"),
4515 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
4516 })?;
4517 },
4518 }
4519 } else {
4520 match wire_param_count {
4521 0 => quote! {
4522 let __journal_key_bytes = rmp_serde::to_vec(&()).unwrap_or_default();
4523 },
4524 1 => {
4525 let name = &wire_param_names[0];
4526 quote! {
4527 let __journal_key_bytes = rmp_serde::to_vec(&#name)
4528 .map_err(|e| #krate::error::ClusterError::PersistenceError {
4529 reason: ::std::format!("failed to serialize journal key: {e}"),
4530 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
4531 })?;
4532 }
4533 }
4534 _ => quote! {
4535 let __journal_key_bytes = rmp_serde::to_vec(&(#(&#wire_param_names),*))
4536 .map_err(|e| #krate::error::ClusterError::PersistenceError {
4537 reason: ::std::format!("failed to serialize journal key: {e}"),
4538 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
4539 })?;
4540 },
4541 }
4542 };
4543
4544 let gen_execute_and_journal = |key_bytes_var: &str, param_list: &[proc_macro2::TokenStream], in_retry: bool| -> proc_macro2::TokenStream {
4549 let key_var = format_ident!("{}", key_bytes_var);
4550 let error_handling = if in_retry {
4551 quote! {
4553 if __act_result.is_err() {
4554 drop(__activity_view);
4555 __act_result
4556 }
4557 }
4558 } else {
4559 quote! {
4561 if __act_result.is_err() {
4562 drop(__activity_view);
4563 return __act_result;
4564 }
4565 }
4566 };
4567 quote! {
4568 let __sql_pool = __wf_storage.sql_pool().cloned();
4570
4571 let __pool = __sql_pool.expect("SQL storage is required for workflow activities");
4572 let __sql_tx = __pool.begin().await.map_err(|e| #krate::error::ClusterError::PersistenceError {
4574 reason: ::std::format!("failed to begin activity transaction: {e}"),
4575 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
4576 })?;
4577 let __activity_view = #activity_view_name {
4578 __handler: self.__handler,
4579 tx: #krate::__internal::ActivityTx::new(__sql_tx),
4580 pool: __pool,
4581 };
4582 let __act_result = __activity_view.#method_name(#(#param_list),*).await;
4583
4584 #error_handling
4587 else {
4588 let __storage_key = #krate::__internal::DurableContext::journal_storage_key(
4590 #method_name_str,
4591 &#key_var,
4592 __journal_ctx.entity_type(),
4593 __journal_ctx.entity_id(),
4594 );
4595 let __journal_bytes = #krate::__internal::DurableContext::serialize_journal_result(&__act_result)?;
4596 #krate::__internal::WorkflowScope::register_journal_key(__storage_key.clone());
4597
4598 let mut __tx_back = __activity_view.tx.into_inner().await;
4600 #krate::__internal::save_journal_entry(&mut *__tx_back, &__storage_key, &__journal_bytes).await?;
4601 __tx_back.commit().await.map_err(|e| #krate::error::ClusterError::PersistenceError {
4602 reason: ::std::format!("activity transaction commit failed: {e}"),
4603 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
4604 })?;
4605
4606 __act_result
4607 }
4608 }
4609 };
4610
4611 let param_tokens: Vec<proc_macro2::TokenStream> = param_names
4613 .iter()
4614 .map(|name| quote! { #name.clone() })
4615 .collect();
4616
4617 let max_retries = act.retries.unwrap_or(0);
4619 let journal_body = if max_retries == 0 {
4620 let exec_body = gen_execute_and_journal("__journal_key_bytes", ¶m_tokens, false);
4622 quote! {
4623 let __journal_ctx = #krate::__internal::DurableContext::with_journal_storage(
4625 ::std::sync::Arc::clone(__engine),
4626 self.__handler.ctx.address.entity_type.0.clone(),
4627 self.__handler.ctx.address.entity_id.0.clone(),
4628 ::std::sync::Arc::clone(__msg_storage),
4629 ::std::sync::Arc::clone(__wf_storage),
4630 );
4631 if let ::std::option::Option::Some(__cached) = __journal_ctx.check_journal(#method_name_str, &__journal_key_bytes).await? {
4632 return ::std::result::Result::Ok(__cached);
4633 }
4634
4635 #exec_body
4637 }
4638 } else {
4639 let backoff_str = act.backoff.as_deref().unwrap_or("exponential");
4640 let param_clones: Vec<proc_macro2::TokenStream> = param_names
4642 .iter()
4643 .map(|name| {
4644 let clone_name = format_ident!("__{}_clone", name);
4645 quote! { let #clone_name = #name.clone(); }
4646 })
4647 .collect();
4648 let cloned_param_names: Vec<syn::Ident> = param_names
4649 .iter()
4650 .map(|name| format_ident!("__{}_clone", name))
4651 .collect();
4652
4653 let exec_body = gen_execute_and_journal("__retry_key_bytes", &{
4654 cloned_param_names.iter().map(|n| quote! { #n.clone() }).collect::<Vec<_>>()
4655 }, true);
4656
4657 quote! {
4658 let mut __attempt = 0u32;
4659 loop {
4660 #(#param_clones)*
4662 let __retry_key_bytes = {
4664 let mut __k = __journal_key_bytes.clone();
4665 __k.extend_from_slice(&__attempt.to_le_bytes());
4666 __k
4667 };
4668 let __journal_ctx = #krate::__internal::DurableContext::with_journal_storage(
4670 ::std::sync::Arc::clone(__engine),
4671 self.__handler.ctx.address.entity_type.0.clone(),
4672 self.__handler.ctx.address.entity_id.0.clone(),
4673 ::std::sync::Arc::clone(__msg_storage),
4674 ::std::sync::Arc::clone(__wf_storage),
4675 );
4676 if let ::std::option::Option::Some(__cached) = __journal_ctx.check_journal::<_>(#method_name_str, &__retry_key_bytes).await? {
4677 break ::std::result::Result::Ok(__cached);
4678 }
4679 match { #exec_body } {
4681 ::std::result::Result::Ok(__val) => {
4682 break ::std::result::Result::Ok(__val);
4683 }
4684 ::std::result::Result::Err(__e) if __attempt < #max_retries => {
4685 let __delay = #krate::__internal::compute_retry_backoff(
4687 __attempt, #backoff_str, 1,
4688 );
4689 let __sleep_name = ::std::format!(
4690 "{}/retry/{}", #method_name_str, __attempt
4691 );
4692 __engine.sleep(
4693 &self.__handler.ctx.address.entity_type.0,
4694 &self.__handler.ctx.address.entity_id.0,
4695 &__sleep_name,
4696 __delay,
4697 ).await?;
4698 __attempt += 1;
4699 }
4700 ::std::result::Result::Err(__e) => {
4701 break ::std::result::Result::Err(__e);
4702 }
4703 }
4704 }
4705 }
4706 };
4707
4708 quote! {
4709 #[inline]
4710 async fn #method_name(&self, #(#params),*) #output {
4711 if let (
4712 ::std::option::Option::Some(__engine),
4713 ::std::option::Option::Some(__msg_storage),
4714 ::std::option::Option::Some(__wf_storage),
4715 ) = (
4716 self.__handler.__workflow_engine.as_ref(),
4717 self.__handler.__message_storage.as_ref(),
4718 self.__handler.__state_storage.as_ref(),
4719 ) {
4720 #key_bytes_code
4721 let __journal_key_bytes = {
4723 let mut __scoped = ::std::vec::Vec::new();
4724 if let ::std::option::Option::Some(__wf_id) = #krate::__internal::WorkflowScope::current() {
4725 __scoped.extend_from_slice(&__wf_id.to_le_bytes());
4726 }
4727 __scoped.extend_from_slice(&__journal_key_bytes);
4728 __scoped
4729 };
4730 #journal_body
4731 } else {
4732 panic!("SQL storage is required for workflow activities; configure SqlWorkflowStorage")
4733 }
4734 }
4735 }
4736 })
4737 .collect();
4738
4739 let group_handler_fields: Vec<proc_macro2::TokenStream> = group_infos
4741 .iter()
4742 .map(|info| {
4743 let field = &info.field;
4744 let wrapper_path = &info.wrapper_path;
4745 quote! {
4746 #field: #wrapper_path,
4747 }
4748 })
4749 .collect();
4750
4751 let group_new_params: Vec<proc_macro2::TokenStream> = group_infos
4752 .iter()
4753 .map(|info| {
4754 let field = &info.field;
4755 let path = &info.path;
4756 quote! {
4757 #field: #path,
4758 }
4759 })
4760 .collect();
4761
4762 let group_field_inits: Vec<proc_macro2::TokenStream> = group_infos
4763 .iter()
4764 .map(|info| {
4765 let field = &info.field;
4766 let wrapper_path = &info.wrapper_path;
4767 quote! {
4768 #field: #wrapper_path::new(
4769 #field,
4770 ctx.workflow_engine.clone(),
4771 ctx.message_storage.clone(),
4772 ctx.state_storage.clone(),
4773 ctx.address.entity_type.0.clone(),
4774 ctx.address.entity_id.0.clone(),
4775 ),
4776 }
4777 })
4778 .collect();
4779
4780 let handler_def = quote! {
4782 #[doc(hidden)]
4783 pub struct #handler_name {
4784 __workflow: #struct_name,
4786 #[allow(dead_code)]
4788 ctx: #krate::entity::EntityContext,
4789 __state_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowStorage>>,
4791 __workflow_engine: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowEngine>>,
4793 __message_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::MessageStorage>>,
4795 __sharding: ::std::option::Option<::std::sync::Arc<dyn #krate::sharding::Sharding>>,
4797 __entity_address: #krate::types::EntityAddress,
4799 #(#group_handler_fields)*
4800 }
4801
4802 impl #handler_name {
4803 #[doc(hidden)]
4804 pub async fn __new(
4805 workflow: #struct_name,
4806 #(#group_new_params)*
4807 ctx: #krate::entity::EntityContext,
4808 ) -> ::std::result::Result<Self, #krate::error::ClusterError> {
4809 let __state_storage = ctx.state_storage.clone();
4810 let __sharding = ctx.sharding.clone();
4811 let __entity_address = ctx.address.clone();
4812 ::std::result::Result::Ok(Self {
4813 __workflow: workflow,
4814 __workflow_engine: ctx.workflow_engine.clone(),
4815 __message_storage: ctx.message_storage.clone(),
4816 #(#group_field_inits)*
4817 ctx,
4818 __state_storage,
4819 __sharding,
4820 __entity_address,
4821 })
4822 }
4823
4824 pub async fn sleep(&self, name: &str, duration: ::std::time::Duration) -> ::std::result::Result<(), #krate::error::ClusterError> {
4826 let engine = self.__workflow_engine.as_ref().ok_or_else(|| {
4827 #krate::error::ClusterError::MalformedMessage {
4828 reason: "sleep() requires a workflow engine".into(),
4829 source: ::std::option::Option::None,
4830 }
4831 })?;
4832 let ctx = #krate::__internal::DurableContext::new(
4833 ::std::sync::Arc::clone(engine),
4834 self.ctx.address.entity_type.0.clone(),
4835 self.ctx.address.entity_id.0.clone(),
4836 );
4837 ctx.sleep(name, duration).await
4838 }
4839
4840 pub async fn await_deferred<T, K>(&self, key: K) -> ::std::result::Result<T, #krate::error::ClusterError>
4842 where
4843 T: serde::Serialize + serde::de::DeserializeOwned,
4844 K: #krate::__internal::DeferredKeyLike<T>,
4845 {
4846 let engine = self.__workflow_engine.as_ref().ok_or_else(|| {
4847 #krate::error::ClusterError::MalformedMessage {
4848 reason: "await_deferred() requires a workflow engine".into(),
4849 source: ::std::option::Option::None,
4850 }
4851 })?;
4852 let ctx = #krate::__internal::DurableContext::new(
4853 ::std::sync::Arc::clone(engine),
4854 self.ctx.address.entity_type.0.clone(),
4855 self.ctx.address.entity_id.0.clone(),
4856 );
4857 ctx.await_deferred(key).await
4858 }
4859
4860 pub async fn resolve_deferred<T, K>(&self, key: K, value: &T) -> ::std::result::Result<(), #krate::error::ClusterError>
4862 where
4863 T: serde::Serialize,
4864 K: #krate::__internal::DeferredKeyLike<T>,
4865 {
4866 let engine = self.__workflow_engine.as_ref().ok_or_else(|| {
4867 #krate::error::ClusterError::MalformedMessage {
4868 reason: "resolve_deferred() requires a workflow engine".into(),
4869 source: ::std::option::Option::None,
4870 }
4871 })?;
4872 let ctx = #krate::__internal::DurableContext::new(
4873 ::std::sync::Arc::clone(engine),
4874 self.ctx.address.entity_type.0.clone(),
4875 self.ctx.address.entity_id.0.clone(),
4876 );
4877 ctx.resolve_deferred(key, value).await
4878 }
4879
4880 pub async fn on_interrupt(&self) -> ::std::result::Result<(), #krate::error::ClusterError> {
4882 let engine = self.__workflow_engine.as_ref().ok_or_else(|| {
4883 #krate::error::ClusterError::MalformedMessage {
4884 reason: "on_interrupt() requires a workflow engine".into(),
4885 source: ::std::option::Option::None,
4886 }
4887 })?;
4888 let ctx = #krate::__internal::DurableContext::new(
4889 ::std::sync::Arc::clone(engine),
4890 self.ctx.address.entity_type.0.clone(),
4891 self.ctx.address.entity_id.0.clone(),
4892 );
4893 ctx.on_interrupt().await
4894 }
4895
4896 pub fn execution_id(&self) -> &str {
4898 &self.__entity_address.entity_id.0
4899 }
4900
4901 pub fn entity_id(&self) -> &#krate::types::EntityId {
4903 &self.__entity_address.entity_id
4904 }
4905
4906 pub fn sharding(&self) -> ::std::option::Option<&::std::sync::Arc<dyn #krate::sharding::Sharding>> {
4908 self.__sharding.as_ref()
4909 }
4910
4911 pub fn entity_address(&self) -> &#krate::types::EntityAddress {
4913 &self.__entity_address
4914 }
4915 }
4916 };
4917
4918 let view_structs = quote! {
4920 #[doc(hidden)]
4922 #[allow(non_camel_case_types)]
4923 struct #execute_view_name<'a> {
4924 __handler: &'a #handler_name,
4925 }
4926
4927 #[doc(hidden)]
4932 #[allow(non_camel_case_types)]
4933 struct #activity_view_name<'a> {
4934 __handler: &'a #handler_name,
4935 pub tx: #krate::__internal::ActivityTx,
4938 pub pool: sqlx::PgPool,
4941 }
4942
4943 impl ::std::ops::Deref for #execute_view_name<'_> {
4945 type Target = #struct_name;
4946 fn deref(&self) -> &Self::Target {
4947 &self.__handler.__workflow
4948 }
4949 }
4950
4951 impl ::std::ops::Deref for #activity_view_name<'_> {
4952 type Target = #struct_name;
4953 fn deref(&self) -> &Self::Target {
4954 &self.__handler.__workflow
4955 }
4956 }
4957 };
4958
4959 let group_access_impls: Vec<proc_macro2::TokenStream> = group_infos
4961 .iter()
4962 .map(|info| {
4963 let access_trait_path = &info.access_trait_path;
4964 let wrapper_path = &info.wrapper_path;
4965 let field = &info.field;
4966 quote! {
4967 impl #access_trait_path for #execute_view_name<'_> {
4968 fn __activity_group_wrapper(&self) -> &#wrapper_path {
4969 &self.__handler.#field
4970 }
4971 }
4972 }
4973 })
4974 .collect();
4975
4976 let group_use_methods: Vec<proc_macro2::TokenStream> = group_infos
4978 .iter()
4979 .map(|info| {
4980 let methods_trait_path = &info.methods_trait_path;
4981 quote! {
4982 #[allow(unused_imports)]
4983 use #methods_trait_path as _;
4984 }
4985 })
4986 .collect();
4987
4988 let execute_view_impl = quote! {
4990 #(#group_use_methods)*
4991 impl #execute_view_name<'_> {
4992 #[inline]
4994 async fn sleep(&self, duration: ::std::time::Duration) -> ::std::result::Result<(), #krate::error::ClusterError> {
4995 self.__handler.sleep("__wf_sleep", duration).await
4997 }
4998
4999 #[inline]
5001 async fn await_deferred<T, K>(&self, key: K) -> ::std::result::Result<T, #krate::error::ClusterError>
5002 where
5003 T: serde::Serialize + serde::de::DeserializeOwned,
5004 K: #krate::__internal::DeferredKeyLike<T>,
5005 {
5006 self.__handler.await_deferred(key).await
5007 }
5008
5009 #[inline]
5011 async fn resolve_deferred<T, K>(&self, key: K, value: &T) -> ::std::result::Result<(), #krate::error::ClusterError>
5012 where
5013 T: serde::Serialize,
5014 K: #krate::__internal::DeferredKeyLike<T>,
5015 {
5016 self.__handler.resolve_deferred(key, value).await
5017 }
5018
5019 #[inline]
5021 async fn on_interrupt(&self) -> ::std::result::Result<(), #krate::error::ClusterError> {
5022 self.__handler.on_interrupt().await
5023 }
5024
5025 #[inline]
5027 fn execution_id(&self) -> &str {
5028 self.__handler.execution_id()
5029 }
5030
5031 #[inline]
5033 fn sharding(&self) -> ::std::option::Option<&::std::sync::Arc<dyn #krate::sharding::Sharding>> {
5034 self.__handler.sharding()
5035 }
5036
5037 #[inline]
5042 fn client<T: #krate::entity_client::WorkflowClientFactory>(&self) -> T::Client {
5043 let sharding = self.__handler.__sharding.clone()
5044 .expect("client() requires a sharding interface");
5045 T::workflow_client(sharding)
5046 }
5047
5048 #(#activity_delegations)*
5050
5051 #(#execute_attrs)*
5053 async fn execute(&self, #execute_param_name: #execute_param_type) #execute_output
5054 #execute_block
5055
5056 #(#helper_execute_methods)*
5058 }
5059 };
5060
5061 let activity_view_impl = quote! {
5062 impl #activity_view_name<'_> {
5063 #(#activity_view_methods)*
5064 #(#helper_activity_methods)*
5065 }
5066 };
5067
5068 let dispatch_impl = quote! {
5070 #[async_trait::async_trait]
5071 impl #krate::entity::EntityHandler for #handler_name {
5072 async fn handle_request(
5073 &self,
5074 tag: &str,
5075 payload: &[u8],
5076 headers: &::std::collections::HashMap<::std::string::String, ::std::string::String>,
5077 ) -> ::std::result::Result<::std::vec::Vec<u8>, #krate::error::ClusterError> {
5078 #[allow(unused_variables)]
5079 let headers = headers;
5080 match tag {
5081 "execute" => {
5082 let __request: #request_type = rmp_serde::from_slice(payload)
5083 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
5084 reason: ::std::format!("failed to deserialize workflow request: {e}"),
5085 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
5086 })?;
5087 let __request_id = headers
5088 .get(#krate::__internal::REQUEST_ID_HEADER_KEY)
5089 .and_then(|v| v.parse::<i64>().ok())
5090 .unwrap_or(0);
5091 let (__wf_result, __journal_keys) = #krate::__internal::WorkflowScope::run(__request_id, || async {
5092 let __view = #execute_view_name { __handler: self };
5093 __view.execute(__request).await
5094 }).await;
5095 let response = __wf_result?;
5096 if let ::std::option::Option::Some(ref __wf_storage) = self.__state_storage {
5098 for __key in &__journal_keys {
5099 let _ = __wf_storage.mark_completed(__key).await;
5100 }
5101 }
5102 rmp_serde::to_vec(&response)
5103 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
5104 reason: ::std::format!("failed to serialize workflow response: {e}"),
5105 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
5106 })
5107 }
5108 _ => ::std::result::Result::Err(
5109 #krate::error::ClusterError::MalformedMessage {
5110 reason: ::std::format!("unknown workflow tag: {tag}"),
5111 source: ::std::option::Option::None,
5112 }
5113 ),
5114 }
5115 }
5116 }
5117 };
5118
5119 let (entity_impl, register_impl) = if has_activity_groups {
5123 let group_struct_fields: Vec<proc_macro2::TokenStream> = group_infos
5124 .iter()
5125 .map(|info| {
5126 let field = &info.field;
5127 let path = &info.path;
5128 quote! { #field: #path, }
5129 })
5130 .collect();
5131
5132 let group_spawn_args: Vec<proc_macro2::TokenStream> = group_infos
5133 .iter()
5134 .map(|info| {
5135 let field = &info.field;
5136 quote! { self.#field.clone(), }
5137 })
5138 .collect();
5139
5140 let group_register_params: Vec<proc_macro2::TokenStream> = group_infos
5141 .iter()
5142 .map(|info| {
5143 let field = &info.field;
5144 let path = &info.path;
5145 quote! { #field: #path, }
5146 })
5147 .collect();
5148
5149 let group_register_field_inits: Vec<proc_macro2::TokenStream> = group_infos
5150 .iter()
5151 .map(|info| {
5152 let field = &info.field;
5153 quote! { #field, }
5154 })
5155 .collect();
5156
5157 let entity_impl_tokens = quote! {
5158 #[doc(hidden)]
5160 #[derive(Clone)]
5161 pub struct #with_groups_name {
5162 __workflow: #struct_name,
5163 #(#group_struct_fields)*
5164 }
5165
5166 #[async_trait::async_trait]
5167 impl #krate::entity::Entity for #with_groups_name {
5168 fn entity_type(&self) -> #krate::types::EntityType {
5169 self.__workflow.__entity_type()
5170 }
5171
5172 fn shard_group(&self) -> &str {
5173 self.__workflow.__shard_group()
5174 }
5175
5176 fn shard_group_for(&self, entity_id: &#krate::types::EntityId) -> &str {
5177 self.__workflow.__shard_group_for(entity_id)
5178 }
5179
5180 fn max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
5181 self.__workflow.__max_idle_time()
5182 }
5183
5184 fn mailbox_capacity(&self) -> ::std::option::Option<usize> {
5185 self.__workflow.__mailbox_capacity()
5186 }
5187
5188 fn concurrency(&self) -> ::std::option::Option<usize> {
5189 self.__workflow.__concurrency()
5190 }
5191
5192 async fn spawn(
5193 &self,
5194 ctx: #krate::entity::EntityContext,
5195 ) -> ::std::result::Result<
5196 ::std::boxed::Box<dyn #krate::entity::EntityHandler>,
5197 #krate::error::ClusterError,
5198 > {
5199 let handler = #handler_name::__new(
5200 self.__workflow.clone(),
5201 #(#group_spawn_args)*
5202 ctx,
5203 ).await?;
5204 ::std::result::Result::Ok(::std::boxed::Box::new(handler))
5205 }
5206 }
5207 };
5208
5209 let register_impl_tokens = quote! {
5210 impl #struct_name {
5211 pub async fn register(
5215 self,
5216 sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>,
5217 #(#group_register_params)*
5218 ) -> ::std::result::Result<#client_name, #krate::error::ClusterError> {
5219 let bundle = #with_groups_name {
5220 __workflow: self,
5221 #(#group_register_field_inits)*
5222 };
5223 sharding.register_entity(::std::sync::Arc::new(bundle)).await?;
5224 ::std::result::Result::Ok(#client_name::new(sharding))
5225 }
5226 }
5227 };
5228
5229 (entity_impl_tokens, register_impl_tokens)
5230 } else {
5231 let entity_impl_tokens = quote! {
5232 #[async_trait::async_trait]
5233 impl #krate::entity::Entity for #struct_name {
5234 fn entity_type(&self) -> #krate::types::EntityType {
5235 self.__entity_type()
5236 }
5237
5238 fn shard_group(&self) -> &str {
5239 self.__shard_group()
5240 }
5241
5242 fn shard_group_for(&self, entity_id: &#krate::types::EntityId) -> &str {
5243 self.__shard_group_for(entity_id)
5244 }
5245
5246 fn max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
5247 self.__max_idle_time()
5248 }
5249
5250 fn mailbox_capacity(&self) -> ::std::option::Option<usize> {
5251 self.__mailbox_capacity()
5252 }
5253
5254 fn concurrency(&self) -> ::std::option::Option<usize> {
5255 self.__concurrency()
5256 }
5257
5258 async fn spawn(
5259 &self,
5260 ctx: #krate::entity::EntityContext,
5261 ) -> ::std::result::Result<
5262 ::std::boxed::Box<dyn #krate::entity::EntityHandler>,
5263 #krate::error::ClusterError,
5264 > {
5265 let handler = #handler_name::__new(self.clone(), ctx).await?;
5266 ::std::result::Result::Ok(::std::boxed::Box::new(handler))
5267 }
5268 }
5269 };
5270
5271 let register_impl_tokens = quote! {
5272 impl #struct_name {
5273 pub async fn register(
5275 self,
5276 sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>,
5277 ) -> ::std::result::Result<#client_name, #krate::error::ClusterError> {
5278 sharding.register_entity(::std::sync::Arc::new(self)).await?;
5279 ::std::result::Result::Ok(#client_name::new(sharding))
5280 }
5281 }
5282 };
5283
5284 (entity_impl_tokens, register_impl_tokens)
5285 };
5286
5287 let struct_name_str = entity_name;
5289 let client_with_key_name = format_ident!("{}ClientWithKey", struct_name);
5290
5291 let derive_entity_id_fn = if let Some(ref key_closure) = args.key {
5293 if args.hash {
5294 quote! {
5296 fn derive_entity_id(
5297 request: &#request_type,
5298 ) -> ::std::result::Result<#krate::types::EntityId, #krate::error::ClusterError> {
5299 let key_value = (#key_closure)(request);
5300 let key_bytes = rmp_serde::to_vec(&key_value)
5301 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
5302 reason: ::std::format!("failed to serialize workflow key: {e}"),
5303 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
5304 })?;
5305 ::std::result::Result::Ok(#krate::types::EntityId::new(
5306 #krate::hash::sha256_hex(&key_bytes)
5307 ))
5308 }
5309 }
5310 } else {
5311 quote! {
5313 fn derive_entity_id(
5314 request: &#request_type,
5315 ) -> ::std::result::Result<#krate::types::EntityId, #krate::error::ClusterError> {
5316 let key_value = (#key_closure)(request);
5317 ::std::result::Result::Ok(#krate::types::EntityId::new(
5318 key_value.to_string()
5319 ))
5320 }
5321 }
5322 }
5323 } else {
5324 quote! {
5326 fn derive_entity_id(
5328 request: &#request_type,
5329 ) -> ::std::result::Result<#krate::types::EntityId, #krate::error::ClusterError> {
5330 let key_bytes = rmp_serde::to_vec(request)
5331 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
5332 reason: ::std::format!("failed to serialize workflow request: {e}"),
5333 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
5334 })?;
5335 ::std::result::Result::Ok(#krate::types::EntityId::new(
5336 #krate::hash::sha256_hex(&key_bytes)
5337 ))
5338 }
5339 }
5340 };
5341
5342 let client_impl = quote! {
5343 #[derive(Clone)]
5347 pub struct #client_name {
5348 inner: #krate::entity_client::EntityClient,
5349 }
5350
5351 impl #client_name {
5352 pub fn new(sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>) -> Self {
5354 Self {
5355 inner: #krate::entity_client::EntityClient::new(
5356 sharding,
5357 #krate::types::EntityType::new(#struct_name_str),
5358 ),
5359 }
5360 }
5361
5362 pub fn inner(&self) -> &#krate::entity_client::EntityClient {
5364 &self.inner
5365 }
5366
5367 pub fn with_key(&self, key: impl ::std::fmt::Display) -> #client_with_key_name<'_> {
5372 let key_str = key.to_string();
5373 let entity_id = #krate::types::EntityId::new(
5374 #krate::hash::sha256_hex(key_str.as_bytes())
5375 );
5376 #client_with_key_name {
5377 inner: &self.inner,
5378 entity_id,
5379 }
5380 }
5381
5382 pub fn with_key_raw(&self, key: impl ::std::string::ToString) -> #client_with_key_name<'_> {
5387 #client_with_key_name {
5388 inner: &self.inner,
5389 entity_id: #krate::types::EntityId::new(key.to_string()),
5390 }
5391 }
5392
5393 #derive_entity_id_fn
5394
5395 pub async fn execute(
5400 &self,
5401 request: &#request_type,
5402 ) -> ::std::result::Result<#response_type, #krate::error::ClusterError> {
5403 let entity_id = Self::derive_entity_id(request)?;
5404 let key_bytes = entity_id.0.as_bytes().to_vec();
5405 self.inner.send_persisted_with_key(
5406 &entity_id,
5407 "execute",
5408 request,
5409 ::std::option::Option::Some(key_bytes),
5410 #krate::schema::Uninterruptible::No,
5411 ).await
5412 }
5413
5414 pub async fn start(
5420 &self,
5421 request: &#request_type,
5422 ) -> ::std::result::Result<::std::string::String, #krate::error::ClusterError> {
5423 let entity_id = Self::derive_entity_id(request)?;
5424 let key_bytes = entity_id.0.as_bytes().to_vec();
5425 self.inner.notify_persisted_with_key(
5426 &entity_id,
5427 "execute",
5428 request,
5429 ::std::option::Option::Some(key_bytes),
5430 ).await?;
5431 ::std::result::Result::Ok(entity_id.0)
5432 }
5433
5434 pub async fn poll(
5440 &self,
5441 execution_id: &str,
5442 ) -> ::std::result::Result<::std::option::Option<#response_type>, #krate::error::ClusterError> {
5443 let entity_id = #krate::types::EntityId::new(execution_id);
5444 let key_bytes = entity_id.0.as_bytes();
5445 self.inner.poll_reply::<#response_type>(
5446 &entity_id,
5447 "execute",
5448 key_bytes,
5449 ).await
5450 }
5451
5452 pub async fn join(
5458 &self,
5459 execution_id: &str,
5460 ) -> ::std::result::Result<#response_type, #krate::error::ClusterError> {
5461 let entity_id = #krate::types::EntityId::new(execution_id);
5462 let key_bytes = entity_id.0.as_bytes();
5463 self.inner.join_reply::<#response_type>(
5464 &entity_id,
5465 "execute",
5466 key_bytes,
5467 ).await
5468 }
5469 }
5470
5471 impl #krate::entity_client::EntityClientAccessor for #client_name {
5472 fn entity_client(&self) -> &#krate::entity_client::EntityClient {
5473 &self.inner
5474 }
5475 }
5476
5477 pub struct #client_with_key_name<'a> {
5482 inner: &'a #krate::entity_client::EntityClient,
5483 entity_id: #krate::types::EntityId,
5484 }
5485
5486 impl #client_with_key_name<'_> {
5487 pub async fn execute(
5489 &self,
5490 request: &#request_type,
5491 ) -> ::std::result::Result<#response_type, #krate::error::ClusterError> {
5492 let key_bytes = self.entity_id.0.as_bytes().to_vec();
5493 self.inner.send_persisted_with_key(
5494 &self.entity_id,
5495 "execute",
5496 request,
5497 ::std::option::Option::Some(key_bytes),
5498 #krate::schema::Uninterruptible::No,
5499 ).await
5500 }
5501
5502 pub async fn start(
5506 &self,
5507 request: &#request_type,
5508 ) -> ::std::result::Result<::std::string::String, #krate::error::ClusterError> {
5509 let key_bytes = self.entity_id.0.as_bytes().to_vec();
5510 self.inner.notify_persisted_with_key(
5511 &self.entity_id,
5512 "execute",
5513 request,
5514 ::std::option::Option::Some(key_bytes),
5515 ).await?;
5516 ::std::result::Result::Ok(self.entity_id.0.clone())
5517 }
5518
5519 pub async fn poll(
5525 &self,
5526 ) -> ::std::result::Result<::std::option::Option<#response_type>, #krate::error::ClusterError> {
5527 let key_bytes = self.entity_id.0.as_bytes();
5528 self.inner.poll_reply::<#response_type>(
5529 &self.entity_id,
5530 "execute",
5531 key_bytes,
5532 ).await
5533 }
5534
5535 pub async fn join(
5541 &self,
5542 ) -> ::std::result::Result<#response_type, #krate::error::ClusterError> {
5543 let key_bytes = self.entity_id.0.as_bytes();
5544 self.inner.join_reply::<#response_type>(
5545 &self.entity_id,
5546 "execute",
5547 key_bytes,
5548 ).await
5549 }
5550 }
5551 };
5552
5553 let client_factory_impl = quote! {
5555 impl #krate::entity_client::WorkflowClientFactory for #struct_name {
5556 type Client = #client_name;
5557
5558 fn workflow_client(sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>) -> #client_name {
5559 #client_name::new(sharding)
5560 }
5561 }
5562 };
5563
5564 Ok(quote! {
5565 #handler_def
5566 #view_structs
5567 #(#group_access_impls)*
5568 #execute_view_impl
5569 #activity_view_impl
5570 #dispatch_impl
5571 #entity_impl
5572 #register_impl
5573 #client_impl
5574 #client_factory_impl
5575 })
5576}