use std::collections::HashMap;
use std::hash::Hash;
use crate::{
utils::graph::{algorithms, DirectedGraph, NodeId},
Result,
};
#[derive(Debug, Clone)]
pub struct IndexedGraph<K, E>
where
K: Hash + Eq + Clone,
{
graph: DirectedGraph<'static, (), E>,
key_to_node: HashMap<K, NodeId>,
node_to_key: HashMap<NodeId, K>,
}
impl<K, E> Default for IndexedGraph<K, E>
where
K: Hash + Eq + Clone,
{
fn default() -> Self {
Self::new()
}
}
impl<K, E> IndexedGraph<K, E>
where
K: Hash + Eq + Clone,
{
#[must_use]
pub fn new() -> Self {
Self {
graph: DirectedGraph::new(),
key_to_node: HashMap::new(),
node_to_key: HashMap::new(),
}
}
#[must_use]
pub fn with_capacity(node_capacity: usize, edge_capacity: usize) -> Self {
Self {
graph: DirectedGraph::with_capacity(node_capacity, edge_capacity),
key_to_node: HashMap::with_capacity(node_capacity),
node_to_key: HashMap::with_capacity(node_capacity),
}
}
pub fn add_node(&mut self, key: K) -> NodeId {
if let Some(&node_id) = self.key_to_node.get(&key) {
return node_id;
}
let node_id = self.graph.add_node(());
self.key_to_node.insert(key.clone(), node_id);
self.node_to_key.insert(node_id, key);
node_id
}
pub fn add_edge(&mut self, from: K, to: K, data: E) -> Result<bool>
where
E: Clone,
{
let from_node = self.add_node(from);
let to_node = self.add_node(to);
if self.graph.successors(from_node).any(|s| s == to_node) {
return Ok(false);
}
self.graph.add_edge(from_node, to_node, data)?;
Ok(true)
}
#[must_use]
pub fn get_node_id(&self, key: &K) -> Option<NodeId> {
self.key_to_node.get(key).copied()
}
#[must_use]
pub fn get_key(&self, node_id: NodeId) -> Option<&K> {
self.node_to_key.get(&node_id)
}
#[must_use]
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
#[must_use]
pub fn edge_count(&self) -> usize {
self.graph.edge_count()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.graph.is_empty()
}
#[must_use]
pub fn inner(&self) -> &DirectedGraph<'static, (), E> {
&self.graph
}
pub fn keys(&self) -> impl Iterator<Item = &K> {
self.key_to_node.keys()
}
#[must_use]
pub fn map_nodes_to_keys(&self, nodes: &[NodeId]) -> Vec<K> {
nodes
.iter()
.filter_map(|node_id| self.node_to_key.get(node_id).cloned())
.collect()
}
#[must_use]
pub fn map_sccs_to_keys(&self, sccs: &[Vec<NodeId>]) -> Vec<Vec<K>> {
sccs.iter().map(|scc| self.map_nodes_to_keys(scc)).collect()
}
}
impl<K, E> IndexedGraph<K, E>
where
K: Hash + Eq + Clone,
{
#[must_use]
pub fn find_cycle_from(&self, start: &K) -> Option<Vec<K>> {
let start_node = self.key_to_node.get(start)?;
let cycle_nodes = algorithms::find_cycle(&self.graph, *start_node)?;
Some(self.map_nodes_to_keys(&cycle_nodes))
}
#[must_use]
pub fn has_cycle_from(&self, start: &K) -> bool {
self.key_to_node
.get(start)
.is_some_and(|&start_node| algorithms::has_cycle(&self.graph, start_node))
}
#[must_use]
pub fn find_any_cycle(&self) -> Option<Vec<K>> {
for &start_node in self.key_to_node.values() {
if let Some(cycle_nodes) = algorithms::find_cycle(&self.graph, start_node) {
return Some(self.map_nodes_to_keys(&cycle_nodes));
}
}
None
}
#[must_use]
pub fn strongly_connected_components(&self) -> Vec<Vec<K>> {
let sccs = algorithms::strongly_connected_components(&self.graph);
self.map_sccs_to_keys(&sccs)
}
#[must_use]
pub fn topological_sort(&self) -> Option<Vec<K>> {
let order = algorithms::topological_sort(&self.graph)?;
Some(self.map_nodes_to_keys(&order))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_indexed_graph_basic() {
let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new();
let a = graph.add_node("A");
let b = graph.add_node("B");
assert_eq!(graph.node_count(), 2);
assert_eq!(graph.get_node_id(&"A"), Some(a));
assert_eq!(graph.get_node_id(&"B"), Some(b));
assert_eq!(graph.get_key(a), Some(&"A"));
assert_eq!(graph.get_key(b), Some(&"B"));
}
#[test]
fn test_indexed_graph_idempotent_add() {
let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new();
let a1 = graph.add_node("A");
let a2 = graph.add_node("A");
assert_eq!(a1, a2);
assert_eq!(graph.node_count(), 1);
}
#[test]
fn test_indexed_graph_add_edge() {
let mut graph: IndexedGraph<&str, i32> = IndexedGraph::new();
assert!(graph.add_edge("A", "B", 10).unwrap());
assert!(graph.add_edge("B", "C", 20).unwrap());
assert_eq!(graph.node_count(), 3);
assert_eq!(graph.edge_count(), 2);
assert!(!graph.add_edge("A", "B", 10).unwrap());
assert_eq!(graph.edge_count(), 2);
}
#[test]
fn test_indexed_graph_find_cycle() {
let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new();
graph.add_edge("A", "B", ()).unwrap();
graph.add_edge("B", "C", ()).unwrap();
graph.add_edge("C", "A", ()).unwrap();
let cycle = graph.find_cycle_from(&"A");
assert!(cycle.is_some());
let cycle = cycle.unwrap();
assert!(cycle.contains(&"A"));
assert!(cycle.contains(&"B"));
assert!(cycle.contains(&"C"));
}
#[test]
fn test_indexed_graph_no_cycle() {
let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new();
graph.add_edge("A", "B", ()).unwrap();
graph.add_edge("B", "C", ()).unwrap();
assert!(graph.find_cycle_from(&"A").is_none());
assert!(!graph.has_cycle_from(&"A"));
}
#[test]
fn test_indexed_graph_topological_sort() {
let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new();
graph.add_edge("A", "B", ()).unwrap();
graph.add_edge("A", "C", ()).unwrap();
graph.add_edge("B", "D", ()).unwrap();
graph.add_edge("C", "D", ()).unwrap();
let order = graph.topological_sort();
assert!(order.is_some());
let order = order.unwrap();
assert_eq!(order.len(), 4);
let pos = |k: &str| order.iter().position(|&x| x == k).unwrap();
assert!(pos("A") < pos("B"));
assert!(pos("A") < pos("C"));
assert!(pos("B") < pos("D"));
assert!(pos("C") < pos("D"));
}
#[test]
fn test_indexed_graph_topological_sort_with_cycle() {
let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new();
graph.add_edge("A", "B", ()).unwrap();
graph.add_edge("B", "A", ()).unwrap();
assert!(graph.topological_sort().is_none());
}
#[test]
fn test_indexed_graph_scc() {
let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new();
graph.add_edge("A", "B", ()).unwrap();
graph.add_edge("B", "A", ()).unwrap(); graph.add_edge("B", "C", ()).unwrap();
let sccs = graph.strongly_connected_components();
assert_eq!(sccs.len(), 2);
let mut sizes: Vec<usize> = sccs.iter().map(|scc| scc.len()).collect();
sizes.sort();
assert_eq!(sizes, vec![1, 2]);
}
#[test]
fn test_indexed_graph_with_integers() {
let mut graph: IndexedGraph<i32, &str> = IndexedGraph::new();
graph.add_edge(1, 2, "one-two").unwrap();
graph.add_edge(2, 3, "two-three").unwrap();
assert_eq!(graph.node_count(), 3);
assert!(graph.topological_sort().is_some());
}
}