Skip to main content

durable_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{ItemFn, parse_macro_input};
4
5/// Marks a function as a durable workflow with automatic crash recovery.
6///
7/// Supports two signatures:
8///
9/// **No input** — `fn(Ctx)`:
10/// ```ignore
11/// #[durable::workflow]
12/// async fn daily_etl(ctx: Ctx) -> Result<EtlReport, DurableError> {
13///     let data = ctx.step("extract", || async { fetch().await }).await?;
14///     ctx.complete(&data).await?;
15///     Ok(data)
16/// }
17///
18/// // Generated: durable_daily_etl(&db, "my-etl").await?
19/// ```
20///
21/// **Typed input** — `fn(Ctx, T)` where `T: Serialize + DeserializeOwned`:
22/// ```ignore
23/// #[derive(Serialize, Deserialize)]
24/// struct IngestInput { crawl_id: String, shard_count: u32 }
25///
26/// #[durable::workflow]
27/// async fn ingest(ctx: Ctx, input: IngestInput) -> Result<(), DurableError> {
28///     for shard in 0..input.shard_count {
29///         ctx.step(&format!("shard-{shard}"), || async {
30///             process_shard(&input.crawl_id, shard).await
31///         }).await?;
32///     }
33///     ctx.complete(&"done").await?;
34///     Ok(())
35/// }
36///
37/// // Generated: durable_ingest(&db, "my-ingest", IngestInput { crawl_id: "CC-2026".into(), shard_count: 72 }).await?
38/// // On crash recovery: input is automatically deserialized and passed to ingest()
39/// ```
40#[proc_macro_attribute]
41pub fn workflow(_attr: TokenStream, item: TokenStream) -> TokenStream {
42    let input_fn = parse_macro_input!(item as ItemFn);
43    let fn_name = &input_fn.sig.ident;
44    let fn_name_str = fn_name.to_string();
45    let vis = &input_fn.vis;
46    let sig = &input_fn.sig;
47    let body = &input_fn.block;
48    let attrs = &input_fn.attrs;
49
50    let is_ctx_type = |ty: &syn::Type| -> bool {
51        matches!(ty, syn::Type::Path(tp) if tp.path.segments.last().is_some_and(|s| s.ident == "Ctx"))
52    };
53
54    // Detect signature shape
55    let first_is_ctx = input_fn.sig.inputs.len() >= 1
56        && matches!(&input_fn.sig.inputs[0], syn::FnArg::Typed(pat) if is_ctx_type(&pat.ty));
57
58    let registration = if first_is_ctx && input_fn.sig.inputs.len() == 1 {
59        // fn(Ctx) — no input, simple case
60        let start_fn_name = syn::Ident::new(
61            &format!("durable_{}", fn_name),
62            fn_name.span(),
63        );
64        quote! {
65            ::durable::inventory::submit! {
66                ::durable::WorkflowRegistration {
67                    name: #fn_name_str,
68                    resume_fn: |ctx| ::std::boxed::Box::pin(async move {
69                        let _ = #fn_name(ctx).await?;
70                        Ok(())
71                    }),
72                }
73            }
74
75            /// Auto-generated by `#[durable::workflow]`. Starts (or idempotently
76            /// attaches to) a root workflow, recording the handler name for crash
77            /// recovery.
78            #vis async fn #start_fn_name(
79                db: &::sea_orm::DatabaseConnection,
80                name: &str,
81            ) -> ::std::result::Result<::durable::Ctx, ::durable::DurableError> {
82                ::durable::Ctx::start_with_handler(db, name, ::std::option::Option::None, ::std::option::Option::Some(#fn_name_str)).await
83            }
84        }
85    } else if first_is_ctx && input_fn.sig.inputs.len() == 2 {
86        // fn(Ctx, T) — typed input
87        let input_type = match &input_fn.sig.inputs[1] {
88            syn::FnArg::Typed(pat) => &pat.ty,
89            _ => panic!("#[durable::workflow] second parameter must be a typed argument"),
90        };
91        let start_fn_name = syn::Ident::new(
92            &format!("durable_{}", fn_name),
93            fn_name.span(),
94        );
95        quote! {
96            ::durable::inventory::submit! {
97                ::durable::WorkflowRegistration {
98                    name: #fn_name_str,
99                    resume_fn: |ctx| ::std::boxed::Box::pin(async move {
100                        let input: #input_type = ctx.input().await?;
101                        let _ = #fn_name(ctx, input).await?;
102                        Ok(())
103                    }),
104                }
105            }
106
107            /// Auto-generated by `#[durable::workflow]`. Starts (or idempotently
108            /// attaches to) a root workflow with typed input, recording the handler
109            /// name for crash recovery.
110            #vis async fn #start_fn_name(
111                db: &::sea_orm::DatabaseConnection,
112                name: &str,
113                input: #input_type,
114            ) -> ::std::result::Result<::durable::Ctx, ::durable::DurableError> {
115                let input_json = ::serde_json::to_value(&input)
116                    .map_err(|e| ::durable::DurableError::custom(format!("failed to serialize workflow input: {e}")))?;
117                ::durable::Ctx::start_with_handler(db, name, ::std::option::Option::Some(input_json), ::std::option::Option::Some(#fn_name_str)).await
118            }
119        }
120    } else {
121        quote! {}
122    };
123
124    let expanded = quote! {
125        #(#attrs)*
126        #vis #sig {
127            let _workflow_name = #fn_name_str;
128            tracing::info!(workflow = _workflow_name, "workflow started");
129            let _result: Result<_, _> = async { #body }.await;
130            match &_result {
131                Ok(_) => tracing::info!(workflow = _workflow_name, "workflow completed"),
132                Err(e) => tracing::error!(workflow = _workflow_name, error = %e, "workflow failed"),
133            }
134            _result
135        }
136
137        #registration
138    };
139
140    TokenStream::from(expanded)
141}
142
143/// Marks a function as a durable step.
144///
145/// The first parameter must be `ctx: &Ctx`. The macro wraps the function
146/// to check for saved output, execute if needed, and save the result.
147///
148/// ```ignore
149/// #[durable::step]
150/// async fn fetch_data(ctx: &Ctx, url: &str) -> Result<Data> {
151///     // ...
152/// }
153/// ```
154#[proc_macro_attribute]
155pub fn step(_attr: TokenStream, item: TokenStream) -> TokenStream {
156    let input_fn = parse_macro_input!(item as ItemFn);
157    let fn_name = &input_fn.sig.ident;
158    let fn_name_str = fn_name.to_string();
159    let vis = &input_fn.vis;
160    let sig = &input_fn.sig;
161    let body = &input_fn.block;
162    let attrs = &input_fn.attrs;
163
164    let expanded = quote! {
165        #(#attrs)*
166        #vis #sig {
167            let _step_name = #fn_name_str;
168            tracing::debug!(step = _step_name, "step executing");
169            let _result: Result<_, _> = async { #body }.await;
170            match &_result {
171                Ok(_) => tracing::debug!(step = _step_name, "step completed"),
172                Err(e) => tracing::warn!(step = _step_name, error = %e, "step failed"),
173            }
174            _result
175        }
176    };
177
178    TokenStream::from(expanded)
179}