use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::error::{CoreError, CoreResult, ErrorContext};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct TaskId(usize);
impl fmt::Display for TaskId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Task({})", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TaskStatus {
Success,
Skipped,
Failed(String),
}
#[derive(Debug, Clone)]
pub struct TaskResult<T: Clone> {
pub task_id: TaskId,
pub task_name: String,
pub value: Option<T>,
pub status: TaskStatus,
pub elapsed: Duration,
pub started_at: Instant,
}
pub struct TaskNode<T: Clone + Send + 'static> {
id: TaskId,
name: String,
compute: Box<dyn Fn() -> T + Send + Sync>,
estimated_ms: u64,
memory_bytes: usize,
}
impl<T: Clone + Send + 'static> TaskNode<T> {
pub fn new<F>(id: TaskId, name: impl Into<String>, f: F) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
{
Self {
id,
name: name.into(),
compute: Box::new(f),
estimated_ms: 1,
memory_bytes: 0,
}
}
pub fn with_estimated_ms(mut self, ms: u64) -> Self {
self.estimated_ms = ms;
self
}
pub fn with_memory_bytes(mut self, bytes: usize) -> Self {
self.memory_bytes = bytes;
self
}
fn execute(&self) -> TaskResult<T> {
let started_at = Instant::now();
let value = (self.compute)();
let elapsed = started_at.elapsed();
TaskResult {
task_id: self.id,
task_name: self.name.clone(),
value: Some(value),
status: TaskStatus::Success,
elapsed,
started_at,
}
}
}
pub struct TaskGraph<T: Clone + Send + 'static> {
nodes: HashMap<TaskId, TaskNode<T>>,
deps: HashMap<TaskId, HashSet<TaskId>>,
dependents: HashMap<TaskId, HashSet<TaskId>>,
next_id: usize,
}
impl<T: Clone + Send + 'static> TaskGraph<T> {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
deps: HashMap::new(),
dependents: HashMap::new(),
next_id: 0,
}
}
pub fn add_task<F>(&mut self, name: impl Into<String>, f: F) -> TaskId
where
F: Fn() -> T + Send + Sync + 'static,
{
let id = TaskId(self.next_id);
self.next_id += 1;
let node = TaskNode::new(id, name, f);
self.nodes.insert(id, node);
self.deps.insert(id, HashSet::new());
self.dependents.insert(id, HashSet::new());
id
}
pub fn add_node(&mut self, node: TaskNode<T>) -> TaskId {
let id = node.id;
self.nodes.insert(id, node);
self.deps.entry(id).or_default();
self.dependents.entry(id).or_default();
id
}
pub fn add_dependency(&mut self, dependent: TaskId, dependency: TaskId) -> CoreResult<()> {
if !self.nodes.contains_key(&dependent) {
return Err(CoreError::InvalidInput(ErrorContext::new(format!(
"add_dependency: {dependent} not found"
))));
}
if !self.nodes.contains_key(&dependency) {
return Err(CoreError::InvalidInput(ErrorContext::new(format!(
"add_dependency: {dependency} not found"
))));
}
if dependent == dependency {
return Err(CoreError::InvalidInput(ErrorContext::new(format!(
"add_dependency: self-loop on {dependent}"
))));
}
if self.is_reachable(dependency, dependent) {
return Err(CoreError::InvalidInput(ErrorContext::new(format!(
"add_dependency: cycle detected ({dependency} already depends on {dependent})"
))));
}
self.deps.entry(dependent).or_default().insert(dependency);
self.dependents
.entry(dependency)
.or_default()
.insert(dependent);
Ok(())
}
fn is_reachable(&self, from: TaskId, target: TaskId) -> bool {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(from);
while let Some(current) = queue.pop_front() {
if current == target {
return true;
}
if visited.contains(¤t) {
continue;
}
visited.insert(current);
if let Some(deps) = self.deps.get(¤t) {
for dep in deps {
if !visited.contains(dep) {
queue.push_back(*dep);
}
}
}
}
false
}
pub fn topological_order(&self) -> CoreResult<Vec<TaskId>> {
let mut in_degree: HashMap<TaskId, usize> = self
.nodes
.keys()
.map(|id| (*id, self.deps[id].len()))
.collect();
let mut ready: VecDeque<TaskId> = in_degree
.iter()
.filter(|(_, °)| deg == 0)
.map(|(id, _)| *id)
.collect();
let mut order = Vec::with_capacity(self.nodes.len());
while let Some(id) = ready.pop_front() {
order.push(id);
if let Some(children) = self.dependents.get(&id) {
for child in children {
let deg = in_degree.entry(*child).or_insert(0);
if *deg > 0 {
*deg -= 1;
}
if *deg == 0 {
ready.push_back(*child);
}
}
}
}
if order.len() != self.nodes.len() {
return Err(CoreError::InvalidInput(ErrorContext::new(
"topological_order: cycle detected",
)));
}
Ok(order)
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn dependencies(&self, id: TaskId) -> Option<&HashSet<TaskId>> {
self.deps.get(&id)
}
pub fn dependents_of(&self, id: TaskId) -> Option<&HashSet<TaskId>> {
self.dependents.get(&id)
}
}
pub struct CriticalPath {
pub path: Vec<TaskId>,
pub total_estimated_ms: u64,
}
impl CriticalPath {
pub fn compute<T: Clone + Send + 'static>(graph: &TaskGraph<T>) -> CoreResult<Self> {
let order = graph.topological_order()?;
let mut earliest_finish: HashMap<TaskId, u64> = HashMap::new();
let mut predecessor: HashMap<TaskId, Option<TaskId>> = HashMap::new();
for &id in &order {
let node = &graph.nodes[&id];
let max_pred_finish = graph
.deps
.get(&id)
.map(|deps| {
deps.iter()
.map(|d| earliest_finish.get(d).copied().unwrap_or(0))
.max()
.unwrap_or(0)
})
.unwrap_or(0);
let ef = max_pred_finish + node.estimated_ms;
earliest_finish.insert(id, ef);
let pred = graph.deps.get(&id).and_then(|deps| {
deps.iter()
.max_by_key(|d| earliest_finish.get(d).copied().unwrap_or(0))
.copied()
});
predecessor.insert(id, pred);
}
let sink = earliest_finish
.iter()
.max_by_key(|(_, &ef)| ef)
.map(|(id, _)| *id);
let total_ms = sink
.and_then(|id| earliest_finish.get(&id).copied())
.unwrap_or(0);
let mut path = Vec::new();
let mut current = sink;
while let Some(id) = current {
path.push(id);
current = predecessor.get(&id).and_then(|opt| *opt);
}
path.reverse();
Ok(CriticalPath {
path,
total_estimated_ms: total_ms,
})
}
}
pub struct TopologicalScheduler<T: Clone + Send + 'static> {
graph: TaskGraph<T>,
}
impl<T: Clone + Send + 'static> TopologicalScheduler<T> {
pub fn new(graph: TaskGraph<T>) -> Self {
Self { graph }
}
pub fn run_serial(&self) -> CoreResult<Vec<TaskResult<T>>> {
let order = self.graph.topological_order()?;
let mut results: HashMap<TaskId, TaskResult<T>> = HashMap::new();
for id in &order {
let any_dep_failed = self
.graph
.deps
.get(id)
.map(|deps| {
deps.iter().any(|d| {
results
.get(d)
.map(|r| r.status != TaskStatus::Success)
.unwrap_or(false)
})
})
.unwrap_or(false);
let node = &self.graph.nodes[id];
let result = if any_dep_failed {
TaskResult {
task_id: *id,
task_name: node.name.clone(),
value: None,
status: TaskStatus::Skipped,
elapsed: Duration::ZERO,
started_at: Instant::now(),
}
} else {
node.execute()
};
results.insert(*id, result);
}
Ok(order
.into_iter()
.filter_map(|id| results.remove(&id))
.collect())
}
pub fn run_parallel(&self) -> CoreResult<Vec<TaskResult<T>>> {
#[cfg(feature = "parallel")]
{
self.run_parallel_impl()
}
#[cfg(not(feature = "parallel"))]
{
self.run_serial()
}
}
#[cfg(feature = "parallel")]
fn run_parallel_impl(&self) -> CoreResult<Vec<TaskResult<T>>> {
use rayon::prelude::*;
let order = self.graph.topological_order()?;
let results_map: Arc<Mutex<HashMap<TaskId, TaskResult<T>>>> =
Arc::new(Mutex::new(HashMap::new()));
let mut remaining: HashSet<TaskId> = order.iter().cloned().collect();
let mut all_results: Vec<TaskResult<T>> = Vec::new();
while !remaining.is_empty() {
let completed: HashSet<TaskId> = {
let rm = results_map.lock().map_err(|_| {
CoreError::InvalidInput(ErrorContext::new("parallel_run: mutex poisoned"))
})?;
rm.keys().cloned().collect()
};
let wave: Vec<TaskId> = remaining
.iter()
.filter(|id| {
self.graph
.deps
.get(id)
.map(|deps| deps.iter().all(|d| completed.contains(d)))
.unwrap_or(true)
})
.cloned()
.collect();
if wave.is_empty() {
return Err(CoreError::InvalidInput(ErrorContext::new(
"parallel_run: deadlock — no runnable tasks remain",
)));
}
let wave_results: Vec<TaskResult<T>> = wave
.par_iter()
.map(|id| {
let any_dep_failed = self
.graph
.deps
.get(id)
.map(|deps| {
let rm = results_map.lock().ok();
deps.iter().any(|d| {
rm.as_ref()
.and_then(|r| r.get(d))
.map(|r| r.status != TaskStatus::Success)
.unwrap_or(false)
})
})
.unwrap_or(false);
let node = &self.graph.nodes[id];
if any_dep_failed {
TaskResult {
task_id: *id,
task_name: node.name.clone(),
value: None,
status: TaskStatus::Skipped,
elapsed: Duration::ZERO,
started_at: Instant::now(),
}
} else {
node.execute()
}
})
.collect();
{
let mut rm = results_map.lock().map_err(|_| {
CoreError::InvalidInput(ErrorContext::new(
"parallel_run: mutex poisoned (merge)",
))
})?;
for r in &wave_results {
rm.insert(r.task_id, r.clone());
}
}
for id in &wave {
remaining.remove(id);
}
all_results.extend(wave_results);
}
Ok(all_results)
}
pub fn into_graph(self) -> TaskGraph<T> {
self.graph
}
}
#[derive(Debug, Clone)]
pub struct ResourceConstraints {
pub max_concurrent: usize,
pub max_memory_bytes: usize,
}
impl Default for ResourceConstraints {
fn default() -> Self {
Self {
max_concurrent: 4,
max_memory_bytes: 1 << 30, }
}
}
pub struct ResourceConstrainedScheduler<T: Clone + Send + 'static> {
graph: TaskGraph<T>,
constraints: ResourceConstraints,
}
impl<T: Clone + Send + 'static> ResourceConstrainedScheduler<T> {
pub fn new(graph: TaskGraph<T>, constraints: ResourceConstraints) -> Self {
Self { graph, constraints }
}
pub fn run(&self) -> CoreResult<Vec<TaskResult<T>>> {
let order = self.graph.topological_order()?;
let mut completed: HashSet<TaskId> = HashSet::new();
let mut results: Vec<TaskResult<T>> = Vec::new();
let mut remaining: Vec<TaskId> = order;
let mut in_flight_memory: usize = 0;
loop {
let ready_idx = remaining.iter().position(|id| {
let deps_done = self
.graph
.deps
.get(id)
.map(|deps| deps.iter().all(|d| completed.contains(d)))
.unwrap_or(true);
if !deps_done {
return false;
}
let mem = self
.graph
.nodes
.get(id)
.map(|n| n.memory_bytes)
.unwrap_or(0);
in_flight_memory + mem <= self.constraints.max_memory_bytes
});
match ready_idx {
None => {
if remaining.is_empty() {
break;
}
let fallback = remaining.iter().position(|id| {
self.graph
.deps
.get(id)
.map(|deps| deps.iter().all(|d| completed.contains(d)))
.unwrap_or(true)
});
match fallback {
None => break, Some(idx) => {
let id = remaining.remove(idx);
let node = &self.graph.nodes[&id];
let mem = node.memory_bytes;
in_flight_memory = in_flight_memory.saturating_add(mem);
let r = node.execute();
in_flight_memory = in_flight_memory.saturating_sub(mem);
completed.insert(id);
results.push(r);
}
}
}
Some(idx) => {
let id = remaining.remove(idx);
let node = &self.graph.nodes[&id];
let mem = node.memory_bytes;
in_flight_memory = in_flight_memory.saturating_add(mem);
let r = node.execute();
in_flight_memory = in_flight_memory.saturating_sub(mem);
completed.insert(id);
results.push(r);
}
}
}
Ok(results)
}
}
pub mod dependency_graph;
#[cfg(test)]
mod tests {
use super::*;
fn build_linear_graph() -> TaskGraph<u64> {
let mut g = TaskGraph::new();
let t1 = g.add_task("a", || 1u64);
let t2 = g.add_task("b", || 2u64);
let t3 = g.add_task("c", || 3u64);
g.add_dependency(t2, t1).expect("dep b→a");
g.add_dependency(t3, t2).expect("dep c→b");
g
}
#[test]
fn topological_order_linear() {
let g = build_linear_graph();
let order = g.topological_order().expect("acyclic");
assert_eq!(order.len(), 3);
}
#[test]
fn cycle_detection() {
let mut g: TaskGraph<u64> = TaskGraph::new();
let a = g.add_task("a", || 0u64);
let b = g.add_task("b", || 0u64);
g.add_dependency(b, a).expect("b→a");
assert!(g.add_dependency(a, b).is_err(), "cycle should be rejected");
}
#[test]
fn topological_scheduler_serial() {
let g = build_linear_graph();
let sched = TopologicalScheduler::new(g);
let results = sched.run_serial().expect("serial run");
assert_eq!(results.len(), 3);
assert!(results.iter().all(|r| r.status == TaskStatus::Success));
let names: Vec<&str> = results.iter().map(|r| r.task_name.as_str()).collect();
assert_eq!(names, vec!["a", "b", "c"]);
}
#[test]
fn topological_scheduler_parallel() {
let g = build_linear_graph();
let sched = TopologicalScheduler::new(g);
let results = sched.run_parallel().expect("parallel run");
assert_eq!(results.len(), 3);
}
#[test]
fn critical_path_linear() {
let mut g: TaskGraph<u64> = TaskGraph::new();
let t1id = TaskId(0);
let t2id = TaskId(1);
let t3id = TaskId(2);
g.next_id = 3;
g.nodes.insert(
t1id,
TaskNode::new(t1id, "a", || 0u64).with_estimated_ms(10),
);
g.nodes.insert(
t2id,
TaskNode::new(t2id, "b", || 0u64).with_estimated_ms(20),
);
g.nodes.insert(
t3id,
TaskNode::new(t3id, "c", || 0u64).with_estimated_ms(15),
);
g.deps.insert(t1id, HashSet::new());
g.deps.insert(t2id, {
let mut s = HashSet::new();
s.insert(t1id);
s
});
g.deps.insert(t3id, {
let mut s = HashSet::new();
s.insert(t2id);
s
});
g.dependents.insert(t1id, {
let mut s = HashSet::new();
s.insert(t2id);
s
});
g.dependents.insert(t2id, {
let mut s = HashSet::new();
s.insert(t3id);
s
});
g.dependents.insert(t3id, HashSet::new());
let cp = CriticalPath::compute(&g).expect("critical path");
assert_eq!(cp.total_estimated_ms, 45, "10 + 20 + 15 = 45");
assert_eq!(cp.path.len(), 3);
}
#[test]
fn resource_constrained_scheduler_basic() {
let mut g: TaskGraph<u64> = TaskGraph::new();
g.add_task("a", || 1u64);
g.add_task("b", || 2u64);
g.add_task("c", || 3u64);
let sched = ResourceConstrainedScheduler::new(
g,
ResourceConstraints {
max_concurrent: 2,
max_memory_bytes: 1024,
},
);
let results = sched.run().expect("constrained run");
assert_eq!(results.len(), 3);
}
#[test]
fn skip_on_dep_failure() {
let mut g: TaskGraph<Result<u64, String>> = TaskGraph::new();
let a = g.add_task("fail", || Err::<u64, _>("error".to_string()));
let b = g.add_task("skip_me", || Ok::<u64, _>(42));
g.add_dependency(b, a).expect("b→a");
let sched = TopologicalScheduler::new(g);
let results = sched.run_serial().expect("run");
assert_eq!(results.len(), 2);
}
}