Skip to main content

durable_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{ItemFn, parse_macro_input};
4
5/// Marks a function as a durable workflow with automatic crash recovery.
6///
7/// Supports two signatures:
8///
9/// **No input** — `fn(Ctx)`:
10/// ```ignore
11/// #[durable::workflow]
12/// async fn daily_etl(ctx: Ctx) -> Result<EtlReport, DurableError> {
13///     let data = ctx.step("extract", || async { fetch().await }).await?;
14///     ctx.complete(&data).await?;
15///     Ok(data)
16/// }
17///
18/// // Generated: durable_daily_etl(&db, "my-etl").await?
19/// ```
20///
21/// **Typed input** — `fn(Ctx, T)` where `T: Serialize + DeserializeOwned`:
22/// ```ignore
23/// #[derive(Serialize, Deserialize)]
24/// struct IngestInput { crawl_id: String, shard_count: u32 }
25///
26/// #[durable::workflow]
27/// async fn ingest(ctx: Ctx, input: IngestInput) -> Result<(), DurableError> {
28///     for shard in 0..input.shard_count {
29///         ctx.step(&format!("shard-{shard}"), || async {
30///             process_shard(&input.crawl_id, shard).await
31///         }).await?;
32///     }
33///     ctx.complete(&"done").await?;
34///     Ok(())
35/// }
36///
37/// // Generated: durable_ingest(&db, "my-ingest", IngestInput { crawl_id: "CC-2026".into(), shard_count: 72 }).await?
38/// // On crash recovery: input is automatically deserialized and passed to ingest()
39/// ```
40#[proc_macro_attribute]
41pub fn workflow(_attr: TokenStream, item: TokenStream) -> TokenStream {
42    let input_fn = parse_macro_input!(item as ItemFn);
43    let fn_name = &input_fn.sig.ident;
44    let fn_name_str = fn_name.to_string();
45    let vis = &input_fn.vis;
46    let sig = &input_fn.sig;
47    let body = &input_fn.block;
48    let attrs = &input_fn.attrs;
49
50    let is_ctx_type = |ty: &syn::Type| -> bool {
51        matches!(ty, syn::Type::Path(tp) if tp.path.segments.last().is_some_and(|s| s.ident == "Ctx"))
52    };
53
54    // Detect signature shape
55    let first_is_ctx = !input_fn.sig.inputs.is_empty()
56        && matches!(&input_fn.sig.inputs[0], syn::FnArg::Typed(pat) if is_ctx_type(&pat.ty));
57
58    let registration = if first_is_ctx && input_fn.sig.inputs.len() == 1 {
59        // fn(Ctx) — no input, simple case
60        let start_fn_name = syn::Ident::new(&format!("durable_{}", fn_name), fn_name.span());
61        quote! {
62            ::durable::inventory::submit! {
63                ::durable::WorkflowRegistration {
64                    name: #fn_name_str,
65                    resume_fn: |ctx| ::std::boxed::Box::pin(async move {
66                        let _ = #fn_name(ctx).await?;
67                        Ok(())
68                    }),
69                }
70            }
71
72            /// Auto-generated by `#[durable::workflow]`. Starts (or idempotently
73            /// attaches to) a root workflow, recording the handler name for crash
74            /// recovery.
75            #vis async fn #start_fn_name(
76                db: &::sea_orm::DatabaseConnection,
77                name: &str,
78            ) -> ::std::result::Result<::durable::StartResult, ::durable::DurableError> {
79                ::durable::Ctx::start_with_handler(db, name, ::std::option::Option::None, ::std::option::Option::Some(#fn_name_str)).await
80            }
81        }
82    } else if first_is_ctx && input_fn.sig.inputs.len() == 2 {
83        // fn(Ctx, T) — typed input
84        let input_type = match &input_fn.sig.inputs[1] {
85            syn::FnArg::Typed(pat) => &pat.ty,
86            _ => panic!("#[durable::workflow] second parameter must be a typed argument"),
87        };
88        let start_fn_name = syn::Ident::new(&format!("durable_{}", fn_name), fn_name.span());
89        quote! {
90            ::durable::inventory::submit! {
91                ::durable::WorkflowRegistration {
92                    name: #fn_name_str,
93                    resume_fn: |ctx| ::std::boxed::Box::pin(async move {
94                        let input: #input_type = ctx.input().await?;
95                        let _ = #fn_name(ctx, input).await?;
96                        Ok(())
97                    }),
98                }
99            }
100
101            /// Auto-generated by `#[durable::workflow]`. Starts (or idempotently
102            /// attaches to) a root workflow with typed input, recording the handler
103            /// name for crash recovery.
104            #vis async fn #start_fn_name(
105                db: &::sea_orm::DatabaseConnection,
106                name: &str,
107                input: #input_type,
108            ) -> ::std::result::Result<::durable::StartResult, ::durable::DurableError> {
109                let input_json = ::serde_json::to_value(&input)
110                    .map_err(|e| ::durable::DurableError::custom(format!("failed to serialize workflow input: {e}")))?;
111                ::durable::Ctx::start_with_handler(db, name, ::std::option::Option::Some(input_json), ::std::option::Option::Some(#fn_name_str)).await
112            }
113        }
114    } else {
115        quote! {}
116    };
117
118    let expanded = quote! {
119        #(#attrs)*
120        #vis #sig {
121            let _workflow_name = #fn_name_str;
122            tracing::info!(workflow = _workflow_name, "workflow started");
123            let _result: Result<_, _> = async { #body }.await;
124            match &_result {
125                Ok(_) => tracing::info!(workflow = _workflow_name, "workflow completed"),
126                Err(e) => tracing::error!(workflow = _workflow_name, error = %e, "workflow failed"),
127            }
128            _result
129        }
130
131        #registration
132    };
133
134    TokenStream::from(expanded)
135}
136
137/// Marks a function as a durable step.
138///
139/// The first parameter must be `ctx: &Ctx`. The macro wraps the function
140/// to check for saved output, execute if needed, and save the result.
141///
142/// ```ignore
143/// #[durable::step]
144/// async fn fetch_data(ctx: &Ctx, url: &str) -> Result<Data> {
145///     // ...
146/// }
147/// ```
148#[proc_macro_attribute]
149pub fn step(_attr: TokenStream, item: TokenStream) -> TokenStream {
150    let input_fn = parse_macro_input!(item as ItemFn);
151    let fn_name = &input_fn.sig.ident;
152    let fn_name_str = fn_name.to_string();
153    let vis = &input_fn.vis;
154    let sig = &input_fn.sig;
155    let body = &input_fn.block;
156    let attrs = &input_fn.attrs;
157
158    let expanded = quote! {
159        #(#attrs)*
160        #vis #sig {
161            let _step_name = #fn_name_str;
162            tracing::debug!(step = _step_name, "step executing");
163            let _result: Result<_, _> = async { #body }.await;
164            match &_result {
165                Ok(_) => tracing::debug!(step = _step_name, "step completed"),
166                Err(e) => tracing::warn!(step = _step_name, error = %e, "step failed"),
167            }
168            _result
169        }
170    };
171
172    TokenStream::from(expanded)
173}