#![allow(dead_code)]
use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TaskId(pub u64);
impl fmt::Display for TaskId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Task({})", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TaskBackend {
Gpu,
Cpu,
Auto,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TaskState {
Pending,
Ready,
Running,
Completed,
Failed,
Cancelled,
}
impl TaskState {
#[must_use]
pub fn is_terminal(self) -> bool {
matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
}
#[must_use]
pub fn is_success(self) -> bool {
matches!(self, Self::Completed)
}
}
#[derive(Debug, Clone)]
pub struct TaskNode {
pub id: TaskId,
pub label: String,
pub backend: TaskBackend,
pub state: TaskState,
pub estimated_cost: u64,
pub dependencies: HashSet<TaskId>,
pub dependents: HashSet<TaskId>,
}
impl TaskNode {
#[must_use]
pub fn new(id: TaskId, label: String, backend: TaskBackend) -> Self {
Self {
id,
label,
backend,
state: TaskState::Pending,
estimated_cost: 1,
dependencies: HashSet::new(),
dependents: HashSet::new(),
}
}
#[must_use]
pub fn with_cost(mut self, cost: u64) -> Self {
self.estimated_cost = cost;
self
}
#[must_use]
pub fn dependencies_satisfied(&self, completed: &HashSet<TaskId>) -> bool {
self.dependencies.iter().all(|dep| completed.contains(dep))
}
#[must_use]
pub fn pending_dependency_count(&self, completed: &HashSet<TaskId>) -> usize {
self.dependencies
.iter()
.filter(|dep| !completed.contains(dep))
.count()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TaskGraphError {
DuplicateTask(TaskId),
TaskNotFound(TaskId),
CycleDetected,
InvalidState(String),
}
impl fmt::Display for TaskGraphError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DuplicateTask(id) => write!(f, "duplicate task: {id}"),
Self::TaskNotFound(id) => write!(f, "task not found: {id}"),
Self::CycleDetected => write!(f, "cycle detected in task graph"),
Self::InvalidState(msg) => write!(f, "invalid state: {msg}"),
}
}
}
#[derive(Debug)]
pub struct TaskGraph {
tasks: HashMap<TaskId, TaskNode>,
next_id: u64,
}
impl TaskGraph {
#[must_use]
pub fn new() -> Self {
Self {
tasks: HashMap::new(),
next_id: 0,
}
}
pub fn add_task(&mut self, label: &str, backend: TaskBackend) -> TaskId {
let id = TaskId(self.next_id);
self.next_id += 1;
let node = TaskNode::new(id, label.to_string(), backend);
self.tasks.insert(id, node);
id
}
pub fn add_task_with_cost(&mut self, label: &str, backend: TaskBackend, cost: u64) -> TaskId {
let id = TaskId(self.next_id);
self.next_id += 1;
let node = TaskNode::new(id, label.to_string(), backend).with_cost(cost);
self.tasks.insert(id, node);
id
}
pub fn add_dependency(
&mut self,
task: TaskId,
dependency: TaskId,
) -> Result<(), TaskGraphError> {
if !self.tasks.contains_key(&task) {
return Err(TaskGraphError::TaskNotFound(task));
}
if !self.tasks.contains_key(&dependency) {
return Err(TaskGraphError::TaskNotFound(dependency));
}
if self.would_create_cycle(task, dependency) {
return Err(TaskGraphError::CycleDetected);
}
self.tasks
.get_mut(&task)
.unwrap_or_else(|| unreachable!("task guaranteed to exist after contains_key check"))
.dependencies
.insert(dependency);
self.tasks
.get_mut(&dependency)
.unwrap_or_else(|| {
unreachable!("dependency guaranteed to exist after contains_key check")
})
.dependents
.insert(task);
Ok(())
}
fn would_create_cycle(&self, task: TaskId, dep: TaskId) -> bool {
if task == dep {
return true;
}
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(task);
while let Some(current) = queue.pop_front() {
if current == dep {
return true;
}
if visited.insert(current) {
if let Some(node) = self.tasks.get(¤t) {
for &dependent in &node.dependents {
queue.push_back(dependent);
}
}
}
}
false
}
#[must_use]
pub fn task_count(&self) -> usize {
self.tasks.len()
}
#[must_use]
pub fn get_task(&self, id: TaskId) -> Option<&TaskNode> {
self.tasks.get(&id)
}
pub fn get_task_mut(&mut self, id: TaskId) -> Option<&mut TaskNode> {
self.tasks.get_mut(&id)
}
#[must_use]
pub fn ready_tasks(&self) -> Vec<TaskId> {
let completed: HashSet<TaskId> = self
.tasks
.values()
.filter(|t| t.state == TaskState::Completed)
.map(|t| t.id)
.collect();
self.tasks
.values()
.filter(|t| t.state == TaskState::Pending && t.dependencies_satisfied(&completed))
.map(|t| t.id)
.collect()
}
pub fn start_task(&mut self, id: TaskId) -> Result<(), TaskGraphError> {
let task = self
.tasks
.get_mut(&id)
.ok_or(TaskGraphError::TaskNotFound(id))?;
if task.state != TaskState::Pending && task.state != TaskState::Ready {
return Err(TaskGraphError::InvalidState(format!(
"cannot start task in state {:?}",
task.state
)));
}
task.state = TaskState::Running;
Ok(())
}
pub fn complete_task(&mut self, id: TaskId) -> Result<(), TaskGraphError> {
let task = self
.tasks
.get_mut(&id)
.ok_or(TaskGraphError::TaskNotFound(id))?;
if task.state != TaskState::Running {
return Err(TaskGraphError::InvalidState(format!(
"cannot complete task in state {:?}",
task.state
)));
}
task.state = TaskState::Completed;
Ok(())
}
pub fn fail_task(&mut self, id: TaskId) -> Result<(), TaskGraphError> {
let task = self
.tasks
.get_mut(&id)
.ok_or(TaskGraphError::TaskNotFound(id))?;
task.state = TaskState::Failed;
Ok(())
}
#[must_use]
pub fn is_complete(&self) -> bool {
self.tasks.values().all(|t| t.state.is_terminal())
}
#[must_use]
pub fn has_failures(&self) -> bool {
self.tasks.values().any(|t| t.state == TaskState::Failed)
}
#[must_use]
pub fn total_cost(&self) -> u64 {
self.tasks.values().map(|t| t.estimated_cost).sum()
}
pub fn topological_sort(&self) -> Result<Vec<TaskId>, TaskGraphError> {
let mut in_degree: HashMap<TaskId, usize> = HashMap::new();
for task in self.tasks.values() {
in_degree.entry(task.id).or_insert(0);
for &dep in &task.dependents {
*in_degree.entry(dep).or_insert(0) += 1;
}
}
let mut queue: VecDeque<TaskId> = in_degree
.iter()
.filter(|(_, °)| deg == 0)
.map(|(&id, _)| id)
.collect();
let mut result = Vec::new();
while let Some(id) = queue.pop_front() {
result.push(id);
if let Some(node) = self.tasks.get(&id) {
for &dependent in &node.dependents {
if let Some(deg) = in_degree.get_mut(&dependent) {
*deg -= 1;
if *deg == 0 {
queue.push_back(dependent);
}
}
}
}
}
if result.len() != self.tasks.len() {
return Err(TaskGraphError::CycleDetected);
}
Ok(result)
}
#[must_use]
pub fn critical_path_cost(&self) -> u64 {
if self.tasks.is_empty() {
return 0;
}
let sorted = match self.topological_sort() {
Ok(s) => s,
Err(_) => return 0,
};
let mut longest: HashMap<TaskId, u64> = HashMap::new();
for &id in &sorted {
let task = &self.tasks[&id];
let max_dep = task
.dependencies
.iter()
.filter_map(|d| longest.get(d))
.copied()
.max()
.unwrap_or(0);
longest.insert(id, max_dep + task.estimated_cost);
}
longest.values().copied().max().unwrap_or(0)
}
}
impl Default for TaskGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_id_display() {
let id = TaskId(42);
assert_eq!(format!("{id}"), "Task(42)");
}
#[test]
fn test_task_state_terminal() {
assert!(!TaskState::Pending.is_terminal());
assert!(!TaskState::Ready.is_terminal());
assert!(!TaskState::Running.is_terminal());
assert!(TaskState::Completed.is_terminal());
assert!(TaskState::Failed.is_terminal());
assert!(TaskState::Cancelled.is_terminal());
}
#[test]
fn test_task_state_success() {
assert!(TaskState::Completed.is_success());
assert!(!TaskState::Failed.is_success());
assert!(!TaskState::Pending.is_success());
}
#[test]
fn test_add_task() {
let mut g = TaskGraph::new();
let id = g.add_task("resize", TaskBackend::Gpu);
assert_eq!(g.task_count(), 1);
let task = g.get_task(id).expect("task should be valid");
assert_eq!(task.label, "resize");
assert_eq!(task.backend, TaskBackend::Gpu);
assert_eq!(task.state, TaskState::Pending);
}
#[test]
fn test_add_task_with_cost() {
let mut g = TaskGraph::new();
let id = g.add_task_with_cost("heavy", TaskBackend::Cpu, 100);
assert_eq!(
g.get_task(id)
.expect("get_task should succeed")
.estimated_cost,
100
);
}
#[test]
fn test_add_dependency() {
let mut g = TaskGraph::new();
let a = g.add_task("a", TaskBackend::Auto);
let b = g.add_task("b", TaskBackend::Auto);
g.add_dependency(b, a)
.expect("add_dependency should succeed");
let node_b = g.get_task(b).expect("node_b should be valid");
assert!(node_b.dependencies.contains(&a));
let node_a = g.get_task(a).expect("node_a should be valid");
assert!(node_a.dependents.contains(&b));
}
#[test]
fn test_cycle_detection_self() {
let mut g = TaskGraph::new();
let a = g.add_task("a", TaskBackend::Auto);
let result = g.add_dependency(a, a);
assert_eq!(result, Err(TaskGraphError::CycleDetected));
}
#[test]
fn test_cycle_detection_indirect() {
let mut g = TaskGraph::new();
let a = g.add_task("a", TaskBackend::Auto);
let b = g.add_task("b", TaskBackend::Auto);
let c = g.add_task("c", TaskBackend::Auto);
g.add_dependency(b, a)
.expect("add_dependency should succeed");
g.add_dependency(c, b)
.expect("add_dependency should succeed");
let result = g.add_dependency(a, c);
assert_eq!(result, Err(TaskGraphError::CycleDetected));
}
#[test]
fn test_ready_tasks_no_deps() {
let mut g = TaskGraph::new();
let a = g.add_task("a", TaskBackend::Gpu);
let b = g.add_task("b", TaskBackend::Cpu);
let ready = g.ready_tasks();
assert!(ready.contains(&a));
assert!(ready.contains(&b));
}
#[test]
fn test_ready_tasks_with_deps() {
let mut g = TaskGraph::new();
let a = g.add_task("a", TaskBackend::Auto);
let b = g.add_task("b", TaskBackend::Auto);
g.add_dependency(b, a)
.expect("add_dependency should succeed");
let ready = g.ready_tasks();
assert!(ready.contains(&a));
assert!(!ready.contains(&b));
g.start_task(a).expect("start_task should succeed");
g.complete_task(a).expect("complete_task should succeed");
let ready2 = g.ready_tasks();
assert!(ready2.contains(&b));
}
#[test]
fn test_start_and_complete_task() {
let mut g = TaskGraph::new();
let a = g.add_task("a", TaskBackend::Auto);
g.start_task(a).expect("start_task should succeed");
assert_eq!(
g.get_task(a).expect("get_task should succeed").state,
TaskState::Running
);
g.complete_task(a).expect("complete_task should succeed");
assert_eq!(
g.get_task(a).expect("get_task should succeed").state,
TaskState::Completed
);
}
#[test]
fn test_fail_task() {
let mut g = TaskGraph::new();
let a = g.add_task("a", TaskBackend::Auto);
g.fail_task(a).expect("fail_task should succeed");
assert_eq!(
g.get_task(a).expect("get_task should succeed").state,
TaskState::Failed
);
assert!(g.has_failures());
}
#[test]
fn test_is_complete() {
let mut g = TaskGraph::new();
let a = g.add_task("a", TaskBackend::Auto);
assert!(!g.is_complete());
g.start_task(a).expect("start_task should succeed");
g.complete_task(a).expect("complete_task should succeed");
assert!(g.is_complete());
}
#[test]
fn test_total_cost() {
let mut g = TaskGraph::new();
g.add_task_with_cost("a", TaskBackend::Gpu, 10);
g.add_task_with_cost("b", TaskBackend::Cpu, 20);
assert_eq!(g.total_cost(), 30);
}
#[test]
fn test_topological_sort() {
let mut g = TaskGraph::new();
let a = g.add_task("a", TaskBackend::Auto);
let b = g.add_task("b", TaskBackend::Auto);
let c = g.add_task("c", TaskBackend::Auto);
g.add_dependency(b, a)
.expect("add_dependency should succeed");
g.add_dependency(c, b)
.expect("add_dependency should succeed");
let sorted = g.topological_sort().expect("sorted should be valid");
let pos_a = sorted
.iter()
.position(|&x| x == a)
.expect("pos_a should be valid");
let pos_b = sorted
.iter()
.position(|&x| x == b)
.expect("pos_b should be valid");
let pos_c = sorted
.iter()
.position(|&x| x == c)
.expect("pos_c should be valid");
assert!(pos_a < pos_b);
assert!(pos_b < pos_c);
}
#[test]
fn test_critical_path_cost() {
let mut g = TaskGraph::new();
let a = g.add_task_with_cost("a", TaskBackend::Auto, 5);
let b = g.add_task_with_cost("b", TaskBackend::Auto, 10);
let c = g.add_task_with_cost("c", TaskBackend::Auto, 3);
g.add_dependency(b, a)
.expect("add_dependency should succeed");
g.add_dependency(c, a)
.expect("add_dependency should succeed");
assert_eq!(g.critical_path_cost(), 15);
}
#[test]
fn test_critical_path_empty() {
let g = TaskGraph::new();
assert_eq!(g.critical_path_cost(), 0);
}
#[test]
fn test_dependency_not_found() {
let mut g = TaskGraph::new();
let a = g.add_task("a", TaskBackend::Auto);
let bad = TaskId(999);
assert_eq!(
g.add_dependency(a, bad),
Err(TaskGraphError::TaskNotFound(bad))
);
}
#[test]
fn test_start_task_not_found() {
let mut g = TaskGraph::new();
assert!(g.start_task(TaskId(999)).is_err());
}
}