use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use crate::feature::task_parser::TaskEntry;
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SchedulerMode {
#[default]
Dag,
Sequential,
}
#[derive(Debug)]
pub struct BlockedDiagnostic {
pub blocked_tasks: Vec<BlockedTask>,
pub has_cycle: bool,
pub cycle_members: Vec<String>,
}
#[derive(Debug)]
pub struct BlockedTask {
pub number: String,
pub missing_deps: Vec<String>,
}
#[derive(Debug)]
pub enum ScheduleResult<'a> {
Runnable(Vec<&'a TaskEntry>),
AllDone,
Blocked(BlockedDiagnostic),
}
pub fn select_runnable<'a>(tasks: &'a [TaskEntry], mode: SchedulerMode) -> ScheduleResult<'a> {
match mode {
SchedulerMode::Dag => select_runnable_dag(tasks),
SchedulerMode::Sequential => select_runnable_sequential(tasks),
}
}
fn select_runnable_dag<'a>(tasks: &'a [TaskEntry]) -> ScheduleResult<'a> {
let completed_set: HashSet<&str> = tasks
.iter()
.filter(|t| t.completed)
.map(|t| t.number.as_str())
.collect();
let uncompleted: Vec<&TaskEntry> = tasks.iter().filter(|t| !t.completed).collect();
if uncompleted.is_empty() {
return ScheduleResult::AllDone;
}
let mut runnable = Vec::new();
let mut blocked_tasks = Vec::new();
for task in &uncompleted {
let missing_deps: Vec<String> = task
.dependencies
.iter()
.filter(|dep| !completed_set.contains(dep.as_str()))
.cloned()
.collect();
if missing_deps.is_empty() {
runnable.push(*task);
} else {
blocked_tasks.push(BlockedTask {
number: task.number.clone(),
missing_deps,
});
}
}
if !runnable.is_empty() {
ScheduleResult::Runnable(runnable)
} else {
let uncompleted_set: HashSet<&str> =
uncompleted.iter().map(|t| t.number.as_str()).collect();
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
let mut in_degree: HashMap<&str, usize> = HashMap::new();
for t in &uncompleted {
let node = t.number.as_str();
adj.entry(node).or_default();
in_degree.entry(node).or_insert(0);
let unique_deps: HashSet<&str> = t
.dependencies
.iter()
.map(|d| d.as_str())
.filter(|d| uncompleted_set.contains(d) && !completed_set.contains(d))
.collect();
for dep in &unique_deps {
adj.entry(dep).or_default().push(node);
*in_degree.entry(node).or_insert(0) += 1;
}
}
let mut queue: Vec<&str> = in_degree
.iter()
.filter(|(_, °)| deg == 0)
.map(|(&node, _)| node)
.collect();
let mut removed = HashSet::new();
while let Some(node) = queue.pop() {
removed.insert(node);
for &dependent in adj.get(node).unwrap_or(&Vec::new()) {
if let Some(deg) = in_degree.get_mut(dependent) {
*deg = deg.saturating_sub(1);
if *deg == 0 && !removed.contains(dependent) {
queue.push(dependent);
}
}
}
}
let remainder: HashSet<&str> = uncompleted
.iter()
.map(|t| t.number.as_str())
.filter(|n| !removed.contains(n))
.collect();
let cycle_members: Vec<String> = remainder
.iter()
.filter(|&&node| {
let mut visited = HashSet::new();
let mut stack = Vec::new();
for &next in adj.get(node).unwrap_or(&Vec::new()) {
if remainder.contains(next) {
stack.push(next);
}
}
while let Some(current) = stack.pop() {
if current == node {
return true;
}
if !visited.insert(current) {
continue;
}
for &next in adj.get(current).unwrap_or(&Vec::new()) {
if remainder.contains(next) {
stack.push(next);
}
}
}
false
})
.map(|s| s.to_string())
.collect();
let has_cycle = !cycle_members.is_empty();
ScheduleResult::Blocked(BlockedDiagnostic {
blocked_tasks,
has_cycle,
cycle_members,
})
}
}
fn select_runnable_sequential<'a>(tasks: &'a [TaskEntry]) -> ScheduleResult<'a> {
let uncompleted: Vec<&TaskEntry> = tasks.iter().filter(|t| !t.completed).collect();
if uncompleted.is_empty() {
ScheduleResult::AllDone
} else {
ScheduleResult::Runnable(uncompleted)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn task(number: &str, completed: bool, deps: &[&str]) -> TaskEntry {
TaskEntry {
number: number.to_string(),
title: format!("Task {}", number),
completed,
indent_level: 0,
dependencies: deps.iter().map(|s| s.to_string()).collect(),
}
}
#[test]
fn dag_simple_chain() {
let tasks = vec![
task("1", false, &[]),
task("2", false, &["1"]),
task("3", false, &["2"]),
];
match select_runnable(&tasks, SchedulerMode::Dag) {
ScheduleResult::Runnable(r) => {
assert_eq!(
r.len(),
1,
"dag_simple_chain: only task 1 should be runnable"
);
assert_eq!(r[0].number, "1");
}
other => panic!("dag_simple_chain: expected Runnable, got {:?}", other),
}
}
#[test]
fn dag_parallel_after_common_dep() {
let tasks = vec![
task("1", true, &[]),
task("2", false, &["1"]),
task("3", false, &["1"]),
];
match select_runnable(&tasks, SchedulerMode::Dag) {
ScheduleResult::Runnable(r) => {
assert_eq!(
r.len(),
2,
"dag_parallel_after_common_dep: both tasks 2 and 3 should be runnable"
);
let numbers: Vec<&str> = r.iter().map(|t| t.number.as_str()).collect();
assert!(numbers.contains(&"2"));
assert!(numbers.contains(&"3"));
}
other => panic!(
"dag_parallel_after_common_dep: expected Runnable, got {:?}",
other
),
}
}
#[test]
fn dag_blocked_when_deps_incomplete() {
let tasks = vec![task("1", false, &["2"]), task("2", false, &["1"])];
match select_runnable(&tasks, SchedulerMode::Dag) {
ScheduleResult::Blocked(diag) => {
assert_eq!(
diag.blocked_tasks.len(),
2,
"dag_blocked_when_deps_incomplete: both tasks should be blocked"
);
assert!(
diag.has_cycle,
"dag_blocked_when_deps_incomplete: should detect cycle"
);
assert_eq!(
diag.cycle_members.len(),
2,
"dag_blocked_when_deps_incomplete: cycle_members should contain both tasks"
);
}
other => panic!(
"dag_blocked_when_deps_incomplete: expected Blocked, got {:?}",
other
),
}
}
#[test]
fn dag_cycle_detected() {
let tasks = vec![task("1", false, &["2"]), task("2", false, &["1"])];
match select_runnable(&tasks, SchedulerMode::Dag) {
ScheduleResult::Blocked(diag) => {
assert!(
diag.has_cycle,
"dag_cycle_detected: should detect cycle when all deps are known"
);
}
other => panic!("dag_cycle_detected: expected Blocked, got {:?}", other),
}
}
#[test]
fn dag_missing_dep_blocks_task() {
let tasks = vec![task("1", false, &["99"])];
match select_runnable(&tasks, SchedulerMode::Dag) {
ScheduleResult::Blocked(diag) => {
assert_eq!(diag.blocked_tasks.len(), 1);
assert_eq!(diag.blocked_tasks[0].missing_deps, vec!["99"]);
assert!(
!diag.has_cycle,
"dag_missing_dep_blocks_task: should not detect cycle for external dep"
);
}
other => panic!(
"dag_missing_dep_blocks_task: expected Blocked, got {:?}",
other
),
}
}
#[test]
fn dag_no_deps_always_runnable() {
let tasks = vec![
task("1", false, &[]),
task("2", false, &[]),
task("3", false, &[]),
];
match select_runnable(&tasks, SchedulerMode::Dag) {
ScheduleResult::Runnable(r) => {
assert_eq!(
r.len(),
3,
"dag_no_deps_always_runnable: all tasks should be runnable"
);
}
other => panic!(
"dag_no_deps_always_runnable: expected Runnable, got {:?}",
other
),
}
}
#[test]
fn dag_all_done() {
let tasks = vec![task("1", true, &[]), task("2", true, &["1"])];
assert!(
matches!(
select_runnable(&tasks, SchedulerMode::Dag),
ScheduleResult::AllDone
),
"dag_all_done: should return AllDone when all tasks completed"
);
}
#[test]
fn sequential_ignores_deps() {
let tasks = vec![
task("1", false, &["2"]),
task("2", false, &["1"]),
task("3", false, &[]),
];
match select_runnable(&tasks, SchedulerMode::Sequential) {
ScheduleResult::Runnable(r) => {
assert_eq!(
r.len(),
3,
"sequential_ignores_deps: all uncompleted tasks should be returned"
);
assert_eq!(r[0].number, "1");
assert_eq!(r[1].number, "2");
assert_eq!(r[2].number, "3");
}
other => panic!(
"sequential_ignores_deps: expected Runnable, got {:?}",
other
),
}
}
#[test]
fn sequential_all_done() {
let tasks = vec![task("1", true, &[]), task("2", true, &["1"])];
assert!(
matches!(
select_runnable(&tasks, SchedulerMode::Sequential),
ScheduleResult::AllDone
),
"sequential_all_done: should return AllDone when all tasks completed"
);
}
#[test]
fn dag_cycle_reports_cycle_members() {
let tasks = vec![task("1", false, &["2"]), task("2", false, &["1"])];
match select_runnable(&tasks, SchedulerMode::Dag) {
ScheduleResult::Blocked(diag) => {
assert!(
diag.has_cycle,
"dag_cycle_reports_cycle_members: should detect cycle"
);
let members: std::collections::HashSet<&str> =
diag.cycle_members.iter().map(|s| s.as_str()).collect();
assert!(
members.contains("1") && members.contains("2"),
"dag_cycle_reports_cycle_members: cycle_members should contain both tasks, got: {:?}",
diag.cycle_members
);
}
other => panic!(
"dag_cycle_reports_cycle_members: expected Blocked, got {:?}",
other
),
}
}
#[test]
fn dag_three_node_cycle_reports_all_members() {
let tasks = vec![
task("1", false, &["3"]),
task("2", false, &["1"]),
task("3", false, &["2"]),
];
match select_runnable(&tasks, SchedulerMode::Dag) {
ScheduleResult::Blocked(diag) => {
assert!(
diag.has_cycle,
"dag_three_node_cycle_reports_all_members: should detect cycle"
);
let members: std::collections::HashSet<&str> =
diag.cycle_members.iter().map(|s| s.as_str()).collect();
assert!(
members.contains("1") && members.contains("2") && members.contains("3"),
"dag_three_node_cycle_reports_all_members: all 3 tasks should be cycle members, got: {:?}",
diag.cycle_members
);
}
other => panic!(
"dag_three_node_cycle_reports_all_members: expected Blocked, got {:?}",
other
),
}
}
#[test]
fn dag_no_cycle_empty_cycle_members() {
let tasks = vec![task("1", false, &["99"])];
match select_runnable(&tasks, SchedulerMode::Dag) {
ScheduleResult::Blocked(diag) => {
assert!(
!diag.has_cycle,
"dag_no_cycle_empty_cycle_members: should not detect cycle for external dep"
);
assert!(
diag.cycle_members.is_empty(),
"dag_no_cycle_empty_cycle_members: cycle_members should be empty, got: {:?}",
diag.cycle_members
);
}
other => panic!(
"dag_no_cycle_empty_cycle_members: expected Blocked, got {:?}",
other
),
}
}
#[test]
fn dag_self_cycle_detected() {
let tasks = vec![task("1", false, &["1"])];
match select_runnable(&tasks, SchedulerMode::Dag) {
ScheduleResult::Blocked(diag) => {
assert!(
diag.has_cycle,
"dag_self_cycle_detected: should detect self-cycle"
);
assert_eq!(
diag.cycle_members,
vec!["1"],
"dag_self_cycle_detected: cycle_members should contain the self-referencing task"
);
}
other => panic!("dag_self_cycle_detected: expected Blocked, got {:?}", other),
}
}
#[test]
fn dag_downstream_of_cycle_excluded_from_cycle_members() {
let tasks = vec![
task("1", false, &["2"]),
task("2", false, &["1"]),
task("3", false, &["1"]),
];
match select_runnable(&tasks, SchedulerMode::Dag) {
ScheduleResult::Blocked(diag) => {
assert!(
diag.has_cycle,
"dag_downstream_of_cycle: should detect cycle"
);
let members: HashSet<&str> =
diag.cycle_members.iter().map(|s| s.as_str()).collect();
assert!(
members.contains("1") && members.contains("2"),
"dag_downstream_of_cycle: cycle_members should contain 1 and 2, got: {:?}",
diag.cycle_members
);
assert!(
!members.contains("3"),
"dag_downstream_of_cycle: task 3 is downstream, not a cycle member, got: {:?}",
diag.cycle_members
);
}
other => panic!("dag_downstream_of_cycle: expected Blocked, got {:?}", other),
}
}
#[test]
fn dag_duplicate_deps_not_false_cycle() {
let tasks = vec![task("1", false, &[]), task("2", false, &["1", "1"])];
match select_runnable(&tasks, SchedulerMode::Dag) {
ScheduleResult::Runnable(r) => {
assert_eq!(
r.len(),
1,
"dag_duplicate_deps: only task 1 should be runnable"
);
assert_eq!(r[0].number, "1");
}
other => panic!("dag_duplicate_deps: expected Runnable, got {:?}", other),
}
}
}