use crate::error::{SparseError, SparseResult};
use crate::sparray::SparseArray;
use scirs2_core::numeric::{Float, SparseElement};
use std::cmp::Ordering;
use std::fmt::Debug;
pub mod connected_components;
pub mod laplacian;
pub mod minimum_spanning_tree;
pub mod shortest_path;
pub mod traversal;
pub mod centrality;
pub mod community_detection;
pub mod max_flow;
pub use centrality::*;
pub use community_detection::*;
pub use connected_components::*;
pub use laplacian::*;
pub use max_flow::*;
pub use minimum_spanning_tree::*;
pub use shortest_path::*;
pub use traversal::*;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum GraphMode {
Directed,
Undirected,
}
#[derive(Debug, Clone)]
struct PriorityQueueNode<T>
where
T: Float + PartialOrd,
{
distance: T,
node: usize,
}
impl<T> PartialEq for PriorityQueueNode<T>
where
T: Float + PartialOrd,
{
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance && self.node == other.node
}
}
impl<T> Eq for PriorityQueueNode<T> where T: Float + PartialOrd {}
impl<T> PartialOrd for PriorityQueueNode<T>
where
T: Float + PartialOrd,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T> Ord for PriorityQueueNode<T>
where
T: Float + PartialOrd,
{
fn cmp(&self, other: &Self) -> Ordering {
other
.distance
.partial_cmp(&self.distance)
.unwrap_or(Ordering::Equal)
}
}
#[allow(dead_code)]
pub fn validate_graph<T, S>(matrix: &S, directed: bool) -> SparseResult<()>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
let (rows, cols) = matrix.shape();
if rows != cols {
return Err(SparseError::ValueError(
"Graph _matrix must be square".to_string(),
));
}
let (row_indices, col_indices, values) = matrix.find();
for &value in values.iter() {
if value < T::sparse_zero() {
return Err(SparseError::ValueError(
"Negative edge weights not supported".to_string(),
));
}
}
if !directed {
for (i, (&row, &col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
if row != col {
let weight = values[i];
let reverse_weight = matrix.get(col, row);
if (weight - reverse_weight).abs() > T::from(1e-10).expect("Operation failed") {
return Err(SparseError::ValueError(
"Undirected graph _matrix must be symmetric".to_string(),
));
}
}
}
}
Ok(())
}
#[allow(dead_code)]
pub fn to_adjacency_list<T, S>(matrix: &S, directed: bool) -> SparseResult<Vec<Vec<(usize, T)>>>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
let (n_, _) = matrix.shape();
let mut adj_list = vec![Vec::new(); n_];
let (row_indices, col_indices, values) = matrix.find();
for (i, (&row, &col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
let weight = values[i];
if !SparseElement::is_zero(&weight) {
adj_list[row].push((col, weight));
if !directed && row != col {
let reverse_exists = row_indices
.iter()
.zip(col_indices.iter())
.any(|(r, c)| *r == col && *c == row);
if !reverse_exists {
adj_list[col].push((row, weight));
}
}
}
}
Ok(adj_list)
}
#[allow(dead_code)]
pub fn num_vertices<T, S>(matrix: &S) -> usize
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
matrix.shape().0
}
#[allow(dead_code)]
pub fn num_edges<T, S>(matrix: &S, directed: bool) -> SparseResult<usize>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
let nnz = matrix.nnz();
if directed {
Ok(nnz)
} else {
let (row_indices, col_indices_, _) = matrix.find();
let mut diagonal_count = 0;
let mut off_diagonal_count = 0;
for (&row, &col) in row_indices.iter().zip(col_indices_.iter()) {
if row == col {
diagonal_count += 1;
} else {
off_diagonal_count += 1;
}
}
Ok(diagonal_count + off_diagonal_count / 2)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csr_array::CsrArray;
fn create_test_graph() -> CsrArray<f64> {
let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
let cols = vec![1, 2, 0, 3, 0, 3, 1, 2];
let data = vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 3.0, 1.0];
CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).expect("Operation failed")
}
#[test]
fn test_validate_graph_symmetric() {
let graph = create_test_graph();
assert!(validate_graph(&graph, false).is_ok());
assert!(validate_graph(&graph, true).is_ok());
}
#[test]
fn test_validate_graph_asymmetric() {
let rows = vec![0, 1];
let cols = vec![1, 0];
let data = vec![1.0, 2.0];
let graph =
CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).expect("Operation failed");
assert!(validate_graph(&graph, true).is_ok());
assert!(validate_graph(&graph, false).is_err());
}
#[test]
fn test_validate_graph_negative_weights() {
let rows = vec![0, 1];
let cols = vec![1, 0];
let data = vec![-1.0, 1.0];
let graph =
CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).expect("Operation failed");
assert!(validate_graph(&graph, true).is_err());
assert!(validate_graph(&graph, false).is_err());
}
#[test]
fn test_to_adjacency_list() {
let graph = create_test_graph();
let adj_list = to_adjacency_list(&graph, false).expect("Operation failed");
assert_eq!(adj_list.len(), 4);
assert_eq!(adj_list[0].len(), 2);
assert!(adj_list[0].contains(&(1, 1.0)));
assert!(adj_list[0].contains(&(2, 2.0)));
assert_eq!(adj_list[1].len(), 2);
assert!(adj_list[1].contains(&(0, 1.0)));
assert!(adj_list[1].contains(&(3, 3.0)));
}
#[test]
fn test_num_vertices() {
let graph = create_test_graph();
assert_eq!(num_vertices(&graph), 4);
}
#[test]
fn test_num_edges() {
let graph = create_test_graph();
assert_eq!(num_edges(&graph, true).expect("Operation failed"), 8);
assert_eq!(num_edges(&graph, false).expect("Operation failed"), 4);
}
}