use crate::error::{GnnError, GnnResult};
use crate::graph::csr::CsrGraph;
#[derive(Debug, Clone)]
pub struct CooGraph {
n_nodes: usize,
src: Vec<usize>, dst: Vec<usize>, weight: Vec<f32>, }
impl CooGraph {
pub fn new(n_nodes: usize, src: Vec<usize>, dst: Vec<usize>) -> GnnResult<Self> {
if n_nodes == 0 {
return Err(GnnError::EmptyGraph);
}
if src.len() != dst.len() {
return Err(GnnError::DimensionMismatch {
expected: src.len(),
got: dst.len(),
});
}
for &s in &src {
if s >= n_nodes {
return Err(GnnError::NodeIndexOutOfRange { idx: s, n_nodes });
}
}
for &d in &dst {
if d >= n_nodes {
return Err(GnnError::NodeIndexOutOfRange { idx: d, n_nodes });
}
}
let n = src.len();
let weight = vec![1.0_f32; n];
Ok(Self {
n_nodes,
src,
dst,
weight,
})
}
pub fn with_weights(
n_nodes: usize,
src: Vec<usize>,
dst: Vec<usize>,
weight: Vec<f32>,
) -> GnnResult<Self> {
let mut g = Self::new(n_nodes, src, dst)?;
if weight.len() != g.src.len() {
return Err(GnnError::EdgeFeatureMismatch(g.src.len(), weight.len()));
}
g.weight = weight;
Ok(g)
}
pub fn to_csr(&self) -> GnnResult<CsrGraph> {
let n = self.n_nodes;
let mut order: Vec<usize> = (0..self.src.len()).collect();
order.sort_unstable_by_key(|&i| (self.src[i], self.dst[i]));
let mut row_ptr = vec![0usize; n + 1];
let mut col_idx = Vec::with_capacity(order.len());
let mut weights = Vec::with_capacity(order.len());
for &i in &order {
row_ptr[self.src[i] + 1] += 1;
col_idx.push(self.dst[i]);
weights.push(self.weight[i]);
}
for i in 0..n {
row_ptr[i + 1] += row_ptr[i];
}
CsrGraph::with_weights(n, row_ptr, col_idx, weights)
}
#[inline]
pub fn n_edges(&self) -> usize {
self.src.len()
}
#[inline]
pub fn n_nodes(&self) -> usize {
self.n_nodes
}
#[inline]
pub fn src(&self) -> &[usize] {
&self.src
}
#[inline]
pub fn dst(&self) -> &[usize] {
&self.dst
}
#[inline]
pub fn weight(&self) -> &[f32] {
&self.weight
}
pub fn sort_by_src(&mut self) {
let n = self.src.len();
if n == 0 {
return;
}
let mut order: Vec<usize> = (0..n).collect();
order.sort_by_key(|&i| (self.src[i], self.dst[i]));
let old_src = self.src.clone();
let old_dst = self.dst.clone();
let old_w = self.weight.clone();
for (new_pos, &old_pos) in order.iter().enumerate() {
self.src[new_pos] = old_src[old_pos];
self.dst[new_pos] = old_dst[old_pos];
self.weight[new_pos] = old_w[old_pos];
}
}
pub fn make_undirected(&mut self) {
let n_orig = self.src.len();
let mut new_src = Vec::new();
let mut new_dst = Vec::new();
let mut new_w = Vec::new();
for i in 0..n_orig {
let s = self.src[i];
let d = self.dst[i];
let w = self.weight[i];
let has_reverse = (0..n_orig).any(|j| self.src[j] == d && self.dst[j] == s);
if !has_reverse {
new_src.push(d);
new_dst.push(s);
new_w.push(w);
}
}
self.src.extend_from_slice(&new_src);
self.dst.extend_from_slice(&new_dst);
self.weight.extend_from_slice(&new_w);
self.sort_by_src();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_graph_error() {
let err = CooGraph::new(0, vec![], vec![]);
assert_eq!(err.unwrap_err(), GnnError::EmptyGraph);
}
#[test]
fn basic_construction() {
let g = CooGraph::new(4, vec![0, 1, 2], vec![1, 2, 3])
.expect("test invariant: value must be valid");
assert_eq!(g.n_nodes(), 4);
assert_eq!(g.n_edges(), 3);
}
#[test]
fn out_of_range_src_error() {
let err = CooGraph::new(3, vec![5], vec![1]);
assert!(matches!(err, Err(GnnError::NodeIndexOutOfRange { .. })));
}
#[test]
fn out_of_range_dst_error() {
let err = CooGraph::new(3, vec![0], vec![10]);
assert!(matches!(err, Err(GnnError::NodeIndexOutOfRange { .. })));
}
#[test]
fn to_csr_roundtrip() {
let src = vec![0usize, 1, 2, 0];
let dst = vec![1usize, 2, 0, 2];
let coo = CooGraph::new(3, src, dst).expect("test invariant: value must be valid");
let csr = coo.to_csr().expect("test invariant: value must be valid");
assert_eq!(csr.n_nodes(), 3);
assert_eq!(csr.n_edges(), 4);
assert_eq!(
csr.degree(0).expect("test invariant: value must be valid"),
2
);
assert_eq!(
csr.degree(1).expect("test invariant: value must be valid"),
1
);
assert_eq!(
csr.degree(2).expect("test invariant: value must be valid"),
1
);
}
#[test]
fn to_csr_sorted_neighbors() {
let src = vec![0usize, 0];
let dst = vec![2usize, 1];
let coo = CooGraph::new(3, src, dst).expect("test invariant: value must be valid");
let csr = coo.to_csr().expect("test invariant: value must be valid");
let nb = csr
.neighbors(0)
.expect("test invariant: value must be valid");
assert_eq!(nb, &[1, 2]);
}
#[test]
fn sort_by_src_orders_edges() {
let mut g = CooGraph::new(4, vec![2, 0, 1], vec![3, 1, 2])
.expect("test invariant: value must be valid");
g.sort_by_src();
assert_eq!(g.src(), &[0, 1, 2]);
assert_eq!(g.dst(), &[1, 2, 3]);
}
#[test]
fn make_undirected_adds_reverses() {
let mut g =
CooGraph::new(3, vec![0, 1], vec![1, 2]).expect("test invariant: value must be valid");
g.make_undirected();
assert_eq!(g.n_edges(), 4);
let pairs: Vec<(usize, usize)> = g
.src()
.iter()
.zip(g.dst().iter())
.map(|(&s, &d)| (s, d))
.collect();
assert!(pairs.contains(&(1, 0)));
assert!(pairs.contains(&(2, 1)));
}
#[test]
fn make_undirected_no_duplicate_reverses() {
let mut g =
CooGraph::new(3, vec![0, 1], vec![1, 0]).expect("test invariant: value must be valid");
let before = g.n_edges();
g.make_undirected();
assert_eq!(g.n_edges(), before);
}
#[test]
fn with_weights_correct() {
let g = CooGraph::with_weights(3, vec![0, 1], vec![1, 2], vec![0.5, 1.5])
.expect("test invariant: value must be valid");
assert!((g.weight()[0] - 0.5).abs() < 1e-6);
assert!((g.weight()[1] - 1.5).abs() < 1e-6);
}
#[test]
fn coo_to_csr_preserves_weights() {
let coo = CooGraph::with_weights(3, vec![0, 1], vec![1, 2], vec![2.0, 3.0])
.expect("test invariant: value must be valid");
let csr = coo.to_csr().expect("test invariant: value must be valid");
let w0 = csr
.edge_weights(0)
.expect("test invariant: value must be valid");
assert!((w0[0] - 2.0).abs() < 1e-6);
let w1 = csr
.edge_weights(1)
.expect("test invariant: value must be valid");
assert!((w1[0] - 3.0).abs() < 1e-6);
}
#[test]
fn length_mismatch_error() {
let err = CooGraph::new(3, vec![0, 1], vec![1]); assert!(matches!(err, Err(GnnError::DimensionMismatch { .. })));
}
}