use crate::{Edge, Node};
use slab::Slab;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
pub type NodeId = usize;
#[derive(Debug)]
pub enum GraphError {
InvalidNodeId(NodeId),
EdgeNotFound { from: NodeId, to: NodeId },
MatrixError(String),
}
impl std::fmt::Display for GraphError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GraphError::InvalidNodeId(id) => write!(f, "Invalid node ID: {}", id),
GraphError::EdgeNotFound { from, to } => {
write!(f, "Edge not found between nodes {} and {}", from, to)
}
GraphError::MatrixError(msg) => write!(f, "Matrix error: {}", msg),
}
}
}
impl std::error::Error for GraphError {}
#[derive(Debug, Clone)]
pub struct Graph<W, N, E>
where
W: Copy + PartialEq,
N: Clone + Eq + Hash + Debug,
E: Clone + Debug,
{
pub nodes: Slab<Node<W, N>>,
pub edges: Slab<Edge<W, E>>,
pub directed: bool,
}
impl<W, N, E> Graph<W, N, E>
where
W: Copy + Default + PartialEq,
N: Clone + Eq + Hash + Debug,
E: Clone + Default + Debug,
{
pub fn new(directed: bool) -> Self {
Self {
nodes: Slab::with_capacity(1024),
edges: Slab::with_capacity(4096),
directed,
}
}
pub fn add_node(&mut self, data: N) -> NodeId {
self.nodes.insert(Node {
data,
neighbors: Vec::new(),
attributes: HashMap::new(),
})
}
pub fn remove_node(&mut self, node: NodeId) -> Result<bool, GraphError> {
if !self.nodes.contains(node) {
return Ok(false);
}
let edges_to_remove: Vec<_> = self
.edges
.iter()
.filter(|(_, e)| e.from == node || e.to == node)
.map(|(id, _)| id)
.collect();
for edge_id in edges_to_remove {
self.edges.remove(edge_id);
}
self.nodes.remove(node);
Ok(true)
}
pub fn add_edge(
&mut self,
from: NodeId,
to: NodeId,
weight: W,
edge_data: E,
) -> Result<(), GraphError> {
let (from, to) = self.normalize_edge(from, to);
if !self.nodes.contains(from) {
return Err(GraphError::InvalidNodeId(from));
}
if !self.nodes.contains(to) {
return Err(GraphError::InvalidNodeId(to));
}
let edge = Edge {
from,
to,
weight,
data: edge_data,
attributes: HashMap::new(),
};
let _edge_id = self.edges.insert(edge);
self.nodes[from].neighbors.push((to, weight));
if !self.directed {
self.nodes[to].neighbors.push((from, weight));
}
Ok(())
}
pub fn add_nodes_batch(&mut self, data_iter: impl Iterator<Item = N>) -> Vec<NodeId> {
data_iter.map(|data| self.add_node(data)).collect()
}
pub fn add_edges_batch(
&mut self,
edges: Vec<(NodeId, NodeId, W, E)>,
) -> Result<(), GraphError> {
for (from, to, weight, data) in edges {
self.add_edge(from, to, weight, data)?;
}
Ok(())
}
pub fn remove_edge(&mut self, from: NodeId, to: NodeId) -> Result<(), GraphError> {
let directed = self.directed;
let edge_ids: Vec<_> = self
.edges
.iter()
.filter(|(_, e)| Self::edge_matches(e, from, to, directed))
.map(|(id, _)| id)
.collect();
if edge_ids.is_empty() {
return Err(GraphError::EdgeNotFound { from, to });
}
for edge_id in edge_ids {
self.edges.remove(edge_id);
}
self.nodes[from].neighbors.retain(|&(n, _)| n != to);
if !self.directed {
self.nodes[to].neighbors.retain(|&(n, _)| n != from);
}
Ok(())
}
pub fn from_adjacency_matrix(
matrix: &[Vec<W>],
directed: bool,
default_node: N,
default_edge: E,
) -> Result<Self, GraphError>
where
W: PartialEq + Default,
N: Clone,
E: Clone,
{
if matrix.is_empty() || matrix.iter().any(|row| row.len() != matrix.len()) {
return Err(GraphError::MatrixError(
"Matrix must be non-empty and square".to_string(),
));
}
let mut graph = Graph::new(directed);
let n = matrix.len();
for _ in 0..n {
graph.add_node(default_node.clone());
}
let edges: Vec<_> = (0..n)
.flat_map(|i| {
let range = if directed { 0..n } else { i..n };
let edge_data = default_edge.clone();
range.filter_map(move |j| {
if matrix[i][j] != W::default() {
Some((i, j, matrix[i][j], edge_data.clone()))
} else {
None
}
})
})
.collect();
graph.add_edges_batch(edges)?;
Ok(graph)
}
pub fn to_adjacency_matrix(&self) -> Vec<Vec<W>>
where
W: Default + Copy,
{
let size = self.nodes.len();
let mut matrix = vec![vec![W::default(); size]; size];
for (_, edge) in self.edges.iter() {
matrix[edge.from][edge.to] = edge.weight;
if !self.directed {
matrix[edge.to][edge.from] = edge.weight;
}
}
matrix
}
pub fn set_node_attribute(
&mut self,
node: NodeId,
key: String,
value: String,
) -> Result<bool, GraphError> {
if let Some(node) = self.nodes.get_mut(node) {
node.attributes.insert(key, value);
Ok(true)
} else {
Err(GraphError::InvalidNodeId(node))
}
}
pub fn set_edge_attribute(
&mut self,
from: NodeId,
to: NodeId,
key: String,
value: String,
) -> Result<bool, GraphError> {
let mut found = false;
let directed = self.directed;
for (_, edge) in self.edges.iter_mut() {
if Self::edge_matches(edge, from, to, directed) {
edge.attributes.insert(key.clone(), value.clone());
found = true;
}
}
if found {
Ok(true)
} else {
Err(GraphError::EdgeNotFound { from, to })
}
}
pub fn get_node_attribute(&self, node: NodeId, key: &str) -> Option<&String> {
self.nodes.get(node).and_then(|n| n.attributes.get(key))
}
pub fn get_edge_attribute(&self, from: NodeId, to: NodeId, key: &str) -> Option<&String> {
self.edges
.iter()
.find(|(_, e)| Self::edge_matches(e, from, to, self.directed))
.and_then(|(_, e)| e.attributes.get(key))
}
pub fn get_all_edge_attributes(
&self,
from: NodeId,
to: NodeId,
) -> Option<&HashMap<String, String>> {
self.edges
.iter()
.find(|(_, e)| Self::edge_matches(e, from, to, self.directed))
.map(|(_, e)| &e.attributes)
}
pub fn get_all_edges(&self) -> Vec<(NodeId, NodeId, W, E)> {
self.edges
.iter()
.map(|(_, e)| (e.from, e.to, e.weight, e.data.clone()))
.collect()
}
pub fn get_neighbors(&self, node: NodeId) -> Vec<(NodeId, W)> {
self.nodes
.get(node)
.map(|n| n.neighbors.iter().map(|&(id, w)| (id, w)).collect())
.unwrap_or_default()
}
pub fn validate_graph(&self) -> bool {
let edges_valid = self
.edges
.iter()
.all(|(_, e)| self.nodes.contains(e.from) && self.nodes.contains(e.to));
let neighbors_valid = self
.nodes
.iter()
.all(|(_, n)| n.neighbors.iter().all(|&(id, _)| self.nodes.contains(id)));
edges_valid && neighbors_valid
}
pub fn all_nodes(&self) -> impl Iterator<Item = (NodeId, &N)> + '_ {
self.nodes.iter().map(|(id, node)| (id, &node.data))
}
fn normalize_edge(&self, mut from: NodeId, mut to: NodeId) -> (NodeId, NodeId) {
if !self.directed && from > to {
std::mem::swap(&mut from, &mut to);
}
(from, to)
}
fn edge_matches(edge: &Edge<W, E>, from: NodeId, to: NodeId, directed: bool) -> bool {
if directed {
edge.from == from && edge.to == to
} else {
(edge.from == from && edge.to == to) || (edge.from == to && edge.to == from)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mixed_types() {
let mut graph = Graph::<u32, String, &str>::new(false);
let paris = graph.add_node("Paris".into());
let london = graph.add_node("London".into());
assert!(graph.add_edge(paris, london, 343, "Eurostar").is_ok());
assert_eq!(graph.get_all_edges()[0].3, "Eurostar");
}
#[test]
fn test_directed_graph() {
let mut graph = Graph::<f64, i32, String>::new(true);
let a = graph.add_node(1);
let b = graph.add_node(2);
graph.add_edge(a, b, 1.5, "AB".into()).unwrap();
graph.add_edge(b, a, 2.5, "BA".into()).unwrap();
assert_eq!(graph.get_all_edges().len(), 2);
}
#[test]
fn test_attributes() {
let mut graph = Graph::<i32, &str, ()>::new(false);
let n1 = graph.add_node("Node1");
let n2 = graph.add_node("Node2");
graph
.set_node_attribute(n1, "color".into(), "red".into())
.unwrap();
graph.add_edge(n1, n2, 10, ()).unwrap();
graph
.set_edge_attribute(n1, n2, "type".into(), "road".into())
.unwrap();
assert_eq!(graph.get_node_attribute(n1, "color"), Some(&"red".into()));
assert_eq!(
graph.get_edge_attribute(n1, n2, "type"),
Some(&"road".into())
);
}
#[test]
fn test_empty_graph() {
let graph: Graph<u32, (), ()> = Graph::new(false);
assert!(graph.nodes.is_empty());
assert!(graph.edges.is_empty());
}
#[test]
fn test_complete_graph() {
let mut graph = Graph::<u32, i32, ()>::new(false);
let nodes: Vec<_> = (0..5).map(|i| graph.add_node(i)).collect();
for &a in &nodes {
for &b in &nodes {
if a != b {
graph.add_edge(a, b, 1, ()).unwrap();
}
}
}
assert_eq!(graph.edges.len(), 20);
}
#[test]
fn test_disconnected_graph() {
let mut graph = Graph::<u32, &str, ()>::new(false);
graph.add_node("A");
graph.add_node("B");
assert_eq!(graph.get_all_edges().len(), 0);
}
#[test]
fn test_graph_validation() {
let mut graph = Graph::<u32, &str, ()>::new(false);
let a = graph.add_node("A");
let b = graph.add_node("B");
graph.add_edge(a, b, 1, ()).unwrap();
assert!(graph.validate_graph());
}
#[test]
fn test_varied_node_edge_types() {
let mut graph = Graph::<f32, &str, char>::new(false);
let a = graph.add_node("Alpha");
let b = graph.add_node("Beta");
let c = graph.add_node("Gamma");
graph.add_edge(a, b, 2.5, 'X').unwrap();
graph.add_edge(b, c, 3.1, 'Y').unwrap();
assert_eq!(graph.get_all_edges().len(), 2);
}
#[test]
fn test_mixed_weight_types() {
let mut graph = Graph::<i64, u32, &str>::new(true);
let n1 = graph.add_node(100);
let n2 = graph.add_node(200);
graph.add_edge(n1, n2, 50, "Highway").unwrap();
assert_eq!(graph.get_neighbors(n1).len(), 1);
}
}
pub mod graph {
use super::*;
#[derive(Debug, Clone)]
pub struct Node<W, N> {
pub data: N,
pub neighbors: Vec<(NodeId, W)>,
pub attributes: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct Edge<W, E> {
pub from: NodeId,
pub to: NodeId,
pub weight: W,
pub data: E,
pub attributes: HashMap<String, String>,
}
}