use std::collections::HashMap;
use scirs2_core::ndarray::Array2;
use crate::error::{GraphError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct NodeId(pub usize);
impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "NodeId({})", self.0)
}
}
#[derive(Debug, Clone)]
pub struct AttributedGraph<N, E> {
nodes: Vec<N>,
adj: Vec<Vec<(usize, usize)>>,
edges: Vec<(usize, usize, E)>,
}
impl<N, E> Default for AttributedGraph<N, E> {
fn default() -> Self {
Self::new()
}
}
impl<N, E> AttributedGraph<N, E> {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
adj: Vec::new(),
edges: Vec::new(),
}
}
pub fn with_capacity(node_capacity: usize, edge_capacity: usize) -> Self {
Self {
nodes: Vec::with_capacity(node_capacity),
adj: Vec::with_capacity(node_capacity),
edges: Vec::with_capacity(edge_capacity),
}
}
pub fn add_node(&mut self, data: N) -> NodeId {
let id = self.nodes.len();
self.nodes.push(data);
self.adj.push(Vec::new());
NodeId(id)
}
pub fn add_edge(&mut self, src: NodeId, dst: NodeId, edge_data: E) -> Result<()> {
self.validate_node(src)?;
self.validate_node(dst)?;
let edge_idx = self.edges.len();
self.edges.push((src.0, dst.0, edge_data));
self.adj[src.0].push((dst.0, edge_idx));
Ok(())
}
#[inline]
pub fn node_count(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn node_data(&self, id: NodeId) -> Option<&N> {
self.nodes.get(id.0)
}
pub fn node_data_mut(&mut self, id: NodeId) -> Option<&mut N> {
self.nodes.get_mut(id.0)
}
pub fn nodes(&self) -> impl Iterator<Item = (NodeId, &N)> {
self.nodes
.iter()
.enumerate()
.map(|(i, n)| (NodeId(i), n))
}
pub fn edges_iter(&self) -> impl Iterator<Item = (NodeId, NodeId, &E)> {
self.edges
.iter()
.map(|(s, d, e)| (NodeId(*s), NodeId(*d), e))
}
pub fn out_neighbors(&self, node: NodeId) -> Result<Vec<(NodeId, &E)>> {
self.validate_node(node)?;
let result = self.adj[node.0]
.iter()
.map(|&(dst, eidx)| (NodeId(dst), &self.edges[eidx].2))
.collect();
Ok(result)
}
pub fn has_edge(&self, src: NodeId, dst: NodeId) -> bool {
if src.0 >= self.nodes.len() || dst.0 >= self.nodes.len() {
return false;
}
self.adj[src.0].iter().any(|&(d, _)| d == dst.0)
}
pub fn edge_data(&self, src: NodeId, dst: NodeId) -> Option<&E> {
if src.0 >= self.nodes.len() {
return None;
}
self.adj[src.0]
.iter()
.find(|&&(d, _)| d == dst.0)
.map(|&(_, eidx)| &self.edges[eidx].2)
}
fn validate_node(&self, id: NodeId) -> Result<()> {
if id.0 < self.nodes.len() {
Ok(())
} else {
Err(GraphError::node_not_found_with_context(
id.0,
self.nodes.len(),
"AttributedGraph node validation",
))
}
}
}
#[derive(Debug, Clone)]
pub struct AttributedGraphBuilder<N, E> {
graph: AttributedGraph<N, E>,
errors: Vec<GraphError>,
}
impl<N, E> Default for AttributedGraphBuilder<N, E> {
fn default() -> Self {
Self::new()
}
}
impl<N, E> AttributedGraphBuilder<N, E> {
pub fn new() -> Self {
Self {
graph: AttributedGraph::new(),
errors: Vec::new(),
}
}
pub fn with_capacity(node_capacity: usize, edge_capacity: usize) -> Self {
Self {
graph: AttributedGraph::with_capacity(node_capacity, edge_capacity),
errors: Vec::new(),
}
}
pub fn node(&mut self, data: N) -> NodeId {
self.graph.add_node(data)
}
pub fn edge(&mut self, src: NodeId, dst: NodeId, data: E) -> &mut Self {
if let Err(e) = self.graph.add_edge(src, dst, data) {
self.errors.push(e);
}
self
}
pub fn build(self) -> Result<AttributedGraph<N, E>> {
if let Some(err) = self.errors.into_iter().next() {
return Err(err);
}
Ok(self.graph)
}
pub fn build_unchecked(self) -> AttributedGraph<N, E> {
self.graph
}
}
#[derive(Debug)]
pub struct NeighborInfo<'a, N, E> {
pub id: NodeId,
pub node_data: &'a N,
pub edge_data: &'a E,
}
pub fn node_feature_matrix<N, E, F>(
graph: &AttributedGraph<N, E>,
feature_fn: F,
) -> Result<Array2<f64>>
where
F: Fn(&N) -> Vec<f64>,
{
let n = graph.node_count();
if n == 0 {
return Err(GraphError::invalid_parameter(
"graph",
"empty graph",
"at least one node",
));
}
let features: Vec<Vec<f64>> = graph.nodes.iter().map(|nd| feature_fn(nd)).collect();
let dim = features[0].len();
if dim == 0 {
return Err(GraphError::invalid_parameter(
"feature_fn",
"zero-length feature vector",
"non-empty feature vector",
));
}
for (i, fv) in features.iter().enumerate() {
if fv.len() != dim {
return Err(GraphError::InvalidParameter {
param: "feature_fn".to_string(),
value: format!("node {i} returned dim={}", fv.len()),
expected: format!("uniform dim={dim}"),
context: "node_feature_matrix".to_string(),
});
}
}
let mut mat = Array2::zeros((n, dim));
for (i, fv) in features.iter().enumerate() {
for (j, &v) in fv.iter().enumerate() {
mat[[i, j]] = v;
}
}
Ok(mat)
}
pub fn edge_feature_matrix<N, E, F>(
graph: &AttributedGraph<N, E>,
feature_fn: F,
) -> Result<Array2<f64>>
where
F: Fn(&E) -> Vec<f64>,
{
let m = graph.edge_count();
if m == 0 {
return Err(GraphError::invalid_parameter(
"graph",
"graph has no edges",
"at least one edge",
));
}
let features: Vec<Vec<f64>> = graph.edges.iter().map(|(_, _, e)| feature_fn(e)).collect();
let dim = features[0].len();
if dim == 0 {
return Err(GraphError::invalid_parameter(
"feature_fn",
"zero-length feature vector",
"non-empty feature vector",
));
}
for (i, fv) in features.iter().enumerate() {
if fv.len() != dim {
return Err(GraphError::InvalidParameter {
param: "feature_fn".to_string(),
value: format!("edge {i} returned dim={}", fv.len()),
expected: format!("uniform dim={dim}"),
context: "edge_feature_matrix".to_string(),
});
}
}
let mut mat = Array2::zeros((m, dim));
for (i, fv) in features.iter().enumerate() {
for (j, &v) in fv.iter().enumerate() {
mat[[i, j]] = v;
}
}
Ok(mat)
}
pub fn attributed_neighbors<'a, N, E>(
node: NodeId,
graph: &'a AttributedGraph<N, E>,
) -> Result<Vec<NeighborInfo<'a, N, E>>> {
graph.validate_node(node)?;
let result = graph.adj[node.0]
.iter()
.map(|&(dst, eidx)| NeighborInfo {
id: NodeId(dst),
node_data: &graph.nodes[dst],
edge_data: &graph.edges[eidx].2,
})
.collect();
Ok(result)
}
#[derive(Debug, Clone, PartialEq)]
struct DijkstraEntry {
cost: f64,
node: usize,
}
impl Eq for DijkstraEntry {}
impl PartialOrd for DijkstraEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for DijkstraEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
pub fn dijkstra_attributed<N, E, F>(
graph: &AttributedGraph<N, E>,
source: NodeId,
cost_fn: F,
) -> Result<HashMap<NodeId, f64>>
where
F: Fn(&E) -> f64,
{
use std::collections::BinaryHeap;
graph.validate_node(source)?;
let n = graph.node_count();
let mut dist = vec![f64::INFINITY; n];
dist[source.0] = 0.0;
let mut heap = BinaryHeap::new();
heap.push(DijkstraEntry {
cost: 0.0,
node: source.0,
});
while let Some(DijkstraEntry { cost, node }) = heap.pop() {
if cost > dist[node] {
continue;
}
for &(dst, eidx) in &graph.adj[node] {
let edge_cost = cost_fn(&graph.edges[eidx].2);
let new_cost = cost + edge_cost;
if new_cost < dist[dst] {
dist[dst] = new_cost;
heap.push(DijkstraEntry {
cost: new_cost,
node: dst,
});
}
}
}
let result = dist
.into_iter()
.enumerate()
.filter(|(_, d)| d.is_finite())
.map(|(i, d)| (NodeId(i), d))
.collect();
Ok(result)
}
pub fn in_degrees<N, E>(graph: &AttributedGraph<N, E>) -> Vec<usize> {
let mut deg = vec![0usize; graph.node_count()];
for (_, dst, _) in &graph.edges {
deg[*dst] += 1;
}
deg
}
pub fn out_degrees<N, E>(graph: &AttributedGraph<N, E>) -> Vec<usize> {
graph.adj.iter().map(|nbrs| nbrs.len()).collect()
}
pub fn filter_nodes<N, E, P>(graph: &AttributedGraph<N, E>, predicate: P) -> Vec<NodeId>
where
P: Fn(&N) -> bool,
{
graph
.nodes
.iter()
.enumerate()
.filter_map(|(i, n)| {
if predicate(n) {
Some(NodeId(i))
} else {
None
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, PartialEq)]
struct Person {
name: String,
age: u32,
}
#[derive(Debug, Clone, PartialEq)]
struct Rel {
weight: f64,
}
fn make_graph() -> AttributedGraph<Person, Rel> {
let mut g = AttributedGraph::new();
let alice = g.add_node(Person {
name: "Alice".into(),
age: 30,
});
let bob = g.add_node(Person {
name: "Bob".into(),
age: 25,
});
let charlie = g.add_node(Person {
name: "Charlie".into(),
age: 35,
});
g.add_edge(alice, bob, Rel { weight: 1.0 }).unwrap();
g.add_edge(bob, charlie, Rel { weight: 2.0 }).unwrap();
g.add_edge(alice, charlie, Rel { weight: 5.0 }).unwrap();
g
}
#[test]
fn test_basic_construction() {
let g = make_graph();
assert_eq!(g.node_count(), 3);
assert_eq!(g.edge_count(), 3);
}
#[test]
fn test_node_data_access() {
let g = make_graph();
let alice = NodeId(0);
let data = g.node_data(alice).unwrap();
assert_eq!(data.name, "Alice");
assert_eq!(data.age, 30);
}
#[test]
fn test_edge_data_access() {
let g = make_graph();
let alice = NodeId(0);
let bob = NodeId(1);
let ed = g.edge_data(alice, bob).unwrap();
assert!((ed.weight - 1.0).abs() < 1e-12);
}
#[test]
fn test_has_edge() {
let g = make_graph();
assert!(g.has_edge(NodeId(0), NodeId(1)));
assert!(!g.has_edge(NodeId(1), NodeId(0)));
assert!(!g.has_edge(NodeId(2), NodeId(0)));
}
#[test]
fn test_invalid_node_add_edge() {
let mut g = AttributedGraph::<i32, ()>::new();
g.add_node(1);
let err = g.add_edge(NodeId(0), NodeId(5), ());
assert!(err.is_err());
}
#[test]
fn test_builder() {
let mut b = AttributedGraphBuilder::<i32, f64>::new();
let a = b.node(1);
let bb = b.node(2);
b.edge(a, bb, 3.14);
let g = b.build().unwrap();
assert_eq!(g.node_count(), 2);
assert_eq!(g.edge_count(), 1);
}
#[test]
fn test_builder_bad_edge_deferred() {
let mut b = AttributedGraphBuilder::<i32, f64>::new();
b.node(1);
b.edge(NodeId(0), NodeId(99), 1.0);
assert!(b.build().is_err());
}
#[test]
fn test_node_feature_matrix() {
let g = make_graph();
let mat = node_feature_matrix(&g, |p| vec![p.age as f64]).unwrap();
assert_eq!(mat.shape(), &[3, 1]);
assert!((mat[[0, 0]] - 30.0).abs() < 1e-12);
assert!((mat[[1, 0]] - 25.0).abs() < 1e-12);
assert!((mat[[2, 0]] - 35.0).abs() < 1e-12);
}
#[test]
fn test_edge_feature_matrix() {
let g = make_graph();
let mat = edge_feature_matrix(&g, |r| vec![r.weight]).unwrap();
assert_eq!(mat.shape(), &[3, 1]);
assert!((mat[[0, 0]] - 1.0).abs() < 1e-12);
assert!((mat[[1, 0]] - 2.0).abs() < 1e-12);
assert!((mat[[2, 0]] - 5.0).abs() < 1e-12);
}
#[test]
fn test_node_feature_matrix_multi_dim() {
let g = make_graph();
let mat = node_feature_matrix(&g, |p| vec![p.age as f64, p.name.len() as f64]).unwrap();
assert_eq!(mat.shape(), &[3, 2]);
}
#[test]
fn test_node_feature_matrix_empty_graph() {
let g = AttributedGraph::<i32, ()>::new();
let result = node_feature_matrix(&g, |v| vec![*v as f64]);
assert!(result.is_err());
}
#[test]
fn test_edge_feature_matrix_no_edges() {
let mut g = AttributedGraph::<i32, f64>::new();
g.add_node(1);
let result = edge_feature_matrix(&g, |v| vec![*v]);
assert!(result.is_err());
}
#[test]
fn test_attributed_neighbors() {
let g = make_graph();
let nbrs = attributed_neighbors(NodeId(0), &g).unwrap();
assert_eq!(nbrs.len(), 2);
assert_eq!(nbrs[0].node_data.name, "Bob");
assert!((nbrs[0].edge_data.weight - 1.0).abs() < 1e-12);
assert_eq!(nbrs[1].node_data.name, "Charlie");
assert!((nbrs[1].edge_data.weight - 5.0).abs() < 1e-12);
}
#[test]
fn test_attributed_neighbors_unknown_node() {
let g = make_graph();
assert!(attributed_neighbors(NodeId(99), &g).is_err());
}
#[test]
fn test_dijkstra_simple() {
let g = make_graph();
let dist = dijkstra_attributed(&g, NodeId(0), |r| r.weight).unwrap();
assert!((dist[&NodeId(0)] - 0.0).abs() < 1e-12);
assert!((dist[&NodeId(1)] - 1.0).abs() < 1e-12);
assert!((dist[&NodeId(2)] - 3.0).abs() < 1e-12); }
#[test]
fn test_dijkstra_unreachable() {
let mut g = AttributedGraph::<(), f64>::new();
g.add_node(());
g.add_node(());
let dist = dijkstra_attributed(&g, NodeId(0), |e| *e).unwrap();
assert!(dist.contains_key(&NodeId(0)));
assert!(!dist.contains_key(&NodeId(1)));
}
#[test]
fn test_dijkstra_invalid_source() {
let g = make_graph();
assert!(dijkstra_attributed(&g, NodeId(99), |r| r.weight).is_err());
}
#[test]
fn test_dijkstra_single_node() {
let mut g = AttributedGraph::<(), ()>::new();
g.add_node(());
let dist = dijkstra_attributed(&g, NodeId(0), |_| 1.0).unwrap();
assert_eq!(dist.len(), 1);
assert_eq!(dist[&NodeId(0)], 0.0);
}
#[test]
fn test_in_out_degrees() {
let g = make_graph();
let out = out_degrees(&g);
assert_eq!(out[0], 2);
assert_eq!(out[1], 1);
assert_eq!(out[2], 0);
let inn = in_degrees(&g);
assert_eq!(inn[0], 0);
assert_eq!(inn[1], 1);
assert_eq!(inn[2], 2);
}
#[test]
fn test_filter_nodes() {
let g = make_graph();
let young = filter_nodes(&g, |p| p.age < 31);
assert_eq!(young.len(), 2);
}
#[test]
fn test_nodes_iterator() {
let g = make_graph();
let names: Vec<&str> = g.nodes().map(|(_, p)| p.name.as_str()).collect();
assert_eq!(names, vec!["Alice", "Bob", "Charlie"]);
}
#[test]
fn test_edges_iterator() {
let g = make_graph();
let edges: Vec<(NodeId, NodeId, f64)> = g
.edges_iter()
.map(|(s, d, e)| (s, d, e.weight))
.collect();
assert_eq!(edges.len(), 3);
}
}