dagx_macros/lib.rs
1//! Procedural macros for dagx
2//!
3//! This crate provides the `#[task]` attribute macro that automatically implements
4//! the `Task` trait by deriving Input and Output types from the `run()` method signature.
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, FnArg, ImplItem, ItemImpl, Pat, PatType, ReturnType, Type};
9
10/// Attribute macro to automatically implement the `Task` trait.
11///
12/// Apply this to an `impl` block containing a `run()` method (sync or async). The macro:
13/// - Derives `Input` and `Output` types from the `run()` signature
14/// - Automatically implements the `Task` trait
15/// - **Generates type-specific extraction logic** - works with ANY type (Clone + Send + Sync)!
16/// - Supports both sync and async run methods
17/// - Supports stateless (no self) and stateful (&self, &mut self) tasks
18/// - Handles various input patterns (no inputs, single input, multiple inputs)
19///
20/// **Key Feature**: Custom types work automatically without implementing any traits!
21/// The macro generates inline extraction logic in `extract_and_run()` specific to your
22/// task's parameter types. Just derive `Clone` on your types and they'll work seamlessly.
23///
24/// # Task Patterns
25///
26/// The `#[task]` macro supports three patterns based on state requirements:
27///
28/// ## 1. Stateless Tasks (No State)
29///
30/// Unit structs for pure computations. Use **no `self` parameter**:
31///
32/// ```ignore
33/// use dagx::{task, Task};
34///
35/// struct Add;
36///
37/// #[task]
38/// impl Add {
39/// async fn run(a: &i32, b: &i32) -> i32 {
40/// a + b // Pure function, no state
41/// }
42/// }
43/// ```
44///
45/// ## 2. Read-Only State Tasks
46///
47/// Tasks that read configuration or constant data. Use **`&self`**:
48///
49/// ```ignore
50/// use dagx::{task, Task};
51///
52/// struct Multiplier {
53/// factor: i32,
54/// }
55///
56/// #[task]
57/// impl Multiplier {
58/// async fn run(&self, input: &i32) -> i32 {
59/// input * self.factor // Read-only access
60/// }
61/// }
62/// ```
63///
64/// ## 3. Mutable State Tasks
65///
66/// Tasks that accumulate or modify state. Use **`&mut self`**:
67///
68/// ```ignore
69/// use dagx::{task, Task};
70///
71/// struct Counter {
72/// count: i32,
73/// }
74///
75/// #[task]
76/// impl Counter {
77/// async fn run(&mut self, increment: &i32) -> i32 {
78/// self.count += increment; // Modifies state
79/// self.count
80/// }
81/// }
82/// ```
83///
84/// # Input Patterns
85///
86/// ## No Inputs (Source Tasks)
87///
88/// ```ignore
89/// use dagx::{task, Task};
90///
91/// struct LoadData {
92/// value: i32,
93/// }
94///
95/// #[task]
96/// impl LoadData {
97/// async fn run(&mut self) -> i32 {
98/// self.value
99/// }
100/// }
101/// ```ignore
102///
103/// ## Single Input
104///
105/// ```ignore
106/// use dagx::{task, Task};
107///
108/// struct Double;
109///
110/// #[task]
111/// impl Double {
112/// async fn run(&mut self, input: &i32) -> i32 {
113/// input * 2
114/// }
115/// }
116/// ```ignore
117///
118/// ## Multiple Inputs (up to 8)
119///
120/// ```ignore
121/// use dagx::{task, Task};
122///
123/// struct Combine;
124///
125/// #[task]
126/// impl Combine {
127/// async fn run(&mut self, a: &i32, b: &String, c: &bool) -> String {
128/// format!("{}: {} ({})", b, a, c)
129/// }
130/// }
131/// ```ignore
132///
133/// # Requirements
134///
135/// - The impl block must contain exactly one `async fn run()` method
136/// - The `run()` method can be stateless (no self parameter) or stateful (`&mut self`)
137/// - All input parameters must be references (e.g., `&i32`, not `i32`)
138/// - The macro requires `Task` to be in scope: `use dagx::Task;`
139/// - For stateless tasks, the struct must implement `Default` (e.g., unit structs)
140///
141/// # Generated Code
142///
143/// The macro transforms your implementation into a full `Task` trait implementation.
144///
145/// For stateless tasks (no self parameter):
146///
147/// ```ignore
148/// // Your code:
149/// #[task]
150/// impl Add {
151/// async fn run(a: &i32, b: &i32) -> i32 {
152/// a + b
153/// }
154/// }
155///
156/// // Generated:
157/// impl Task for Add {
158/// type Input = (i32, i32);
159/// type Output = i32;
160///
161/// async fn run(&mut self, input: Self::Input) -> Self::Output {
162/// let (a, b) = input;
163/// Self::run_impl(&a, &b).await
164/// }
165/// }
166///
167/// impl Add {
168/// #[inline]
169/// async fn run_impl(a: &i32, b: &i32) -> i32 {
170/// a + b
171/// }
172/// }
173/// ```ignore
174///
175/// For stateful tasks (with &mut self):
176///
177/// ```ignore
178/// // Your code:
179/// #[task]
180/// impl Counter {
181/// async fn run(&mut self, inc: &i32) -> i32 {
182/// self.count += inc;
183/// self.count
184/// }
185/// }
186///
187/// // Generated:
188/// impl Task for Counter {
189/// type Input = i32;
190/// type Output = i32;
191///
192/// async fn run(&mut self, input: Self::Input) -> Self::Output {
193/// let inc = input;
194/// self.run_impl(&inc).await
195/// }
196/// }
197///
198/// impl Counter {
199/// async fn run_impl(&mut self, inc: &i32) -> i32 {
200/// self.count += inc;
201/// self.count
202/// }
203/// }
204/// ```ignore
205#[proc_macro_attribute]
206pub fn task(_attr: TokenStream, item: TokenStream) -> TokenStream {
207 let impl_block = parse_macro_input!(item as ItemImpl);
208
209 // Extract the struct name
210 let struct_name = &impl_block.self_ty;
211
212 // Find the run() method
213 let run_method = match impl_block.items.iter().find_map(|item| {
214 if let ImplItem::Fn(method) = item {
215 if method.sig.ident == "run" {
216 return Some(method);
217 }
218 }
219 None
220 }) {
221 Some(method) => method,
222 None => {
223 return syn::Error::new_spanned(
224 &impl_block,
225 "impl block must contain a run() method\n\n\
226 Expected signature: fn run(&mut self, ...) -> OutputType\n\
227 or: async fn run(&mut self, ...) -> OutputType\n\
228 The #[task] macro requires a run() method to implement the Task trait.",
229 )
230 .to_compile_error()
231 .into();
232 }
233 };
234
235 // Check if the method is async or sync
236 let is_async = run_method.sig.asyncness.is_some();
237
238 // Extract parameters (excluding self)
239 let params_result: Result<Vec<_>, _> = run_method
240 .sig
241 .inputs
242 .iter()
243 .filter_map(|arg| {
244 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
245 // Extract the parameter name
246 let param_name = if let Pat::Ident(pat_ident) = &**pat {
247 &pat_ident.ident
248 } else {
249 return Some(Err(syn::Error::new_spanned(
250 pat,
251 "Unsupported parameter pattern\n\n\
252 Parameters must be simple identifiers like 'input: &T' or 'a: &i32'.",
253 )));
254 };
255
256 // Extract the inner type from &Type
257 let inner_type = if let Type::Reference(type_ref) = &**ty {
258 &type_ref.elem
259 } else {
260 return Some(Err(syn::Error::new_spanned(
261 ty,
262 "All parameters must be references (&T)\n\n\
263 Task inputs must be references to allow sharing data between tasks.\n\
264 Change this parameter from 'T' to '&T'.",
265 )));
266 };
267
268 Some(Ok((param_name.clone(), inner_type.clone())))
269 } else {
270 None // Skip self parameter
271 }
272 })
273 .collect();
274
275 let params = match params_result {
276 Ok(p) => p,
277 Err(e) => return e.to_compile_error().into(),
278 };
279
280 // Extract return type
281 let output_type = match &run_method.sig.output {
282 ReturnType::Default => {
283 return syn::Error::new_spanned(
284 &run_method.sig,
285 "run() method must have an explicit return type\n\n\
286 Specify the output type: async fn run(...) -> OutputType\n\
287 For tasks that don't return a value, use '-> ()'.",
288 )
289 .to_compile_error()
290 .into();
291 }
292 ReturnType::Type(_, ty) => ty.clone(),
293 };
294
295 // Build Input type based on parameter count
296 let input_type = match params.len() {
297 0 => quote! { () },
298 1 => {
299 let (_name, ty) = ¶ms[0];
300 quote! { #ty }
301 }
302 _ => {
303 let types: Vec<_> = params.iter().map(|(_, ty)| ty).collect();
304 quote! { ( #(#types),* ) }
305 }
306 };
307
308 // Generate parameter destructuring for the wrapper run() method
309 let (param_destructure, param_refs) = if params.is_empty() {
310 (quote! { _ }, quote! {})
311 } else if params.len() == 1 {
312 let (name, _) = ¶ms[0];
313 (quote! { #name }, quote! { &#name })
314 } else {
315 let names: Vec<_> = params.iter().map(|(name, _)| name).collect();
316 let refs: Vec<_> = params.iter().map(|(name, _)| quote! { &#name }).collect();
317 (quote! { ( #(#names),* ) }, quote! { #(#refs),* })
318 };
319
320 // Clone the run method and rename it to run_impl
321 let mut run_impl_method = run_method.clone();
322 run_impl_method.sig.ident = syn::Ident::new("run_impl", run_method.sig.ident.span());
323
324 // Check if the method has a self receiver
325 let has_self_receiver = run_method
326 .sig
327 .inputs
328 .iter()
329 .any(|arg| matches!(arg, FnArg::Receiver(_)));
330
331 // Generate extract_and_run implementation based on parameter count
332 let extract_and_run_impl = match params.len() {
333 0 => {
334 // Zero parameters - no extraction needed
335 quote! {
336 fn extract_and_run(
337 self,
338 _receivers: Vec<Box<dyn std::any::Any + Send>>,
339 ) -> impl std::future::Future<Output = Result<Self::Output, String>> + Send {
340 async move {
341 let input = ();
342 Ok(self.run(input).await)
343 }
344 }
345 }
346 }
347 1 => {
348 // Single parameter
349 let param_type = ¶ms[0].1;
350 quote! {
351 fn extract_and_run(
352 self,
353 mut receivers: Vec<Box<dyn std::any::Any + Send>>,
354 ) -> impl std::future::Future<Output = Result<Self::Output, String>> + Send {
355 async move {
356 use futures::channel::oneshot;
357 use std::sync::Arc;
358
359 if receivers.len() != 1 {
360 return Err(format!("Expected 1 dependency, got {}", receivers.len()));
361 }
362
363 let rx = *receivers.pop()
364 .unwrap()
365 .downcast::<oneshot::Receiver<Arc<#param_type>>>()
366 .map_err(|_| format!("Type mismatch: expected Arc<{}>", std::any::type_name::<#param_type>()))?;
367
368 let arc_value = rx.await
369 .map_err(|_| "Channel closed before receiving value".to_string())?;
370
371 let input = (*arc_value).clone();
372 Ok(self.run(input).await)
373 }
374 }
375 }
376 }
377 _ => {
378 // Multiple parameters
379 let param_count = params.len();
380 let param_types: Vec<_> = params.iter().map(|(_, ty)| ty).collect();
381 let indices: Vec<_> = (0..param_count).collect();
382
383 // Generate unique receiver variable names
384 let rx_vars: Vec<_> = (0..param_count)
385 .map(|i| syn::Ident::new(&format!("rx_{}", i), proc_macro2::Span::call_site()))
386 .collect();
387
388 // Create syn::Index for tuple field access (avoids the suffix warning)
389 let syn_indices: Vec<_> = (0..param_count).map(syn::Index::from).collect();
390
391 quote! {
392 fn extract_and_run(
393 self,
394 receivers: Vec<Box<dyn std::any::Any + Send>>,
395 ) -> impl std::future::Future<Output = Result<Self::Output, String>> + Send {
396 async move {
397 use futures::channel::oneshot;
398 use std::sync::Arc;
399
400 let expected_count = #param_count;
401 if receivers.len() != expected_count {
402 return Err(format!("Expected {} dependencies, got {}", expected_count, receivers.len()));
403 }
404
405 let mut iter = receivers.into_iter();
406
407 // Extract each receiver
408 #(
409 let #rx_vars = *iter.next()
410 .ok_or_else(|| format!("Missing receiver at index {}", #indices))?
411 .downcast::<oneshot::Receiver<Arc<#param_types>>>()
412 .map_err(|_| format!("Type mismatch at index {}: expected Arc<{}>",
413 #indices, std::any::type_name::<#param_types>()))?;
414 )*
415
416 // Await all channels concurrently
417 let arc_results = futures::join!(
418 #(
419 async move {
420 #rx_vars.await.map_err(|_| format!("Channel {} closed", #indices))
421 }
422 ),*
423 );
424
425 // Clone inner values and build tuple
426 let input = (#(
427 (*arc_results.#syn_indices?).clone()
428 ),*);
429
430 Ok(self.run(input).await)
431 }
432 }
433 }
434 }
435 };
436
437 // Generate the Task trait implementation based on whether we have self and async/sync
438 let expanded = if has_self_receiver {
439 // Stateful task - consumes self but delegates to a method that borrows
440 if is_async {
441 // Async with self
442 quote! {
443 impl Task for #struct_name {
444 type Input = #input_type;
445 type Output = #output_type;
446
447 async fn run(mut self, input: Self::Input) -> Self::Output {
448 let #param_destructure = input;
449 self.run_impl(#param_refs).await
450 }
451
452 #extract_and_run_impl
453 }
454
455 impl #struct_name {
456 #run_impl_method
457 }
458 }
459 } else {
460 // Sync with self - wrap in async block
461 quote! {
462 impl Task for #struct_name {
463 type Input = #input_type;
464 type Output = #output_type;
465
466 async fn run(mut self, input: Self::Input) -> Self::Output {
467 let #param_destructure = input;
468 self.run_impl(#param_refs)
469 }
470
471 #extract_and_run_impl
472 }
473
474 impl #struct_name {
475 #run_impl_method
476 }
477 }
478 }
479 } else {
480 // Stateless task
481 if is_async {
482 // Async stateless
483 quote! {
484 impl Task for #struct_name {
485 type Input = #input_type;
486 type Output = #output_type;
487
488 async fn run(self, input: Self::Input) -> Self::Output {
489 let #param_destructure = input;
490 Self::run_impl(#param_refs).await
491 }
492
493 #extract_and_run_impl
494 }
495
496 impl #struct_name {
497 #[inline]
498 #run_impl_method
499 }
500 }
501 } else {
502 // Sync stateless
503 quote! {
504 impl Task for #struct_name {
505 type Input = #input_type;
506 type Output = #output_type;
507
508 async fn run(self, input: Self::Input) -> Self::Output {
509 let #param_destructure = input;
510 Self::run_impl(#param_refs)
511 }
512
513 #extract_and_run_impl
514 }
515
516 impl #struct_name {
517 #[inline]
518 #run_impl_method
519 }
520 }
521 }
522 };
523
524 TokenStream::from(expanded)
525}