1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{ItemFn, parse_macro_input};
4
5#[proc_macro_attribute]
41pub fn workflow(_attr: TokenStream, item: TokenStream) -> TokenStream {
42 let input_fn = parse_macro_input!(item as ItemFn);
43 let fn_name = &input_fn.sig.ident;
44 let fn_name_str = fn_name.to_string();
45 let vis = &input_fn.vis;
46 let sig = &input_fn.sig;
47 let body = &input_fn.block;
48 let attrs = &input_fn.attrs;
49
50 let is_ctx_type = |ty: &syn::Type| -> bool {
51 matches!(ty, syn::Type::Path(tp) if tp.path.segments.last().is_some_and(|s| s.ident == "Ctx"))
52 };
53
54 let first_is_ctx = input_fn.sig.inputs.len() >= 1
56 && matches!(&input_fn.sig.inputs[0], syn::FnArg::Typed(pat) if is_ctx_type(&pat.ty));
57
58 let registration = if first_is_ctx && input_fn.sig.inputs.len() == 1 {
59 let start_fn_name = syn::Ident::new(
61 &format!("durable_{}", fn_name),
62 fn_name.span(),
63 );
64 quote! {
65 ::durable::inventory::submit! {
66 ::durable::WorkflowRegistration {
67 name: #fn_name_str,
68 resume_fn: |ctx| ::std::boxed::Box::pin(async move {
69 let _ = #fn_name(ctx).await?;
70 Ok(())
71 }),
72 }
73 }
74
75 #vis async fn #start_fn_name(
79 db: &::sea_orm::DatabaseConnection,
80 name: &str,
81 ) -> ::std::result::Result<::durable::Ctx, ::durable::DurableError> {
82 ::durable::Ctx::start_with_handler(db, name, ::std::option::Option::None, ::std::option::Option::Some(#fn_name_str)).await
83 }
84 }
85 } else if first_is_ctx && input_fn.sig.inputs.len() == 2 {
86 let input_type = match &input_fn.sig.inputs[1] {
88 syn::FnArg::Typed(pat) => &pat.ty,
89 _ => panic!("#[durable::workflow] second parameter must be a typed argument"),
90 };
91 let start_fn_name = syn::Ident::new(
92 &format!("durable_{}", fn_name),
93 fn_name.span(),
94 );
95 quote! {
96 ::durable::inventory::submit! {
97 ::durable::WorkflowRegistration {
98 name: #fn_name_str,
99 resume_fn: |ctx| ::std::boxed::Box::pin(async move {
100 let input: #input_type = ctx.input().await?;
101 let _ = #fn_name(ctx, input).await?;
102 Ok(())
103 }),
104 }
105 }
106
107 #vis async fn #start_fn_name(
111 db: &::sea_orm::DatabaseConnection,
112 name: &str,
113 input: #input_type,
114 ) -> ::std::result::Result<::durable::Ctx, ::durable::DurableError> {
115 let input_json = ::serde_json::to_value(&input)
116 .map_err(|e| ::durable::DurableError::custom(format!("failed to serialize workflow input: {e}")))?;
117 ::durable::Ctx::start_with_handler(db, name, ::std::option::Option::Some(input_json), ::std::option::Option::Some(#fn_name_str)).await
118 }
119 }
120 } else {
121 quote! {}
122 };
123
124 let expanded = quote! {
125 #(#attrs)*
126 #vis #sig {
127 let _workflow_name = #fn_name_str;
128 tracing::info!(workflow = _workflow_name, "workflow started");
129 let _result: Result<_, _> = async { #body }.await;
130 match &_result {
131 Ok(_) => tracing::info!(workflow = _workflow_name, "workflow completed"),
132 Err(e) => tracing::error!(workflow = _workflow_name, error = %e, "workflow failed"),
133 }
134 _result
135 }
136
137 #registration
138 };
139
140 TokenStream::from(expanded)
141}
142
143#[proc_macro_attribute]
155pub fn step(_attr: TokenStream, item: TokenStream) -> TokenStream {
156 let input_fn = parse_macro_input!(item as ItemFn);
157 let fn_name = &input_fn.sig.ident;
158 let fn_name_str = fn_name.to_string();
159 let vis = &input_fn.vis;
160 let sig = &input_fn.sig;
161 let body = &input_fn.block;
162 let attrs = &input_fn.attrs;
163
164 let expanded = quote! {
165 #(#attrs)*
166 #vis #sig {
167 let _step_name = #fn_name_str;
168 tracing::debug!(step = _step_name, "step executing");
169 let _result: Result<_, _> = async { #body }.await;
170 match &_result {
171 Ok(_) => tracing::debug!(step = _step_name, "step completed"),
172 Err(e) => tracing::warn!(step = _step_name, error = %e, "step failed"),
173 }
174 _result
175 }
176 };
177
178 TokenStream::from(expanded)
179}