dagx_macros/
lib.rs

1//! Procedural macros for dagx
2//!
3//! This crate provides the `#[task]` attribute macro that automatically implements
4//! the `Task` trait by deriving Input and Output types from the `run()` method signature.
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, FnArg, ImplItem, ItemImpl, Pat, PatType, ReturnType, Type};
9
10/// Attribute macro to automatically implement the `Task` trait.
11///
12/// Apply this to an `impl` block containing a `run()` method (sync or async). The macro:
13/// - Derives `Input` and `Output` types from the `run()` signature
14/// - Automatically implements the `Task` trait
15/// - **Generates type-specific extraction logic** - works with ANY type (Clone + Send + Sync)!
16/// - Supports both sync and async run methods
17/// - Supports stateless (no self) and stateful (&self, &mut self) tasks
18/// - Handles various input patterns (no inputs, single input, multiple inputs)
19///
20/// **Key Feature**: Custom types work automatically without implementing any traits!
21/// The macro generates inline extraction logic in `extract_and_run()` specific to your
22/// task's parameter types. Just derive `Clone` on your types and they'll work seamlessly.
23///
24/// # Task Patterns
25///
26/// The `#[task]` macro supports three patterns based on state requirements:
27///
28/// ## 1. Stateless Tasks (No State)
29///
30/// Unit structs for pure computations. Use **no `self` parameter**:
31///
32/// ```ignore
33/// use dagx::{task, Task};
34///
35/// struct Add;
36///
37/// #[task]
38/// impl Add {
39///     async fn run(a: &i32, b: &i32) -> i32 {
40///         a + b  // Pure function, no state
41///     }
42/// }
43/// ```
44///
45/// ## 2. Read-Only State Tasks
46///
47/// Tasks that read configuration or constant data. Use **`&self`**:
48///
49/// ```ignore
50/// use dagx::{task, Task};
51///
52/// struct Multiplier {
53///     factor: i32,
54/// }
55///
56/// #[task]
57/// impl Multiplier {
58///     async fn run(&self, input: &i32) -> i32 {
59///         input * self.factor  // Read-only access
60///     }
61/// }
62/// ```
63///
64/// ## 3. Mutable State Tasks
65///
66/// Tasks that accumulate or modify state. Use **`&mut self`**:
67///
68/// ```ignore
69/// use dagx::{task, Task};
70///
71/// struct Counter {
72///     count: i32,
73/// }
74///
75/// #[task]
76/// impl Counter {
77///     async fn run(&mut self, increment: &i32) -> i32 {
78///         self.count += increment;  // Modifies state
79///         self.count
80///     }
81/// }
82/// ```
83///
84/// # Input Patterns
85///
86/// ## No Inputs (Source Tasks)
87///
88/// ```ignore
89/// use dagx::{task, Task};
90///
91/// struct LoadData {
92///     value: i32,
93/// }
94///
95/// #[task]
96/// impl LoadData {
97///     async fn run(&mut self) -> i32 {
98///         self.value
99///     }
100/// }
101/// ```ignore
102///
103/// ## Single Input
104///
105/// ```ignore
106/// use dagx::{task, Task};
107///
108/// struct Double;
109///
110/// #[task]
111/// impl Double {
112///     async fn run(&mut self, input: &i32) -> i32 {
113///         input * 2
114///     }
115/// }
116/// ```ignore
117///
118/// ## Multiple Inputs (up to 8)
119///
120/// ```ignore
121/// use dagx::{task, Task};
122///
123/// struct Combine;
124///
125/// #[task]
126/// impl Combine {
127///     async fn run(&mut self, a: &i32, b: &String, c: &bool) -> String {
128///         format!("{}: {} ({})", b, a, c)
129///     }
130/// }
131/// ```ignore
132///
133/// # Requirements
134///
135/// - The impl block must contain exactly one `async fn run()` method
136/// - The `run()` method can be stateless (no self parameter) or stateful (`&mut self`)
137/// - All input parameters must be references (e.g., `&i32`, not `i32`)
138/// - The macro requires `Task` to be in scope: `use dagx::Task;`
139/// - For stateless tasks, the struct must implement `Default` (e.g., unit structs)
140///
141/// # Generated Code
142///
143/// The macro transforms your implementation into a full `Task` trait implementation.
144///
145/// For stateless tasks (no self parameter):
146///
147/// ```ignore
148/// // Your code:
149/// #[task]
150/// impl Add {
151///     async fn run(a: &i32, b: &i32) -> i32 {
152///         a + b
153///     }
154/// }
155///
156/// // Generated:
157/// impl Task for Add {
158///     type Input = (i32, i32);
159///     type Output = i32;
160///
161///     async fn run(&mut self, input: Self::Input) -> Self::Output {
162///         let (a, b) = input;
163///         Self::run_impl(&a, &b).await
164///     }
165/// }
166///
167/// impl Add {
168///     #[inline]
169///     async fn run_impl(a: &i32, b: &i32) -> i32 {
170///         a + b
171///     }
172/// }
173/// ```ignore
174///
175/// For stateful tasks (with &mut self):
176///
177/// ```ignore
178/// // Your code:
179/// #[task]
180/// impl Counter {
181///     async fn run(&mut self, inc: &i32) -> i32 {
182///         self.count += inc;
183///         self.count
184///     }
185/// }
186///
187/// // Generated:
188/// impl Task for Counter {
189///     type Input = i32;
190///     type Output = i32;
191///
192///     async fn run(&mut self, input: Self::Input) -> Self::Output {
193///         let inc = input;
194///         self.run_impl(&inc).await
195///     }
196/// }
197///
198/// impl Counter {
199///     async fn run_impl(&mut self, inc: &i32) -> i32 {
200///         self.count += inc;
201///         self.count
202///     }
203/// }
204/// ```ignore
205#[proc_macro_attribute]
206pub fn task(_attr: TokenStream, item: TokenStream) -> TokenStream {
207    let impl_block = parse_macro_input!(item as ItemImpl);
208
209    // Extract the struct name
210    let struct_name = &impl_block.self_ty;
211
212    // Find the run() method
213    let run_method = match impl_block.items.iter().find_map(|item| {
214        if let ImplItem::Fn(method) = item {
215            if method.sig.ident == "run" {
216                return Some(method);
217            }
218        }
219        None
220    }) {
221        Some(method) => method,
222        None => {
223            return syn::Error::new_spanned(
224                &impl_block,
225                "impl block must contain a run() method\n\n\
226                 Expected signature: fn run(&mut self, ...) -> OutputType\n\
227                              or: async fn run(&mut self, ...) -> OutputType\n\
228                 The #[task] macro requires a run() method to implement the Task trait.",
229            )
230            .to_compile_error()
231            .into();
232        }
233    };
234
235    // Check if the method is async or sync
236    let is_async = run_method.sig.asyncness.is_some();
237
238    // Extract parameters (excluding self)
239    let params_result: Result<Vec<_>, _> = run_method
240        .sig
241        .inputs
242        .iter()
243        .filter_map(|arg| {
244            if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
245                // Extract the parameter name
246                let param_name = if let Pat::Ident(pat_ident) = &**pat {
247                    &pat_ident.ident
248                } else {
249                    return Some(Err(syn::Error::new_spanned(
250                        pat,
251                        "Unsupported parameter pattern\n\n\
252                         Parameters must be simple identifiers like 'input: &T' or 'a: &i32'.",
253                    )));
254                };
255
256                // Extract the inner type from &Type
257                let inner_type = if let Type::Reference(type_ref) = &**ty {
258                    &type_ref.elem
259                } else {
260                    return Some(Err(syn::Error::new_spanned(
261                        ty,
262                        "All parameters must be references (&T)\n\n\
263                         Task inputs must be references to allow sharing data between tasks.\n\
264                         Change this parameter from 'T' to '&T'.",
265                    )));
266                };
267
268                Some(Ok((param_name.clone(), inner_type.clone())))
269            } else {
270                None // Skip self parameter
271            }
272        })
273        .collect();
274
275    let params = match params_result {
276        Ok(p) => p,
277        Err(e) => return e.to_compile_error().into(),
278    };
279
280    // Extract return type
281    let output_type = match &run_method.sig.output {
282        ReturnType::Default => {
283            return syn::Error::new_spanned(
284                &run_method.sig,
285                "run() method must have an explicit return type\n\n\
286                 Specify the output type: async fn run(...) -> OutputType\n\
287                 For tasks that don't return a value, use '-> ()'.",
288            )
289            .to_compile_error()
290            .into();
291        }
292        ReturnType::Type(_, ty) => ty.clone(),
293    };
294
295    // Build Input type based on parameter count
296    let input_type = match params.len() {
297        0 => quote! { () },
298        1 => {
299            let (_name, ty) = &params[0];
300            quote! { #ty }
301        }
302        _ => {
303            let types: Vec<_> = params.iter().map(|(_, ty)| ty).collect();
304            quote! { ( #(#types),* ) }
305        }
306    };
307
308    // Generate parameter destructuring for the wrapper run() method
309    let (param_destructure, param_refs) = if params.is_empty() {
310        (quote! { _ }, quote! {})
311    } else if params.len() == 1 {
312        let (name, _) = &params[0];
313        (quote! { #name }, quote! { &#name })
314    } else {
315        let names: Vec<_> = params.iter().map(|(name, _)| name).collect();
316        let refs: Vec<_> = params.iter().map(|(name, _)| quote! { &#name }).collect();
317        (quote! { ( #(#names),* ) }, quote! { #(#refs),* })
318    };
319
320    // Clone the run method and rename it to run_impl
321    let mut run_impl_method = run_method.clone();
322    run_impl_method.sig.ident = syn::Ident::new("run_impl", run_method.sig.ident.span());
323
324    // Check if the method has a self receiver
325    let has_self_receiver = run_method
326        .sig
327        .inputs
328        .iter()
329        .any(|arg| matches!(arg, FnArg::Receiver(_)));
330
331    // Generate extract_and_run implementation based on parameter count
332    let extract_and_run_impl = match params.len() {
333        0 => {
334            // Zero parameters - no extraction needed
335            quote! {
336                fn extract_and_run(
337                    self,
338                    _receivers: Vec<Box<dyn std::any::Any + Send>>,
339                ) -> impl std::future::Future<Output = Result<Self::Output, String>> + Send {
340                    async move {
341                        let input = ();
342                        Ok(self.run(input).await)
343                    }
344                }
345            }
346        }
347        1 => {
348            // Single parameter
349            let param_type = &params[0].1;
350            quote! {
351                fn extract_and_run(
352                    self,
353                    mut receivers: Vec<Box<dyn std::any::Any + Send>>,
354                ) -> impl std::future::Future<Output = Result<Self::Output, String>> + Send {
355                    async move {
356                        use futures::channel::oneshot;
357                        use std::sync::Arc;
358
359                        if receivers.len() != 1 {
360                            return Err(format!("Expected 1 dependency, got {}", receivers.len()));
361                        }
362
363                        let rx = *receivers.pop()
364                            .unwrap()
365                            .downcast::<oneshot::Receiver<Arc<#param_type>>>()
366                            .map_err(|_| format!("Type mismatch: expected Arc<{}>", std::any::type_name::<#param_type>()))?;
367
368                        let arc_value = rx.await
369                            .map_err(|_| "Channel closed before receiving value".to_string())?;
370
371                        let input = (*arc_value).clone();
372                        Ok(self.run(input).await)
373                    }
374                }
375            }
376        }
377        _ => {
378            // Multiple parameters
379            let param_count = params.len();
380            let param_types: Vec<_> = params.iter().map(|(_, ty)| ty).collect();
381            let indices: Vec<_> = (0..param_count).collect();
382
383            // Generate unique receiver variable names
384            let rx_vars: Vec<_> = (0..param_count)
385                .map(|i| syn::Ident::new(&format!("rx_{}", i), proc_macro2::Span::call_site()))
386                .collect();
387
388            // Create syn::Index for tuple field access (avoids the suffix warning)
389            let syn_indices: Vec<_> = (0..param_count).map(syn::Index::from).collect();
390
391            quote! {
392                fn extract_and_run(
393                    self,
394                    receivers: Vec<Box<dyn std::any::Any + Send>>,
395                ) -> impl std::future::Future<Output = Result<Self::Output, String>> + Send {
396                    async move {
397                        use futures::channel::oneshot;
398                        use std::sync::Arc;
399
400                        let expected_count = #param_count;
401                        if receivers.len() != expected_count {
402                            return Err(format!("Expected {} dependencies, got {}", expected_count, receivers.len()));
403                        }
404
405                        let mut iter = receivers.into_iter();
406
407                        // Extract each receiver
408                        #(
409                            let #rx_vars = *iter.next()
410                                .ok_or_else(|| format!("Missing receiver at index {}", #indices))?
411                                .downcast::<oneshot::Receiver<Arc<#param_types>>>()
412                                .map_err(|_| format!("Type mismatch at index {}: expected Arc<{}>",
413                                    #indices, std::any::type_name::<#param_types>()))?;
414                        )*
415
416                        // Await all channels concurrently
417                        let arc_results = futures::join!(
418                            #(
419                                async move {
420                                    #rx_vars.await.map_err(|_| format!("Channel {} closed", #indices))
421                                }
422                            ),*
423                        );
424
425                        // Clone inner values and build tuple
426                        let input = (#(
427                            (*arc_results.#syn_indices?).clone()
428                        ),*);
429
430                        Ok(self.run(input).await)
431                    }
432                }
433            }
434        }
435    };
436
437    // Generate the Task trait implementation based on whether we have self and async/sync
438    let expanded = if has_self_receiver {
439        // Stateful task - consumes self but delegates to a method that borrows
440        if is_async {
441            // Async with self
442            quote! {
443                impl Task for #struct_name {
444                    type Input = #input_type;
445                    type Output = #output_type;
446
447                    async fn run(mut self, input: Self::Input) -> Self::Output {
448                        let #param_destructure = input;
449                        self.run_impl(#param_refs).await
450                    }
451
452                    #extract_and_run_impl
453                }
454
455                impl #struct_name {
456                    #run_impl_method
457                }
458            }
459        } else {
460            // Sync with self - wrap in async block
461            quote! {
462                impl Task for #struct_name {
463                    type Input = #input_type;
464                    type Output = #output_type;
465
466                    async fn run(mut self, input: Self::Input) -> Self::Output {
467                        let #param_destructure = input;
468                        self.run_impl(#param_refs)
469                    }
470
471                    #extract_and_run_impl
472                }
473
474                impl #struct_name {
475                    #run_impl_method
476                }
477            }
478        }
479    } else {
480        // Stateless task
481        if is_async {
482            // Async stateless
483            quote! {
484                impl Task for #struct_name {
485                    type Input = #input_type;
486                    type Output = #output_type;
487
488                    async fn run(self, input: Self::Input) -> Self::Output {
489                        let #param_destructure = input;
490                        Self::run_impl(#param_refs).await
491                    }
492
493                    #extract_and_run_impl
494                }
495
496                impl #struct_name {
497                    #[inline]
498                    #run_impl_method
499                }
500            }
501        } else {
502            // Sync stateless
503            quote! {
504                impl Task for #struct_name {
505                    type Input = #input_type;
506                    type Output = #output_type;
507
508                    async fn run(self, input: Self::Input) -> Self::Output {
509                        let #param_destructure = input;
510                        Self::run_impl(#param_refs)
511                    }
512
513                    #extract_and_run_impl
514                }
515
516                impl #struct_name {
517                    #[inline]
518                    #run_impl_method
519                }
520            }
521        }
522    };
523
524    TokenStream::from(expanded)
525}