1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{ItemFn, parse_macro_input};
4
5#[proc_macro_attribute]
41pub fn workflow(_attr: TokenStream, item: TokenStream) -> TokenStream {
42 let input_fn = parse_macro_input!(item as ItemFn);
43 let fn_name = &input_fn.sig.ident;
44 let fn_name_str = fn_name.to_string();
45 let vis = &input_fn.vis;
46 let sig = &input_fn.sig;
47 let body = &input_fn.block;
48 let attrs = &input_fn.attrs;
49
50 let is_ctx_type = |ty: &syn::Type| -> bool {
51 matches!(ty, syn::Type::Path(tp) if tp.path.segments.last().is_some_and(|s| s.ident == "Ctx"))
52 };
53
54 let first_is_ctx = !input_fn.sig.inputs.is_empty()
56 && matches!(&input_fn.sig.inputs[0], syn::FnArg::Typed(pat) if is_ctx_type(&pat.ty));
57
58 let registration = if first_is_ctx && input_fn.sig.inputs.len() == 1 {
59 let start_fn_name = syn::Ident::new(&format!("durable_{}", fn_name), fn_name.span());
61 quote! {
62 ::durable::inventory::submit! {
63 ::durable::WorkflowRegistration {
64 name: #fn_name_str,
65 resume_fn: |ctx| ::std::boxed::Box::pin(async move {
66 let _ = #fn_name(ctx).await?;
67 Ok(())
68 }),
69 }
70 }
71
72 #vis async fn #start_fn_name(
76 db: &::sea_orm::DatabaseConnection,
77 name: &str,
78 ) -> ::std::result::Result<::durable::StartResult, ::durable::DurableError> {
79 ::durable::Ctx::start_with_handler(db, name, ::std::option::Option::None, ::std::option::Option::Some(#fn_name_str)).await
80 }
81 }
82 } else if first_is_ctx && input_fn.sig.inputs.len() == 2 {
83 let input_type = match &input_fn.sig.inputs[1] {
85 syn::FnArg::Typed(pat) => &pat.ty,
86 _ => panic!("#[durable::workflow] second parameter must be a typed argument"),
87 };
88 let start_fn_name = syn::Ident::new(&format!("durable_{}", fn_name), fn_name.span());
89 quote! {
90 ::durable::inventory::submit! {
91 ::durable::WorkflowRegistration {
92 name: #fn_name_str,
93 resume_fn: |ctx| ::std::boxed::Box::pin(async move {
94 let input: #input_type = ctx.input().await?;
95 let _ = #fn_name(ctx, input).await?;
96 Ok(())
97 }),
98 }
99 }
100
101 #vis async fn #start_fn_name(
105 db: &::sea_orm::DatabaseConnection,
106 name: &str,
107 input: #input_type,
108 ) -> ::std::result::Result<::durable::StartResult, ::durable::DurableError> {
109 let input_json = ::serde_json::to_value(&input)
110 .map_err(|e| ::durable::DurableError::custom(format!("failed to serialize workflow input: {e}")))?;
111 ::durable::Ctx::start_with_handler(db, name, ::std::option::Option::Some(input_json), ::std::option::Option::Some(#fn_name_str)).await
112 }
113 }
114 } else {
115 quote! {}
116 };
117
118 let expanded = quote! {
119 #(#attrs)*
120 #vis #sig {
121 let _workflow_name = #fn_name_str;
122 tracing::info!(workflow = _workflow_name, "workflow started");
123 let _result: Result<_, _> = async { #body }.await;
124 match &_result {
125 Ok(_) => tracing::info!(workflow = _workflow_name, "workflow completed"),
126 Err(e) => tracing::error!(workflow = _workflow_name, error = %e, "workflow failed"),
127 }
128 _result
129 }
130
131 #registration
132 };
133
134 TokenStream::from(expanded)
135}
136
137#[proc_macro_attribute]
149pub fn step(_attr: TokenStream, item: TokenStream) -> TokenStream {
150 let input_fn = parse_macro_input!(item as ItemFn);
151 let fn_name = &input_fn.sig.ident;
152 let fn_name_str = fn_name.to_string();
153 let vis = &input_fn.vis;
154 let sig = &input_fn.sig;
155 let body = &input_fn.block;
156 let attrs = &input_fn.attrs;
157
158 let expanded = quote! {
159 #(#attrs)*
160 #vis #sig {
161 let _step_name = #fn_name_str;
162 tracing::debug!(step = _step_name, "step executing");
163 let _result: Result<_, _> = async { #body }.await;
164 match &_result {
165 Ok(_) => tracing::debug!(step = _step_name, "step completed"),
166 Err(e) => tracing::warn!(step = _step_name, error = %e, "step failed"),
167 }
168 _result
169 }
170 };
171
172 TokenStream::from(expanded)
173}