use crate::task::TaskId;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
pub enum CycleDetectionResult {
NoCycle,
CycleDetected(Vec<TaskId>),
}
impl CycleDetectionResult {
pub fn has_cycle(&self) -> bool {
matches!(self, CycleDetectionResult::CycleDetected(_))
}
pub fn cycle_path(&self) -> Option<&[TaskId]> {
match self {
CycleDetectionResult::CycleDetected(path) => Some(path),
CycleDetectionResult::NoCycle => None,
}
}
}
pub struct CycleDetector {
adjacency: HashMap<TaskId, Vec<TaskId>>,
tasks: HashSet<TaskId>,
}
impl CycleDetector {
pub fn new() -> Self {
Self {
adjacency: HashMap::new(),
tasks: HashSet::new(),
}
}
pub fn add_task(&mut self, task_id: TaskId) {
self.tasks.insert(task_id);
self.adjacency.entry(task_id).or_insert_with(Vec::new);
}
pub fn add_dependency(&mut self, from: TaskId, to: TaskId) {
self.add_task(from);
self.add_task(to);
self.adjacency.entry(from).or_insert_with(Vec::new).push(to);
}
pub fn detect_cycle(&self) -> CycleDetectionResult {
let mut visited = HashSet::new();
let mut in_stack = HashSet::new();
let mut parent = HashMap::new();
for &task_id in &self.tasks {
if !visited.contains(&task_id) {
if let Some(cycle) = self.dfs_cycle_detection(
task_id,
&mut visited,
&mut in_stack,
&mut parent,
) {
return CycleDetectionResult::CycleDetected(cycle);
}
}
}
CycleDetectionResult::NoCycle
}
fn dfs_cycle_detection(
&self,
node: TaskId,
visited: &mut HashSet<TaskId>,
in_stack: &mut HashSet<TaskId>,
parent: &mut HashMap<TaskId, TaskId>,
) -> Option<Vec<TaskId>> {
visited.insert(node);
in_stack.insert(node);
if let Some(successors) = self.adjacency.get(&node) {
for &successor in successors {
if !visited.contains(&successor) {
parent.insert(successor, node);
if let Some(cycle) = self.dfs_cycle_detection(
successor,
visited,
in_stack,
parent,
) {
return Some(cycle);
}
} else if in_stack.contains(&successor) {
return Some(self.reconstruct_cycle(node, successor, parent));
}
}
}
in_stack.remove(&node);
None
}
fn reconstruct_cycle(
&self,
start: TaskId,
end: TaskId,
parent: &HashMap<TaskId, TaskId>,
) -> Vec<TaskId> {
let mut cycle = vec![end];
let mut current = start;
while current != end {
cycle.push(current);
if let Some(&p) = parent.get(¤t) {
current = p;
} else {
break;
}
}
cycle.reverse();
cycle
}
pub fn topological_sort(&self) -> Option<Vec<TaskId>> {
if self.detect_cycle().has_cycle() {
return None;
}
let mut in_degree: HashMap<TaskId, usize> = HashMap::new();
for &task_id in &self.tasks {
in_degree.entry(task_id).or_insert(0);
}
for successors in self.adjacency.values() {
for &successor in successors {
*in_degree.entry(successor).or_insert(0) += 1;
}
}
let mut queue: VecDeque<TaskId> = in_degree
.iter()
.filter(|(_, °ree)| degree == 0)
.map(|(&id, _)| id)
.collect();
let mut sorted = Vec::new();
while let Some(node) = queue.pop_front() {
sorted.push(node);
if let Some(successors) = self.adjacency.get(&node) {
for &successor in successors {
if let Some(degree) = in_degree.get_mut(&successor) {
*degree -= 1;
if *degree == 0 {
queue.push_back(successor);
}
}
}
}
}
if sorted.len() == self.tasks.len() {
Some(sorted)
} else {
None
}
}
pub fn strongly_connected_components(&self) -> Vec<Vec<TaskId>> {
let mut index_counter = 0;
let mut stack = Vec::new();
let mut indices = HashMap::new();
let mut lowlinks = HashMap::new();
let mut on_stack = HashSet::new();
let mut components = Vec::new();
for &task_id in &self.tasks {
if !indices.contains_key(&task_id) {
self.tarjan_scc(
task_id,
&mut index_counter,
&mut stack,
&mut indices,
&mut lowlinks,
&mut on_stack,
&mut components,
);
}
}
components
}
fn tarjan_scc(
&self,
node: TaskId,
index_counter: &mut usize,
stack: &mut Vec<TaskId>,
indices: &mut HashMap<TaskId, usize>,
lowlinks: &mut HashMap<TaskId, usize>,
on_stack: &mut HashSet<TaskId>,
components: &mut Vec<Vec<TaskId>>,
) {
indices.insert(node, *index_counter);
lowlinks.insert(node, *index_counter);
*index_counter += 1;
stack.push(node);
on_stack.insert(node);
if let Some(successors) = self.adjacency.get(&node) {
for &successor in successors {
if !indices.contains_key(&successor) {
self.tarjan_scc(
successor,
index_counter,
stack,
indices,
lowlinks,
on_stack,
components,
);
let successor_lowlink = *lowlinks.get(&successor).unwrap();
let node_lowlink = *lowlinks.get(&node).unwrap();
lowlinks.insert(node, node_lowlink.min(successor_lowlink));
} else if on_stack.contains(&successor) {
let successor_index = *indices.get(&successor).unwrap();
let node_lowlink = *lowlinks.get(&node).unwrap();
lowlinks.insert(node, node_lowlink.min(successor_index));
}
}
}
if lowlinks.get(&node) == indices.get(&node) {
let mut component = Vec::new();
loop {
let w = stack.pop().unwrap();
on_stack.remove(&w);
component.push(w);
if w == node {
break;
}
}
if component.len() > 1 || self.has_self_loop(node) {
components.push(component);
}
}
}
fn has_self_loop(&self, node: TaskId) -> bool {
if let Some(successors) = self.adjacency.get(&node) {
successors.contains(&node)
} else {
false
}
}
}
impl Default for CycleDetector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_cycle() {
let mut detector = CycleDetector::new();
detector.add_dependency(1, 2);
detector.add_dependency(2, 3);
let result = detector.detect_cycle();
assert!(!result.has_cycle());
}
#[test]
fn test_simple_cycle() {
let mut detector = CycleDetector::new();
detector.add_dependency(1, 2);
detector.add_dependency(2, 3);
detector.add_dependency(3, 1);
let result = detector.detect_cycle();
assert!(result.has_cycle());
let cycle = result.cycle_path().unwrap();
assert!(!cycle.is_empty());
}
#[test]
fn test_self_loop() {
let mut detector = CycleDetector::new();
detector.add_dependency(1, 1);
let result = detector.detect_cycle();
assert!(result.has_cycle());
}
#[test]
fn test_complex_graph_no_cycle() {
let mut detector = CycleDetector::new();
detector.add_dependency(1, 2);
detector.add_dependency(1, 3);
detector.add_dependency(2, 4);
detector.add_dependency(3, 4);
let result = detector.detect_cycle();
assert!(!result.has_cycle());
}
#[test]
fn test_topological_sort() {
let mut detector = CycleDetector::new();
detector.add_dependency(1, 2);
detector.add_dependency(2, 3);
detector.add_dependency(1, 3);
let sorted = detector.topological_sort();
assert!(sorted.is_some());
let sorted = sorted.unwrap();
assert_eq!(sorted.len(), 3);
let pos_1 = sorted.iter().position(|&x| x == 1).unwrap();
let pos_2 = sorted.iter().position(|&x| x == 2).unwrap();
let pos_3 = sorted.iter().position(|&x| x == 3).unwrap();
assert!(pos_1 < pos_2);
assert!(pos_1 < pos_3);
assert!(pos_2 < pos_3);
}
#[test]
fn test_topological_sort_with_cycle() {
let mut detector = CycleDetector::new();
detector.add_dependency(1, 2);
detector.add_dependency(2, 3);
detector.add_dependency(3, 1);
let sorted = detector.topological_sort();
assert!(sorted.is_none());
}
}