use na::{DMatrix, RealField};
use nalgebra::DMatrixView;
use crate::Real;
use crate::csv::{CsVecBuilder, CsVecMut, CsVecRef};
use crate::traits::IntoView;
#[derive(Default, Clone, Debug)]
pub struct CsMatrix<T> {
pub(crate) secondary_indices: Vec<usize>,
pub(crate) primary_offsets: Vec<usize>,
pub(crate) values: Vec<T>,
pub(crate) num_secondary: usize,
}
impl<T> CsMatrix<T>
where
T: Real,
{
#[inline]
pub fn from_dense(dense_mat: DMatrixView<T>, row_sparse: bool, zero_threshold: T) -> Self {
let secondary_size = if row_sparse {
dense_mat.ncols()
} else {
dense_mat.nrows()
};
let mut csm = CsMatrix::new(secondary_size);
if row_sparse {
for i in 0..dense_mat.nrows() {
let mut rb = csm.new_lane_builder(zero_threshold);
let lane = dense_mat.row(i);
rb.extend_with_nonzeros(lane.iter().copied().enumerate());
}
} else {
for i in 0..dense_mat.ncols() {
let mut rb = csm.new_lane_builder(zero_threshold);
let lane = dense_mat.column(i);
rb.extend_with_nonzeros(lane.iter().copied().enumerate());
}
}
csm
}
pub fn new(secondary_size: usize) -> Self {
let primary_offsets = vec![0];
Self {
secondary_indices: Vec::new(),
values: Vec::new(),
primary_offsets,
num_secondary: secondary_size,
}
}
#[inline]
pub fn new_lane_builder(&mut self, zero_threshold: T) -> CsVecBuilder<T> {
CsVecBuilder::new(self, zero_threshold)
}
pub fn reset(&mut self, secondary_size: usize) {
self.clear();
self.num_secondary = secondary_size;
}
#[inline]
pub fn clear(&mut self) {
self.secondary_indices.clear();
self.values.clear();
self.primary_offsets.clear();
self.primary_offsets.push(0);
}
#[inline]
pub fn num_primary(&self) -> usize {
self.primary_offsets.len() - 1
}
#[inline]
pub fn num_secondary(&self) -> usize {
self.num_secondary
}
#[inline]
pub fn get_lane(&self, lane_index: usize) -> CsVecRef<T> {
let start = self.primary_offsets[lane_index];
let end = self.primary_offsets[lane_index + 1];
let col_indices = &self.secondary_indices[start..end];
let values = &self.values[start..end];
CsVecRef::from_parts_unchecked(col_indices, values, self.num_secondary)
}
#[inline]
pub fn get_lane_mut(&mut self, lane_index: usize) -> CsVecMut<T> {
let start = self.primary_offsets[lane_index];
let end = self.primary_offsets[lane_index + 1];
let col_indices = &self.secondary_indices[start..end];
let values = &mut self.values[start..end];
CsVecMut {
col_indices,
values,
}
}
#[inline]
pub fn as_view(&self) -> CsMatrixView<T> {
CsMatrixView {
secondary_indices: &self.secondary_indices,
primary_offsets: &self.primary_offsets,
values: &self.values,
num_secondary: self.num_secondary,
}
}
#[inline]
pub fn dense_rate(&self) -> f32 {
self.values.len() as f32 / (self.num_primary() * self.num_secondary()) as f32
}
}
impl<'a, T> IntoView for &'a CsMatrix<T> {
type View = CsMatrixView<'a, T>;
fn into_view(self) -> Self::View {
CsMatrixView {
secondary_indices: &self.secondary_indices,
primary_offsets: &self.primary_offsets,
values: &self.values,
num_secondary: self.num_secondary,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct CsMatrixView<'a, T> {
secondary_indices: &'a [usize],
primary_offsets: &'a [usize],
values: &'a [T],
num_secondary: usize,
}
impl<'a, T> CsMatrixView<'a, T>
where
T: RealField,
{
#[inline]
pub fn num_primary(&self) -> usize {
self.primary_offsets.len() - 1
}
#[inline]
pub fn num_secondary(&self) -> usize {
self.num_secondary
}
#[inline]
pub fn get_lane(self, lane_index: usize) -> CsVecRef<'a, T> {
let start = self.primary_offsets[lane_index];
let end = self.primary_offsets[lane_index + 1];
let col_indices = &self.secondary_indices[start..end];
let values = &self.values[start..end];
CsVecRef::from_parts_unchecked(col_indices, values, self.num_secondary)
}
#[inline]
pub fn dense_rate(&self) -> f32 {
self.values.len() as f32 / (self.num_primary() * self.num_secondary()) as f32
}
}
impl<'a, T> IntoView for CsMatrixView<'a, T> {
type View = CsMatrixView<'a, T>;
#[inline]
fn into_view(self) -> Self::View {
self
}
}
#[derive(Default, Clone)]
pub struct CsrMatrix<T>(CsMatrix<T>);
impl<T> CsrMatrix<T>
where
T: Real,
{
#[inline]
pub fn from_dense(dense_mat: DMatrixView<T>) -> Self {
Self(CsMatrix::from_dense(dense_mat, true, T::zero_threshold()))
}
pub fn new(ncols: usize) -> Self {
Self(CsMatrix::new(ncols))
}
pub fn reset(&mut self, ncols: usize) {
self.0.reset(ncols);
}
#[inline]
pub fn clear(&mut self) {
self.0.clear();
}
#[inline]
pub fn new_row_builder(&mut self, zero_threshold: T) -> CsVecBuilder<T> {
self.0.new_lane_builder(zero_threshold)
}
#[inline]
pub fn get_row_mut(&mut self, row_index: usize) -> CsVecMut<T> {
self.0.get_lane_mut(row_index)
}
#[inline]
pub fn as_view(&self) -> CsrMatrixView<T> {
CsrMatrixView(self.0.as_view())
}
}
impl<'a, T> IntoView for &'a CsrMatrix<T>
where
T: Real,
{
type View = CsrMatrixView<'a, T>;
fn into_view(self) -> Self::View {
CsrMatrixView(self.0.as_view())
}
}
#[derive(Clone, Copy, Debug)]
pub struct CsrMatrixView<'a, T>(CsMatrixView<'a, T>);
impl<'a, T> CsrMatrixView<'a, T> {
#[inline]
pub fn from_parts_unchecked(
row_offsets: &'a [usize],
col_indices: &'a [usize],
values: &'a [T],
ncol: usize,
) -> Self {
Self(CsMatrixView {
secondary_indices: col_indices,
primary_offsets: row_offsets,
values,
num_secondary: ncol,
})
}
}
impl<'a, T> CsrMatrixView<'a, T>
where
T: Real,
{
#[inline]
pub fn dense_rate(&self) -> f32 {
self.0.dense_rate()
}
#[inline]
pub fn transpose(&self) -> CscMatrixView<'a, T> {
CscMatrixView(self.0)
}
#[inline]
pub fn shape(&self) -> (usize, usize) {
(self.nrows(), self.ncols())
}
#[inline]
pub fn nrows(&self) -> usize {
self.0.num_primary()
}
#[inline]
pub fn ncols(&self) -> usize {
self.0.num_secondary()
}
#[inline]
pub fn get_row(self, row_index: usize) -> CsVecRef<'a, T> {
self.0.get_lane(row_index)
}
}
impl<'a, T> IntoView for CsrMatrixView<'a, T> {
type View = CsrMatrixView<'a, T>;
#[inline]
fn into_view(self) -> Self::View {
self
}
}
impl<'b, T> IntoView for &CsrMatrixView<'b, T>
where
T: Copy,
{
type View = CsrMatrixView<'b, T>;
#[inline]
fn into_view(self) -> Self::View {
*self
}
}
pub trait CsrMatrixViewMethods<'a, T> {
fn nrows(self) -> usize;
fn ncols(self) -> usize;
fn get_row(self, row_index: usize) -> CsVecRef<'a, T>;
fn to_dense(self) -> DMatrix<T>
where
Self: Sized + Copy,
T: Real,
{
let mut m = DMatrix::zeros(self.nrows(), self.ncols());
for i in 0..self.nrows() {
let row = self.get_row(i);
for (col, value) in row.iter() {
unsafe {
*m.get_unchecked_mut((i, col)) = value;
}
}
}
m
}
}
impl<'a, T, V> CsrMatrixViewMethods<'a, V> for &'a T
where
V: Real,
&'a T: IntoView<View = CsrMatrixView<'a, V>>,
{
#[inline]
fn nrows(self) -> usize {
CsrMatrixView::nrows(&self.into_view())
}
#[inline]
fn ncols(self) -> usize {
CsrMatrixView::ncols(&self.into_view())
}
#[inline]
fn get_row(self, row_index: usize) -> CsVecRef<'a, V> {
self.into_view().get_row(row_index)
}
}
#[derive(Clone, Copy)]
pub struct CscMatrixView<'a, T>(CsMatrixView<'a, T>);
impl<'a, T> CscMatrixView<'a, T>
where
T: Real,
{
#[inline]
pub fn nrows(&self) -> usize {
self.0.num_secondary()
}
#[inline]
pub fn ncols(&self) -> usize {
self.0.num_primary()
}
#[inline]
pub fn get_col(&self, col_index: usize) -> CsVecRef<'a, T> {
self.0.get_lane(col_index)
}
#[inline]
pub fn transpose(&self) -> CsrMatrixView<'a, T> {
CsrMatrixView(self.0)
}
pub fn to_dense(&self) -> DMatrix<T> {
let mut dense = DMatrix::zeros(self.nrows(), self.ncols());
for col in 0..self.ncols() {
let col_vec = self.get_col(col);
for (row, value) in col_vec.iter() {
dense[(row, col)] = value;
}
}
dense
}
}