1use proc_macro::TokenStream;
59use quote::{format_ident, quote};
60use std::collections::HashSet;
61use syn::{parse_macro_input, spanned::Spanned};
62
63#[proc_macro_attribute]
77pub fn entity(attr: TokenStream, item: TokenStream) -> TokenStream {
78 let args = parse_macro_input!(attr as EntityArgs);
79 let input = parse_macro_input!(item as syn::ItemStruct);
80 match entity_impl_inner(args, input) {
81 Ok(tokens) => tokens.into(),
82 Err(e) => e.to_compile_error().into(),
83 }
84}
85
86#[proc_macro_attribute]
136pub fn entity_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
137 let args = parse_macro_input!(attr as ImplArgs);
138 let input = parse_macro_input!(item as syn::ItemImpl);
139 match entity_impl_block_inner(args, input) {
140 Ok(tokens) => tokens.into(),
141 Err(e) => e.to_compile_error().into(),
142 }
143}
144
145#[proc_macro_attribute]
149pub fn entity_trait(attr: TokenStream, item: TokenStream) -> TokenStream {
150 let _args = parse_macro_input!(attr as TraitArgs);
151 let input = parse_macro_input!(item as syn::ItemStruct);
152 quote! { #input }.into()
153}
154
155#[proc_macro_attribute]
202pub fn entity_trait_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
203 let args = parse_macro_input!(attr as TraitArgs);
204 let input = parse_macro_input!(item as syn::ItemImpl);
205 match entity_trait_impl_inner(args, input) {
206 Ok(tokens) => tokens.into(),
207 Err(e) => e.to_compile_error().into(),
208 }
209}
210
211#[proc_macro_attribute]
269pub fn state(_attr: TokenStream, item: TokenStream) -> TokenStream {
270 item
273}
274
275#[proc_macro_attribute]
316pub fn rpc(_attr: TokenStream, item: TokenStream) -> TokenStream {
317 item
318}
319
320#[proc_macro_attribute]
390pub fn workflow(_attr: TokenStream, item: TokenStream) -> TokenStream {
391 item
392}
393
394#[proc_macro_attribute]
453pub fn activity(_attr: TokenStream, item: TokenStream) -> TokenStream {
454 item
455}
456
457#[proc_macro_attribute]
483pub fn public(_attr: TokenStream, item: TokenStream) -> TokenStream {
484 item
485}
486
487#[proc_macro_attribute]
515pub fn protected(_attr: TokenStream, item: TokenStream) -> TokenStream {
516 item
517}
518
519#[proc_macro_attribute]
548pub fn private(_attr: TokenStream, item: TokenStream) -> TokenStream {
549 item
550}
551
552#[proc_macro_attribute]
588pub fn method(_attr: TokenStream, item: TokenStream) -> TokenStream {
589 item
590}
591
592struct EntityArgs {
595 name: Option<String>,
596 shard_group: Option<String>,
597 max_idle_time_secs: Option<u64>,
598 mailbox_capacity: Option<usize>,
599 concurrency: Option<usize>,
600 krate: Option<syn::Path>,
601}
602
603impl syn::parse::Parse for EntityArgs {
604 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
605 let mut args = EntityArgs {
606 name: None,
607 shard_group: None,
608 max_idle_time_secs: None,
609 mailbox_capacity: None,
610 concurrency: None,
611 krate: None,
612 };
613
614 while !input.is_empty() {
615 let ident: syn::Ident = input.parse()?;
616 input.parse::<syn::Token![=]>()?;
617
618 match ident.to_string().as_str() {
619 "name" => {
620 let lit: syn::LitStr = input.parse()?;
621 args.name = Some(lit.value());
622 }
623 "shard_group" => {
624 let lit: syn::LitStr = input.parse()?;
625 args.shard_group = Some(lit.value());
626 }
627 "max_idle_time_secs" => {
628 let lit: syn::LitInt = input.parse()?;
629 args.max_idle_time_secs = Some(lit.base10_parse()?);
630 }
631 "mailbox_capacity" => {
632 let lit: syn::LitInt = input.parse()?;
633 args.mailbox_capacity = Some(lit.base10_parse()?);
634 }
635 "concurrency" => {
636 let lit: syn::LitInt = input.parse()?;
637 args.concurrency = Some(lit.base10_parse()?);
638 }
639 "krate" => {
640 let lit: syn::LitStr = input.parse()?;
641 args.krate = Some(lit.parse()?);
642 }
643 other => {
644 return Err(syn::Error::new(
645 ident.span(),
646 format!("unknown entity attribute: {other}"),
647 ));
648 }
649 }
650
651 if !input.is_empty() {
652 input.parse::<syn::Token![,]>()?;
653 }
654 }
655
656 Ok(args)
657 }
658}
659
660struct ImplArgs {
661 krate: Option<syn::Path>,
662 traits: Vec<syn::Path>,
663 deferred_keys: Vec<DeferredKeyDecl>,
664}
665
666struct DeferredKeyDecl {
667 ident: syn::Ident,
668 ty: syn::Type,
669 name: syn::LitStr,
670}
671
672impl syn::parse::Parse for DeferredKeyDecl {
673 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
674 let ident: syn::Ident = input.parse()?;
675 input.parse::<syn::Token![:]>()?;
676 let ty: syn::Type = input.parse()?;
677 if !input.peek(syn::Token![=]) {
678 return Err(syn::Error::new(
679 input.span(),
680 "expected `= \"name\"` for deferred key",
681 ));
682 }
683 input.parse::<syn::Token![=]>()?;
684 let name: syn::LitStr = input.parse()?;
685 Ok(DeferredKeyDecl { ident, ty, name })
686 }
687}
688
689impl syn::parse::Parse for ImplArgs {
690 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
691 let mut args = ImplArgs {
692 krate: None,
693 traits: Vec::new(),
694 deferred_keys: Vec::new(),
695 };
696 while !input.is_empty() {
697 let ident: syn::Ident = input.parse()?;
698 match ident.to_string().as_str() {
699 "krate" => {
700 input.parse::<syn::Token![=]>()?;
701 let lit: syn::LitStr = input.parse()?;
702 args.krate = Some(lit.parse()?);
703 }
704 "traits" => {
705 let content;
706 syn::parenthesized!(content in input);
707 while !content.is_empty() {
708 let path: syn::Path = content.parse()?;
709 args.traits.push(path);
710 if !content.is_empty() {
711 content.parse::<syn::Token![,]>()?;
712 }
713 }
714 }
715 "deferred_keys" => {
716 let content;
717 syn::parenthesized!(content in input);
718 let decls = content.parse_terminated(DeferredKeyDecl::parse, syn::Token![,])?;
719 args.deferred_keys.extend(decls);
720 }
721 other => {
722 return Err(syn::Error::new(
723 ident.span(),
724 format!("unknown entity_impl attribute: {other}"),
725 ));
726 }
727 }
728 if !input.is_empty() {
729 input.parse::<syn::Token![,]>()?;
730 }
731 }
732 Ok(args)
733 }
734}
735
736struct TraitArgs {
737 krate: Option<syn::Path>,
738}
739
740impl syn::parse::Parse for TraitArgs {
741 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
742 let mut args = TraitArgs { krate: None };
743 while !input.is_empty() {
744 let ident: syn::Ident = input.parse()?;
745 input.parse::<syn::Token![=]>()?;
746 match ident.to_string().as_str() {
747 "krate" => {
748 let lit: syn::LitStr = input.parse()?;
749 args.krate = Some(lit.parse()?);
750 }
751 other => {
752 return Err(syn::Error::new(
753 ident.span(),
754 format!("unknown entity_trait attribute: {other}"),
755 ));
756 }
757 }
758 if !input.is_empty() {
759 input.parse::<syn::Token![,]>()?;
760 }
761 }
762 Ok(args)
763 }
764}
765
766fn default_crate_path() -> syn::Path {
767 syn::parse_str("cruster").unwrap()
768}
769
770fn replace_last_segment(path: &syn::Path, ident: syn::Ident) -> syn::Path {
771 let mut new_path = path.clone();
772 if let Some(last) = new_path.segments.last_mut() {
773 last.ident = ident;
774 last.arguments = syn::PathArguments::None;
775 }
776 new_path
777}
778
779struct TraitInfo {
780 path: syn::Path,
781 ident: syn::Ident,
782 field: syn::Ident,
783 wrapper_path: syn::Path,
784 state_info_path: syn::Path,
785 state_init_path: syn::Path,
786 missing_reason: String,
787}
788
789fn trait_infos_from_paths(paths: &[syn::Path]) -> Vec<TraitInfo> {
790 paths
791 .iter()
792 .map(|path| {
793 let ident = path
794 .segments
795 .last()
796 .expect("trait path missing segment")
797 .ident
798 .clone();
799 let field = format_ident!("__trait_{}", to_snake(&ident.to_string()));
800 let wrapper_ident = format_ident!("{}StateWrapper", ident);
801 let state_info_ident = format_ident!("__{}TraitStateInfo", ident);
802 let state_init_ident = format_ident!("__{}TraitStateInit", ident);
803 let wrapper_path = replace_last_segment(path, wrapper_ident);
804 let state_info_path = replace_last_segment(path, state_info_ident);
805 let state_init_path = replace_last_segment(path, state_init_ident);
806 let missing_reason = format!("missing trait dependency: {ident}");
807 TraitInfo {
808 path: path.clone(),
809 ident,
810 field,
811 wrapper_path,
812 state_info_path,
813 state_init_path,
814 missing_reason,
815 }
816 })
817 .collect()
818}
819
820fn entity_impl_inner(
823 args: EntityArgs,
824 input: syn::ItemStruct,
825) -> syn::Result<proc_macro2::TokenStream> {
826 let krate = args.krate.clone().unwrap_or_else(default_crate_path);
827 let struct_name = &input.ident;
828 let entity_name = args.name.unwrap_or_else(|| struct_name.to_string());
829 let shard_group_value = if let Some(sg) = &args.shard_group {
830 quote! { #sg }
831 } else {
832 quote! { "default" }
833 };
834 let max_idle_value = if let Some(secs) = args.max_idle_time_secs {
835 quote! { ::std::option::Option::Some(::std::time::Duration::from_secs(#secs)) }
836 } else {
837 quote! { ::std::option::Option::None }
838 };
839 let mailbox_value = if let Some(cap) = args.mailbox_capacity {
840 quote! { ::std::option::Option::Some(#cap) }
841 } else {
842 quote! { ::std::option::Option::None }
843 };
844 let concurrency_value = if let Some(c) = args.concurrency {
845 quote! { ::std::option::Option::Some(#c) }
846 } else {
847 quote! { ::std::option::Option::None }
848 };
849
850 Ok(quote! {
851 #input
852
853 #[allow(dead_code)]
854 impl #struct_name {
855 #[doc(hidden)]
856 fn __entity_type(&self) -> #krate::types::EntityType {
857 #krate::types::EntityType::new(#entity_name)
858 }
859
860 #[doc(hidden)]
861 fn __shard_group(&self) -> &str {
862 #shard_group_value
863 }
864
865 #[doc(hidden)]
866 fn __shard_group_for(&self, _entity_id: &#krate::types::EntityId) -> &str {
867 self.__shard_group()
868 }
869
870 #[doc(hidden)]
871 fn __max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
872 #max_idle_value
873 }
874
875 #[doc(hidden)]
876 fn __mailbox_capacity(&self) -> ::std::option::Option<usize> {
877 #mailbox_value
878 }
879
880 #[doc(hidden)]
881 fn __concurrency(&self) -> ::std::option::Option<usize> {
882 #concurrency_value
883 }
884 }
885 })
886}
887
888#[derive(Clone, Copy, Debug, PartialEq, Eq)]
891enum RpcKind {
892 Rpc,
893 Workflow,
894 Activity,
895 Method,
896}
897
898#[derive(Clone, Copy, Debug, PartialEq, Eq)]
899enum RpcVisibility {
900 Public,
901 Protected,
902 Private,
903}
904
905impl RpcKind {
906 fn is_persisted(&self) -> bool {
907 matches!(self, RpcKind::Workflow | RpcKind::Activity)
908 }
909}
910
911impl RpcVisibility {
912 fn is_public(&self) -> bool {
913 matches!(self, RpcVisibility::Public)
914 }
915
916 fn is_private(&self) -> bool {
917 matches!(self, RpcVisibility::Private)
918 }
919}
920
921struct RpcMethod {
922 name: syn::Ident,
923 tag: String,
924 params: Vec<RpcParam>,
925 response_type: syn::Type,
926 is_mut: bool,
927 kind: RpcKind,
928 visibility: RpcVisibility,
929 persist_key: Option<syn::ExprClosure>,
931 has_durable_context: bool,
933}
934
935impl RpcMethod {
936 fn is_dispatchable(&self) -> bool {
937 self.visibility.is_public() && !matches!(self.kind, RpcKind::Activity | RpcKind::Method)
938 }
939
940 fn is_client_visible(&self) -> bool {
941 self.visibility.is_public() && !matches!(self.kind, RpcKind::Activity | RpcKind::Method)
942 }
943
944 fn is_trait_visible(&self) -> bool {
945 !self.visibility.is_private() && !matches!(self.kind, RpcKind::Method)
946 }
947}
948
949struct RpcParam {
950 name: syn::Ident,
951 ty: syn::Type,
952}
953
954fn entity_impl_block_inner(
955 args: ImplArgs,
956 input: syn::ItemImpl,
957) -> syn::Result<proc_macro2::TokenStream> {
958 let krate = args.krate.unwrap_or_else(default_crate_path);
959 let traits = args.traits;
960 let deferred_keys = args.deferred_keys;
961 let mut input = input;
962 let state_info = parse_state_attr(&mut input.attrs)?;
963 let self_ty = &input.self_ty;
964
965 let struct_name = match self_ty.as_ref() {
966 syn::Type::Path(tp) => tp
967 .path
968 .segments
969 .last()
970 .map(|s| s.ident.clone())
971 .ok_or_else(|| syn::Error::new(self_ty.span(), "expected struct name"))?,
972 _ => return Err(syn::Error::new(self_ty.span(), "expected struct name")),
973 };
974
975 let handler_name = format_ident!("{}Handler", struct_name);
976 let client_name = format_ident!("{}Client", struct_name);
977
978 let state_type: Option<syn::Type> = state_info.as_ref().map(|info| info.ty.clone());
980 let state_persisted = state_info
981 .as_ref()
982 .map(|info| info.persistent)
983 .unwrap_or(false);
984 let mut has_init = false;
985 let mut rpcs = Vec::new();
986 let mut original_methods = Vec::new();
987 let mut init_method: Option<syn::ImplItemFn> = None;
988
989 for item in &input.items {
990 match item {
991 syn::ImplItem::Type(type_item) if type_item.ident == "State" => {
992 return Err(syn::Error::new(
993 type_item.span(),
994 "use #[state(Type)] on the impl block instead of `type State`",
995 ));
996 }
997 syn::ImplItem::Fn(method) => {
998 let has_rpc_attrs = parse_kind_attr(&method.attrs)?.is_some()
999 || parse_visibility_attr(&method.attrs)?.is_some();
1000 if method.sig.ident == "init" && method.sig.asyncness.is_none() {
1001 if has_rpc_attrs {
1002 return Err(syn::Error::new(
1003 method.sig.span(),
1004 "RPC annotations are only valid on async methods",
1005 ));
1006 }
1007 has_init = true;
1008 init_method = Some(method.clone());
1009 } else if method.sig.asyncness.is_some() {
1010 if let Some(rpc) = parse_rpc_method(method)? {
1011 rpcs.push(rpc);
1012 }
1013 } else if has_rpc_attrs {
1014 return Err(syn::Error::new(
1015 method.sig.span(),
1016 "RPC annotations are only valid on async methods",
1017 ));
1018 }
1019 original_methods.push(method.clone());
1020 }
1021 _ => {}
1022 }
1023 }
1024
1025 let is_stateful = state_type.is_some();
1026
1027 if is_stateful && !has_init {
1028 return Err(syn::Error::new(
1029 input.self_ty.span(),
1030 "stateful entities with #[state(...)] must define `fn init(&self, ctx: &EntityContext) -> Result<State, ClusterError>`",
1031 ));
1032 }
1033
1034 if has_init && !is_stateful {
1035 return Err(syn::Error::new(
1036 init_method.as_ref().unwrap().sig.span(),
1037 "`fn init` is only valid when #[state(...)] is also defined",
1038 ));
1039 }
1040
1041 let entity_tokens = if is_stateful {
1042 generate_stateful_entity(
1043 &krate,
1044 &struct_name,
1045 &handler_name,
1046 &client_name,
1047 state_type.as_ref().unwrap(),
1048 state_persisted,
1049 &traits,
1050 &rpcs,
1051 &original_methods,
1052 )?
1053 } else {
1054 generate_stateless_entity(
1055 &krate,
1056 &struct_name,
1057 &handler_name,
1058 &client_name,
1059 &traits,
1060 &rpcs,
1061 &original_methods,
1062 )?
1063 };
1064
1065 let deferred_consts = generate_deferred_key_consts(&krate, &deferred_keys)?;
1066
1067 Ok(quote! {
1068 #entity_tokens
1069 #deferred_consts
1070 })
1071}
1072
1073fn generate_deferred_key_consts(
1074 krate: &syn::Path,
1075 deferred_keys: &[DeferredKeyDecl],
1076) -> syn::Result<proc_macro2::TokenStream> {
1077 if deferred_keys.is_empty() {
1078 return Ok(quote! {});
1079 }
1080
1081 let mut seen_idents = HashSet::new();
1082 let mut seen_names = HashSet::new();
1083 for decl in deferred_keys {
1084 let ident = decl.ident.to_string();
1085 if !seen_idents.insert(ident.clone()) {
1086 return Err(syn::Error::new(
1087 decl.ident.span(),
1088 format!("duplicate deferred key constant: {ident}"),
1089 ));
1090 }
1091 let name = decl.name.value();
1092 if !seen_names.insert(name.clone()) {
1093 return Err(syn::Error::new(
1094 decl.name.span(),
1095 format!("duplicate deferred key name: {name}"),
1096 ));
1097 }
1098 }
1099
1100 let consts: Vec<_> = deferred_keys
1101 .iter()
1102 .map(|decl| {
1103 let ident = &decl.ident;
1104 let ty = &decl.ty;
1105 let name = &decl.name;
1106 quote! {
1107 #[allow(dead_code)]
1108 pub const #ident: #krate::__internal::DeferredKey<#ty> =
1109 #krate::__internal::DeferredKey::new(#name);
1110 }
1111 })
1112 .collect();
1113
1114 Ok(quote! {
1115 #(#consts)*
1116 })
1117}
1118
1119fn entity_trait_impl_inner(
1120 args: TraitArgs,
1121 input: syn::ItemImpl,
1122) -> syn::Result<proc_macro2::TokenStream> {
1123 let krate = args.krate.unwrap_or_else(default_crate_path);
1124 let mut input = input;
1125 let state_info = parse_state_attr(&mut input.attrs)?;
1126 let self_ty = &input.self_ty;
1127
1128 let trait_ident = match self_ty.as_ref() {
1129 syn::Type::Path(tp) => tp
1130 .path
1131 .segments
1132 .last()
1133 .map(|s| s.ident.clone())
1134 .ok_or_else(|| syn::Error::new(self_ty.span(), "expected trait struct name"))?,
1135 _ => {
1136 return Err(syn::Error::new(
1137 self_ty.span(),
1138 "expected trait struct name",
1139 ))
1140 }
1141 };
1142
1143 let mut rpcs = Vec::new();
1144 let state_type: Option<syn::Type> = state_info.as_ref().map(|info| info.ty.clone());
1145 let state_persisted = state_info
1146 .as_ref()
1147 .map(|info| info.persistent)
1148 .unwrap_or(false);
1149 let mut has_init = false;
1150 let mut init_method: Option<syn::ImplItemFn> = None;
1151 let mut original_methods = Vec::new();
1152
1153 for item in &input.items {
1154 match item {
1155 syn::ImplItem::Type(type_item) if type_item.ident == "State" => {
1156 return Err(syn::Error::new(
1157 type_item.span(),
1158 "use #[state(Type, persistent)] on the impl block instead of `type State`",
1159 ));
1160 }
1161 syn::ImplItem::Fn(method) => {
1162 let has_rpc_attrs = parse_kind_attr(&method.attrs)?.is_some()
1163 || parse_visibility_attr(&method.attrs)?.is_some();
1164 if method.sig.ident == "init" && method.sig.asyncness.is_none() {
1165 if has_rpc_attrs {
1166 return Err(syn::Error::new(
1167 method.sig.span(),
1168 "RPC annotations are only valid on async methods",
1169 ));
1170 }
1171 has_init = true;
1172 init_method = Some(method.clone());
1173 } else if method.sig.asyncness.is_some() {
1174 if let Some(rpc) = parse_rpc_method(method)? {
1175 if rpc.has_durable_context {
1176 return Err(syn::Error::new(
1177 method.sig.span(),
1178 "DurableContext parameters are not supported in entity traits",
1179 ));
1180 }
1181 rpcs.push(rpc);
1182 }
1183 } else if has_rpc_attrs {
1184 return Err(syn::Error::new(
1185 method.sig.span(),
1186 "RPC annotations are only valid on async methods",
1187 ));
1188 }
1189 original_methods.push(method.clone());
1190 }
1191 _ => {}
1192 }
1193 }
1194
1195 let is_stateful = state_type.is_some();
1196
1197 if is_stateful && !state_persisted {
1198 return Err(syn::Error::new(
1199 input.self_ty.span(),
1200 "entity trait state must be declared as #[state(Type, persistent)]",
1201 ));
1202 }
1203
1204 if is_stateful && !has_init {
1205 return Err(syn::Error::new(
1206 input.self_ty.span(),
1207 "entity traits with #[state(...)] must define `fn init(&self) -> Result<State, ClusterError>`",
1208 ));
1209 }
1210
1211 if has_init && !is_stateful {
1212 return Err(syn::Error::new(
1213 init_method.as_ref().unwrap().sig.span(),
1214 "`fn init` is only valid when #[state(...)] is also defined",
1215 ));
1216 }
1217
1218 let mut cleaned_impl = input.clone();
1219 cleaned_impl.attrs.retain(|a| !a.path().is_ident("state"));
1220 for item in &mut cleaned_impl.items {
1221 if let syn::ImplItem::Fn(method) = item {
1222 method.attrs.retain(|a| {
1223 !a.path().is_ident("rpc")
1224 && !a.path().is_ident("workflow")
1225 && !a.path().is_ident("activity")
1226 && !a.path().is_ident("method")
1227 && !a.path().is_ident("public")
1228 && !a.path().is_ident("protected")
1229 && !a.path().is_ident("private")
1230 });
1231 }
1232 }
1233 cleaned_impl.items.retain(|item| match item {
1234 syn::ImplItem::Type(_) => false,
1235 syn::ImplItem::Fn(method) => method.sig.asyncness.is_none() && method.sig.ident != "init",
1236 _ => true,
1237 });
1238
1239 let wrapper_name = format_ident!("{}StateWrapper", trait_ident);
1240 let state_type = state_type.unwrap_or_else(|| syn::parse_str("()").unwrap());
1241
1242 let init_takes_ctx = if let Some(init_method) = init_method.as_ref() {
1243 match init_method.sig.inputs.len() {
1244 1 => false,
1245 2 => true,
1246 _ => {
1247 return Err(syn::Error::new(
1248 init_method.sig.span(),
1249 "trait init must take either `&self` or `&self, &EntityContext`",
1250 ))
1251 }
1252 }
1253 } else {
1254 false
1255 };
1256
1257 let init_call = if is_stateful {
1258 if init_takes_ctx {
1259 quote! { self.init(ctx) }
1260 } else {
1261 quote! {
1262 let _ = ctx;
1263 self.init()
1264 }
1265 }
1266 } else {
1267 quote! {
1268 let _ = ctx;
1269 ::std::result::Result::Ok(())
1270 }
1271 };
1272
1273 let init_method_impl = if is_stateful {
1274 let init_method = init_method.as_ref().unwrap();
1275 let init_sig = &init_method.sig;
1276 let init_block = &init_method.block;
1277 let init_attrs = &init_method.attrs;
1278 let init_vis = &init_method.vis;
1279 quote! {
1280 impl #trait_ident {
1281 #(#init_attrs)*
1282 #init_vis #init_sig #init_block
1283 }
1284 }
1285 } else {
1286 quote! {}
1287 };
1288
1289 let read_view_name = format_ident!("__{}ReadView", wrapper_name);
1291 let mut_view_name = format_ident!("__{}MutView", wrapper_name);
1292
1293 let mut read_methods = Vec::new();
1294 let mut mut_methods = Vec::new();
1295 let mut wrapper_methods = Vec::new();
1296
1297 for method in original_methods
1298 .iter()
1299 .filter(|m| m.sig.asyncness.is_some())
1300 {
1301 let method_name = &method.sig.ident;
1302 let rpc_info = rpcs.iter().find(|r| r.name == *method_name);
1303 let is_activity = rpc_info.is_some_and(|r| matches!(r.kind, RpcKind::Activity));
1304
1305 let block = &method.block;
1306 let output = &method.sig.output;
1307 let generics = &method.sig.generics;
1308 let where_clause = &generics.where_clause;
1309 let attrs: Vec<_> = method
1310 .attrs
1311 .iter()
1312 .filter(|a| {
1313 !a.path().is_ident("rpc")
1314 && !a.path().is_ident("workflow")
1315 && !a.path().is_ident("activity")
1316 && !a.path().is_ident("method")
1317 && !a.path().is_ident("public")
1318 && !a.path().is_ident("protected")
1319 && !a.path().is_ident("private")
1320 })
1321 .collect();
1322 let vis = &method.vis;
1323
1324 let params: Vec<_> = method.sig.inputs.iter().skip(1).cloned().collect();
1326 let param_names: Vec<_> = method
1327 .sig
1328 .inputs
1329 .iter()
1330 .skip(1)
1331 .filter_map(|arg| {
1332 if let syn::FnArg::Typed(pat_type) = arg {
1333 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
1334 return Some(pat_ident.ident.clone());
1335 }
1336 }
1337 None
1338 })
1339 .collect();
1340
1341 if is_activity {
1342 mut_methods.push(quote! {
1344 #(#attrs)*
1345 #vis async fn #method_name #generics (&mut self, #(#params),*) #output #where_clause
1346 #block
1347 });
1348
1349 wrapper_methods.push(quote! {
1351 #(#attrs)*
1352 #vis async fn #method_name #generics (&self, #(#params),*) #output #where_clause {
1353 let __lock = self.__write_lock.clone().lock_owned().await;
1354 let mut __guard = #krate::__internal::TraitStateMutGuard::new(
1355 self.__state.clone(),
1356 __lock,
1357 );
1358 let mut __view = #mut_view_name {
1359 __wrapper: self,
1360 state: &mut *__guard,
1361 };
1362 __view.#method_name(#(#param_names),*).await
1363 }
1364 });
1365 } else {
1366 read_methods.push(quote! {
1368 #(#attrs)*
1369 #vis async fn #method_name #generics (&self, #(#params),*) #output #where_clause
1370 #block
1371 });
1372
1373 wrapper_methods.push(quote! {
1375 #(#attrs)*
1376 #vis async fn #method_name #generics (&self, #(#params),*) #output #where_clause {
1377 let __guard = self.__state.load();
1378 let __view = #read_view_name {
1379 __wrapper: self,
1380 state: &**__guard,
1381 };
1382 __view.#method_name(#(#param_names),*).await
1383 }
1384 });
1385 }
1386 }
1387
1388 let activity_delegations: Vec<proc_macro2::TokenStream> = rpcs
1390 .iter()
1391 .filter(|rpc| matches!(rpc.kind, RpcKind::Activity))
1392 .map(|rpc| {
1393 let method_name = &rpc.name;
1394 let method_info = original_methods
1395 .iter()
1396 .find(|m| m.sig.ident == *method_name)
1397 .unwrap();
1398 let params: Vec<_> = method_info.sig.inputs.iter().skip(1).collect();
1399 let param_names: Vec<_> = method_info
1400 .sig
1401 .inputs
1402 .iter()
1403 .skip(1)
1404 .filter_map(|arg| {
1405 if let syn::FnArg::Typed(pat_type) = arg {
1406 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
1407 return Some(&pat_ident.ident);
1408 }
1409 }
1410 None
1411 })
1412 .collect();
1413 let output = &method_info.sig.output;
1414 let generics = &method_info.sig.generics;
1415 let where_clause = &generics.where_clause;
1416
1417 quote! {
1418 #[inline]
1419 async fn #method_name #generics (&self, #(#params),*) #output #where_clause {
1420 self.__wrapper.#method_name(#(#param_names),*).await
1421 }
1422 }
1423 })
1424 .collect();
1425
1426 let view_structs = quote! {
1427 #[doc(hidden)]
1428 #[allow(non_camel_case_types)]
1429 struct #read_view_name<'a> {
1430 #[allow(dead_code)]
1431 __wrapper: &'a #wrapper_name,
1432 state: &'a #state_type,
1433 }
1434
1435 #[doc(hidden)]
1436 #[allow(non_camel_case_types)]
1437 struct #mut_view_name<'a> {
1438 #[allow(dead_code)]
1439 __wrapper: &'a #wrapper_name,
1440 state: &'a mut #state_type,
1441 }
1442
1443 impl ::std::ops::Deref for #read_view_name<'_> {
1445 type Target = #trait_ident;
1446 fn deref(&self) -> &Self::Target {
1447 &self.__wrapper.__trait
1448 }
1449 }
1450
1451 impl ::std::ops::Deref for #mut_view_name<'_> {
1452 type Target = #trait_ident;
1453 fn deref(&self) -> &Self::Target {
1454 &self.__wrapper.__trait
1455 }
1456 }
1457 };
1458
1459 let view_impls = quote! {
1460 impl #read_view_name<'_> {
1461 #(#activity_delegations)*
1462 #(#read_methods)*
1463 }
1464
1465 impl #mut_view_name<'_> {
1466 #(#mut_methods)*
1467 }
1468 };
1469
1470 let dispatch_impl = generate_trait_dispatch_impl(&krate, &wrapper_name, &rpcs);
1471 let client_ext = generate_trait_client_ext(&krate, &trait_ident, &rpcs);
1472 let access_trait_ident = format_ident!("__{}TraitAccess", trait_ident);
1473 let methods_trait_ident = format_ident!("__{}TraitMethods", trait_ident);
1474 let state_info_ident = format_ident!("__{}TraitStateInfo", trait_ident);
1475 let state_init_ident = format_ident!("__{}TraitStateInit", trait_ident);
1476 let missing_reason = format!("missing trait dependency: {trait_ident}");
1477
1478 let methods_impls: Vec<proc_macro2::TokenStream> = rpcs
1480 .iter()
1481 .filter(|rpc| rpc.is_trait_visible())
1482 .map(|rpc| {
1483 let method_name = &rpc.name;
1484 let resp_type = &rpc.response_type;
1485 let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
1486 let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
1487 let param_defs: Vec<_> = param_names
1488 .iter()
1489 .zip(param_types.iter())
1490 .map(|(name, ty)| quote! { #name: #ty })
1491 .collect();
1492 quote! {
1493 async fn #method_name(
1494 &self,
1495 #(#param_defs),*
1496 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
1497 let __trait = self.__trait_ref().ok_or_else(|| #krate::error::ClusterError::MalformedMessage {
1498 reason: ::std::string::String::from(#missing_reason),
1499 source: ::std::option::Option::None,
1500 })?;
1501 __trait.#method_name(#(#param_names),*).await
1502 }
1503 }
1504 })
1505 .collect();
1506
1507 let trait_helpers = quote! {
1508 #[doc(hidden)]
1509 pub trait #access_trait_ident {
1510 fn __trait_ref(&self) -> ::std::option::Option<&::std::sync::Arc<#wrapper_name>>;
1511 }
1512
1513 #[doc(hidden)]
1514 #[async_trait::async_trait]
1515 pub trait #methods_trait_ident: #access_trait_ident {
1516 #(#methods_impls)*
1517 }
1518
1519 #[async_trait::async_trait]
1520 impl<T> #methods_trait_ident for T where T: #access_trait_ident + ::std::marker::Sync + ::std::marker::Send {}
1521 };
1522
1523 let state_traits = quote! {
1524 #[doc(hidden)]
1525 pub trait #state_info_ident {
1526 type State;
1527 }
1528
1529 impl #state_info_ident for #trait_ident {
1530 type State = #state_type;
1531 }
1532
1533 #[doc(hidden)]
1534 pub trait #state_init_ident {
1535 fn __init_state(
1536 &self,
1537 ctx: &#krate::entity::EntityContext,
1538 ) -> ::std::result::Result<#state_type, #krate::error::ClusterError>;
1539 }
1540
1541 impl #state_init_ident for #trait_ident {
1542 fn __init_state(
1543 &self,
1544 ctx: &#krate::entity::EntityContext,
1545 ) -> ::std::result::Result<#state_type, #krate::error::ClusterError> {
1546 #init_call
1547 }
1548 }
1549 };
1550
1551 let wrapper_def = quote! {
1552 #[doc(hidden)]
1553 pub struct #wrapper_name {
1554 __state: ::std::sync::Arc<arc_swap::ArcSwap<#state_type>>,
1556 __write_lock: ::std::sync::Arc<tokio::sync::Mutex<()>>,
1558 __trait: #trait_ident,
1560 }
1561
1562 impl ::std::ops::Deref for #wrapper_name {
1563 type Target = #trait_ident;
1564 fn deref(&self) -> &Self::Target {
1565 &self.__trait
1566 }
1567 }
1568
1569 #view_structs
1570
1571 #view_impls
1572
1573 impl #wrapper_name {
1574 #[doc(hidden)]
1575 pub fn __new(inner: #trait_ident, state: #state_type) -> Self {
1576 Self {
1577 __state: ::std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(state)),
1578 __write_lock: ::std::sync::Arc::new(tokio::sync::Mutex::new(())),
1579 __trait: inner,
1580 }
1581 }
1582
1583 #[doc(hidden)]
1584 pub fn __state_arc(&self) -> &::std::sync::Arc<arc_swap::ArcSwap<#state_type>> {
1585 &self.__state
1586 }
1587
1588 #(#wrapper_methods)*
1589 }
1590 };
1591
1592 let cleaned_impl_tokens = if cleaned_impl.items.is_empty() {
1595 quote! {}
1596 } else {
1597 quote! { #cleaned_impl }
1598 };
1599
1600 Ok(quote! {
1601 #cleaned_impl_tokens
1602 #init_method_impl
1603 #state_traits
1604 #wrapper_def
1605 #dispatch_impl
1606 #client_ext
1607 #trait_helpers
1608 })
1609}
1610
1611fn generate_stateless_entity(
1613 krate: &syn::Path,
1614 struct_name: &syn::Ident,
1615 handler_name: &syn::Ident,
1616 client_name: &syn::Ident,
1617 traits: &[syn::Path],
1618 rpcs: &[RpcMethod],
1619 original_methods: &[syn::ImplItemFn],
1620) -> syn::Result<proc_macro2::TokenStream> {
1621 let dispatch_arms = generate_dispatch_arms(krate, rpcs, false, None, false);
1622 let client_methods = generate_client_methods(krate, rpcs);
1623 let method_impls = generate_method_impls(original_methods);
1624 let struct_name_str = struct_name.to_string();
1625 let has_durable = rpcs.iter().any(|r| r.has_durable_context);
1626
1627 let trait_infos = trait_infos_from_paths(traits);
1628 let has_traits = !trait_infos.is_empty();
1629 let entity_impl = if has_traits {
1630 quote! {}
1631 } else {
1632 quote! {
1633 #[async_trait::async_trait]
1634 impl #krate::entity::Entity for #struct_name {
1635 fn entity_type(&self) -> #krate::types::EntityType {
1636 self.__entity_type()
1637 }
1638
1639 fn shard_group(&self) -> &str {
1640 self.__shard_group()
1641 }
1642
1643 fn shard_group_for(&self, entity_id: &#krate::types::EntityId) -> &str {
1644 self.__shard_group_for(entity_id)
1645 }
1646
1647 fn max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
1648 self.__max_idle_time()
1649 }
1650
1651 fn mailbox_capacity(&self) -> ::std::option::Option<usize> {
1652 self.__mailbox_capacity()
1653 }
1654
1655 fn concurrency(&self) -> ::std::option::Option<usize> {
1656 self.__concurrency()
1657 }
1658
1659 async fn spawn(
1660 &self,
1661 ctx: #krate::entity::EntityContext,
1662 ) -> ::std::result::Result<
1663 ::std::boxed::Box<dyn #krate::entity::EntityHandler>,
1664 #krate::error::ClusterError,
1665 > {
1666 let handler = #handler_name::__new(self.clone(), ctx).await?;
1667 ::std::result::Result::Ok(::std::boxed::Box::new(handler))
1668 }
1669 }
1670 }
1671 };
1672 let durable_field = if has_durable {
1673 quote! {
1674 __workflow_engine: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowEngine>>,
1675 }
1676 } else {
1677 quote! {}
1678 };
1679 let durable_field_init = if has_durable {
1680 quote! { __workflow_engine: ctx.workflow_engine.clone(), }
1681 } else {
1682 quote! {}
1683 };
1684 let trait_dispatch_checks: Vec<proc_macro2::TokenStream> = trait_infos
1685 .iter()
1686 .map(|info| {
1687 let field = &info.field;
1688 quote! {
1689 if let ::std::option::Option::Some(ref __trait) = self.#field {
1690 if let ::std::option::Option::Some(response) = __trait.__dispatch(tag, payload, headers).await? {
1691 return ::std::result::Result::Ok(response);
1692 }
1693 }
1694 }
1695 })
1696 .collect();
1697
1698 let trait_dispatch_fallback = if has_traits {
1699 quote! {{
1700 #(#trait_dispatch_checks)*
1701 ::std::result::Result::Err(
1702 #krate::error::ClusterError::MalformedMessage {
1703 reason: ::std::format!("unknown RPC tag: {tag}"),
1704 source: ::std::option::Option::None,
1705 }
1706 )
1707 }}
1708 } else {
1709 quote! {{
1710 ::std::result::Result::Err(
1711 #krate::error::ClusterError::MalformedMessage {
1712 reason: ::std::format!("unknown RPC tag: {tag}"),
1713 source: ::std::option::Option::None,
1714 }
1715 )
1716 }}
1717 };
1718
1719 let trait_field_defs: Vec<proc_macro2::TokenStream> = trait_infos
1720 .iter()
1721 .map(|info| {
1722 let field = &info.field;
1723 let wrapper_path = &info.wrapper_path;
1724 quote! { #field: ::std::option::Option<::std::sync::Arc<#wrapper_path>>, }
1725 })
1726 .collect();
1727
1728 let trait_field_init_none: Vec<proc_macro2::TokenStream> = trait_infos
1729 .iter()
1730 .map(|info| {
1731 let field = &info.field;
1732 quote! { #field: ::std::option::Option::None, }
1733 })
1734 .collect();
1735
1736 let trait_params: Vec<proc_macro2::TokenStream> = trait_infos
1737 .iter()
1738 .map(|info| {
1739 let path = &info.path;
1740 let ident = &info.ident;
1741 let param = format_ident!("__trait_{}", to_snake(&ident.to_string()));
1742 quote! { #param: #path }
1743 })
1744 .collect();
1745
1746 let trait_state_inits: Vec<proc_macro2::TokenStream> = trait_infos
1747 .iter()
1748 .map(|info| {
1749 let path = &info.path;
1750 let ident = &info.ident;
1751 let state_init_path = &info.state_init_path;
1752 let param = format_ident!("__trait_{}", to_snake(&ident.to_string()));
1753 let state_var = format_ident!("__trait_{}_state", to_snake(&ident.to_string()));
1754 quote! {
1755 let #state_var = <#path as #state_init_path>::__init_state(&#param, &ctx)?;
1756 }
1757 })
1758 .collect();
1759
1760 let trait_field_init_some: Vec<proc_macro2::TokenStream> = trait_infos
1761 .iter()
1762 .map(|info| {
1763 let ident = &info.ident;
1764 let field = &info.field;
1765 let wrapper_path = &info.wrapper_path;
1766 let param = format_ident!("__trait_{}", to_snake(&ident.to_string()));
1767 let state_var = format_ident!("__trait_{}_state", to_snake(&ident.to_string()));
1768 quote! {
1769 #field: ::std::option::Option::Some(::std::sync::Arc::new(#wrapper_path::__new(
1770 #param,
1771 #state_var,
1772 ))),
1773 }
1774 })
1775 .collect();
1776
1777 let trait_param_idents: Vec<syn::Ident> = trait_infos
1778 .iter()
1779 .map(|info| {
1780 let ident = &info.ident;
1781 format_ident!("__trait_{}", to_snake(&ident.to_string()))
1782 })
1783 .collect();
1784
1785 let with_traits_name = format_ident!("{}WithTraits", struct_name);
1786 let with_trait_trait_name = format_ident!("__{}WithTrait", struct_name);
1787 let trait_use_tokens: Vec<proc_macro2::TokenStream> = trait_infos
1788 .iter()
1789 .map(|info| {
1790 let path = &info.path;
1791 let ident = &info.ident;
1792 let methods_trait_ident = format_ident!("__{}TraitMethods", ident);
1793 let methods_trait_path = replace_last_segment(path, methods_trait_ident);
1794 quote! {
1795 #[allow(unused_imports)]
1796 use #methods_trait_path as _;
1797 }
1798 })
1799 .collect();
1800
1801 let trait_access_impls: Vec<proc_macro2::TokenStream> = trait_infos
1803 .iter()
1804 .map(|info| {
1805 let path = &info.path;
1806 let ident = &info.ident;
1807 let field = &info.field;
1808 let wrapper_path = &info.wrapper_path;
1809 let access_trait_ident = format_ident!("__{}TraitAccess", ident);
1810 let access_trait_path = replace_last_segment(path, access_trait_ident);
1811 quote! {
1812 impl #access_trait_path for #handler_name {
1813 fn __trait_ref(&self) -> ::std::option::Option<&::std::sync::Arc<#wrapper_path>> {
1814 self.#field.as_ref()
1815 }
1816 }
1817 }
1818 })
1819 .collect();
1820
1821 let with_traits_impl = if has_traits {
1822 let trait_option_fields = trait_infos
1823 .iter()
1824 .map(|info| {
1825 let field = &info.field;
1826 let path = &info.path;
1827 quote! { #field: ::std::option::Option<#path>, }
1828 })
1829 .collect::<Vec<_>>();
1830
1831 let trait_setters: Vec<proc_macro2::TokenStream> = trait_infos
1832 .iter()
1833 .map(|info| {
1834 let path = &info.path;
1835 let field = &info.field;
1836 quote! {
1837 impl #with_trait_trait_name<#path> for #with_traits_name {
1838 fn __with_trait(&mut self, value: #path) {
1839 self.#field = ::std::option::Option::Some(value);
1840 }
1841 }
1842 }
1843 })
1844 .collect();
1845
1846 let trait_missing_guards: Vec<proc_macro2::TokenStream> = trait_infos
1847 .iter()
1848 .map(|info| {
1849 let path = &info.path;
1850 let ident = &info.ident;
1851 let field = &info.field;
1852 let missing_reason = &info.missing_reason;
1853 let param = format_ident!("__trait_{}", to_snake(&ident.to_string()));
1854 quote! {
1855 let #param: #path = self.#field.clone().ok_or_else(|| {
1856 #krate::error::ClusterError::MalformedMessage {
1857 reason: ::std::string::String::from(#missing_reason),
1858 source: ::std::option::Option::None,
1859 }
1860 })?;
1861 }
1862 })
1863 .collect();
1864
1865 let trait_bounds: Vec<proc_macro2::TokenStream> = trait_infos
1866 .iter()
1867 .map(|info| {
1868 let path = &info.path;
1869 quote! { #path: ::std::clone::Clone }
1870 })
1871 .collect();
1872
1873 let trait_field_init_none_tokens = trait_infos
1874 .iter()
1875 .map(|info| {
1876 let field = &info.field;
1877 quote! { #field: ::std::option::Option::None, }
1878 })
1879 .collect::<Vec<_>>();
1880
1881 quote! {
1882 #[doc(hidden)]
1883 pub struct #with_traits_name {
1884 entity: #struct_name,
1885 #(#trait_option_fields)*
1886 }
1887
1888 trait #with_trait_trait_name<T> {
1889 fn __with_trait(&mut self, value: T);
1890 }
1891
1892 impl #struct_name {
1893 pub fn with<T>(self, value: T) -> #with_traits_name
1894 where
1895 #with_traits_name: #with_trait_trait_name<T>,
1896 {
1897 let mut bundle = #with_traits_name {
1898 entity: self,
1899 #(#trait_field_init_none_tokens)*
1900 };
1901 bundle.__with_trait(value);
1902 bundle
1903 }
1904 }
1905
1906 impl #with_traits_name {
1907 pub fn with<T>(mut self, value: T) -> Self
1908 where
1909 Self: #with_trait_trait_name<T>,
1910 {
1911 self.__with_trait(value);
1912 self
1913 }
1914 }
1915
1916 #(#trait_setters)*
1917
1918 #[async_trait::async_trait]
1919 impl #krate::entity::Entity for #with_traits_name
1920 where
1921 #struct_name: ::std::clone::Clone,
1922 #(#trait_bounds,)*
1923 {
1924 fn entity_type(&self) -> #krate::types::EntityType {
1925 self.entity.__entity_type()
1926 }
1927
1928 fn shard_group(&self) -> &str {
1929 self.entity.__shard_group()
1930 }
1931
1932 fn shard_group_for(&self, entity_id: &#krate::types::EntityId) -> &str {
1933 self.entity.__shard_group_for(entity_id)
1934 }
1935
1936 fn max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
1937 self.entity.__max_idle_time()
1938 }
1939
1940 fn mailbox_capacity(&self) -> ::std::option::Option<usize> {
1941 self.entity.__mailbox_capacity()
1942 }
1943
1944 fn concurrency(&self) -> ::std::option::Option<usize> {
1945 self.entity.__concurrency()
1946 }
1947
1948 async fn spawn(
1949 &self,
1950 ctx: #krate::entity::EntityContext,
1951 ) -> ::std::result::Result<
1952 ::std::boxed::Box<dyn #krate::entity::EntityHandler>,
1953 #krate::error::ClusterError,
1954 > {
1955 #(#trait_missing_guards)*
1956 let handler = #handler_name::__new_with_traits(
1957 self.entity.clone(),
1958 #(#trait_param_idents,)*
1959 ctx,
1960 )
1961 .await?;
1962 ::std::result::Result::Ok(::std::boxed::Box::new(handler))
1963 }
1964 }
1965 }
1966 } else {
1967 quote! {}
1968 };
1969
1970 let register_impl = if has_traits {
1972 let register_trait_params: Vec<proc_macro2::TokenStream> = trait_infos
1974 .iter()
1975 .map(|info| {
1976 let param = &info.field;
1977 let path = &info.path;
1978 quote! { #param: #path }
1979 })
1980 .collect();
1981 let trait_with_calls: Vec<proc_macro2::TokenStream> = trait_infos
1982 .iter()
1983 .map(|info| {
1984 let field = &info.field;
1985 quote! { .with(#field) }
1986 })
1987 .collect();
1988 quote! {
1989 impl #struct_name {
1990 pub async fn register(
1995 self,
1996 sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>,
1997 #(#register_trait_params),*
1998 ) -> ::std::result::Result<#client_name, #krate::error::ClusterError> {
1999 let entity_with_traits = self #(#trait_with_calls)*;
2000 sharding.register_entity(::std::sync::Arc::new(entity_with_traits)).await?;
2001 ::std::result::Result::Ok(#client_name::new(sharding))
2002 }
2003 }
2004 }
2005 } else {
2006 quote! {
2008 impl #struct_name {
2009 pub async fn register(
2014 self,
2015 sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>,
2016 ) -> ::std::result::Result<#client_name, #krate::error::ClusterError> {
2017 sharding.register_entity(::std::sync::Arc::new(self)).await?;
2018 ::std::result::Result::Ok(#client_name::new(sharding))
2019 }
2020 }
2021 }
2022 };
2023
2024 Ok(quote! {
2025 #(#trait_use_tokens)*
2026
2027 impl #struct_name {
2028 #(#method_impls)*
2029 }
2030
2031 #[doc(hidden)]
2033 pub struct #handler_name {
2034 entity: #struct_name,
2035 #[allow(dead_code)]
2036 ctx: #krate::entity::EntityContext,
2037 #(#trait_field_defs)*
2038 #durable_field
2039 }
2040
2041 impl #handler_name {
2042 #[doc(hidden)]
2043 pub async fn __new(entity: #struct_name, ctx: #krate::entity::EntityContext) -> ::std::result::Result<Self, #krate::error::ClusterError> {
2044 ::std::result::Result::Ok(Self {
2045 entity,
2046 #(#trait_field_init_none)*
2047 #durable_field_init
2048 ctx,
2049 })
2050 }
2051
2052 #[doc(hidden)]
2053 pub async fn __new_with_traits(
2054 entity: #struct_name,
2055 #(#trait_params,)*
2056 ctx: #krate::entity::EntityContext,
2057 ) -> ::std::result::Result<Self, #krate::error::ClusterError> {
2058 #(#trait_state_inits)*
2059 ::std::result::Result::Ok(Self {
2060 entity,
2061 #(#trait_field_init_some)*
2062 #durable_field_init
2063 ctx,
2064 })
2065 }
2066 }
2067
2068 #[async_trait::async_trait]
2069 impl #krate::entity::EntityHandler for #handler_name {
2070 async fn handle_request(
2071 &self,
2072 tag: &str,
2073 payload: &[u8],
2074 headers: &::std::collections::HashMap<::std::string::String, ::std::string::String>,
2075 ) -> ::std::result::Result<::std::vec::Vec<u8>, #krate::error::ClusterError> {
2076 let _ = headers;
2077 match tag {
2078 #(#dispatch_arms,)*
2079 _ => #trait_dispatch_fallback,
2080 }
2081 }
2082 }
2083
2084 pub struct #client_name {
2086 inner: #krate::entity_client::EntityClient,
2087 }
2088
2089 impl #client_name {
2090 pub fn new(sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>) -> Self {
2092 Self {
2093 inner: #krate::entity_client::EntityClient::new(
2094 sharding,
2095 #krate::types::EntityType::new(#struct_name_str),
2096 ),
2097 }
2098 }
2099
2100 pub fn inner(&self) -> &#krate::entity_client::EntityClient {
2102 &self.inner
2103 }
2104
2105 #(#client_methods)*
2106 }
2107
2108 impl #krate::entity_client::EntityClientAccessor for #client_name {
2109 fn entity_client(&self) -> &#krate::entity_client::EntityClient {
2110 &self.inner
2111 }
2112 }
2113
2114 #register_impl
2115 #with_traits_impl
2116 #entity_impl
2117
2118 #(#trait_access_impls)*
2119 })
2120}
2121
2122#[allow(clippy::too_many_arguments)]
2124fn generate_stateful_entity(
2125 krate: &syn::Path,
2126 struct_name: &syn::Ident,
2127 handler_name: &syn::Ident,
2128 client_name: &syn::Ident,
2129 state_type: &syn::Type,
2130 state_persisted: bool,
2131 traits: &[syn::Path],
2132 rpcs: &[RpcMethod],
2133 original_methods: &[syn::ImplItemFn],
2134) -> syn::Result<proc_macro2::TokenStream> {
2135 let trait_infos = trait_infos_from_paths(traits);
2136 let has_traits = !trait_infos.is_empty();
2137 let entity_impl = if has_traits {
2138 quote! {}
2139 } else {
2140 quote! {
2141 #[async_trait::async_trait]
2142 impl #krate::entity::Entity for #struct_name {
2143 fn entity_type(&self) -> #krate::types::EntityType {
2144 self.__entity_type()
2145 }
2146
2147 fn shard_group(&self) -> &str {
2148 self.__shard_group()
2149 }
2150
2151 fn shard_group_for(&self, entity_id: &#krate::types::EntityId) -> &str {
2152 self.__shard_group_for(entity_id)
2153 }
2154
2155 fn max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
2156 self.__max_idle_time()
2157 }
2158
2159 fn mailbox_capacity(&self) -> ::std::option::Option<usize> {
2160 self.__mailbox_capacity()
2161 }
2162
2163 fn concurrency(&self) -> ::std::option::Option<usize> {
2164 self.__concurrency()
2165 }
2166
2167 async fn spawn(
2168 &self,
2169 ctx: #krate::entity::EntityContext,
2170 ) -> ::std::result::Result<
2171 ::std::boxed::Box<dyn #krate::entity::EntityHandler>,
2172 #krate::error::ClusterError,
2173 > {
2174 let handler = #handler_name::__new(self.clone(), ctx).await?;
2175 ::std::result::Result::Ok(::std::boxed::Box::new(handler))
2176 }
2177 }
2178 }
2179 };
2180 let save_state_code = if state_persisted {
2181 if has_traits {
2182 let composite_ref_name = format_ident!("{}CompositeStateRef", struct_name);
2183 let trait_state_refs: Vec<proc_macro2::TokenStream> = trait_infos
2184 .iter()
2185 .map(|info| {
2186 let field = &info.field;
2187 let missing_reason = &info.missing_reason;
2188 quote! {
2189 let #field = guard.#field.as_ref().ok_or_else(|| {
2190 #krate::error::ClusterError::MalformedMessage {
2191 reason: ::std::string::String::from(#missing_reason),
2192 source: ::std::option::Option::None,
2193 }
2194 })?;
2195 }
2196 })
2197 .collect();
2198 let composite_field_defs: Vec<proc_macro2::TokenStream> = trait_infos
2199 .iter()
2200 .map(|info| {
2201 let field = &info.field;
2202 quote! { #field: #field.__state() }
2203 })
2204 .collect();
2205 quote! {
2206 #(#trait_state_refs)*
2207 let composite = #composite_ref_name {
2208 entity: &guard.state,
2209 #(#composite_field_defs,)*
2210 };
2211 let state_bytes = rmp_serde::to_vec(&composite)
2212 .map_err(|e| #krate::error::ClusterError::PersistenceError {
2213 reason: ::std::format!("failed to serialize state: {e}"),
2214 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
2215 })?;
2216 if let Some(ref storage) = self.__state_storage {
2217 storage.save(&self.__state_key, &state_bytes).await.map_err(|e| {
2218 tracing::warn!(
2219 state_key = %self.__state_key,
2220 error = %e,
2221 "failed to persist entity state"
2222 );
2223 e
2224 })?;
2225 }
2226 }
2227 } else {
2228 quote! {
2229 if let Some(ref storage) = self.__state_storage {
2231 let state_bytes = rmp_serde::to_vec(&guard.state)
2232 .map_err(|e| #krate::error::ClusterError::PersistenceError {
2233 reason: ::std::format!("failed to serialize state: {e}"),
2234 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
2235 })?;
2236 storage.save(&self.__state_key, &state_bytes).await.map_err(|e| {
2237 tracing::warn!(
2238 state_key = %self.__state_key,
2239 error = %e,
2240 "failed to persist entity state"
2241 );
2242 e
2243 })?;
2244 }
2245 }
2246 }
2247 } else {
2248 quote! {}
2249 };
2250
2251 let dispatch_arms = generate_dispatch_arms(
2252 krate,
2253 rpcs,
2254 true,
2255 Some(&save_state_code),
2256 has_traits && state_persisted,
2257 );
2258 let client_methods = generate_client_methods(krate, rpcs);
2259
2260 let read_view_name = format_ident!("__{}ReadView", handler_name);
2268 let mut_view_name = format_ident!("__{}MutView", handler_name);
2269
2270 let mut read_methods: Vec<proc_macro2::TokenStream> = Vec::new();
2272 let mut mut_methods: Vec<proc_macro2::TokenStream> = Vec::new();
2273 let mut wrapper_methods: Vec<proc_macro2::TokenStream> = Vec::new();
2274
2275 for m in original_methods.iter().filter(|m| m.sig.ident != "init") {
2276 let method_name = &m.sig.ident;
2277 let block = &m.block;
2278
2279 let attrs: Vec<_> = m
2281 .attrs
2282 .iter()
2283 .filter(|a| {
2284 !a.path().is_ident("rpc")
2285 && !a.path().is_ident("workflow")
2286 && !a.path().is_ident("activity")
2287 && !a.path().is_ident("method")
2288 && !a.path().is_ident("public")
2289 && !a.path().is_ident("protected")
2290 && !a.path().is_ident("private")
2291 })
2292 .collect();
2293 let vis = &m.vis;
2294 let output = &m.sig.output;
2295 let generics = &m.sig.generics;
2296 let where_clause = &generics.where_clause;
2297
2298 let params: Vec<_> = m.sig.inputs.iter().skip(1).collect();
2300 let param_names: Vec<_> = m
2301 .sig
2302 .inputs
2303 .iter()
2304 .skip(1)
2305 .filter_map(|arg| {
2306 if let syn::FnArg::Typed(pat_type) = arg {
2307 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
2308 return Some(&pat_ident.ident);
2309 }
2310 }
2311 None
2312 })
2313 .collect();
2314
2315 let rpc_info = rpcs.iter().find(|rpc| rpc.name == *method_name);
2317 let is_mut = rpc_info.map(|r| r.is_mut).unwrap_or(false);
2318 let is_activity = rpc_info
2319 .map(|r| matches!(r.kind, RpcKind::Activity))
2320 .unwrap_or(false);
2321 let is_async = m.sig.asyncness.is_some();
2322
2323 let async_token = if is_async {
2324 quote! { async }
2325 } else {
2326 quote! {}
2327 };
2328 let await_token = if is_async {
2329 quote! { .await }
2330 } else {
2331 quote! {}
2332 };
2333
2334 if is_mut {
2335 mut_methods.push(quote! {
2337 #(#attrs)*
2338 #vis #async_token fn #method_name #generics (&mut self, #(#params),*) #output #where_clause
2339 #block
2340 });
2341
2342 if is_activity && state_persisted {
2344 wrapper_methods.push(quote! {
2345 #(#attrs)*
2346 #vis async fn #method_name #generics (&self, #(#params),*) #output #where_clause {
2347 let __storage = self.__state_storage.clone()
2348 .ok_or_else(|| #krate::error::ClusterError::PersistenceError {
2349 reason: "activity requires storage but none configured".to_string(),
2350 source: ::std::option::Option::None,
2351 })?;
2352 let __lock = self.__write_lock.clone().lock_owned().await;
2353 let __arc_swap = self.__state.clone();
2356 let __storage_opt = self.__state_storage.clone();
2357 let __storage_key = self.__state_key.clone();
2358 let __handler = self;
2359 #krate::__internal::ActivityScope::run(&__storage, || async move {
2360 let mut __guard = #krate::__internal::StateMutGuard::new(
2361 __arc_swap,
2362 __storage_opt,
2363 __storage_key,
2364 __lock,
2365 );
2366 let mut __view = #mut_view_name {
2367 __handler,
2368 state: &mut *__guard,
2369 };
2370 __view.#method_name(#(#param_names),*) #await_token
2371 }).await
2373 }
2374 });
2375 } else {
2376 wrapper_methods.push(quote! {
2377 #(#attrs)*
2378 #vis async fn #method_name #generics (&self, #(#params),*) #output #where_clause {
2379 let __lock = self.__write_lock.clone().lock_owned().await;
2380 let mut __guard = #krate::__internal::StateMutGuard::new(
2381 self.__state.clone(),
2382 self.__state_storage.clone(),
2383 self.__state_key.clone(),
2384 __lock,
2385 );
2386 let mut __view = #mut_view_name {
2387 __handler: self,
2388 state: &mut *__guard,
2389 };
2390 __view.#method_name(#(#param_names),*) #await_token
2391 }
2392 });
2393 }
2394 } else {
2395 read_methods.push(quote! {
2397 #(#attrs)*
2398 #vis #async_token fn #method_name #generics (&self, #(#params),*) #output #where_clause
2399 #block
2400 });
2401
2402 wrapper_methods.push(quote! {
2403 #(#attrs)*
2404 #vis #async_token fn #method_name #generics (&self, #(#params),*) #output #where_clause {
2405 let __guard = self.__state.load();
2406 let __view = #read_view_name {
2407 __handler: self,
2408 state: &**__guard,
2409 };
2410 __view.#method_name(#(#param_names),*) #await_token
2411 }
2412 });
2413 }
2414 }
2415
2416 let view_structs = quote! {
2418 #[doc(hidden)]
2419 #[allow(non_camel_case_types)]
2420 struct #read_view_name<'a> {
2421 #[allow(dead_code)]
2422 __handler: &'a #handler_name,
2423 state: &'a #state_type,
2424 }
2425
2426 #[doc(hidden)]
2427 #[allow(non_camel_case_types)]
2428 struct #mut_view_name<'a> {
2429 #[allow(dead_code)]
2430 __handler: &'a #handler_name,
2431 state: &'a mut #state_type,
2432 }
2433 };
2434
2435 let view_delegation_methods = if state_persisted {
2438 quote! {
2439 #[inline]
2441 fn entity_id(&self) -> &#krate::types::EntityId {
2442 self.__handler.entity_id()
2443 }
2444
2445 #[inline]
2447 fn self_client(&self) -> ::std::option::Option<#krate::entity_client::EntityClient> {
2448 self.__handler.self_client()
2449 }
2450
2451 #[inline]
2453 async fn sleep(&self, name: &str, duration: ::std::time::Duration) -> ::std::result::Result<(), #krate::error::ClusterError> {
2454 self.__handler.sleep(name, duration).await
2455 }
2456
2457 #[inline]
2459 async fn await_deferred<T, K>(&self, key: K) -> ::std::result::Result<T, #krate::error::ClusterError>
2460 where
2461 T: serde::Serialize + serde::de::DeserializeOwned,
2462 K: #krate::__internal::DeferredKeyLike<T>,
2463 {
2464 self.__handler.await_deferred(key).await
2465 }
2466
2467 #[inline]
2469 async fn resolve_deferred<T, K>(&self, key: K, value: &T) -> ::std::result::Result<(), #krate::error::ClusterError>
2470 where
2471 T: serde::Serialize,
2472 K: #krate::__internal::DeferredKeyLike<T>,
2473 {
2474 self.__handler.resolve_deferred(key, value).await
2475 }
2476
2477 #[inline]
2479 fn sharding(&self) -> ::std::option::Option<&::std::sync::Arc<dyn #krate::sharding::Sharding>> {
2480 self.__handler.sharding()
2481 }
2482
2483 #[inline]
2485 fn entity_address(&self) -> &#krate::types::EntityAddress {
2486 self.__handler.entity_address()
2487 }
2488 }
2489 } else {
2490 quote! {}
2491 };
2492
2493 let activity_delegations: Vec<proc_macro2::TokenStream> = rpcs
2495 .iter()
2496 .filter(|rpc| matches!(rpc.kind, RpcKind::Activity))
2497 .map(|rpc| {
2498 let method_name = &rpc.name;
2499 let method_info = original_methods
2500 .iter()
2501 .find(|m| m.sig.ident == *method_name)
2502 .unwrap();
2503 let params: Vec<_> = method_info.sig.inputs.iter().skip(1).collect();
2504 let param_names: Vec<_> = method_info
2505 .sig
2506 .inputs
2507 .iter()
2508 .skip(1)
2509 .filter_map(|arg| {
2510 if let syn::FnArg::Typed(pat_type) = arg {
2511 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
2512 return Some(&pat_ident.ident);
2513 }
2514 }
2515 None
2516 })
2517 .collect();
2518 let output = &method_info.sig.output;
2519 let generics = &method_info.sig.generics;
2520 let where_clause = &generics.where_clause;
2521
2522 quote! {
2523 #[inline]
2524 async fn #method_name #generics (&self, #(#params),*) #output #where_clause {
2525 self.__handler.#method_name(#(#param_names),*).await
2526 }
2527 }
2528 })
2529 .collect();
2530
2531 let view_trait_access_impls: Vec<proc_macro2::TokenStream> = trait_infos
2533 .iter()
2534 .map(|info| {
2535 let path = &info.path;
2536 let ident = &info.ident;
2537 let wrapper_path = &info.wrapper_path;
2538 let access_trait_ident = format_ident!("__{}TraitAccess", ident);
2539 let access_trait_path = replace_last_segment(path, access_trait_ident);
2540 quote! {
2541 impl #access_trait_path for #read_view_name<'_> {
2542 fn __trait_ref(&self) -> ::std::option::Option<&::std::sync::Arc<#wrapper_path>> {
2543 #access_trait_path::__trait_ref(self.__handler)
2544 }
2545 }
2546
2547 impl #access_trait_path for #mut_view_name<'_> {
2548 fn __trait_ref(&self) -> ::std::option::Option<&::std::sync::Arc<#wrapper_path>> {
2549 #access_trait_path::__trait_ref(self.__handler)
2550 }
2551 }
2552 }
2553 })
2554 .collect();
2555
2556 let view_impls = quote! {
2558 impl #read_view_name<'_> {
2559 #view_delegation_methods
2560 #(#activity_delegations)*
2561 #(#read_methods)*
2562 }
2563
2564 impl #mut_view_name<'_> {
2565 #view_delegation_methods
2566 #(#mut_methods)*
2567 }
2568
2569 #(#view_trait_access_impls)*
2570 };
2571
2572 let init_method = original_methods
2574 .iter()
2575 .find(|m| m.sig.ident == "init")
2576 .unwrap();
2577 let init_sig = &init_method.sig;
2578 let init_block = &init_method.block;
2579 let init_attrs = &init_method.attrs;
2580 let init_vis = &init_method.vis;
2581
2582 let struct_name_str = struct_name.to_string();
2583 let _state_wrapper_name = format_ident!("{}StateWrapper", struct_name);
2584
2585 let trait_field_defs: Vec<proc_macro2::TokenStream> = trait_infos
2586 .iter()
2587 .map(|info| {
2588 let field = &info.field;
2589 let wrapper_path = &info.wrapper_path;
2590 quote! { #field: ::std::option::Option<::std::sync::Arc<#wrapper_path>>, }
2591 })
2592 .collect();
2593
2594 let trait_field_init_none: Vec<proc_macro2::TokenStream> = trait_infos
2595 .iter()
2596 .map(|info| {
2597 let field = &info.field;
2598 quote! { #field: ::std::option::Option::None, }
2599 })
2600 .collect();
2601
2602 let trait_params: Vec<proc_macro2::TokenStream> = trait_infos
2603 .iter()
2604 .map(|info| {
2605 let path = &info.path;
2606 let ident = &info.ident;
2607 let param = format_ident!("__trait_{}", to_snake(&ident.to_string()));
2608 quote! { #param: #path }
2609 })
2610 .collect();
2611
2612 let trait_state_inits: Vec<proc_macro2::TokenStream> = trait_infos
2613 .iter()
2614 .map(|info| {
2615 let path = &info.path;
2616 let ident = &info.ident;
2617 let state_init_path = &info.state_init_path;
2618 let param = format_ident!("__trait_{}", to_snake(&ident.to_string()));
2619 let state_var = format_ident!("__trait_{}_state", to_snake(&ident.to_string()));
2620 quote! {
2621 let #state_var = <#path as #state_init_path>::__init_state(&#param, &ctx)?;
2622 }
2623 })
2624 .collect();
2625
2626 let trait_field_init_some: Vec<proc_macro2::TokenStream> = trait_infos
2627 .iter()
2628 .map(|info| {
2629 let ident = &info.ident;
2630 let field = &info.field;
2631 let wrapper_path = &info.wrapper_path;
2632 let param = format_ident!("__trait_{}", to_snake(&ident.to_string()));
2633 let state_var = format_ident!("__trait_{}_state", to_snake(&ident.to_string()));
2634 quote! {
2635 #field: ::std::option::Option::Some(::std::sync::Arc::new(#wrapper_path::__new(
2636 #param,
2637 #state_var,
2638 ))),
2639 }
2640 })
2641 .collect();
2642
2643 let trait_param_idents: Vec<syn::Ident> = trait_infos
2644 .iter()
2645 .map(|info| {
2646 let ident = &info.ident;
2647 format_ident!("__trait_{}", to_snake(&ident.to_string()))
2648 })
2649 .collect();
2650
2651 let trait_state_vars: Vec<syn::Ident> = trait_infos
2652 .iter()
2653 .map(|info| format_ident!("__trait_{}_state", to_snake(&info.ident.to_string())))
2654 .collect();
2655
2656 let composite_state_name = format_ident!("{}CompositeState", struct_name);
2657 let composite_ref_name = format_ident!("{}CompositeStateRef", struct_name);
2658
2659 let composite_state_defs = if state_persisted && has_traits {
2660 let composite_fields: Vec<proc_macro2::TokenStream> = trait_infos
2661 .iter()
2662 .map(|info| {
2663 let field = &info.field;
2664 let path = &info.path;
2665 let state_info_path = &info.state_info_path;
2666 quote! { #field: <#path as #state_info_path>::State, }
2667 })
2668 .collect();
2669 let composite_ref_fields: Vec<proc_macro2::TokenStream> = trait_infos
2670 .iter()
2671 .map(|info| {
2672 let field = &info.field;
2673 let path = &info.path;
2674 let state_info_path = &info.state_info_path;
2675 quote! { #field: &'a <#path as #state_info_path>::State, }
2676 })
2677 .collect();
2678 quote! {
2679 #[derive(serde::Serialize, serde::Deserialize)]
2680 struct #composite_state_name {
2681 entity: #state_type,
2682 #(#composite_fields)*
2683 }
2684
2685 #[derive(serde::Serialize)]
2686 struct #composite_ref_name<'a> {
2687 entity: &'a #state_type,
2688 #(#composite_ref_fields)*
2689 }
2690 }
2691 } else {
2692 quote! {}
2693 };
2694
2695 let state_init_code = if state_persisted {
2697 quote! {
2698 let state: #state_type = if let Some(ref storage) = ctx.state_storage {
2700 let key = ::std::format!(
2701 "entity/{}/{}/state",
2702 ctx.address.entity_type.0,
2703 ctx.address.entity_id.0,
2704 );
2705 match storage.load(&key).await {
2706 Ok(Some(bytes)) => {
2707 rmp_serde::from_slice(&bytes).map_err(|e| {
2708 #krate::error::ClusterError::PersistenceError {
2709 reason: ::std::format!("failed to deserialize persisted state: {e}"),
2710 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
2711 }
2712 })?
2713 }
2714 Ok(None) => entity.init(&ctx)?,
2715 Err(e) => {
2716 tracing::warn!(
2717 entity_type = %ctx.address.entity_type.0,
2718 entity_id = %ctx.address.entity_id.0,
2719 error = %e,
2720 "failed to load persisted state, falling back to init"
2721 );
2722 entity.init(&ctx)?
2723 }
2724 }
2725 } else {
2726 entity.init(&ctx)?
2727 };
2728 }
2729 } else {
2730 quote! {
2731 let state: #state_type = entity.init(&ctx)?;
2732 }
2733 };
2734
2735 let state_init_with_traits_code = if has_traits {
2736 if state_persisted {
2737 let composite_fields: Vec<proc_macro2::TokenStream> = trait_infos
2738 .iter()
2739 .map(|info| {
2740 let field = &info.field;
2741 quote! { composite.#field }
2742 })
2743 .collect();
2744 let trait_state_inits = &trait_state_inits;
2745 quote! {
2746 let (state, #(#trait_state_vars),*) = if let Some(ref storage) = ctx.state_storage {
2747 let key = ::std::format!(
2748 "entity/{}/{}/state",
2749 ctx.address.entity_type.0,
2750 ctx.address.entity_id.0,
2751 );
2752 match storage.load(&key).await {
2753 Ok(Some(bytes)) => {
2754 match rmp_serde::from_slice::<#composite_state_name>(&bytes) {
2755 Ok(composite) => {
2756 (composite.entity, #(#composite_fields),*)
2757 }
2758 Err(composite_err) => match rmp_serde::from_slice::<#state_type>(&bytes) {
2759 Ok(state_only) => {
2760 let state = state_only;
2761 #(#trait_state_inits)*
2762 (state, #(#trait_state_vars),*)
2763 }
2764 Err(state_err) => {
2765 return ::std::result::Result::Err(
2766 #krate::error::ClusterError::PersistenceError {
2767 reason: ::std::format!(
2768 "failed to deserialize persisted state: composite={composite_err}; state={state_err}"
2769 ),
2770 source: ::std::option::Option::Some(::std::boxed::Box::new(state_err)),
2771 }
2772 );
2773 }
2774 },
2775 }
2776 }
2777 Ok(None) => {
2778 let state = entity.init(&ctx)?;
2779 #(#trait_state_inits)*
2780 (state, #(#trait_state_vars),*)
2781 }
2782 Err(e) => {
2783 tracing::warn!(
2784 entity_type = %ctx.address.entity_type.0,
2785 entity_id = %ctx.address.entity_id.0,
2786 error = %e,
2787 "failed to load persisted state, falling back to init"
2788 );
2789 let state = entity.init(&ctx)?;
2790 #(#trait_state_inits)*
2791 (state, #(#trait_state_vars),*)
2792 }
2793 }
2794 } else {
2795 let state = entity.init(&ctx)?;
2796 #(#trait_state_inits)*
2797 (state, #(#trait_state_vars),*)
2798 };
2799 }
2800 } else {
2801 let trait_state_inits = &trait_state_inits;
2802 quote! {
2803 let state: #state_type = entity.init(&ctx)?;
2804 #(#trait_state_inits)*
2805 }
2806 }
2807 } else {
2808 quote! { #state_init_code }
2809 };
2810
2811 let has_durable = state_persisted || rpcs.iter().any(|r| r.has_durable_context);
2814
2815 let handler_storage_field = if state_persisted {
2817 quote! {
2818 __state_storage: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowStorage>>,
2819 __state_key: ::std::string::String,
2820 }
2821 } else {
2822 quote! {}
2823 };
2824
2825 let handler_storage_init = if state_persisted {
2826 quote! {
2827 let __state_key = ::std::format!(
2828 "entity/{}/{}/state",
2829 ctx.address.entity_type.0,
2830 ctx.address.entity_id.0,
2831 );
2832 let __state_storage = ctx.state_storage.clone();
2833 }
2834 } else {
2835 quote! {}
2836 };
2837
2838 let handler_storage_fields_init = if state_persisted {
2839 quote! {
2840 __state_storage,
2841 __state_key,
2842 }
2843 } else {
2844 quote! {}
2845 };
2846
2847 let durable_field = if has_durable {
2848 quote! {
2849 __workflow_engine: ::std::option::Option<::std::sync::Arc<dyn #krate::__internal::WorkflowEngine>>,
2850 }
2851 } else {
2852 quote! {}
2853 };
2854 let durable_field_init = if has_durable {
2855 quote! { __workflow_engine: ctx.workflow_engine.clone(), }
2856 } else {
2857 quote! {}
2858 };
2859
2860 let durable_builtin_impls = if state_persisted {
2863 quote! {
2864 pub async fn sleep(&self, name: &str, duration: ::std::time::Duration) -> ::std::result::Result<(), #krate::error::ClusterError> {
2870 let engine = self.__workflow_engine.as_ref().ok_or_else(|| {
2871 #krate::error::ClusterError::MalformedMessage {
2872 reason: "sleep() requires a workflow engine — ensure EntityContext has workflow_engine set".into(),
2873 source: ::std::option::Option::None,
2874 }
2875 })?;
2876 let ctx = #krate::__internal::DurableContext::new(
2877 ::std::sync::Arc::clone(engine),
2878 self.ctx.address.entity_type.0.clone(),
2879 self.ctx.address.entity_id.0.clone(),
2880 );
2881 ctx.sleep(name, duration).await
2882 }
2883
2884 pub async fn await_deferred<T, K>(&self, key: K) -> ::std::result::Result<T, #krate::error::ClusterError>
2889 where
2890 T: serde::Serialize + serde::de::DeserializeOwned,
2891 K: #krate::__internal::DeferredKeyLike<T>,
2892 {
2893 let engine = self.__workflow_engine.as_ref().ok_or_else(|| {
2894 #krate::error::ClusterError::MalformedMessage {
2895 reason: "await_deferred() requires a workflow engine — ensure EntityContext has workflow_engine set".into(),
2896 source: ::std::option::Option::None,
2897 }
2898 })?;
2899 let ctx = #krate::__internal::DurableContext::new(
2900 ::std::sync::Arc::clone(engine),
2901 self.ctx.address.entity_type.0.clone(),
2902 self.ctx.address.entity_id.0.clone(),
2903 );
2904 ctx.await_deferred(key).await
2905 }
2906
2907 pub async fn resolve_deferred<T, K>(&self, key: K, value: &T) -> ::std::result::Result<(), #krate::error::ClusterError>
2909 where
2910 T: serde::Serialize,
2911 K: #krate::__internal::DeferredKeyLike<T>,
2912 {
2913 let engine = self.__workflow_engine.as_ref().ok_or_else(|| {
2914 #krate::error::ClusterError::MalformedMessage {
2915 reason: "resolve_deferred() requires a workflow engine — ensure EntityContext has workflow_engine set".into(),
2916 source: ::std::option::Option::None,
2917 }
2918 })?;
2919 let ctx = #krate::__internal::DurableContext::new(
2920 ::std::sync::Arc::clone(engine),
2921 self.ctx.address.entity_type.0.clone(),
2922 self.ctx.address.entity_id.0.clone(),
2923 );
2924 ctx.resolve_deferred(key, value).await
2925 }
2926
2927 pub async fn on_interrupt(&self) -> ::std::result::Result<(), #krate::error::ClusterError> {
2931 let engine = self.__workflow_engine.as_ref().ok_or_else(|| {
2932 #krate::error::ClusterError::MalformedMessage {
2933 reason: "on_interrupt() requires a workflow engine — ensure EntityContext has workflow_engine set".into(),
2934 source: ::std::option::Option::None,
2935 }
2936 })?;
2937 let ctx = #krate::__internal::DurableContext::new(
2938 ::std::sync::Arc::clone(engine),
2939 self.ctx.address.entity_type.0.clone(),
2940 self.ctx.address.entity_id.0.clone(),
2941 );
2942 ctx.on_interrupt().await
2943 }
2944 }
2945 } else {
2946 quote! {}
2947 };
2948
2949 let sharding_builtin_impls = if state_persisted {
2952 quote! {
2953 pub fn sharding(&self) -> ::std::option::Option<&::std::sync::Arc<dyn #krate::sharding::Sharding>> {
2958 self.__sharding.as_ref()
2959 }
2960
2961 pub fn entity_address(&self) -> &#krate::types::EntityAddress {
2965 &self.__entity_address
2966 }
2967
2968 pub fn entity_id(&self) -> &#krate::types::EntityId {
2970 &self.__entity_address.entity_id
2971 }
2972
2973 pub fn self_client(&self) -> ::std::option::Option<#krate::entity_client::EntityClient> {
2977 self.__sharding.as_ref().map(|s| {
2978 ::std::sync::Arc::clone(s).make_client(self.__entity_address.entity_type.clone())
2979 })
2980 }
2981 }
2982 } else {
2983 quote! {}
2984 };
2985
2986 let _durable_ctx_wrapper_field = if state_persisted {
2989 quote! {
2990 __durable_ctx: ::std::option::Option<#krate::__internal::DurableContext>,
2992 }
2993 } else {
2994 quote! {}
2995 };
2996
2997 let _durable_ctx_wrapper_init = if state_persisted {
2998 quote! {
2999 let __durable_ctx = ctx.workflow_engine.as_ref().map(|engine| {
3000 #krate::__internal::DurableContext::new(
3001 ::std::sync::Arc::clone(engine),
3002 ctx.address.entity_type.0.clone(),
3003 ctx.address.entity_id.0.clone(),
3004 )
3005 });
3006 }
3007 } else {
3008 quote! {}
3009 };
3010
3011 let _durable_ctx_wrapper_field_init = if state_persisted {
3012 quote! { __durable_ctx, }
3013 } else {
3014 quote! {}
3015 };
3016
3017 let _sharding_ctx_wrapper_field = if state_persisted {
3020 quote! {
3021 __sharding: ::std::option::Option<::std::sync::Arc<dyn #krate::sharding::Sharding>>,
3023 __entity_address: #krate::types::EntityAddress,
3025 }
3026 } else {
3027 quote! {}
3028 };
3029
3030 let _sharding_ctx_wrapper_init = if state_persisted {
3031 quote! {
3032 let __sharding = ctx.sharding.clone();
3033 let __entity_address = ctx.address.clone();
3034 }
3035 } else {
3036 quote! {}
3037 };
3038
3039 let _sharding_ctx_wrapper_field_init = if state_persisted {
3040 quote! { __sharding, __entity_address, }
3041 } else {
3042 quote! {}
3043 };
3044
3045 let sharding_handler_field = if state_persisted {
3047 quote! {
3048 __sharding: ::std::option::Option<::std::sync::Arc<dyn #krate::sharding::Sharding>>,
3050 __entity_address: #krate::types::EntityAddress,
3052 }
3053 } else {
3054 quote! {}
3055 };
3056
3057 let new_fn = if state_persisted {
3059 quote! {
3060 #[doc(hidden)]
3061 pub async fn __new(entity: #struct_name, ctx: #krate::entity::EntityContext) -> ::std::result::Result<Self, #krate::error::ClusterError> {
3062 #state_init_code
3063 #handler_storage_init
3064 let __sharding = ctx.sharding.clone();
3065 let __entity_address = ctx.address.clone();
3066 ::std::result::Result::Ok(Self {
3067 __state: ::std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(state)),
3068 __write_lock: ::std::sync::Arc::new(tokio::sync::Mutex::new(())),
3069 __entity: entity,
3070 #durable_field_init
3071 ctx,
3072 #handler_storage_fields_init
3073 __sharding,
3074 __entity_address,
3075 #(#trait_field_init_none)*
3076 })
3077 }
3078 }
3079 } else {
3080 quote! {
3081 #[doc(hidden)]
3082 pub async fn __new(entity: #struct_name, ctx: #krate::entity::EntityContext) -> ::std::result::Result<Self, #krate::error::ClusterError> {
3083 #state_init_code
3084 ::std::result::Result::Ok(Self {
3085 __state: ::std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(state)),
3086 __write_lock: ::std::sync::Arc::new(tokio::sync::Mutex::new(())),
3087 __entity: entity,
3088 #durable_field_init
3089 ctx,
3090 #(#trait_field_init_none)*
3091 })
3092 }
3093 }
3094 };
3095
3096 let new_with_traits_fn = if has_traits {
3097 if state_persisted {
3098 quote! {
3099 #[doc(hidden)]
3100 pub async fn __new_with_traits(
3101 entity: #struct_name,
3102 #(#trait_params,)*
3103 ctx: #krate::entity::EntityContext,
3104 ) -> ::std::result::Result<Self, #krate::error::ClusterError> {
3105 #state_init_with_traits_code
3106 #handler_storage_init
3107 let __sharding = ctx.sharding.clone();
3108 let __entity_address = ctx.address.clone();
3109 ::std::result::Result::Ok(Self {
3110 __state: ::std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(state)),
3111 __write_lock: ::std::sync::Arc::new(tokio::sync::Mutex::new(())),
3112 __entity: entity,
3113 #durable_field_init
3114 ctx,
3115 #handler_storage_fields_init
3116 __sharding,
3117 __entity_address,
3118 #(#trait_field_init_some)*
3119 })
3120 }
3121 }
3122 } else {
3123 quote! {
3124 #[doc(hidden)]
3125 pub async fn __new_with_traits(
3126 entity: #struct_name,
3127 #(#trait_params,)*
3128 ctx: #krate::entity::EntityContext,
3129 ) -> ::std::result::Result<Self, #krate::error::ClusterError> {
3130 #state_init_with_traits_code
3131 ::std::result::Result::Ok(Self {
3132 __state: ::std::sync::Arc::new(arc_swap::ArcSwap::from_pointee(state)),
3133 __write_lock: ::std::sync::Arc::new(tokio::sync::Mutex::new(())),
3134 __entity: entity,
3135 #durable_field_init
3136 ctx,
3137 #(#trait_field_init_some)*
3138 })
3139 }
3140 }
3141 }
3142 } else {
3143 quote! {}
3144 };
3145
3146 let save_composite_state_method = if has_traits && state_persisted {
3148 let composite_ref_name = format_ident!("{}CompositeStateRef", struct_name);
3149 let trait_state_loads: Vec<proc_macro2::TokenStream> = trait_infos
3150 .iter()
3151 .map(|info| {
3152 let field = &info.field;
3153 quote! {
3154 let #field = self.#field.as_ref().expect("trait field should be set");
3155 let #field = &**#field.__state_arc().load();
3156 }
3157 })
3158 .collect();
3159 let composite_field_refs: Vec<proc_macro2::TokenStream> = trait_infos
3160 .iter()
3161 .map(|info| {
3162 let field = &info.field;
3163 quote! { #field: #field }
3164 })
3165 .collect();
3166 quote! {
3167 #[doc(hidden)]
3169 async fn __save_composite_state(&self) -> ::std::result::Result<(), #krate::error::ClusterError> {
3170 if let Some(ref storage) = self.__state_storage {
3171 let entity_state = &**self.__state.load();
3172 #(#trait_state_loads)*
3173 let composite = #composite_ref_name {
3174 entity: entity_state,
3175 #(#composite_field_refs,)*
3176 };
3177 let bytes = rmp_serde::to_vec(&composite)
3178 .map_err(|e| #krate::error::ClusterError::PersistenceError {
3179 reason: ::std::format!("failed to serialize composite state: {e}"),
3180 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3181 })?;
3182 storage.save(&self.__state_key, &bytes).await?;
3183 }
3184 ::std::result::Result::Ok(())
3185 }
3186 }
3187 } else {
3188 quote! {}
3189 };
3190
3191 let trait_dispatch_checks: Vec<proc_macro2::TokenStream> = trait_infos
3192 .iter()
3193 .map(|info| {
3194 let field = &info.field;
3195 let save_after_dispatch = if state_persisted {
3197 quote! {
3198 self.__save_composite_state().await?;
3199 }
3200 } else {
3201 quote! {}
3202 };
3203 quote! {
3204 if let ::std::option::Option::Some(ref __trait) = self.#field {
3205 if let ::std::option::Option::Some(response) = __trait.__dispatch(tag, payload, headers).await? {
3206 #save_after_dispatch
3207 return ::std::result::Result::Ok(response);
3208 }
3209 }
3210 }
3211 })
3212 .collect();
3213
3214 let trait_dispatch_fallback = if has_traits {
3215 quote! {{
3216 #(#trait_dispatch_checks)*
3217 ::std::result::Result::Err(
3218 #krate::error::ClusterError::MalformedMessage {
3219 reason: ::std::format!("unknown RPC tag: {tag}"),
3220 source: ::std::option::Option::None,
3221 }
3222 )
3223 }}
3224 } else {
3225 quote! {{
3226 ::std::result::Result::Err(
3227 #krate::error::ClusterError::MalformedMessage {
3228 reason: ::std::format!("unknown RPC tag: {tag}"),
3229 source: ::std::option::Option::None,
3230 }
3231 )
3232 }}
3233 };
3234
3235 let register_impl = if has_traits {
3237 let register_trait_params: Vec<proc_macro2::TokenStream> = trait_infos
3239 .iter()
3240 .map(|info| {
3241 let param = &info.field;
3242 let path = &info.path;
3243 quote! { #param: #path }
3244 })
3245 .collect();
3246 let trait_with_calls: Vec<proc_macro2::TokenStream> = trait_infos
3247 .iter()
3248 .map(|info| {
3249 let field = &info.field;
3250 quote! { .with(#field) }
3251 })
3252 .collect();
3253 quote! {
3254 impl #struct_name {
3255 pub async fn register(
3260 self,
3261 sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>,
3262 #(#register_trait_params),*
3263 ) -> ::std::result::Result<#client_name, #krate::error::ClusterError> {
3264 let entity_with_traits = self #(#trait_with_calls)*;
3265 sharding.register_entity(::std::sync::Arc::new(entity_with_traits)).await?;
3266 ::std::result::Result::Ok(#client_name::new(sharding))
3267 }
3268 }
3269 }
3270 } else {
3271 quote! {
3273 impl #struct_name {
3274 pub async fn register(
3279 self,
3280 sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>,
3281 ) -> ::std::result::Result<#client_name, #krate::error::ClusterError> {
3282 sharding.register_entity(::std::sync::Arc::new(self)).await?;
3283 ::std::result::Result::Ok(#client_name::new(sharding))
3284 }
3285 }
3286 }
3287 };
3288
3289 let with_traits_name = format_ident!("{}WithTraits", struct_name);
3290 let with_trait_trait_name = format_ident!("__{}WithTrait", struct_name);
3291 let trait_use_tokens: Vec<proc_macro2::TokenStream> = trait_infos
3292 .iter()
3293 .map(|info| {
3294 let path = &info.path;
3295 let ident = &info.ident;
3296 let methods_trait_ident = format_ident!("__{}TraitMethods", ident);
3297 let methods_trait_path = replace_last_segment(path, methods_trait_ident);
3298 quote! {
3299 #[allow(unused_imports)]
3300 use #methods_trait_path as _;
3301 }
3302 })
3303 .collect();
3304
3305 let trait_access_impls: Vec<proc_macro2::TokenStream> = trait_infos
3307 .iter()
3308 .map(|info| {
3309 let path = &info.path;
3310 let ident = &info.ident;
3311 let field = &info.field;
3312 let wrapper_path = &info.wrapper_path;
3313 let access_trait_ident = format_ident!("__{}TraitAccess", ident);
3314 let access_trait_path = replace_last_segment(path, access_trait_ident);
3315 quote! {
3316 impl #access_trait_path for #handler_name {
3317 fn __trait_ref(&self) -> ::std::option::Option<&::std::sync::Arc<#wrapper_path>> {
3318 self.#field.as_ref()
3319 }
3320 }
3321 }
3322 })
3323 .collect();
3324
3325 let with_traits_impl = if has_traits {
3326 let trait_option_fields: Vec<proc_macro2::TokenStream> = trait_infos
3327 .iter()
3328 .map(|info| {
3329 let field = &info.field;
3330 let path = &info.path;
3331 quote! { #field: ::std::option::Option<#path>, }
3332 })
3333 .collect();
3334 let trait_setters: Vec<proc_macro2::TokenStream> = trait_infos
3335 .iter()
3336 .map(|info| {
3337 let path = &info.path;
3338 let field = &info.field;
3339 quote! {
3340 impl #with_trait_trait_name<#path> for #with_traits_name {
3341 fn __with_trait(&mut self, value: #path) {
3342 self.#field = ::std::option::Option::Some(value);
3343 }
3344 }
3345 }
3346 })
3347 .collect();
3348
3349 let trait_missing_guards: Vec<proc_macro2::TokenStream> = trait_infos
3350 .iter()
3351 .map(|info| {
3352 let path = &info.path;
3353 let ident = &info.ident;
3354 let field = &info.field;
3355 let missing_reason = &info.missing_reason;
3356 let param = format_ident!("__trait_{}", to_snake(&ident.to_string()));
3357 quote! {
3358 let #param: #path = self.#field.clone().ok_or_else(|| {
3359 #krate::error::ClusterError::MalformedMessage {
3360 reason: ::std::string::String::from(#missing_reason),
3361 source: ::std::option::Option::None,
3362 }
3363 })?;
3364 }
3365 })
3366 .collect();
3367
3368 let trait_bounds: Vec<proc_macro2::TokenStream> = trait_infos
3369 .iter()
3370 .map(|info| {
3371 let path = &info.path;
3372 quote! { #path: ::std::clone::Clone }
3373 })
3374 .collect();
3375
3376 let trait_field_init_none_tokens = &trait_field_init_none;
3377
3378 quote! {
3379 #[doc(hidden)]
3380 pub struct #with_traits_name {
3381 entity: #struct_name,
3382 #(#trait_option_fields)*
3383 }
3384
3385 trait #with_trait_trait_name<T> {
3386 fn __with_trait(&mut self, value: T);
3387 }
3388
3389 impl #struct_name {
3390 pub fn with<T>(self, value: T) -> #with_traits_name
3391 where
3392 #with_traits_name: #with_trait_trait_name<T>,
3393 {
3394 let mut bundle = #with_traits_name {
3395 entity: self,
3396 #(#trait_field_init_none_tokens)*
3397 };
3398 bundle.__with_trait(value);
3399 bundle
3400 }
3401 }
3402
3403 impl #with_traits_name {
3404 pub fn with<T>(mut self, value: T) -> Self
3405 where
3406 Self: #with_trait_trait_name<T>,
3407 {
3408 self.__with_trait(value);
3409 self
3410 }
3411 }
3412
3413 #(#trait_setters)*
3414
3415 #[async_trait::async_trait]
3416 impl #krate::entity::Entity for #with_traits_name
3417 where
3418 #struct_name: ::std::clone::Clone,
3419 #(#trait_bounds,)*
3420 {
3421 fn entity_type(&self) -> #krate::types::EntityType {
3422 self.entity.__entity_type()
3423 }
3424
3425 fn shard_group(&self) -> &str {
3426 self.entity.__shard_group()
3427 }
3428
3429 fn shard_group_for(&self, entity_id: &#krate::types::EntityId) -> &str {
3430 self.entity.__shard_group_for(entity_id)
3431 }
3432
3433 fn max_idle_time(&self) -> ::std::option::Option<::std::time::Duration> {
3434 self.entity.__max_idle_time()
3435 }
3436
3437 fn mailbox_capacity(&self) -> ::std::option::Option<usize> {
3438 self.entity.__mailbox_capacity()
3439 }
3440
3441 fn concurrency(&self) -> ::std::option::Option<usize> {
3442 self.entity.__concurrency()
3443 }
3444
3445 async fn spawn(
3446 &self,
3447 ctx: #krate::entity::EntityContext,
3448 ) -> ::std::result::Result<
3449 ::std::boxed::Box<dyn #krate::entity::EntityHandler>,
3450 #krate::error::ClusterError,
3451 > {
3452 #(#trait_missing_guards)*
3453 let handler = #handler_name::__new_with_traits(
3454 self.entity.clone(),
3455 #(#trait_param_idents,)*
3456 ctx,
3457 )
3458 .await?;
3459 ::std::result::Result::Ok(::std::boxed::Box::new(handler))
3460 }
3461 }
3462 }
3463 } else {
3464 quote! {}
3465 };
3466
3467 Ok(quote! {
3468 #(#trait_use_tokens)*
3469
3470 impl #struct_name {
3472 #(#init_attrs)*
3473 #init_vis #init_sig #init_block
3474 }
3475
3476 #composite_state_defs
3477
3478 #with_traits_impl
3479 #entity_impl
3480
3481 #view_structs
3483
3484 #view_impls
3486
3487 #[doc(hidden)]
3489 pub struct #handler_name {
3490 __state: ::std::sync::Arc<arc_swap::ArcSwap<#state_type>>,
3492 __write_lock: ::std::sync::Arc<tokio::sync::Mutex<()>>,
3494 #[allow(dead_code)]
3496 __entity: #struct_name,
3497 #[allow(dead_code)]
3499 ctx: #krate::entity::EntityContext,
3500 #handler_storage_field
3501 #durable_field
3502 #sharding_handler_field
3503 #(#trait_field_defs)*
3504 }
3505
3506 impl #handler_name {
3507 #new_fn
3508 #new_with_traits_fn
3509
3510 #save_composite_state_method
3511
3512 #durable_builtin_impls
3513 #sharding_builtin_impls
3514
3515 #(#wrapper_methods)*
3517 }
3518
3519 #[async_trait::async_trait]
3520 impl #krate::entity::EntityHandler for #handler_name {
3521 async fn handle_request(
3522 &self,
3523 tag: &str,
3524 payload: &[u8],
3525 headers: &::std::collections::HashMap<::std::string::String, ::std::string::String>,
3526 ) -> ::std::result::Result<::std::vec::Vec<u8>, #krate::error::ClusterError> {
3527 let _ = headers;
3528 match tag {
3529 #(#dispatch_arms,)*
3530 _ => #trait_dispatch_fallback,
3531 }
3532 }
3533 }
3534
3535 #register_impl
3536
3537 pub struct #client_name {
3539 inner: #krate::entity_client::EntityClient,
3540 }
3541
3542 impl #client_name {
3543 pub fn new(sharding: ::std::sync::Arc<dyn #krate::sharding::Sharding>) -> Self {
3545 Self {
3546 inner: #krate::entity_client::EntityClient::new(
3547 sharding,
3548 #krate::types::EntityType::new(#struct_name_str),
3549 ),
3550 }
3551 }
3552
3553 pub fn inner(&self) -> &#krate::entity_client::EntityClient {
3555 &self.inner
3556 }
3557
3558 #(#client_methods)*
3559 }
3560
3561 impl #krate::entity_client::EntityClientAccessor for #client_name {
3562 fn entity_client(&self) -> &#krate::entity_client::EntityClient {
3563 &self.inner
3564 }
3565 }
3566
3567 #(#trait_access_impls)*
3568 })
3569}
3570
3571fn generate_dispatch_arms(
3572 krate: &syn::Path,
3573 rpcs: &[RpcMethod],
3574 stateful: bool,
3575 save_state_code: Option<&proc_macro2::TokenStream>,
3576 save_composite_state: bool,
3577) -> Vec<proc_macro2::TokenStream> {
3578 rpcs
3579 .iter()
3580 .filter(|rpc| rpc.is_dispatchable())
3581 .map(|rpc| {
3582 let tag = &rpc.tag;
3583 let method_name = &rpc.name;
3584 let param_count = rpc.params.len();
3585 let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
3586 let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
3587
3588 let deserialize_request = match param_count {
3589 0 => quote! {},
3590 1 => {
3591 let name = ¶m_names[0];
3592 let ty = ¶m_types[0];
3593 quote! {
3594 let #name: #ty = rmp_serde::from_slice(payload)
3595 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3596 reason: ::std::format!("failed to deserialize request for '{}': {e}", #tag),
3597 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3598 })?;
3599 }
3600 }
3601 _ => quote! {
3602 let (#(#param_names),*): (#(#param_types),*) = rmp_serde::from_slice(payload)
3603 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3604 reason: ::std::format!("failed to deserialize request for '{}': {e}", #tag),
3605 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3606 })?;
3607 },
3608 };
3609
3610 let durable_ctx_code = if rpc.has_durable_context {
3612 quote! {
3613 let __durable_engine = self.__workflow_engine.as_ref().ok_or_else(|| {
3614 #krate::error::ClusterError::MalformedMessage {
3615 reason: ::std::format!("method '{}' requires a DurableContext but no workflow engine was provided", #tag),
3616 source: ::std::option::Option::None,
3617 }
3618 })?;
3619 let __durable_ctx = #krate::__internal::DurableContext::new(
3620 ::std::sync::Arc::clone(__durable_engine),
3621 self.ctx.address.entity_type.0.clone(),
3622 self.ctx.address.entity_id.0.clone(),
3623 );
3624 }
3625 } else {
3626 quote! {}
3627 };
3628
3629 let mut call_args = Vec::new();
3630 if rpc.has_durable_context {
3631 call_args.push(quote! { &__durable_ctx });
3632 }
3633 match param_count {
3634 0 => {}
3635 1 => {
3636 let name = ¶m_names[0];
3637 call_args.push(quote! { #name });
3638 }
3639 _ => {
3640 for name in ¶m_names {
3641 call_args.push(quote! { #name });
3642 }
3643 }
3644 }
3645 let call_args = quote! { #(#call_args),* };
3646
3647 let _ = save_state_code; let _ = rpc.is_mut; let post_call_save = if save_composite_state {
3653 quote! { self.__save_composite_state().await?; }
3654 } else {
3655 quote! {}
3656 };
3657
3658 if stateful {
3659 quote! {
3661 #tag => {
3662 #deserialize_request
3663 #durable_ctx_code
3664 let response = self.#method_name(#call_args).await?;
3665 #post_call_save
3666 rmp_serde::to_vec(&response)
3667 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3668 reason: ::std::format!("failed to serialize response for '{}': {e}", #tag),
3669 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3670 })
3671 }
3672 }
3673 } else {
3674 quote! {
3676 #tag => {
3677 #deserialize_request
3678 #durable_ctx_code
3679 let response = self.entity.#method_name(#call_args).await?;
3680 rmp_serde::to_vec(&response)
3681 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3682 reason: ::std::format!("failed to serialize response for '{}': {e}", #tag),
3683 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3684 })
3685 }
3686 }
3687 }
3688 })
3689 .collect()
3690}
3691
3692fn generate_client_methods(krate: &syn::Path, rpcs: &[RpcMethod]) -> Vec<proc_macro2::TokenStream> {
3693 rpcs.iter()
3694 .filter(|rpc| rpc.is_client_visible())
3695 .map(|rpc| {
3696 let method_name = &rpc.name;
3697 let tag = &rpc.tag;
3698 let resp_type = &rpc.response_type;
3699 let persist_key = rpc.persist_key.as_ref();
3700 let param_count = rpc.params.len();
3701 let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
3702 let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
3703 let param_defs: Vec<_> = param_names
3704 .iter()
3705 .zip(param_types.iter())
3706 .map(|(name, ty)| quote! { #name: &#ty })
3707 .collect();
3708 if rpc.kind.is_persisted() {
3709 match (persist_key, param_count) {
3710 (Some(persist_key), 0) => quote! {
3711 pub async fn #method_name(
3712 &self,
3713 entity_id: &#krate::types::EntityId,
3714 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3715 let key = (#persist_key)();
3716 let key_bytes = rmp_serde::to_vec(&key)
3717 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3718 reason: ::std::format!(
3719 "failed to serialize persist key for '{}': {e}",
3720 #tag
3721 ),
3722 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3723 })?;
3724 self.inner
3725 .send_persisted_with_key(
3726 entity_id,
3727 #tag,
3728 &(),
3729 ::std::option::Option::Some(key_bytes),
3730 #krate::schema::Uninterruptible::No,
3731 )
3732 .await
3733 }
3734 },
3735 (Some(persist_key), 1) => {
3736 let name = ¶m_names[0];
3737 let def = ¶m_defs[0];
3738 quote! {
3739 pub async fn #method_name(
3740 &self,
3741 entity_id: &#krate::types::EntityId,
3742 #def,
3743 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3744 let key = (#persist_key)(#name);
3745 let key_bytes = rmp_serde::to_vec(&key)
3746 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3747 reason: ::std::format!(
3748 "failed to serialize persist key for '{}': {e}",
3749 #tag
3750 ),
3751 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3752 })?;
3753 self.inner
3754 .send_persisted_with_key(
3755 entity_id,
3756 #tag,
3757 #name,
3758 ::std::option::Option::Some(key_bytes),
3759 #krate::schema::Uninterruptible::No,
3760 )
3761 .await
3762 }
3763 }
3764 }
3765 (Some(persist_key), _) => quote! {
3766 pub async fn #method_name(
3767 &self,
3768 entity_id: &#krate::types::EntityId,
3769 #(#param_defs),*
3770 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3771 let key = (#persist_key)(#(#param_names),*);
3772 let key_bytes = rmp_serde::to_vec(&key)
3773 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3774 reason: ::std::format!(
3775 "failed to serialize persist key for '{}': {e}",
3776 #tag
3777 ),
3778 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3779 })?;
3780 let request = (#(#param_names),*);
3781 self.inner
3782 .send_persisted_with_key(
3783 entity_id,
3784 #tag,
3785 &request,
3786 ::std::option::Option::Some(key_bytes),
3787 #krate::schema::Uninterruptible::No,
3788 )
3789 .await
3790 }
3791 },
3792 (None, 0) => quote! {
3793 pub async fn #method_name(
3794 &self,
3795 entity_id: &#krate::types::EntityId,
3796 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3797 self.inner
3798 .send_persisted(entity_id, #tag, &(), #krate::schema::Uninterruptible::No)
3799 .await
3800 }
3801 },
3802 (None, 1) => {
3803 let name = ¶m_names[0];
3804 let def = ¶m_defs[0];
3805 quote! {
3806 pub async fn #method_name(
3807 &self,
3808 entity_id: &#krate::types::EntityId,
3809 #def,
3810 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3811 self.inner
3812 .send_persisted(entity_id, #tag, #name, #krate::schema::Uninterruptible::No)
3813 .await
3814 }
3815 }
3816 }
3817 (None, _) => quote! {
3818 pub async fn #method_name(
3819 &self,
3820 entity_id: &#krate::types::EntityId,
3821 #(#param_defs),*
3822 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3823 let request = (#(#param_names),*);
3824 self.inner
3825 .send_persisted(entity_id, #tag, &request, #krate::schema::Uninterruptible::No)
3826 .await
3827 }
3828 },
3829 }
3830 } else {
3831 match param_count {
3832 0 => quote! {
3833 pub async fn #method_name(
3834 &self,
3835 entity_id: &#krate::types::EntityId,
3836 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3837 self.inner.send(entity_id, #tag, &()).await
3838 }
3839 },
3840 1 => {
3841 let def = ¶m_defs[0];
3842 let name = ¶m_names[0];
3843 quote! {
3844 pub async fn #method_name(
3845 &self,
3846 entity_id: &#krate::types::EntityId,
3847 #def,
3848 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3849 self.inner.send(entity_id, #tag, #name).await
3850 }
3851 }
3852 }
3853 _ => quote! {
3854 pub async fn #method_name(
3855 &self,
3856 entity_id: &#krate::types::EntityId,
3857 #(#param_defs),*
3858 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
3859 let request = (#(#param_names),*);
3860 self.inner.send(entity_id, #tag, &request).await
3861 }
3862 },
3863 }
3864 }
3865 })
3866 .collect()
3867}
3868
3869fn generate_method_impls(original_methods: &[syn::ImplItemFn]) -> Vec<proc_macro2::TokenStream> {
3870 original_methods
3871 .iter()
3872 .map(|m| {
3873 let sig = &m.sig;
3874 let block = &m.block;
3875 let attrs: Vec<_> = m
3877 .attrs
3878 .iter()
3879 .filter(|a| {
3880 !a.path().is_ident("rpc")
3881 && !a.path().is_ident("workflow")
3882 && !a.path().is_ident("activity")
3883 && !a.path().is_ident("method")
3884 && !a.path().is_ident("public")
3885 && !a.path().is_ident("protected")
3886 && !a.path().is_ident("private")
3887 })
3888 .collect();
3889 let vis = &m.vis;
3890 quote! {
3891 #(#attrs)*
3892 #vis #sig #block
3893 }
3894 })
3895 .collect()
3896}
3897
3898fn generate_trait_dispatch_impl(
3899 krate: &syn::Path,
3900 trait_ident: &syn::Ident,
3901 rpcs: &[RpcMethod],
3902) -> proc_macro2::TokenStream {
3903 let dispatch_arms: Vec<_> = rpcs
3904 .iter()
3905 .filter(|rpc| rpc.is_dispatchable())
3906 .map(|rpc| {
3907 let tag = &rpc.tag;
3908 let method_name = &rpc.name;
3909 let param_count = rpc.params.len();
3910 let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
3911 let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
3912
3913 let deserialize_request = match param_count {
3914 0 => quote! {},
3915 1 => {
3916 let name = ¶m_names[0];
3917 let ty = ¶m_types[0];
3918 quote! {
3919 let #name: #ty = rmp_serde::from_slice(payload)
3920 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3921 reason: ::std::format!("failed to deserialize request for '{}': {e}", #tag),
3922 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3923 })?;
3924 }
3925 }
3926 _ => quote! {
3927 let (#(#param_names),*): (#(#param_types),*) = rmp_serde::from_slice(payload)
3928 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3929 reason: ::std::format!("failed to deserialize request for '{}': {e}", #tag),
3930 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3931 })?;
3932 },
3933 };
3934
3935 let mut call_args = Vec::new();
3936 match param_count {
3937 0 => {}
3938 1 => {
3939 let name = ¶m_names[0];
3940 call_args.push(quote! { #name });
3941 }
3942 _ => {
3943 for name in ¶m_names {
3944 call_args.push(quote! { #name });
3945 }
3946 }
3947 }
3948 let call_args = quote! { #(#call_args),* };
3949
3950 quote! {
3952 #tag => {
3953 #deserialize_request
3954 let response = self.#method_name(#call_args).await?;
3955 let bytes = rmp_serde::to_vec(&response)
3956 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
3957 reason: ::std::format!("failed to serialize response for '{}': {e}", #tag),
3958 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
3959 })?;
3960 ::std::result::Result::Ok(::std::option::Option::Some(bytes))
3961 }
3962 }
3963 })
3964 .collect();
3965
3966 quote! {
3967 impl #trait_ident {
3968 #[doc(hidden)]
3969 pub async fn __dispatch(
3970 &self,
3971 tag: &str,
3972 payload: &[u8],
3973 headers: &::std::collections::HashMap<::std::string::String, ::std::string::String>,
3974 ) -> ::std::result::Result<::std::option::Option<::std::vec::Vec<u8>>, #krate::error::ClusterError> {
3975 let _ = headers;
3976 match tag {
3977 #(#dispatch_arms,)*
3978 _ => ::std::result::Result::Ok(::std::option::Option::None),
3979 }
3980 }
3981 }
3982 }
3983}
3984
3985fn generate_trait_client_ext(
3986 krate: &syn::Path,
3987 trait_ident: &syn::Ident,
3988 rpcs: &[RpcMethod],
3989) -> proc_macro2::TokenStream {
3990 let client_methods: Vec<_> = rpcs
3991 .iter()
3992 .filter(|rpc| rpc.is_client_visible())
3993 .map(|rpc| {
3994 let method_name = &rpc.name;
3995 let tag = &rpc.tag;
3996 let resp_type = &rpc.response_type;
3997 let persist_key = rpc.persist_key.as_ref();
3998 let param_count = rpc.params.len();
3999 let param_names: Vec<_> = rpc.params.iter().map(|p| &p.name).collect();
4000 let param_types: Vec<_> = rpc.params.iter().map(|p| &p.ty).collect();
4001 let param_defs: Vec<_> = param_names
4002 .iter()
4003 .zip(param_types.iter())
4004 .map(|(name, ty)| quote! { #name: &#ty })
4005 .collect();
4006
4007 if rpc.kind.is_persisted() {
4008 match (persist_key, param_count) {
4009 (Some(persist_key), 0) => quote! {
4010 async fn #method_name(
4011 &self,
4012 entity_id: &#krate::types::EntityId,
4013 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
4014 let key = (#persist_key)();
4015 let key_bytes = rmp_serde::to_vec(&key)
4016 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
4017 reason: ::std::format!(
4018 "failed to serialize persist key for '{}': {e}",
4019 #tag
4020 ),
4021 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
4022 })?;
4023 self.entity_client()
4024 .send_persisted_with_key(
4025 entity_id,
4026 #tag,
4027 &(),
4028 ::std::option::Option::Some(key_bytes),
4029 #krate::schema::Uninterruptible::No,
4030 )
4031 .await
4032 }
4033 },
4034 (Some(persist_key), 1) => {
4035 let name = ¶m_names[0];
4036 let def = ¶m_defs[0];
4037 quote! {
4038 async fn #method_name(
4039 &self,
4040 entity_id: &#krate::types::EntityId,
4041 #def,
4042 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
4043 let key = (#persist_key)(#name);
4044 let key_bytes = rmp_serde::to_vec(&key)
4045 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
4046 reason: ::std::format!(
4047 "failed to serialize persist key for '{}': {e}",
4048 #tag
4049 ),
4050 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
4051 })?;
4052 self.entity_client()
4053 .send_persisted_with_key(
4054 entity_id,
4055 #tag,
4056 #name,
4057 ::std::option::Option::Some(key_bytes),
4058 #krate::schema::Uninterruptible::No,
4059 )
4060 .await
4061 }
4062 }
4063 }
4064 (Some(persist_key), _) => quote! {
4065 async fn #method_name(
4066 &self,
4067 entity_id: &#krate::types::EntityId,
4068 #(#param_defs),*
4069 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
4070 let key = (#persist_key)(#(#param_names),*);
4071 let key_bytes = rmp_serde::to_vec(&key)
4072 .map_err(|e| #krate::error::ClusterError::MalformedMessage {
4073 reason: ::std::format!(
4074 "failed to serialize persist key for '{}': {e}",
4075 #tag
4076 ),
4077 source: ::std::option::Option::Some(::std::boxed::Box::new(e)),
4078 })?;
4079 let request = (#(#param_names),*);
4080 self.entity_client()
4081 .send_persisted_with_key(
4082 entity_id,
4083 #tag,
4084 &request,
4085 ::std::option::Option::Some(key_bytes),
4086 #krate::schema::Uninterruptible::No,
4087 )
4088 .await
4089 }
4090 },
4091 (None, 0) => quote! {
4092 async fn #method_name(
4093 &self,
4094 entity_id: &#krate::types::EntityId,
4095 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
4096 self.entity_client()
4097 .send_persisted(entity_id, #tag, &(), #krate::schema::Uninterruptible::No)
4098 .await
4099 }
4100 },
4101 (None, 1) => {
4102 let def = ¶m_defs[0];
4103 let name = ¶m_names[0];
4104 quote! {
4105 async fn #method_name(
4106 &self,
4107 entity_id: &#krate::types::EntityId,
4108 #def,
4109 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
4110 self.entity_client()
4111 .send_persisted(entity_id, #tag, #name, #krate::schema::Uninterruptible::No)
4112 .await
4113 }
4114 }
4115 }
4116 (None, _) => quote! {
4117 async fn #method_name(
4118 &self,
4119 entity_id: &#krate::types::EntityId,
4120 #(#param_defs),*
4121 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
4122 let request = (#(#param_names),*);
4123 self.entity_client()
4124 .send_persisted(entity_id, #tag, &request, #krate::schema::Uninterruptible::No)
4125 .await
4126 }
4127 },
4128 }
4129 } else {
4130 match param_count {
4131 0 => quote! {
4132 async fn #method_name(
4133 &self,
4134 entity_id: &#krate::types::EntityId,
4135 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
4136 self.entity_client().send(entity_id, #tag, &()).await
4137 }
4138 },
4139 1 => {
4140 let def = ¶m_defs[0];
4141 let name = ¶m_names[0];
4142 quote! {
4143 async fn #method_name(
4144 &self,
4145 entity_id: &#krate::types::EntityId,
4146 #def,
4147 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
4148 self.entity_client().send(entity_id, #tag, #name).await
4149 }
4150 }
4151 }
4152 _ => quote! {
4153 async fn #method_name(
4154 &self,
4155 entity_id: &#krate::types::EntityId,
4156 #(#param_defs),*
4157 ) -> ::std::result::Result<#resp_type, #krate::error::ClusterError> {
4158 let request = (#(#param_names),*);
4159 self.entity_client().send(entity_id, #tag, &request).await
4160 }
4161 },
4162 }
4163 }
4164 })
4165 .collect();
4166
4167 let client_ext_name = format_ident!("{}ClientExt", trait_ident);
4168 quote! {
4169 #[async_trait::async_trait]
4170 pub trait #client_ext_name: #krate::entity_client::EntityClientAccessor {
4171 #(#client_methods)*
4172 }
4173
4174 impl<T> #client_ext_name for T where T: #krate::entity_client::EntityClientAccessor {}
4175 }
4176}
4177
4178fn is_durable_context_type(ty: &syn::Type) -> bool {
4180 match ty {
4181 syn::Type::Reference(r) => is_durable_context_type(&r.elem),
4182 syn::Type::Path(tp) => tp
4183 .path
4184 .segments
4185 .last()
4186 .map(|s| s.ident == "DurableContext")
4187 .unwrap_or(false),
4188 _ => false,
4189 }
4190}
4191
4192struct StateArgs {
4193 ty: syn::Type,
4194 persistent: bool,
4195}
4196
4197impl syn::parse::Parse for StateArgs {
4198 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
4199 let ty: syn::Type = input.parse()?;
4200
4201 if !input.is_empty() {
4206 return Err(syn::Error::new(
4207 input.span(),
4208 "unexpected tokens in #[state(...)]; state is always persistent",
4209 ));
4210 }
4211
4212 Ok(StateArgs {
4213 ty,
4214 persistent: true,
4215 })
4216 }
4217}
4218
4219fn parse_state_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<StateArgs>> {
4220 let mut state_attr: Option<StateArgs> = None;
4221 let mut i = 0;
4222 while i < attrs.len() {
4223 if attrs[i].path().is_ident("state") {
4224 if state_attr.is_some() {
4225 return Err(syn::Error::new(
4226 attrs[i].span(),
4227 "duplicate #[state(...)] attribute",
4228 ));
4229 }
4230 let args = attrs[i].parse_args::<StateArgs>()?;
4231 state_attr = Some(args);
4232 attrs.remove(i);
4233 continue;
4234 }
4235 i += 1;
4236 }
4237 Ok(state_attr)
4238}
4239
4240struct KeyArgs {
4241 key: Option<syn::ExprClosure>,
4242}
4243
4244impl syn::parse::Parse for KeyArgs {
4245 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
4246 if input.is_empty() {
4247 return Ok(KeyArgs { key: None });
4248 }
4249
4250 let ident: syn::Ident = input.parse()?;
4251 if ident != "key" {
4252 return Err(syn::Error::new(
4253 ident.span(),
4254 "expected `key` in #[workflow(key(...))] or #[activity(key(...))]",
4255 ));
4256 }
4257
4258 if input.peek(syn::Token![=]) {
4259 input.parse::<syn::Token![=]>()?;
4260 }
4261
4262 let expr: syn::Expr = if input.peek(syn::token::Paren) {
4263 let content;
4264 syn::parenthesized!(content in input);
4265 content.parse()?
4266 } else {
4267 input.parse()?
4268 };
4269
4270 if !input.is_empty() {
4271 return Err(syn::Error::new(
4272 input.span(),
4273 "unexpected tokens in #[workflow(...)] or #[activity(...)]",
4274 ));
4275 }
4276
4277 match expr {
4278 syn::Expr::Closure(closure) => Ok(KeyArgs { key: Some(closure) }),
4279 _ => Err(syn::Error::new(
4280 expr.span(),
4281 "key must be a closure, e.g. #[workflow(key(|req| ...))]",
4282 )),
4283 }
4284 }
4285}
4286
4287fn parse_kind_attr(
4288 attrs: &[syn::Attribute],
4289) -> syn::Result<Option<(RpcKind, Option<syn::ExprClosure>)>> {
4290 let mut kind: Option<RpcKind> = None;
4291 let mut key: Option<syn::ExprClosure> = None;
4292
4293 for attr in attrs {
4294 if attr.path().is_ident("rpc") {
4295 if kind.is_some() {
4296 return Err(syn::Error::new(attr.span(), "multiple RPC kind attributes"));
4297 }
4298 match &attr.meta {
4299 syn::Meta::Path(_) => {
4300 kind = Some(RpcKind::Rpc);
4301 }
4302 _ => {
4303 return Err(syn::Error::new(
4304 attr.span(),
4305 "#[rpc] does not take arguments",
4306 ))
4307 }
4308 }
4309 }
4310
4311 if attr.path().is_ident("workflow") {
4312 if kind.is_some() {
4313 return Err(syn::Error::new(attr.span(), "multiple RPC kind attributes"));
4314 }
4315 let args = match &attr.meta {
4316 syn::Meta::Path(_) => KeyArgs { key: None },
4317 syn::Meta::List(_) => attr.parse_args::<KeyArgs>()?,
4318 syn::Meta::NameValue(_) => {
4319 return Err(syn::Error::new(
4320 attr.span(),
4321 "expected #[workflow] or #[workflow(key(...))]",
4322 ))
4323 }
4324 };
4325 kind = Some(RpcKind::Workflow);
4326 if args.key.is_some() {
4327 key = args.key;
4328 }
4329 }
4330
4331 if attr.path().is_ident("activity") {
4332 if kind.is_some() {
4333 return Err(syn::Error::new(attr.span(), "multiple RPC kind attributes"));
4334 }
4335 let args = match &attr.meta {
4336 syn::Meta::Path(_) => KeyArgs { key: None },
4337 syn::Meta::List(_) => attr.parse_args::<KeyArgs>()?,
4338 syn::Meta::NameValue(_) => {
4339 return Err(syn::Error::new(
4340 attr.span(),
4341 "expected #[activity] or #[activity(key(...))]",
4342 ))
4343 }
4344 };
4345 kind = Some(RpcKind::Activity);
4346 if args.key.is_some() {
4347 key = args.key;
4348 }
4349 }
4350
4351 if attr.path().is_ident("method") {
4352 if kind.is_some() {
4353 return Err(syn::Error::new(attr.span(), "multiple RPC kind attributes"));
4354 }
4355 match &attr.meta {
4356 syn::Meta::Path(_) => {
4357 kind = Some(RpcKind::Method);
4358 }
4359 _ => {
4360 return Err(syn::Error::new(
4361 attr.span(),
4362 "#[method] does not take arguments",
4363 ))
4364 }
4365 }
4366 }
4367 }
4368
4369 Ok(kind.map(|kind| (kind, key)))
4370}
4371
4372fn parse_visibility_attr(attrs: &[syn::Attribute]) -> syn::Result<Option<RpcVisibility>> {
4373 let mut visibility: Option<RpcVisibility> = None;
4374
4375 for attr in attrs {
4376 let next = if attr.path().is_ident("public") {
4377 Some(RpcVisibility::Public)
4378 } else if attr.path().is_ident("protected") {
4379 Some(RpcVisibility::Protected)
4380 } else if attr.path().is_ident("private") {
4381 Some(RpcVisibility::Private)
4382 } else {
4383 None
4384 };
4385
4386 if let Some(next) = next {
4387 match &attr.meta {
4388 syn::Meta::Path(_) => {}
4389 _ => {
4390 return Err(syn::Error::new(
4391 attr.span(),
4392 "visibility attributes do not take arguments",
4393 ))
4394 }
4395 }
4396 if visibility.is_some() {
4397 return Err(syn::Error::new(
4398 attr.span(),
4399 "multiple visibility modifiers are not allowed",
4400 ));
4401 }
4402 visibility = Some(next);
4403 }
4404 }
4405
4406 Ok(visibility)
4407}
4408
4409fn parse_rpc_method(method: &syn::ImplItemFn) -> syn::Result<Option<RpcMethod>> {
4410 let name = method.sig.ident.clone();
4411 let tag = name.to_string();
4412
4413 let kind_info = parse_kind_attr(&method.attrs)?;
4414 let visibility_attr = parse_visibility_attr(&method.attrs)?;
4415
4416 let (kind, persist_key) = match kind_info {
4417 Some(info) => info,
4418 None => {
4419 if visibility_attr.is_some() {
4420 return Err(syn::Error::new(
4421 method.sig.span(),
4422 "visibility modifiers require #[rpc], #[workflow], #[activity], or #[method]",
4423 ));
4424 }
4425 return Ok(None);
4426 }
4427 };
4428
4429 if method.sig.asyncness.is_none() && !matches!(kind, RpcKind::Method) {
4431 return Err(syn::Error::new(
4432 method.sig.span(),
4433 "#[rpc]/#[workflow]/#[activity] can only be applied to async methods",
4434 ));
4435 }
4436
4437 if matches!(kind, RpcKind::Rpc | RpcKind::Method) && persist_key.is_some() {
4438 return Err(syn::Error::new(
4439 method.sig.span(),
4440 "#[rpc] and #[method] do not support key(...) — use #[workflow(key(...))] or #[activity(key(...))]",
4441 ));
4442 }
4443
4444 let visibility = match (kind, visibility_attr) {
4445 (_, Some(RpcVisibility::Public)) if matches!(kind, RpcKind::Activity | RpcKind::Method) => {
4447 return Err(syn::Error::new(
4448 method.sig.span(),
4449 "#[activity] and #[method] cannot be #[public]",
4450 ))
4451 }
4452 (RpcKind::Activity | RpcKind::Method, None) => RpcVisibility::Private,
4453 (RpcKind::Rpc | RpcKind::Workflow, None) => RpcVisibility::Public,
4454 (_, Some(vis)) => vis,
4455 };
4456
4457 let is_mut = method
4459 .sig
4460 .inputs
4461 .first()
4462 .map(|arg| match arg {
4463 syn::FnArg::Receiver(r) => r.mutability.is_some(),
4464 _ => false,
4465 })
4466 .unwrap_or(false);
4467
4468 if is_mut && !matches!(kind, RpcKind::Activity) {
4470 return Err(syn::Error::new(
4471 method.sig.span(),
4472 "only #[activity] methods can use `&mut self` for state mutation; use `&self` for read-only access",
4473 ));
4474 }
4475
4476 let mut params = Vec::new();
4477 let mut has_durable_context = false;
4478 let mut saw_non_ctx_param = false;
4479 let mut param_index = 0usize;
4480 for arg in method.sig.inputs.iter().skip(1) {
4481 match arg {
4482 syn::FnArg::Typed(pat_type) => {
4483 if is_durable_context_type(&pat_type.ty) {
4484 if has_durable_context {
4485 return Err(syn::Error::new(
4486 arg.span(),
4487 "duplicate DurableContext parameter",
4488 ));
4489 }
4490 if saw_non_ctx_param {
4491 return Err(syn::Error::new(
4492 arg.span(),
4493 "DurableContext must be the first parameter after &self",
4494 ));
4495 }
4496 has_durable_context = true;
4497 continue; }
4499 saw_non_ctx_param = true;
4500 let name = match &*pat_type.pat {
4501 syn::Pat::Ident(ident) => ident.ident.clone(),
4502 syn::Pat::Wild(_) => {
4503 let ident = format_ident!("__arg{param_index}");
4504 ident
4505 }
4506 _ => {
4507 return Err(syn::Error::new(
4508 pat_type.pat.span(),
4509 "entity RPC parameters must be simple identifiers",
4510 ))
4511 }
4512 };
4513 param_index += 1;
4514 params.push(RpcParam {
4515 name,
4516 ty: (*pat_type.ty).clone(),
4517 });
4518 }
4519 syn::FnArg::Receiver(_) => {}
4520 }
4521 }
4522
4523 if has_durable_context && matches!(kind, RpcKind::Rpc | RpcKind::Method) {
4525 return Err(syn::Error::new(
4526 method.sig.span(),
4527 "methods with `&DurableContext` must be marked #[workflow] or #[activity]",
4528 ));
4529 }
4530
4531 let response_type = match &method.sig.output {
4533 syn::ReturnType::Type(_, ty) => {
4534 if matches!(kind, RpcKind::Method) {
4535 (**ty).clone()
4537 } else {
4538 extract_result_ok_type(ty)?
4539 }
4540 }
4541 syn::ReturnType::Default => {
4542 if matches!(kind, RpcKind::Method) {
4543 syn::parse_quote!(())
4545 } else {
4546 return Err(syn::Error::new(
4547 method.sig.span(),
4548 "entity RPC methods must return Result<T, ClusterError>",
4549 ));
4550 }
4551 }
4552 };
4553
4554 Ok(Some(RpcMethod {
4555 name,
4556 tag,
4557 params,
4558 response_type,
4559 is_mut,
4560 kind,
4561 visibility,
4562 persist_key,
4563 has_durable_context,
4564 }))
4565}
4566
4567fn to_snake(input: &str) -> String {
4568 let mut out = String::new();
4569 let mut prev_is_upper = false;
4570 let mut prev_is_lower = false;
4571 let chars: Vec<char> = input.chars().collect();
4572 for (i, ch) in chars.iter().enumerate() {
4573 let is_upper = ch.is_uppercase();
4574 let is_lower = ch.is_lowercase();
4575 let next_is_lower = chars.get(i + 1).map(|c| c.is_lowercase()).unwrap_or(false);
4576
4577 if is_upper {
4578 if prev_is_lower || (prev_is_upper && next_is_lower) {
4579 out.push('_');
4580 }
4581 for lower in ch.to_lowercase() {
4582 out.push(lower);
4583 }
4584 } else if ch.is_alphanumeric() || *ch == '_' {
4585 out.push(*ch);
4586 }
4587
4588 prev_is_upper = is_upper;
4589 prev_is_lower = is_lower;
4590 }
4591 out
4592}
4593
4594fn extract_result_ok_type(ty: &syn::Type) -> syn::Result<syn::Type> {
4595 if let syn::Type::Path(type_path) = ty {
4596 if let Some(segment) = type_path.path.segments.last() {
4597 if segment.ident == "Result" {
4598 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
4599 if let Some(syn::GenericArgument::Type(ok_type)) = args.args.first() {
4600 return Ok(ok_type.clone());
4601 }
4602 }
4603 }
4604 }
4605 }
4606 Err(syn::Error::new(
4607 ty.span(),
4608 "expected Result<T, ClusterError> return type",
4609 ))
4610}