use crate::dtype::{SparseIndex, SparseSupported};
use crate::gpu_ref::GpuRef;
#[derive(Clone)]
pub enum SparseMatrix<T: SparseSupported, I: SparseIndex> {
Csr {
rows: i64,
cols: i64,
nnz: i64,
row_offsets: GpuRef<I>,
col_indices: GpuRef<I>,
values: GpuRef<T>,
},
Coo {
rows: i64,
cols: i64,
nnz: i64,
row_indices: GpuRef<I>,
col_indices: GpuRef<I>,
values: GpuRef<T>,
},
Csc {
rows: i64,
cols: i64,
nnz: i64,
col_offsets: GpuRef<I>,
row_indices: GpuRef<I>,
values: GpuRef<T>,
},
BlockedEll {
rows: i64,
cols: i64,
ell_block_size: i64,
ell_cols: i64,
col_indices: GpuRef<I>,
values: GpuRef<T>,
},
Bsr {
block_rows: i64,
block_cols: i64,
block_size: i64,
nnz_blocks: i64,
row_offsets: GpuRef<I>,
col_indices: GpuRef<I>,
values: GpuRef<T>,
},
}
impl<T: SparseSupported, I: SparseIndex> SparseMatrix<T, I> {
pub fn format(&self) -> SparseFormat {
match self {
Self::Csr { .. } => SparseFormat::Csr,
Self::Coo { .. } => SparseFormat::Coo,
Self::Csc { .. } => SparseFormat::Csc,
Self::BlockedEll { .. } => SparseFormat::BlockedEll,
Self::Bsr { .. } => SparseFormat::Bsr,
}
}
pub fn rows(&self) -> i64 {
match self {
Self::Csr { rows, .. }
| Self::Coo { rows, .. }
| Self::Csc { rows, .. }
| Self::BlockedEll { rows, .. } => *rows,
Self::Bsr {
block_rows,
block_size,
..
} => block_rows * block_size,
}
}
pub fn cols(&self) -> i64 {
match self {
Self::Csr { cols, .. }
| Self::Coo { cols, .. }
| Self::Csc { cols, .. }
| Self::BlockedEll { cols, .. } => *cols,
Self::Bsr {
block_cols,
block_size,
..
} => block_cols * block_size,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SparseFormat {
Csr,
Coo,
Csc,
BlockedEll,
Bsr,
}
impl SparseFormat {
pub fn as_str(self) -> &'static str {
match self {
SparseFormat::Csr => "csr",
SparseFormat::Coo => "coo",
SparseFormat::Csc => "csc",
SparseFormat::BlockedEll => "blocked_ell",
SparseFormat::Bsr => "bsr",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::DeviceState;
use std::sync::Arc;
#[test]
fn format_round_trip_csr_coo_csc_bsr_ell() {
assert_eq!(SparseFormat::Csr.as_str(), "csr");
assert_eq!(SparseFormat::Coo.as_str(), "coo");
assert_eq!(SparseFormat::Csc.as_str(), "csc");
assert_eq!(SparseFormat::BlockedEll.as_str(), "blocked_ell");
assert_eq!(SparseFormat::Bsr.as_str(), "bsr");
fn _ct<T: SparseSupported, I: SparseIndex>() -> Option<SparseMatrix<T, I>> {
None
}
let _: Option<SparseMatrix<f32, i32>> = _ct::<f32, i32>();
let _: Option<SparseMatrix<f64, i64>> = _ct::<f64, i64>();
let _ = Arc::new(DeviceState::new(0));
}
}