use crate::error::{Error, Result};
use crate::linalg::LinearOperator;
use crate::numeric::Float;
#[derive(Clone, Debug, PartialEq)]
pub struct CsrMatrix<T> {
rows: usize,
cols: usize,
row_offsets: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
}
impl<T> CsrMatrix<T> {
pub fn from_csr(
shape: [usize; 2],
row_offsets: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
) -> Result<Self> {
if row_offsets.len() != shape[0] + 1 {
return Err(Error::shape(vec![shape[0] + 1], vec![row_offsets.len()]));
}
if col_indices.len() != values.len() {
return Err(Error::shape(vec![col_indices.len()], vec![values.len()]));
}
if row_offsets.first().copied() != Some(0)
|| row_offsets.last().copied() != Some(values.len())
|| row_offsets.windows(2).any(|window| window[0] > window[1])
{
return Err(Error::InvalidStride);
}
if col_indices.iter().any(|&col| col >= shape[1]) {
return Err(Error::IndexOutOfBounds);
}
Ok(Self {
rows: shape[0],
cols: shape[1],
row_offsets,
col_indices,
values,
})
}
pub fn from_triplets(shape: [usize; 2], entries: &[(usize, usize, T)]) -> Result<Self>
where
T: Clone,
{
let mut row_counts = vec![0usize; shape[0]];
for &(row, col, _) in entries {
if row >= shape[0] || col >= shape[1] {
return Err(Error::IndexOutOfBounds);
}
row_counts[row] += 1;
}
let mut row_offsets = vec![0usize; shape[0] + 1];
for row in 0..shape[0] {
row_offsets[row + 1] = row_offsets[row] + row_counts[row];
}
let mut next = row_offsets.clone();
let mut col_indices = vec![0usize; entries.len()];
let mut values = Vec::with_capacity(entries.len());
values.resize_with(entries.len(), || {
entries
.first()
.expect("entries is non-empty when resize needs values")
.2
.clone()
});
for (row, col, value) in entries
.iter()
.map(|&(row, col, ref value)| (row, col, value))
{
let offset = next[row];
col_indices[offset] = col;
values[offset] = value.clone();
next[row] += 1;
}
Self::from_csr(shape, row_offsets, col_indices, values)
}
pub fn shape(&self) -> [usize; 2] {
[self.rows, self.cols]
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn row_offsets(&self) -> &[usize] {
&self.row_offsets
}
pub fn col_indices(&self) -> &[usize] {
&self.col_indices
}
pub fn values(&self) -> &[T] {
&self.values
}
}
impl<T: Float> CsrMatrix<T> {
pub fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
if x.len() != self.cols || y.len() != self.rows {
return Err(Error::shape(
vec![self.cols, self.rows],
vec![x.len(), y.len()],
));
}
for (row, yi) in y.iter_mut().enumerate() {
let mut sum = T::zero();
for offset in self.row_offsets[row]..self.row_offsets[row + 1] {
sum += self.values[offset] * x[self.col_indices[offset]];
}
*yi = sum;
}
Ok(())
}
pub fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
if x.len() != self.rows || y.len() != self.cols {
return Err(Error::shape(
vec![self.rows, self.cols],
vec![x.len(), y.len()],
));
}
y.fill(T::zero());
for (row, &x_value) in x.iter().enumerate() {
for offset in self.row_offsets[row]..self.row_offsets[row + 1] {
y[self.col_indices[offset]] += self.values[offset] * x_value;
}
}
Ok(())
}
}
impl<T: Float> LinearOperator<T> for CsrMatrix<T> {
fn rows(&self) -> usize {
self.rows
}
fn cols(&self) -> usize {
self.cols
}
fn matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
CsrMatrix::matvec(self, x, y)
}
fn t_matvec(&self, x: &[T], y: &mut [T]) -> Result<()> {
CsrMatrix::t_matvec(self, x, y)
}
}