use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt::{Debug, Display};
use std::hash::Hash;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GraphType {
Undirected,
Directed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct NodeId(pub(crate) usize);
impl Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Node({})", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct EdgeId(pub(crate) usize);
impl Display for EdgeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Edge({})", self.0)
}
}
#[derive(Debug, Clone)]
pub struct Node<N> {
pub id: NodeId,
pub data: N,
pub(crate) outgoing: Vec<EdgeId>,
pub(crate) incoming: Vec<EdgeId>,
}
impl<N> Node<N> {
pub fn new(id: NodeId, data: N) -> Self {
Node {
id,
data,
outgoing: Vec::new(),
incoming: Vec::new(),
}
}
pub fn degree(&self) -> usize {
self.outgoing.len() + self.incoming.len()
}
pub fn out_degree(&self) -> usize {
self.outgoing.len()
}
pub fn in_degree(&self) -> usize {
self.incoming.len()
}
}
#[derive(Debug, Clone)]
pub struct Edge<W> {
pub id: EdgeId,
pub source: NodeId,
pub target: NodeId,
pub weight: Option<W>,
}
impl<W> Edge<W> {
pub fn new(id: EdgeId, source: NodeId, target: NodeId, weight: Option<W>) -> Self {
Edge {
id,
source,
target,
weight,
}
}
}
#[derive(Debug, Clone)]
pub enum GraphError {
NodeNotFound(NodeId),
EdgeNotFound(EdgeId),
InvalidOperation(String),
CycleDetected,
NotConnected,
NegativeWeightCycle,
}
impl std::fmt::Display for GraphError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GraphError::NodeNotFound(id) => write!(f, "Node not found: {}", id),
GraphError::EdgeNotFound(id) => write!(f, "Edge not found: {}", id),
GraphError::InvalidOperation(msg) => write!(f, "Invalid operation: {}", msg),
GraphError::CycleDetected => write!(f, "Cycle detected in graph"),
GraphError::NotConnected => write!(f, "Graph is not connected"),
GraphError::NegativeWeightCycle => write!(f, "Negative weight cycle detected"),
}
}
}
impl std::error::Error for GraphError {}
#[derive(Debug, Clone)]
pub struct Graph<N, W> {
graph_type: GraphType,
nodes: HashMap<NodeId, Node<N>>,
edges: HashMap<EdgeId, Edge<W>>,
next_node_id: usize,
next_edge_id: usize,
adjacency: HashMap<NodeId, Vec<(NodeId, EdgeId)>>,
reverse_adjacency: HashMap<NodeId, Vec<(NodeId, EdgeId)>>,
}
impl<N, W> Graph<N, W>
where
N: Clone + Debug,
W: Clone + Debug,
{
pub fn new(graph_type: GraphType) -> Self {
Graph {
graph_type,
nodes: HashMap::new(),
edges: HashMap::new(),
next_node_id: 0,
next_edge_id: 0,
adjacency: HashMap::new(),
reverse_adjacency: HashMap::new(),
}
}
pub fn graph_type(&self) -> GraphType {
self.graph_type
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn add_node(&mut self, data: N) -> NodeId {
let id = NodeId(self.next_node_id);
self.next_node_id += 1;
let node = Node::new(id, data);
self.nodes.insert(id, node);
self.adjacency.insert(id, Vec::new());
self.reverse_adjacency.insert(id, Vec::new());
id
}
pub fn remove_node(&mut self, node_id: NodeId) -> Result<N, GraphError> {
let edges_to_remove: Vec<EdgeId> = self
.edges
.iter()
.filter(|(_, edge)| edge.source == node_id || edge.target == node_id)
.map(|(id, _)| *id)
.collect();
for edge_id in edges_to_remove {
self.remove_edge(edge_id)?;
}
let node = self
.nodes
.remove(&node_id)
.ok_or(GraphError::NodeNotFound(node_id))?;
self.adjacency.remove(&node_id);
self.reverse_adjacency.remove(&node_id);
Ok(node.data)
}
pub fn get_node(&self, node_id: NodeId) -> Option<&Node<N>> {
self.nodes.get(&node_id)
}
pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut Node<N>> {
self.nodes.get_mut(&node_id)
}
pub fn add_edge(
&mut self,
source: NodeId,
target: NodeId,
weight: Option<W>,
) -> Result<EdgeId, GraphError> {
if !self.nodes.contains_key(&source) {
return Err(GraphError::NodeNotFound(source));
}
if !self.nodes.contains_key(&target) {
return Err(GraphError::NodeNotFound(target));
}
let edge_id = EdgeId(self.next_edge_id);
self.next_edge_id += 1;
let edge = Edge::new(edge_id, source, target, weight);
self.edges.insert(edge_id, edge);
self.adjacency
.get_mut(&source)
.expect("operation should succeed")
.push((target, edge_id));
if self.graph_type == GraphType::Directed {
self.reverse_adjacency
.get_mut(&target)
.expect("operation should succeed")
.push((source, edge_id));
self.nodes
.get_mut(&source)
.expect("operation should succeed")
.outgoing
.push(edge_id);
self.nodes
.get_mut(&target)
.expect("operation should succeed")
.incoming
.push(edge_id);
} else {
self.adjacency
.get_mut(&target)
.expect("operation should succeed")
.push((source, edge_id));
self.nodes
.get_mut(&source)
.expect("operation should succeed")
.outgoing
.push(edge_id);
self.nodes
.get_mut(&target)
.expect("operation should succeed")
.outgoing
.push(edge_id);
}
Ok(edge_id)
}
pub fn remove_edge(&mut self, edge_id: EdgeId) -> Result<Edge<W>, GraphError> {
let edge = self
.edges
.remove(&edge_id)
.ok_or(GraphError::EdgeNotFound(edge_id))?;
if let Some(neighbors) = self.adjacency.get_mut(&edge.source) {
neighbors.retain(|(_, eid)| *eid != edge_id);
}
if self.graph_type == GraphType::Directed {
if let Some(neighbors) = self.reverse_adjacency.get_mut(&edge.target) {
neighbors.retain(|(_, eid)| *eid != edge_id);
}
if let Some(node) = self.nodes.get_mut(&edge.source) {
node.outgoing.retain(|eid| *eid != edge_id);
}
if let Some(node) = self.nodes.get_mut(&edge.target) {
node.incoming.retain(|eid| *eid != edge_id);
}
} else {
if let Some(neighbors) = self.adjacency.get_mut(&edge.target) {
neighbors.retain(|(_, eid)| *eid != edge_id);
}
if let Some(node) = self.nodes.get_mut(&edge.source) {
node.outgoing.retain(|eid| *eid != edge_id);
}
if let Some(node) = self.nodes.get_mut(&edge.target) {
node.outgoing.retain(|eid| *eid != edge_id);
}
}
Ok(edge)
}
pub fn get_edge(&self, edge_id: EdgeId) -> Option<&Edge<W>> {
self.edges.get(&edge_id)
}
pub fn get_edge_mut(&mut self, edge_id: EdgeId) -> Option<&mut Edge<W>> {
self.edges.get_mut(&edge_id)
}
pub fn nodes(&self) -> impl Iterator<Item = (&NodeId, &Node<N>)> {
self.nodes.iter()
}
pub fn edges(&self) -> impl Iterator<Item = (&EdgeId, &Edge<W>)> {
self.edges.iter()
}
pub fn node_ids(&self) -> impl Iterator<Item = NodeId> + '_ {
self.nodes.keys().copied()
}
pub fn edge_ids(&self) -> impl Iterator<Item = EdgeId> + '_ {
self.edges.keys().copied()
}
pub fn neighbors(&self, node_id: NodeId) -> Option<Vec<NodeId>> {
self.adjacency
.get(&node_id)
.map(|neighbors| neighbors.iter().map(|(n, _)| *n).collect())
}
pub fn predecessors(&self, node_id: NodeId) -> Option<Vec<NodeId>> {
if self.graph_type == GraphType::Directed {
self.reverse_adjacency
.get(&node_id)
.map(|neighbors| neighbors.iter().map(|(n, _)| *n).collect())
} else {
self.neighbors(node_id)
}
}
pub fn edges_of(&self, node_id: NodeId) -> Option<Vec<EdgeId>> {
let node = self.nodes.get(&node_id)?;
if self.graph_type == GraphType::Directed {
let mut edges: Vec<EdgeId> = node.outgoing.clone();
edges.extend(node.incoming.iter().cloned());
Some(edges)
} else {
Some(node.outgoing.clone())
}
}
pub fn has_edge(&self, source: NodeId, target: NodeId) -> bool {
if let Some(neighbors) = self.adjacency.get(&source) {
neighbors.iter().any(|(n, _)| *n == target)
} else {
false
}
}
pub fn get_edge_between(&self, source: NodeId, target: NodeId) -> Option<&Edge<W>> {
let neighbors = self.adjacency.get(&source)?;
let (_, edge_id) = neighbors.iter().find(|(n, _)| *n == target)?;
self.edges.get(edge_id)
}
pub fn degree(&self, node_id: NodeId) -> Option<usize> {
self.nodes.get(&node_id).map(|n| {
if self.graph_type == GraphType::Directed {
n.outgoing.len() + n.incoming.len()
} else {
n.outgoing.len()
}
})
}
pub fn in_degree(&self, node_id: NodeId) -> Option<usize> {
if self.graph_type == GraphType::Directed {
self.nodes.get(&node_id).map(|n| n.incoming.len())
} else {
self.degree(node_id)
}
}
pub fn out_degree(&self, node_id: NodeId) -> Option<usize> {
self.nodes.get(&node_id).map(|n| n.outgoing.len())
}
pub fn subgraph(&self, node_ids: &[NodeId]) -> Graph<N, W> {
let node_set: HashSet<NodeId> = node_ids.iter().copied().collect();
let mut subgraph = Graph::new(self.graph_type);
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
for &node_id in node_ids {
if let Some(node) = self.nodes.get(&node_id) {
let new_id = subgraph.add_node(node.data.clone());
id_map.insert(node_id, new_id);
}
}
for edge in self.edges.values() {
if node_set.contains(&edge.source) && node_set.contains(&edge.target) {
let new_source = id_map[&edge.source];
let new_target = id_map[&edge.target];
let _ = subgraph.add_edge(new_source, new_target, edge.weight.clone());
}
}
subgraph
}
pub fn reverse(&self) -> Graph<N, W> {
let mut reversed = Graph::new(self.graph_type);
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
for (node_id, node) in &self.nodes {
let new_id = reversed.add_node(node.data.clone());
id_map.insert(*node_id, new_id);
}
for edge in self.edges.values() {
let new_source = id_map[&edge.target];
let new_target = id_map[&edge.source];
let _ = reversed.add_edge(new_source, new_target, edge.weight.clone());
}
reversed
}
pub fn to_undirected(&self) -> Graph<N, W> {
let mut undirected = Graph::new(GraphType::Undirected);
let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
for (node_id, node) in &self.nodes {
let new_id = undirected.add_node(node.data.clone());
id_map.insert(*node_id, new_id);
}
let mut added_edges: HashSet<(NodeId, NodeId)> = HashSet::new();
for edge in self.edges.values() {
let new_source = id_map[&edge.source];
let new_target = id_map[&edge.target];
let normalized = if new_source.0 < new_target.0 {
(new_source, new_target)
} else {
(new_target, new_source)
};
if !added_edges.contains(&normalized) {
let _ = undirected.add_edge(new_source, new_target, edge.weight.clone());
added_edges.insert(normalized);
}
}
undirected
}
pub fn clear(&mut self) {
self.nodes.clear();
self.edges.clear();
self.adjacency.clear();
self.reverse_adjacency.clear();
self.next_node_id = 0;
self.next_edge_id = 0;
}
}
impl<N, W> Default for Graph<N, W>
where
N: Clone + Debug,
W: Clone + Debug,
{
fn default() -> Self {
Graph::new(GraphType::Undirected)
}
}
pub struct GraphBuilder<N, W> {
graph: Graph<N, W>,
node_map: HashMap<String, NodeId>,
}
impl<N, W> GraphBuilder<N, W>
where
N: Clone + Debug,
W: Clone + Debug,
{
pub fn new(graph_type: GraphType) -> Self {
GraphBuilder {
graph: Graph::new(graph_type),
node_map: HashMap::new(),
}
}
pub fn undirected() -> Self {
Self::new(GraphType::Undirected)
}
pub fn directed() -> Self {
Self::new(GraphType::Directed)
}
pub fn add_node(mut self, name: &str, data: N) -> Self {
let id = self.graph.add_node(data);
self.node_map.insert(name.to_string(), id);
self
}
pub fn add_edge(mut self, source: &str, target: &str, weight: Option<W>) -> Self {
if let (Some(&src), Some(&tgt)) = (self.node_map.get(source), self.node_map.get(target)) {
let _ = self.graph.add_edge(src, tgt, weight);
}
self
}
pub fn build(self) -> Graph<N, W> {
self.graph
}
pub fn get_node_id(&self, name: &str) -> Option<NodeId> {
self.node_map.get(name).copied()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_graph() {
let graph: Graph<&str, f64> = Graph::new(GraphType::Undirected);
assert_eq!(graph.node_count(), 0);
assert_eq!(graph.edge_count(), 0);
assert!(graph.is_empty());
}
#[test]
fn test_add_nodes() {
let mut graph: Graph<&str, f64> = Graph::new(GraphType::Undirected);
let a = graph.add_node("A");
let b = graph.add_node("B");
assert_eq!(graph.node_count(), 2);
assert_eq!(
graph.get_node(a).expect("operation should succeed").data,
"A"
);
assert_eq!(
graph.get_node(b).expect("operation should succeed").data,
"B"
);
}
#[test]
fn test_add_edges() {
let mut graph: Graph<&str, f64> = Graph::new(GraphType::Undirected);
let a = graph.add_node("A");
let b = graph.add_node("B");
let c = graph.add_node("C");
graph
.add_edge(a, b, Some(1.0))
.expect("operation should succeed");
graph
.add_edge(b, c, Some(2.0))
.expect("operation should succeed");
assert_eq!(graph.edge_count(), 2);
assert!(graph.has_edge(a, b));
assert!(graph.has_edge(b, a)); assert!(graph.has_edge(b, c));
assert!(!graph.has_edge(a, c));
}
#[test]
fn test_directed_graph() {
let mut graph: Graph<&str, f64> = Graph::new(GraphType::Directed);
let a = graph.add_node("A");
let b = graph.add_node("B");
graph
.add_edge(a, b, Some(1.0))
.expect("operation should succeed");
assert!(graph.has_edge(a, b));
assert!(!graph.has_edge(b, a)); }
#[test]
fn test_neighbors() {
let mut graph: Graph<&str, f64> = Graph::new(GraphType::Undirected);
let a = graph.add_node("A");
let b = graph.add_node("B");
let c = graph.add_node("C");
graph
.add_edge(a, b, None)
.expect("operation should succeed");
graph
.add_edge(a, c, None)
.expect("operation should succeed");
let neighbors = graph.neighbors(a).expect("operation should succeed");
assert_eq!(neighbors.len(), 2);
assert!(neighbors.contains(&b));
assert!(neighbors.contains(&c));
}
#[test]
fn test_graph_builder() {
let graph: Graph<&str, f64> = GraphBuilder::undirected()
.add_node("a", "A")
.add_node("b", "B")
.add_node("c", "C")
.add_edge("a", "b", Some(1.0))
.add_edge("b", "c", Some(2.0))
.build();
assert_eq!(graph.node_count(), 3);
assert_eq!(graph.edge_count(), 2);
}
#[test]
fn test_remove_node() {
let mut graph: Graph<&str, f64> = Graph::new(GraphType::Undirected);
let a = graph.add_node("A");
let b = graph.add_node("B");
let c = graph.add_node("C");
graph
.add_edge(a, b, None)
.expect("operation should succeed");
graph
.add_edge(b, c, None)
.expect("operation should succeed");
assert_eq!(graph.edge_count(), 2);
graph.remove_node(b).expect("operation should succeed");
assert_eq!(graph.node_count(), 2);
assert_eq!(graph.edge_count(), 0); }
#[test]
fn test_subgraph() {
let mut graph: Graph<&str, f64> = Graph::new(GraphType::Undirected);
let a = graph.add_node("A");
let b = graph.add_node("B");
let c = graph.add_node("C");
let d = graph.add_node("D");
graph
.add_edge(a, b, None)
.expect("operation should succeed");
graph
.add_edge(b, c, None)
.expect("operation should succeed");
graph
.add_edge(c, d, None)
.expect("operation should succeed");
let subgraph = graph.subgraph(&[a, b, c]);
assert_eq!(subgraph.node_count(), 3);
assert_eq!(subgraph.edge_count(), 2); }
}