conductor_macros/lib.rs
1// Copyright {{.Year}} Conductor OSS
2// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.
3
4use darling::{ast::NestedMeta, FromMeta};
5use proc_macro::TokenStream;
6use proc_macro2::TokenStream as TokenStream2;
7use quote::{format_ident, quote};
8use syn::{parse_macro_input, FnArg, ItemFn, Pat, ReturnType};
9
10/// Configuration options for the `#[worker]` attribute macro
11#[derive(Debug, FromMeta, Default)]
12struct WorkerArgs {
13 /// Task definition name (defaults to function name)
14 #[darling(default)]
15 name: Option<String>,
16
17 /// Poll interval in milliseconds (default: 100)
18 #[darling(default)]
19 poll_interval: Option<u64>,
20
21 /// Maximum concurrent task executions (default: 1)
22 #[darling(default)]
23 thread_count: Option<usize>,
24
25 /// Task routing domain (optional)
26 #[darling(default)]
27 domain: Option<String>,
28
29 /// Worker identity (optional, defaults to hostname-pid)
30 #[darling(default)]
31 identity: Option<String>,
32}
33
34/// Marks an async function as a Conductor worker task.
35///
36/// This macro transforms a regular async function into a Conductor worker that can be
37/// registered with a `TaskHandler`. The function's parameters are automatically extracted
38/// from the task's input data.
39///
40/// # Attributes
41///
42/// - `name` - Task definition name (defaults to function name)
43/// - `poll_interval` - Poll interval in milliseconds (default: 100)
44/// - `thread_count` - Maximum concurrent executions (default: 1)
45/// - `domain` - Task routing domain (optional)
46/// - `identity` - Worker identity (optional)
47///
48/// # Function Signatures
49///
50/// ## Simple parameters extracted from task input
51/// ```rust,ignore
52/// #[worker(name = "greet")]
53/// async fn greet(name: String) -> String {
54/// format!("Hello, {}!", name)
55/// }
56/// ```
57///
58/// ## With Task parameter for full access
59/// ```rust,ignore
60/// #[worker(name = "process")]
61/// async fn process(task: Task) -> WorkerOutput {
62/// let data = task.get_input_string("data").unwrap();
63/// WorkerOutput::completed_with_result(data)
64/// }
65/// ```
66///
67/// ## With TaskContext for metadata access
68/// ```rust,ignore
69/// #[worker(name = "long_running")]
70/// async fn long_running(ctx: TaskContext, batch_size: i32) -> WorkerOutput {
71/// let offset = ctx.poll_count() * batch_size;
72/// // Process batch at offset...
73///
74/// if ctx.is_first_poll() {
75/// println!("Starting task {}", ctx.task_id());
76/// }
77///
78/// WorkerOutput::in_progress(30)
79/// }
80/// ```
81///
82/// ## Combined Task and TaskContext
83/// ```rust,ignore
84/// #[worker(name = "advanced")]
85/// async fn advanced(task: Task, ctx: TaskContext) -> WorkerOutput {
86/// // Full task access and convenient context methods
87/// let input: String = task.get_input("data").unwrap_or_default();
88/// println!("Processing task {} (poll #{})", ctx.task_id(), ctx.poll_count());
89/// WorkerOutput::completed_with_result(input)
90/// }
91/// ```
92///
93/// # Return Types
94///
95/// - `String` or any serializable type - Wrapped in `WorkerOutput::completed_with_result()`
96/// - `WorkerOutput` - Used directly
97/// - `Result<T, E>` - Success wrapped in completed, error converted to failed
98///
99/// # Generated Code
100///
101/// The macro generates a function `{fn_name}_worker()` that returns an `FnWorker`:
102///
103/// ```rust,ignore
104/// #[worker(name = "greet", thread_count = 5)]
105/// async fn greet(name: String) -> String {
106/// format!("Hello, {}!", name)
107/// }
108///
109/// // Generates:
110/// fn greet_worker() -> FnWorker {
111/// FnWorker::new("greet", |task| async move {
112/// let name: String = task.get_input("name").unwrap_or_default();
113/// let result = format!("Hello, {}!", name);
114/// Ok(WorkerOutput::completed_with_result(result))
115/// })
116/// .with_thread_count(5)
117/// }
118/// ```
119///
120/// # Usage
121///
122/// ```rust,ignore
123/// use conductor_macros::worker;
124/// use conductor::{TaskHandler, Configuration};
125///
126/// #[worker(name = "process_order", thread_count = 10, domain = "orders")]
127/// async fn process_order(order_id: String, amount: f64) -> serde_json::Value {
128/// serde_json::json!({
129/// "order_id": order_id,
130/// "amount": amount,
131/// "status": "processed"
132/// })
133/// }
134///
135/// #[tokio::main]
136/// async fn main() {
137/// let config = Configuration::default();
138/// let mut handler = TaskHandler::new(config).unwrap();
139///
140/// // Use the generated worker function
141/// handler.add_worker(process_order_worker());
142///
143/// handler.start().await.unwrap();
144/// }
145/// ```
146#[proc_macro_attribute]
147pub fn worker(args: TokenStream, input: TokenStream) -> TokenStream {
148 let attr_args = match NestedMeta::parse_meta_list(args.into()) {
149 Ok(v) => v,
150 Err(e) => {
151 return TokenStream::from(darling::Error::from(e).write_errors());
152 }
153 };
154
155 let args = match WorkerArgs::from_list(&attr_args) {
156 Ok(v) => v,
157 Err(e) => {
158 return TokenStream::from(e.write_errors());
159 }
160 };
161
162 let input_fn = parse_macro_input!(input as ItemFn);
163
164 match generate_worker(args, input_fn) {
165 Ok(tokens) => tokens.into(),
166 Err(e) => e.to_compile_error().into(),
167 }
168}
169
170fn generate_worker(args: WorkerArgs, input_fn: ItemFn) -> syn::Result<TokenStream2> {
171 let fn_name = &input_fn.sig.ident;
172 let fn_vis = &input_fn.vis;
173 let fn_block = &input_fn.block;
174 let fn_inputs = &input_fn.sig.inputs;
175 let fn_output = &input_fn.sig.output;
176
177 // Ensure function is async
178 if input_fn.sig.asyncness.is_none() {
179 return Err(syn::Error::new_spanned(
180 &input_fn.sig,
181 "worker function must be async",
182 ));
183 }
184
185 // Task name defaults to function name
186 let task_name = args.name.unwrap_or_else(|| fn_name.to_string());
187
188 // Worker function name (appends _worker)
189 let worker_fn_name = format_ident!("{}_worker", fn_name);
190
191 // Configuration values
192 let poll_interval = args.poll_interval.unwrap_or(100);
193 let thread_count = args.thread_count.unwrap_or(1);
194
195 // Analyze parameters
196 let mut param_extractions = Vec::new();
197 let mut fn_args = Vec::new();
198 let mut has_task_param = false;
199 let mut has_context_param = false;
200
201 for arg in fn_inputs {
202 match arg {
203 FnArg::Receiver(_) => {
204 return Err(syn::Error::new_spanned(
205 arg,
206 "worker functions cannot have self parameter",
207 ));
208 }
209 FnArg::Typed(pat_type) => {
210 let name = match &*pat_type.pat {
211 Pat::Ident(ident) => ident.ident.clone(),
212 _ => {
213 return Err(syn::Error::new_spanned(
214 &pat_type.pat,
215 "expected simple identifier pattern",
216 ));
217 }
218 };
219
220 let ty = &pat_type.ty;
221 let ty_str = quote!(#ty).to_string().replace(' ', "");
222
223 // Check if this is the TaskContext parameter
224 if ty_str.contains("TaskContext") {
225 has_context_param = true;
226 fn_args.push(quote! { __ctx });
227 // Check if this is the Task parameter
228 } else if ty_str.contains("Task") {
229 has_task_param = true;
230 fn_args.push(quote! { task.clone() });
231 } else {
232 // Regular parameter - extract from task input
233 let name_str = name.to_string();
234 param_extractions.push(quote! {
235 let #name: #ty = task.get_input(#name_str).unwrap_or_default();
236 });
237 fn_args.push(quote! { #name });
238 }
239 }
240 }
241 }
242
243 // Generate TaskContext extraction if needed
244 let context_extraction = if has_context_param {
245 quote! {
246 let __ctx = task.context();
247 }
248 } else {
249 quote! {}
250 };
251
252 // Generate return handling based on return type
253 let return_handling = match fn_output {
254 ReturnType::Default => {
255 quote! {
256 Ok(::conductor::worker::WorkerOutput::completed_with_result(()))
257 }
258 }
259 ReturnType::Type(_, ty) => {
260 let ty_str = quote!(#ty).to_string().replace(' ', "");
261
262 if ty_str.contains("WorkerOutput") {
263 quote! { Ok(result) }
264 } else if ty_str.starts_with("Result<") || ty_str.contains("::Result<") {
265 quote! {
266 match result {
267 Ok(value) => Ok(::conductor::worker::WorkerOutput::completed_with_result(value)),
268 Err(e) => Ok(::conductor::worker::WorkerOutput::failed(format!("{}", e))),
269 }
270 }
271 } else {
272 quote! {
273 Ok(::conductor::worker::WorkerOutput::completed_with_result(result))
274 }
275 }
276 }
277 };
278
279 // Generate domain configuration
280 let domain_config = if let Some(domain) = &args.domain {
281 quote! { .with_domain(#domain) }
282 } else {
283 quote! {}
284 };
285
286 // Generate identity configuration
287 let identity_config = if let Some(identity) = &args.identity {
288 quote! { .with_identity(#identity) }
289 } else {
290 quote! {}
291 };
292
293 // Build the async block body
294 let async_body = if has_task_param || has_context_param {
295 quote! {
296 #context_extraction
297 #(#param_extractions)*
298 let result = (|#fn_inputs| async move #fn_block)(#(#fn_args),*).await;
299 #return_handling
300 }
301 } else if !fn_args.is_empty() {
302 quote! {
303 #(#param_extractions)*
304 let result = (|#fn_inputs| async move #fn_block)(#(#fn_args),*).await;
305 #return_handling
306 }
307 } else {
308 quote! {
309 #(#param_extractions)*
310 let result = (|| async move #fn_block)().await;
311 #return_handling
312 }
313 };
314
315 // Generate the output
316 let output = quote! {
317 /// Creates a worker for the `#task_name` task.
318 ///
319 /// Generated by the `#[worker]` macro from the `#fn_name` function.
320 #fn_vis fn #worker_fn_name() -> ::conductor::worker::FnWorker {
321 ::conductor::worker::FnWorker::new(#task_name, |task: ::conductor::models::Task| async move {
322 #async_body
323 })
324 .with_poll_interval_millis(#poll_interval)
325 .with_thread_count(#thread_count)
326 #domain_config
327 #identity_config
328 }
329 };
330
331 Ok(output)
332}
333
334/// Alias for `#[worker]` - marks a function as a Conductor worker task.
335///
336/// This is provided for familiarity with Python SDK's `@worker_task` decorator.
337#[proc_macro_attribute]
338pub fn worker_task(args: TokenStream, input: TokenStream) -> TokenStream {
339 worker(args, input)
340}