use ndarray::Array2;
use num_traits::Zero;
use sprs::{CompressedStorage, CsMat, TriMat};
use std::ops::Add;
pub trait BipartiteGraph<N: PartialEq>: PartialEq {
fn from_dense(array: Array2<N>) -> Self;
fn rows(&self) -> usize;
fn cols(&self) -> usize;
fn insert(&mut self, row: usize, col: usize, val: N);
fn get(&self, row: usize, col: usize) -> Option<&N>;
}
pub type SparseBipartiteGraph<N> = CsMat<N>;
impl<N: PartialEq + Clone + Zero + Add<Output = N>> BipartiteGraph<N> for SparseBipartiteGraph<N> {
fn rows(&self) -> usize {
self.rows()
}
fn cols(&self) -> usize {
self.cols()
}
fn insert(&mut self, row: usize, col: usize, val: N) {
self.insert(row, col, val)
}
fn get(&self, row: usize, col: usize) -> Option<&N> {
self.get(row, col)
}
fn from_dense(array: Array2<N>) -> Self {
sparse_bartite_graph_from_dense(array, CompressedStorage::CSR)
}
}
pub fn sparse_bartite_graph_from_dense<N: PartialEq + Clone + Zero + Add<Output = N>>(
array: Array2<N>,
storage: CompressedStorage,
) -> SparseBipartiteGraph<N> {
let mut tri_graph = TriMat::new((array.shape()[0], array.shape()[1]));
for ((x, y), value) in array.indexed_iter() {
if !(*value).is_zero() {
tri_graph.add_triplet(x, y, value.clone());
}
}
match storage {
CompressedStorage::CSR => tri_graph.to_csr(),
CompressedStorage::CSC => tri_graph.to_csc(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_sparse_bipartite_graph() {
let graph: SparseBipartiteGraph<f64> = CsMat::eye(100);
assert_eq!(*graph.get(99, 99).unwrap(), 1.0);
}
}