Skip to main content

conductor_macros/
lib.rs

1// Copyright {{.Year}} Conductor OSS
2// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.
3
4use darling::{ast::NestedMeta, FromMeta};
5use proc_macro::TokenStream;
6use proc_macro2::TokenStream as TokenStream2;
7use quote::{format_ident, quote};
8use syn::{parse_macro_input, FnArg, ItemFn, Pat, ReturnType};
9
10/// Configuration options for the `#[worker]` attribute macro
11#[derive(Debug, FromMeta, Default)]
12struct WorkerArgs {
13    /// Task definition name (defaults to function name)
14    #[darling(default)]
15    name: Option<String>,
16
17    /// Poll interval in milliseconds (default: 100)
18    #[darling(default)]
19    poll_interval: Option<u64>,
20
21    /// Maximum concurrent task executions (default: 1)
22    #[darling(default)]
23    thread_count: Option<usize>,
24
25    /// Task routing domain (optional)
26    #[darling(default)]
27    domain: Option<String>,
28
29    /// Worker identity (optional, defaults to hostname-pid)
30    #[darling(default)]
31    identity: Option<String>,
32}
33
34/// Marks an async function as a Conductor worker task.
35///
36/// This macro transforms a regular async function into a Conductor worker that can be
37/// registered with a `TaskHandler`. The function's parameters are automatically extracted
38/// from the task's input data.
39///
40/// # Attributes
41///
42/// - `name` - Task definition name (defaults to function name)
43/// - `poll_interval` - Poll interval in milliseconds (default: 100)
44/// - `thread_count` - Maximum concurrent executions (default: 1)
45/// - `domain` - Task routing domain (optional)
46/// - `identity` - Worker identity (optional)
47///
48/// # Function Signatures
49///
50/// ## Simple parameters extracted from task input
51/// ```rust,ignore
52/// #[worker(name = "greet")]
53/// async fn greet(name: String) -> String {
54///     format!("Hello, {}!", name)
55/// }
56/// ```
57///
58/// ## With Task parameter for full access
59/// ```rust,ignore
60/// #[worker(name = "process")]
61/// async fn process(task: Task) -> WorkerOutput {
62///     let data = task.get_input_string("data").unwrap();
63///     WorkerOutput::completed_with_result(data)
64/// }
65/// ```
66///
67/// ## With TaskContext for metadata access
68/// ```rust,ignore
69/// #[worker(name = "long_running")]
70/// async fn long_running(ctx: TaskContext, batch_size: i32) -> WorkerOutput {
71///     let offset = ctx.poll_count() * batch_size;
72///     // Process batch at offset...
73///     
74///     if ctx.is_first_poll() {
75///         println!("Starting task {}", ctx.task_id());
76///     }
77///     
78///     WorkerOutput::in_progress(30)
79/// }
80/// ```
81///
82/// ## Combined Task and TaskContext
83/// ```rust,ignore
84/// #[worker(name = "advanced")]
85/// async fn advanced(task: Task, ctx: TaskContext) -> WorkerOutput {
86///     // Full task access and convenient context methods
87///     let input: String = task.get_input("data").unwrap_or_default();
88///     println!("Processing task {} (poll #{})", ctx.task_id(), ctx.poll_count());
89///     WorkerOutput::completed_with_result(input)
90/// }
91/// ```
92///
93/// # Return Types
94///
95/// - `String` or any serializable type - Wrapped in `WorkerOutput::completed_with_result()`
96/// - `WorkerOutput` - Used directly  
97/// - `Result<T, E>` - Success wrapped in completed, error converted to failed
98///
99/// # Generated Code
100///
101/// The macro generates a function `{fn_name}_worker()` that returns an `FnWorker`:
102///
103/// ```rust,ignore
104/// #[worker(name = "greet", thread_count = 5)]
105/// async fn greet(name: String) -> String {
106///     format!("Hello, {}!", name)
107/// }
108///
109/// // Generates:
110/// fn greet_worker() -> FnWorker {
111///     FnWorker::new("greet", |task| async move {
112///         let name: String = task.get_input("name").unwrap_or_default();
113///         let result = format!("Hello, {}!", name);
114///         Ok(WorkerOutput::completed_with_result(result))
115///     })
116///     .with_thread_count(5)
117/// }
118/// ```
119///
120/// # Usage
121///
122/// ```rust,ignore
123/// use conductor_macros::worker;
124/// use conductor::{TaskHandler, Configuration};
125///
126/// #[worker(name = "process_order", thread_count = 10, domain = "orders")]
127/// async fn process_order(order_id: String, amount: f64) -> serde_json::Value {
128///     serde_json::json!({
129///         "order_id": order_id,
130///         "amount": amount,
131///         "status": "processed"
132///     })
133/// }
134///
135/// #[tokio::main]
136/// async fn main() {
137///     let config = Configuration::default();
138///     let mut handler = TaskHandler::new(config).unwrap();
139///     
140///     // Use the generated worker function
141///     handler.add_worker(process_order_worker());
142///     
143///     handler.start().await.unwrap();
144/// }
145/// ```
146#[proc_macro_attribute]
147pub fn worker(args: TokenStream, input: TokenStream) -> TokenStream {
148    let attr_args = match NestedMeta::parse_meta_list(args.into()) {
149        Ok(v) => v,
150        Err(e) => {
151            return TokenStream::from(darling::Error::from(e).write_errors());
152        }
153    };
154
155    let args = match WorkerArgs::from_list(&attr_args) {
156        Ok(v) => v,
157        Err(e) => {
158            return TokenStream::from(e.write_errors());
159        }
160    };
161
162    let input_fn = parse_macro_input!(input as ItemFn);
163
164    match generate_worker(args, input_fn) {
165        Ok(tokens) => tokens.into(),
166        Err(e) => e.to_compile_error().into(),
167    }
168}
169
170fn generate_worker(args: WorkerArgs, input_fn: ItemFn) -> syn::Result<TokenStream2> {
171    let fn_name = &input_fn.sig.ident;
172    let fn_vis = &input_fn.vis;
173    let fn_block = &input_fn.block;
174    let fn_inputs = &input_fn.sig.inputs;
175    let fn_output = &input_fn.sig.output;
176
177    // Ensure function is async
178    if input_fn.sig.asyncness.is_none() {
179        return Err(syn::Error::new_spanned(
180            &input_fn.sig,
181            "worker function must be async",
182        ));
183    }
184
185    // Task name defaults to function name
186    let task_name = args.name.unwrap_or_else(|| fn_name.to_string());
187
188    // Worker function name (appends _worker)
189    let worker_fn_name = format_ident!("{}_worker", fn_name);
190
191    // Configuration values
192    let poll_interval = args.poll_interval.unwrap_or(100);
193    let thread_count = args.thread_count.unwrap_or(1);
194
195    // Analyze parameters
196    let mut param_extractions = Vec::new();
197    let mut fn_args = Vec::new();
198    let mut has_task_param = false;
199    let mut has_context_param = false;
200
201    for arg in fn_inputs {
202        match arg {
203            FnArg::Receiver(_) => {
204                return Err(syn::Error::new_spanned(
205                    arg,
206                    "worker functions cannot have self parameter",
207                ));
208            }
209            FnArg::Typed(pat_type) => {
210                let name = match &*pat_type.pat {
211                    Pat::Ident(ident) => ident.ident.clone(),
212                    _ => {
213                        return Err(syn::Error::new_spanned(
214                            &pat_type.pat,
215                            "expected simple identifier pattern",
216                        ));
217                    }
218                };
219
220                let ty = &pat_type.ty;
221                let ty_str = quote!(#ty).to_string().replace(' ', "");
222
223                // Check if this is the TaskContext parameter
224                if ty_str.contains("TaskContext") {
225                    has_context_param = true;
226                    fn_args.push(quote! { __ctx });
227                // Check if this is the Task parameter
228                } else if ty_str.contains("Task") {
229                    has_task_param = true;
230                    fn_args.push(quote! { task.clone() });
231                } else {
232                    // Regular parameter - extract from task input
233                    let name_str = name.to_string();
234                    param_extractions.push(quote! {
235                        let #name: #ty = task.get_input(#name_str).unwrap_or_default();
236                    });
237                    fn_args.push(quote! { #name });
238                }
239            }
240        }
241    }
242
243    // Generate TaskContext extraction if needed
244    let context_extraction = if has_context_param {
245        quote! {
246            let __ctx = task.context();
247        }
248    } else {
249        quote! {}
250    };
251
252    // Generate return handling based on return type
253    let return_handling = match fn_output {
254        ReturnType::Default => {
255            quote! {
256                Ok(::conductor::worker::WorkerOutput::completed_with_result(()))
257            }
258        }
259        ReturnType::Type(_, ty) => {
260            let ty_str = quote!(#ty).to_string().replace(' ', "");
261
262            if ty_str.contains("WorkerOutput") {
263                quote! { Ok(result) }
264            } else if ty_str.starts_with("Result<") || ty_str.contains("::Result<") {
265                quote! {
266                    match result {
267                        Ok(value) => Ok(::conductor::worker::WorkerOutput::completed_with_result(value)),
268                        Err(e) => Ok(::conductor::worker::WorkerOutput::failed(format!("{}", e))),
269                    }
270                }
271            } else {
272                quote! {
273                    Ok(::conductor::worker::WorkerOutput::completed_with_result(result))
274                }
275            }
276        }
277    };
278
279    // Generate domain configuration
280    let domain_config = if let Some(domain) = &args.domain {
281        quote! { .with_domain(#domain) }
282    } else {
283        quote! {}
284    };
285
286    // Generate identity configuration
287    let identity_config = if let Some(identity) = &args.identity {
288        quote! { .with_identity(#identity) }
289    } else {
290        quote! {}
291    };
292
293    // Build the async block body
294    let async_body = if has_task_param || has_context_param {
295        quote! {
296            #context_extraction
297            #(#param_extractions)*
298            let result = (|#fn_inputs| async move #fn_block)(#(#fn_args),*).await;
299            #return_handling
300        }
301    } else if !fn_args.is_empty() {
302        quote! {
303            #(#param_extractions)*
304            let result = (|#fn_inputs| async move #fn_block)(#(#fn_args),*).await;
305            #return_handling
306        }
307    } else {
308        quote! {
309            #(#param_extractions)*
310            let result = (|| async move #fn_block)().await;
311            #return_handling
312        }
313    };
314
315    // Generate the output
316    let output = quote! {
317        /// Creates a worker for the `#task_name` task.
318        ///
319        /// Generated by the `#[worker]` macro from the `#fn_name` function.
320        #fn_vis fn #worker_fn_name() -> ::conductor::worker::FnWorker {
321            ::conductor::worker::FnWorker::new(#task_name, |task: ::conductor::models::Task| async move {
322                #async_body
323            })
324            .with_poll_interval_millis(#poll_interval)
325            .with_thread_count(#thread_count)
326            #domain_config
327            #identity_config
328        }
329    };
330
331    Ok(output)
332}
333
334/// Alias for `#[worker]` - marks a function as a Conductor worker task.
335///
336/// This is provided for familiarity with Python SDK's `@worker_task` decorator.
337#[proc_macro_attribute]
338pub fn worker_task(args: TokenStream, input: TokenStream) -> TokenStream {
339    worker(args, input)
340}