use std::marker::PhantomData;
use crate::deps::DepsTuple;
use crate::runner::DagRunner;
use crate::task::Task;
#[cfg(feature = "tracing")]
use tracing::debug;
#[must_use]
pub struct TaskBuilder<'a, Input, Tk>
where
Tk: Task<Input>,
Input: Send + Sync + 'static,
{
pub(crate) id: NodeId,
pub(crate) dag: &'a mut DagRunner,
pub(crate) _phantom: PhantomData<(Tk, Input)>,
}
impl<'a, Input, Tk> TaskBuilder<'a, Input, Tk>
where
Tk: Task<Input>,
Input: Send + Sync + 'static,
{
#[allow(private_bounds)]
pub fn depends_on<D>(self, deps: D) -> TaskHandle<Tk::Output>
where
D: DepsTuple<Input>,
{
let dep_ids = deps.to_node_ids();
#[cfg(feature = "tracing")]
debug!(
task_id = self.id.0,
dependency_ids = ?dep_ids.iter().map(|id| id.0).collect::<Vec<_>>(),
dependency_count = dep_ids.len(),
"wiring task dependencies"
);
for &dep_id in &dep_ids {
if let Some(node_edges) = self.dag.edges.get_mut(&self.id) {
node_edges.push(dep_id);
}
if let Some(node_dependents) = self.dag.dependents.get_mut(&dep_id) {
node_dependents.push(self.id);
}
}
TaskHandle {
id: self.id,
_phantom: PhantomData,
}
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct NodeId(pub u32);
pub struct TaskHandle<T> {
pub(crate) id: NodeId,
pub(crate) _phantom: PhantomData<fn() -> T>,
}
pub trait TaskWire<Input>: Task<Input> + Sync + 'static
where
Input: Send + Sync + 'static,
{
type Retval<'dag>;
fn new_from_dag<'dag>(id: NodeId, dag: &'dag mut DagRunner) -> Self::Retval<'dag>;
}
impl<Tk> TaskWire<()> for Tk
where
Tk: Task<()> + Sync + 'static,
{
type Retval<'dag> = TaskHandle<Tk::Output>;
fn new_from_dag(id: NodeId, _dag: &mut DagRunner) -> Self::Retval<'static> {
Self::Retval {
id,
_phantom: PhantomData,
}
}
}
macro_rules! impl_wire_tuple {
($($T:ident),+) => {
impl<Tk, $($T: Send + Sync + 'static),+> TaskWire<($($T,)+)> for Tk
where
Tk: Task<($($T,)+)> + Sync + 'static
{
type Retval<'dag> = TaskBuilder<'dag, ($($T,)+), Tk>;
fn new_from_dag<'dag>(id: NodeId, dag: &'dag mut DagRunner) -> Self::Retval<'dag> {
Self::Retval {
id,
dag,
_phantom: PhantomData,
}
}
}
};
}
impl_wire_tuple!(T1);
impl_wire_tuple!(T1, T2);
impl_wire_tuple!(T1, T2, T3);
impl_wire_tuple!(T1, T2, T3, T4);
impl_wire_tuple!(T1, T2, T3, T4, T5);
impl_wire_tuple!(T1, T2, T3, T4, T5, T6);
impl_wire_tuple!(T1, T2, T3, T4, T5, T6, T7);
impl_wire_tuple!(T1, T2, T3, T4, T5, T6, T7, T8);
#[cfg(test)]
mod tests;
#[cfg(test)]
mod coverage_tests;