use crate::blueprint::BlueprintError;
use crate::{Task, TaskId};
use dashmap::DashMap;
use std::collections::{HashMap, HashSet};
#[derive(Debug)]
pub struct Step {
pub tasks: Vec<TaskId>,
}
pub struct Blueprint {
pub steps: Vec<Step>,
}
impl Blueprint {
pub fn from_tasks<T, E>(tasks: &DashMap<TaskId, Task<T, E>>) -> Result<Self, BlueprintError> {
for v in tasks.iter() {
let task_id = v.key();
let task = v.value();
for dep_id in task.dependencies().into_iter() {
if !tasks.contains_key(&dep_id) {
return Err(BlueprintError::MissingDependency(*task_id, dep_id));
}
}
}
let mut in_degree: HashMap<TaskId, usize> = HashMap::new();
let mut adjacency_list: HashMap<TaskId, Vec<TaskId>> = HashMap::new();
for task_id in tasks.iter() {
in_degree.insert(*task_id.key(), 0);
}
for v in tasks {
let task_id = v.key();
let task = v.value();
for dep_id in task.dependencies().into_iter() {
adjacency_list.entry(dep_id).or_default().push(*task_id);
*in_degree.get_mut(task_id).ok_or_else(|| {
BlueprintError::InternalError(format!(
"Task {task_id} not found in_degree map during dependency calculation"
))
})? += 1;
}
}
let mut steps = vec![];
let mut processed = HashSet::new();
loop {
let ready_tasks: Vec<TaskId> = in_degree
.iter()
.filter(|(task_id, degree)| **degree == 0 && !processed.contains(*task_id))
.map(|(task_id, _)| *task_id)
.collect();
if ready_tasks.is_empty() {
break;
}
let step = Step {
tasks: ready_tasks.clone(),
};
for task_id in &ready_tasks {
processed.insert(*task_id);
}
steps.push(step);
for task_id in ready_tasks {
if let Some(dependents) = adjacency_list.get(&task_id) {
for dependent_id in dependents {
if let Some(degree) = in_degree.get_mut(dependent_id) {
*degree -= 1;
}
}
}
}
}
if processed.len() != tasks.len() {
let remaining: Vec<TaskId> = tasks
.iter()
.map(|task_id| *task_id.key())
.filter(|id| !processed.contains(id))
.collect();
return Err(BlueprintError::CircularDependency(remaining));
}
Ok(Blueprint { steps })
}
pub fn step_count(&self) -> usize {
self.steps.len()
}
pub fn tasks_at_step(&self, step: usize) -> Option<&[TaskId]> {
self.steps.get(step).map(|s| s.tasks.as_slice())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Task;
use std::future;
fn create_dummy_task() -> Task<'static, (), ()> {
let future = future::ready(Ok(()));
Task::new_independent(future)
}
#[test]
fn test_simple_blueprint() {
let tasks = DashMap::new();
let task1 = create_dummy_task();
let task2 = create_dummy_task();
let id1 = *task1.id();
let id2 = *task2.id();
tasks.insert(id1, task1);
tasks.insert(id2, task2);
let blueprint = Blueprint::from_tasks(&tasks).unwrap();
assert_eq!(blueprint.step_count(), 1);
assert_eq!(blueprint.tasks_at_step(0).unwrap().len(), 2);
}
#[test]
fn test_sequential_blueprint() {
let tasks = DashMap::new();
let task1 = create_dummy_task();
let id1 = *task1.id();
let task2 = Task::new(future::ready(Ok(())), vec![id1]);
let id2 = *task2.id();
tasks.insert(id1, task1);
tasks.insert(id2, task2);
let blueprint = Blueprint::from_tasks(&tasks).unwrap();
assert_eq!(blueprint.step_count(), 2);
assert_eq!(blueprint.tasks_at_step(0).unwrap().len(), 1);
assert_eq!(blueprint.tasks_at_step(1).unwrap().len(), 1);
}
}