use std::collections::{HashMap, HashSet};
use super::graph::{TaskGraph, TaskId};
#[derive(Debug, Clone)]
pub struct RegionHealth {
pub total_tasks: usize,
pub failed_tasks: usize,
pub failure_rate: f32,
}
impl RegionHealth {
fn new() -> Self {
Self {
total_tasks: 0,
failed_tasks: 0,
failure_rate: 0.0,
}
}
fn record(&mut self, failed: bool) {
self.total_tasks += 1;
if failed {
self.failed_tasks += 1;
}
#[allow(clippy::cast_precision_loss)]
{
self.failure_rate = self.failed_tasks as f32 / self.total_tasks as f32;
}
}
}
#[derive(Debug, Clone)]
pub struct CascadeConfig {
pub failure_threshold: f32,
}
#[derive(Debug)]
pub struct CascadeDetector {
config: CascadeConfig,
region_health: HashMap<TaskId, RegionHealth>,
}
impl CascadeDetector {
#[must_use]
pub fn new(config: CascadeConfig) -> Self {
Self {
config,
region_health: HashMap::new(),
}
}
pub fn record_outcome(&mut self, task_id: TaskId, succeeded: bool, graph: &TaskGraph) {
let root = primary_root(task_id, graph);
self.region_health
.entry(root)
.or_insert_with(RegionHealth::new)
.record(!succeeded);
}
#[must_use]
pub fn is_cascading(&self, task_id: TaskId, graph: &TaskGraph) -> bool {
let root = primary_root(task_id, graph);
self.region_health
.get(&root)
.is_some_and(|h| h.failure_rate > self.config.failure_threshold)
}
#[must_use]
pub fn deprioritized_tasks(&self, graph: &TaskGraph) -> HashSet<TaskId> {
let cascading_roots: HashSet<TaskId> = self
.region_health
.iter()
.filter(|(_, h)| h.failure_rate > self.config.failure_threshold)
.map(|(&root, _)| root)
.collect();
if cascading_roots.is_empty() {
return HashSet::new();
}
let total_regions = self.region_health.len();
if cascading_roots.len() == total_regions && total_regions > 0 {
tracing::warn!(
cascading_regions = total_regions,
"all DAG regions are in cascade failure state; \
deprioritisation has no effect — falling back to default ordering"
);
return HashSet::new();
}
graph
.tasks
.iter()
.filter(|t| cascading_roots.contains(&primary_root(t.id, graph)))
.map(|t| t.id)
.collect()
}
pub fn reset(&mut self) {
self.region_health.clear();
}
#[cfg(test)]
#[must_use]
pub fn region_health(&self) -> &HashMap<TaskId, RegionHealth> {
&self.region_health
}
}
fn primary_root(task_id: TaskId, graph: &TaskGraph) -> TaskId {
let roots = ancestor_roots(task_id, graph);
if roots.is_empty() {
return task_id;
}
if roots.len() == 1 {
return roots[0];
}
roots
.into_iter()
.max_by_key(|&r| (descendant_count(r, graph), u32::MAX - r.as_u32()))
.unwrap_or(task_id)
}
fn ancestor_roots(task_id: TaskId, graph: &TaskGraph) -> Vec<TaskId> {
let mut visited = HashSet::new();
let mut queue = std::collections::VecDeque::new();
queue.push_back(task_id);
visited.insert(task_id);
let mut roots = Vec::new();
while let Some(id) = queue.pop_front() {
let task = &graph.tasks[id.index()];
if task.depends_on.is_empty() {
roots.push(id);
} else {
for &dep in &task.depends_on {
if visited.insert(dep) {
queue.push_back(dep);
}
}
}
}
roots
}
fn descendant_count(root: TaskId, graph: &TaskGraph) -> usize {
let mut visited = HashSet::new();
let mut queue = std::collections::VecDeque::new();
queue.push_back(root);
visited.insert(root);
let mut forward: HashMap<TaskId, Vec<TaskId>> = HashMap::new();
for task in &graph.tasks {
for &dep in &task.depends_on {
forward.entry(dep).or_default().push(task.id);
}
}
while let Some(id) = queue.pop_front() {
if let Some(children) = forward.get(&id) {
for &child in children {
if visited.insert(child) {
queue.push_back(child);
}
}
}
}
visited.len()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::{TaskGraph, TaskId, TaskNode};
fn make_node(id: u32, deps: &[u32]) -> TaskNode {
let mut n = TaskNode::new(id, format!("t{id}"), "desc");
n.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
n
}
fn graph_from(nodes: Vec<TaskNode>) -> TaskGraph {
let mut g = TaskGraph::new("test");
g.tasks = nodes;
g
}
fn cfg(threshold: f32) -> CascadeConfig {
CascadeConfig {
failure_threshold: threshold,
}
}
#[test]
fn root_task_returns_self() {
let g = graph_from(vec![make_node(0, &[])]);
let roots = ancestor_roots(TaskId(0), &g);
assert_eq!(roots, vec![TaskId(0)]);
}
#[test]
fn linear_chain_root_is_task_zero() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[1]),
]);
let roots = ancestor_roots(TaskId(2), &g);
assert_eq!(roots, vec![TaskId(0)]);
}
#[test]
fn diamond_has_two_roots() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[1, 2]),
]);
let mut roots = ancestor_roots(TaskId(3), &g);
roots.sort_by_key(|r| r.as_u32());
assert_eq!(roots, vec![TaskId(0)]);
}
#[test]
fn fan_in_has_multiple_roots() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[]),
make_node(2, &[]),
make_node(3, &[0, 1, 2]),
]);
let mut roots = ancestor_roots(TaskId(3), &g);
roots.sort_by_key(|r| r.as_u32());
assert_eq!(roots, vec![TaskId(0), TaskId(1), TaskId(2)]);
}
#[test]
fn no_failures_not_cascading() {
let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
let mut det = CascadeDetector::new(cfg(0.5));
det.record_outcome(TaskId(1), true, &g);
assert!(!det.is_cascading(TaskId(1), &g));
}
#[test]
fn failure_rate_exceeds_threshold() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[0]),
]);
let mut det = CascadeDetector::new(cfg(0.5));
det.record_outcome(TaskId(1), false, &g);
det.record_outcome(TaskId(2), false, &g);
det.record_outcome(TaskId(3), true, &g);
assert!(det.is_cascading(TaskId(1), &g));
assert!(det.is_cascading(TaskId(2), &g));
assert!(det.is_cascading(TaskId(3), &g));
}
#[test]
fn reset_clears_all_regions() {
let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
let mut det = CascadeDetector::new(cfg(0.3));
det.record_outcome(TaskId(1), false, &g);
det.reset();
assert!(!det.is_cascading(TaskId(1), &g));
assert!(det.region_health().is_empty());
}
#[test]
fn deprioritized_tasks_empty_when_healthy() {
let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
let mut det = CascadeDetector::new(cfg(0.5));
det.record_outcome(TaskId(1), true, &g);
assert!(det.deprioritized_tasks(&g).is_empty());
}
#[test]
fn deprioritized_tasks_returns_failing_subtree() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[]),
make_node(4, &[3]),
]);
let mut det = CascadeDetector::new(cfg(0.4));
det.record_outcome(TaskId(1), false, &g);
det.record_outcome(TaskId(2), false, &g);
det.record_outcome(TaskId(4), true, &g);
let dp = det.deprioritized_tasks(&g);
assert!(dp.contains(&TaskId(0)));
assert!(dp.contains(&TaskId(1)));
assert!(dp.contains(&TaskId(2)));
assert!(!dp.contains(&TaskId(3)));
assert!(!dp.contains(&TaskId(4)));
}
#[test]
fn all_regions_cascading_returns_empty_for_safe_fallback() {
let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
let mut det = CascadeDetector::new(cfg(0.3));
det.record_outcome(TaskId(1), false, &g);
let dp = det.deprioritized_tasks(&g);
assert!(
dp.is_empty(),
"all-regions-cascading should return empty to allow forward progress"
);
}
}