Skip to main content

operese_dagx/
builder.rs

1//! Task builder with type-state pattern for dependency tracking.
2//!
3//! TaskBuilder uses compile-time type checking to ensure dependencies are wired correctly.
4
5use std::marker::PhantomData;
6
7use crate::deps::DepsTuple;
8use crate::runner::DagRunner;
9use crate::task::Task;
10
11#[cfg(feature = "tracing")]
12use tracing::debug;
13
14/// Node builder that tracks dependency completion via type state.
15///
16/// A `TaskBuilder<Tk, Input>` is returned from [`DagRunner::add_task`] for tasks with non-unit inputs.
17/// After calling [`TaskBuilder::depends_on`], it becomes a [`TaskHandle`].
18///
19/// # Examples
20///
21/// ```no_run
22/// # use operese_dagx::{task, DagRunner, Task};
23/// #
24/// // Using tuple struct for simple constants
25/// struct Constant(i32);
26///
27/// #[task]
28/// impl Constant {
29///     async fn run(&mut self) -> i32 { self.0 }
30/// }
31///
32/// // Task with state constructed via ::new()
33/// struct Multiplier { factor: i32 }
34///
35/// impl Multiplier {
36///     fn new(factor: i32) -> Self { Self { factor } }
37/// }
38///
39/// #[task]
40/// impl Multiplier {
41///     async fn run(&mut self, x: &i32) -> i32 { x * self.factor }
42/// }
43///
44/// # async {
45/// let mut dag = DagRunner::new();
46///
47/// let a = dag.add_task(Constant(10));
48/// // a is TaskHandle since no dependencies needed (Input = ())
49///
50/// let b = dag.add_task(Multiplier::new(2));
51/// // b is TaskBuilder until we call depends_on()
52///
53/// let b = b.depends_on(&a);
54/// // Now b is a TaskHandle<i32>
55///
56///let mut output = dag.run(|fut| async move { tokio::spawn(fut).await.unwrap() }).await.unwrap();
57/// assert_eq!(output.get(b), 20);
58/// # };
59/// ```
60#[must_use]
61pub struct TaskBuilder<'a, Input, Tk>
62where
63    Tk: Task<Input>,
64    Input: Send + Sync + 'static,
65{
66    pub(crate) id: NodeId,
67    pub(crate) dag: &'a mut DagRunner,
68    pub(crate) _phantom: PhantomData<(Tk, Input)>,
69}
70
71impl<'a, Input, Tk> TaskBuilder<'a, Input, Tk>
72where
73    Tk: Task<Input>,
74    Input: Send + Sync + 'static,
75{
76    /// Provide all dependencies exactly once as a tuple.
77    ///
78    /// The dependencies must match the task's `Input` type exactly:
79    /// - `Input = A`: Pass `&TaskHandle<A>`
80    /// - `Input = (A, B, ...)`: Pass `(&TaskHandle<A>, &TaskHandle<B>, ...)`
81    ///
82    /// The order of dependencies in the tuple must match the order in `Input`.
83    ///
84    /// # Panics
85    ///
86    /// This function will cause a panic in [`DagRunner::run`] if called with a [`TaskHandle`] from a `DagRunner` instance
87    /// other than the one which created this TaskBuilder.
88    ///
89    /// # Examples
90    ///
91    /// ```no_run
92    /// # use operese_dagx::{task, DagRunner, Task};
93    /// #
94    /// // Tuple struct
95    /// struct Value(i32);
96    ///
97    /// #[task]
98    /// impl Value {
99    ///     async fn run(&mut self) -> i32 { self.0 }
100    /// }
101    ///
102    /// // Tuple struct with multiplier
103    /// struct Scale(i32);
104    ///
105    /// #[task]
106    /// impl Scale {
107    ///     async fn run(&mut self, x: &i32) -> i32 { x * self.0 }
108    /// }
109    ///
110    /// // Unit struct
111    /// struct Add;
112    ///
113    /// #[task]
114    /// impl Add {
115    ///     async fn run(&mut self, a: &i32, b: &i32) -> i32 { a + b }
116    /// }
117    ///
118    /// # async {
119    /// let mut dag = DagRunner::new();
120    ///
121    /// let x = dag.add_task(Value(2));
122    /// let y = dag.add_task(Value(3));
123    ///
124    /// // Single dependency
125    /// let double = dag.add_task(Scale(2)).depends_on(&x);
126    ///
127    /// // Multiple dependencies: tuple form
128    /// let sum = dag.add_task(Add).depends_on((&x, &y));
129    ///
130    ///let mut output = dag.run(|fut| async move { tokio::spawn(fut).await.unwrap() }).await.unwrap();
131    /// # };
132    /// ```
133    #[allow(private_bounds)]
134    pub fn depends_on<D>(self, deps: D) -> TaskHandle<Tk::Output>
135    where
136        D: DepsTuple<Input>,
137    {
138        // Register dependencies in the DAG
139        let dep_ids = deps.to_node_ids();
140
141        #[cfg(feature = "tracing")]
142        debug!(
143            task_id = self.id.0,
144            dependency_ids = ?dep_ids.iter().map(|id| id.0).collect::<Vec<_>>(),
145            dependency_count = dep_ids.len(),
146            "wiring task dependencies"
147        );
148
149        for &dep_id in &dep_ids {
150            // Add edge from this node to dependency
151            if let Some(node_edges) = self.dag.edges.get_mut(&self.id) {
152                node_edges.push(dep_id);
153            }
154
155            // Add this node as dependent of the dependency
156            if let Some(node_dependents) = self.dag.dependents.get_mut(&dep_id) {
157                node_dependents.push(self.id);
158            }
159        }
160
161        TaskHandle {
162            id: self.id,
163            _phantom: PhantomData,
164        }
165    }
166}
167
168/// Opaque node identifier
169#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
170pub struct NodeId(pub u32);
171
172/// Opaque, typed token for a node's output.
173///
174/// A `TaskHandle<T>` provides compile-time type-safe access to a task's output.
175/// You can:
176/// 1. Pass it to [`crate::TaskBuilder::depends_on`] to wire up dependencies
177/// 2. Use it with [`crate::DagOutput::get`] to retrieve the output after [`crate::DagRunner::run`]
178///
179/// # Examples
180///
181/// ```no_run
182/// # use operese_dagx::{task, DagRunner, Task};
183/// #
184/// # struct LoadValue { value: i32 }
185/// # impl LoadValue { pub fn new(v: i32) -> Self { Self { value: v } } }
186/// # #[task]
187/// # impl LoadValue {
188/// #     async fn run(&mut self) -> i32 { self.value }
189/// # }
190/// # async {
191/// let mut dag = DagRunner::new();
192/// let node = dag.add_task(LoadValue::new(42));
193///
194///let mut output = dag.run(|fut| async move { tokio::spawn(fut).await.unwrap() }).await.unwrap();
195///
196/// assert_eq!(output.get(node), 42);
197/// # };
198/// ```
199pub struct TaskHandle<T> {
200    pub(crate) id: NodeId,
201    pub(crate) _phantom: PhantomData<fn() -> T>,
202}
203
204/// Takes a task and converts it to either a TaskBuilder or a TaskHandle,
205/// depending on whether it has inputs or not.
206///
207/// This is useful to enforce at compile time that a TaskBuilder is never created for task with unit input,
208/// and that it can be used directly as a dependency without converting it manually to a TaskHandle.
209pub trait TaskWire<Input>: Task<Input> + Sync + 'static
210where
211    Input: Send + Sync + 'static,
212{
213    type Retval<'dag>;
214
215    fn new_from_dag<'dag>(id: NodeId, dag: &'dag mut DagRunner) -> Self::Retval<'dag>;
216}
217
218impl<Tk> TaskWire<()> for Tk
219where
220    Tk: Task<()> + Sync + 'static,
221{
222    type Retval<'dag> = TaskHandle<Tk::Output>;
223
224    fn new_from_dag(id: NodeId, _dag: &mut DagRunner) -> Self::Retval<'static> {
225        Self::Retval {
226            id,
227            _phantom: PhantomData,
228        }
229    }
230}
231
232/// Macro to implement TaskWire for different tuple sizes.
233macro_rules! impl_wire_tuple {
234    ($($T:ident),+) => {
235        impl<Tk, $($T: Send + Sync + 'static),+> TaskWire<($($T,)+)> for Tk
236        where
237            Tk: Task<($($T,)+)> + Sync + 'static
238        {
239            type Retval<'dag> = TaskBuilder<'dag, ($($T,)+), Tk>;
240
241            fn new_from_dag<'dag>(id: NodeId, dag: &'dag mut DagRunner) -> Self::Retval<'dag> {
242                Self::Retval {
243                    id,
244                    dag,
245                    _phantom: PhantomData,
246                }
247            }
248        }
249    };
250}
251
252impl_wire_tuple!(T1);
253impl_wire_tuple!(T1, T2);
254impl_wire_tuple!(T1, T2, T3);
255impl_wire_tuple!(T1, T2, T3, T4);
256impl_wire_tuple!(T1, T2, T3, T4, T5);
257impl_wire_tuple!(T1, T2, T3, T4, T5, T6);
258impl_wire_tuple!(T1, T2, T3, T4, T5, T6, T7);
259impl_wire_tuple!(T1, T2, T3, T4, T5, T6, T7, T8);
260
261#[cfg(test)]
262mod tests;
263
264#[cfg(test)]
265mod coverage_tests;