use std::collections::{HashMap, HashSet};
use apalis_core::task::task_id::TaskId;
use petgraph::graph::NodeIndex;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize, Default)]
pub struct DagFlowContext<IdType> {
pub prev_node: Option<NodeIndex>,
pub current_node: NodeIndex,
pub completed_nodes: HashSet<NodeIndex>,
pub node_task_ids: HashMap<NodeIndex, TaskId<IdType>>,
pub current_position: usize,
pub is_initial: bool,
pub root_task_id: Option<TaskId<IdType>>,
}
impl<IdType: Clone> Clone for DagFlowContext<IdType> {
fn clone(&self) -> Self {
Self {
prev_node: self.prev_node,
current_node: self.current_node,
completed_nodes: self.completed_nodes.clone(),
node_task_ids: self.node_task_ids.clone(),
current_position: self.current_position,
is_initial: self.is_initial,
root_task_id: self.root_task_id.clone(),
}
}
}
impl<IdType: Clone> DagFlowContext<IdType> {
pub fn new(root_task_id: Option<TaskId<IdType>>) -> Self {
Self {
prev_node: None,
current_node: NodeIndex::new(0),
completed_nodes: HashSet::new(),
node_task_ids: HashMap::new(),
current_position: 0,
is_initial: true,
root_task_id,
}
}
pub fn get_dependency_task_ids(
&self,
dependencies: &[NodeIndex],
) -> HashMap<NodeIndex, TaskId<IdType>> {
dependencies
.iter()
.filter_map(|dep| {
self.node_task_ids
.get(dep)
.cloned()
.map(|task_id| (*dep, task_id))
})
.collect()
}
}