use crate::error::StatError;
use crate::stat_id::StatId;
use petgraph::algo::toposort;
use petgraph::graph::{DiGraph, NodeIndex};
use std::collections::HashMap;
pub struct StatGraph {
graph: DiGraph<StatId, ()>,
node_map: HashMap<StatId, NodeIndex>,
}
impl StatGraph {
pub fn new() -> Self {
Self {
graph: DiGraph::new(),
node_map: HashMap::new(),
}
}
pub fn add_node(&mut self, stat_id: StatId) -> NodeIndex {
if let Some(&idx) = self.node_map.get(&stat_id) {
idx
} else {
let idx = self.graph.add_node(stat_id.clone());
self.node_map.insert(stat_id, idx);
idx
}
}
pub fn add_edge(&mut self, from: StatId, to: StatId) {
let from_idx = self.add_node(from);
let to_idx = self.add_node(to);
self.graph.add_edge(to_idx, from_idx, ());
}
pub fn detect_cycles(&self) -> Result<(), StatError> {
let mut visited = std::collections::HashSet::new();
let mut rec_stack = std::collections::HashSet::new();
let mut cycle_path = Vec::new();
for node_idx in self.graph.node_indices() {
if !visited.contains(&node_idx)
&& self.dfs_cycle_detect(node_idx, &mut visited, &mut rec_stack, &mut cycle_path)
{
return Err(StatError::CycleDetected(cycle_path));
}
}
Ok(())
}
fn dfs_cycle_detect(
&self,
node: NodeIndex,
visited: &mut std::collections::HashSet<NodeIndex>,
rec_stack: &mut std::collections::HashSet<NodeIndex>,
cycle_path: &mut Vec<StatId>,
) -> bool {
visited.insert(node);
rec_stack.insert(node);
cycle_path.push(self.graph[node].clone());
for neighbor in self
.graph
.neighbors_directed(node, petgraph::Direction::Outgoing)
{
if !visited.contains(&neighbor) {
if self.dfs_cycle_detect(neighbor, visited, rec_stack, cycle_path) {
return true;
}
} else if rec_stack.contains(&neighbor) {
cycle_path.push(self.graph[neighbor].clone());
return true;
}
}
rec_stack.remove(&node);
cycle_path.pop();
false
}
pub fn topological_sort(&self) -> Result<Vec<StatId>, StatError> {
self.detect_cycles()?;
match toposort(&self.graph, None) {
Ok(indices) => Ok(indices
.into_iter()
.map(|idx| self.graph[idx].clone())
.collect()),
Err(cycle) => {
let cycle_path = vec![self.graph[cycle.node_id()].clone()];
Err(StatError::CycleDetected(cycle_path))
}
}
}
pub fn nodes(&self) -> Vec<StatId> {
self.graph
.node_indices()
.map(|idx| self.graph[idx].clone())
.collect()
}
pub fn contains_node(&self, stat_id: &StatId) -> bool {
self.node_map.contains_key(stat_id)
}
}
impl Default for StatGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graph_add_nodes() {
let mut graph = StatGraph::new();
let hp = StatId::from_str("HP");
let atk = StatId::from_str("ATK");
graph.add_node(hp.clone());
graph.add_node(atk.clone());
assert!(graph.contains_node(&hp));
assert!(graph.contains_node(&atk));
}
#[test]
fn test_graph_add_edge() {
let mut graph = StatGraph::new();
let atk = StatId::from_str("ATK");
let str = StatId::from_str("STR");
graph.add_edge(atk.clone(), str.clone());
assert!(graph.contains_node(&atk));
assert!(graph.contains_node(&str));
}
#[test]
fn test_graph_no_cycle() {
let mut graph = StatGraph::new();
let str = StatId::from_str("STR");
let atk = StatId::from_str("ATK");
let dps = StatId::from_str("DPS");
graph.add_edge(atk.clone(), str.clone());
graph.add_edge(dps.clone(), atk.clone());
assert!(graph.detect_cycles().is_ok());
}
#[test]
fn test_graph_detect_cycle() {
let mut graph = StatGraph::new();
let a = StatId::from_str("A");
let b = StatId::from_str("B");
let c = StatId::from_str("C");
graph.add_edge(b.clone(), a.clone());
graph.add_edge(c.clone(), b.clone());
graph.add_edge(a.clone(), c.clone());
assert!(graph.detect_cycles().is_err());
}
#[test]
fn test_topological_sort() {
let mut graph = StatGraph::new();
let str = StatId::from_str("STR");
let dex = StatId::from_str("DEX");
let atk = StatId::from_str("ATK");
let crit = StatId::from_str("CRIT");
graph.add_edge(atk.clone(), str.clone());
graph.add_edge(crit.clone(), dex.clone());
let sorted = graph.topological_sort().unwrap();
let str_pos = sorted.iter().position(|s| s == &str).unwrap();
let dex_pos = sorted.iter().position(|s| s == &dex).unwrap();
let atk_pos = sorted.iter().position(|s| s == &atk).unwrap();
let crit_pos = sorted.iter().position(|s| s == &crit).unwrap();
assert!(str_pos < atk_pos);
assert!(dex_pos < crit_pos);
}
}