#[cfg(feature = "serde-serialize")]
mod csr_serde;
use crate::cs;
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::csc::CscMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use nalgebra::Scalar;
use num_traits::One;
use std::slice::{Iter, IterMut};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsrMatrix<T> {
pub(crate) cs: CsMatrix<T>,
}
impl<T> CsrMatrix<T> {
#[inline]
pub fn identity(n: usize) -> Self
where
T: Scalar + One,
{
Self {
cs: CsMatrix::identity(n),
}
}
pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self {
cs: CsMatrix::new(nrows, ncols),
}
}
pub fn try_from_csr_data(
num_rows: usize,
num_cols: usize,
row_offsets: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices(
num_rows,
num_cols,
row_offsets,
col_indices,
)
.map_err(pattern_format_error_to_csr_error)?;
Self::try_from_pattern_and_values(pattern, values)
}
pub fn try_from_unsorted_csr_data(
num_rows: usize,
num_cols: usize,
row_offsets: Vec<usize>,
mut col_indices: Vec<usize>,
mut values: Vec<T>,
) -> Result<Self, SparseFormatError>
where
T: Scalar,
{
let result = cs::validate_and_optionally_sort_cs_data(
num_rows,
num_cols,
&row_offsets,
&mut col_indices,
Some(&mut values),
true,
);
match result {
Ok(()) => {
let pattern = unsafe {
SparsityPattern::from_offset_and_indices_unchecked(
num_rows,
num_cols,
row_offsets,
col_indices,
)
};
Self::try_from_pattern_and_values(pattern, values)
}
Err(err) => Err(err),
}
}
pub fn try_from_pattern_and_values(
pattern: SparsityPattern,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() {
Ok(Self {
cs: CsMatrix::from_pattern_and_values(pattern, values),
})
} else {
Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and column indices must be the same",
))
}
}
#[inline]
#[must_use]
pub fn nrows(&self) -> usize {
self.cs.pattern().major_dim()
}
#[inline]
#[must_use]
pub fn ncols(&self) -> usize {
self.cs.pattern().minor_dim()
}
#[inline]
#[must_use]
pub fn nnz(&self) -> usize {
self.cs.pattern().nnz()
}
#[inline]
#[must_use]
pub fn row_offsets(&self) -> &[usize] {
let (offsets, _, _) = self.cs.cs_data();
offsets
}
#[inline]
#[must_use]
pub fn col_indices(&self) -> &[usize] {
let (_, indices, _) = self.cs.cs_data();
indices
}
#[inline]
#[must_use]
pub fn values(&self) -> &[T] {
self.cs.values()
}
#[inline]
pub fn values_mut(&mut self) -> &mut [T] {
self.cs.values_mut()
}
pub fn triplet_iter(&self) -> CsrTripletIter<'_, T> {
CsrTripletIter {
pattern_iter: self.pattern().entries(),
values_iter: self.values().iter(),
}
}
pub fn triplet_iter_mut(&mut self) -> CsrTripletIterMut<'_, T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CsrTripletIterMut {
pattern_iter: pattern.entries(),
values_mut_iter: values.iter_mut(),
}
}
#[inline]
#[must_use]
pub fn row(&self, index: usize) -> CsrRow<'_, T> {
self.get_row(index).expect("Row index must be in bounds")
}
#[inline]
pub fn row_mut(&mut self, index: usize) -> CsrRowMut<'_, T> {
self.get_row_mut(index)
.expect("Row index must be in bounds")
}
#[inline]
#[must_use]
pub fn get_row(&self, index: usize) -> Option<CsrRow<'_, T>> {
self.cs.get_lane(index).map(|lane| CsrRow { lane })
}
#[inline]
#[must_use]
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<'_, T>> {
self.cs.get_lane_mut(index).map(|lane| CsrRowMut { lane })
}
pub fn row_iter(&self) -> CsrRowIter<'_, T> {
CsrRowIter {
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
}
}
pub fn row_iter_mut(&mut self) -> CsrRowIterMut<'_, T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CsrRowIterMut {
lane_iter: CsLaneIterMut::new(pattern, values),
}
}
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
self.cs.disassemble()
}
pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
self.cs.into_pattern_and_values()
}
#[inline]
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
self.cs.pattern_and_values_mut()
}
#[must_use]
pub fn pattern(&self) -> &SparsityPattern {
self.cs.pattern()
}
pub fn transpose_as_csc(self) -> CscMatrix<T> {
let (pattern, values) = self.cs.take_pattern_and_values();
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
}
#[must_use]
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<'_, T>> {
self.cs.get_entry(row_index, col_index)
}
pub fn get_entry_mut(
&mut self,
row_index: usize,
col_index: usize,
) -> Option<SparseEntryMut<'_, T>> {
self.cs.get_entry_mut(row_index, col_index)
}
#[must_use]
pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<'_, T> {
self.get_entry(row_index, col_index)
.expect("Out of bounds matrix indices encountered")
}
pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<'_, T> {
self.get_entry_mut(row_index, col_index)
.expect("Out of bounds matrix indices encountered")
}
#[must_use]
pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) {
self.cs.cs_data()
}
pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
self.cs.cs_data_mut()
}
#[must_use]
pub fn filter<P>(&self, predicate: P) -> Self
where
T: Clone,
P: Fn(usize, usize, &T) -> bool,
{
Self {
cs: self
.cs
.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)),
}
}
#[must_use]
pub fn upper_triangle(&self) -> Self
where
T: Clone,
{
self.filter(|i, j, _| i <= j)
}
#[must_use]
pub fn lower_triangle(&self) -> Self
where
T: Clone,
{
self.filter(|i, j, _| i >= j)
}
#[must_use]
pub fn diagonal_as_csr(&self) -> Self
where
T: Clone,
{
Self {
cs: self.cs.diagonal_as_matrix(),
}
}
#[must_use]
pub fn transpose(&self) -> CsrMatrix<T>
where
T: Scalar,
{
CscMatrix::from(self).transpose_as_csr()
}
}
impl<T> Default for CsrMatrix<T> {
fn default() -> Self {
Self {
cs: Default::default(),
}
}
}
fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
use SparseFormatError as E;
use SparseFormatErrorKind as K;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*;
match err {
InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure,
"Length of row offset array is not equal to nrows + 1.",
),
InvalidOffsetFirstLast => E::from_kind_and_msg(
K::InvalidStructure,
"First or last row offset is inconsistent with format specification.",
),
NonmonotonicOffsets => E::from_kind_and_msg(
K::InvalidStructure,
"Row offsets are not monotonically increasing.",
),
NonmonotonicMinorIndices => E::from_kind_and_msg(
K::InvalidStructure,
"Column indices are not monotonically increasing (sorted) within each row.",
),
MinorIndexOutOfBounds => {
E::from_kind_and_msg(K::IndexOutOfBounds, "Column indices are out of bounds.")
}
PatternDuplicateEntry => {
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
}
}
}
#[derive(Debug)]
pub struct CsrTripletIter<'a, T> {
pattern_iter: SparsityPatternIter<'a>,
values_iter: Iter<'a, T>,
}
impl<'a, T> Clone for CsrTripletIter<'a, T> {
fn clone(&self) -> Self {
CsrTripletIter {
pattern_iter: self.pattern_iter.clone(),
values_iter: self.values_iter.clone(),
}
}
}
impl<'a, T: Clone> CsrTripletIter<'a, T> {
#[inline]
pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
self.map(|(i, j, v)| (i, j, v.clone()))
}
}
impl<'a, T> Iterator for CsrTripletIter<'a, T> {
type Item = (usize, usize, &'a T);
fn next(&mut self) -> Option<Self::Item> {
let next_entry = self.pattern_iter.next();
let next_value = self.values_iter.next();
match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((i, j, v)),
_ => None,
}
}
}
#[derive(Debug)]
pub struct CsrTripletIterMut<'a, T> {
pattern_iter: SparsityPatternIter<'a>,
values_mut_iter: IterMut<'a, T>,
}
impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
type Item = (usize, usize, &'a mut T);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let next_entry = self.pattern_iter.next();
let next_value = self.values_mut_iter.next();
match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((i, j, v)),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsrRow<'a, T> {
lane: CsLane<'a, T>,
}
#[derive(Debug, PartialEq, Eq)]
pub struct CsrRowMut<'a, T> {
lane: CsLaneMut<'a, T>,
}
macro_rules! impl_csr_row_common_methods {
($name:ty) => {
impl<'a, T> $name {
#[inline]
#[must_use]
pub fn ncols(&self) -> usize {
self.lane.minor_dim()
}
#[inline]
#[must_use]
pub fn nnz(&self) -> usize {
self.lane.nnz()
}
#[inline]
#[must_use]
pub fn col_indices(&self) -> &[usize] {
self.lane.minor_indices()
}
#[inline]
#[must_use]
pub fn values(&self) -> &[T] {
self.lane.values()
}
#[inline]
#[must_use]
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<'_, T>> {
self.lane.get_entry(global_col_index)
}
}
};
}
impl_csr_row_common_methods!(CsrRow<'a, T>);
impl_csr_row_common_methods!(CsrRowMut<'a, T>);
impl<'a, T> CsrRowMut<'a, T> {
#[inline]
pub fn values_mut(&mut self) -> &mut [T] {
self.lane.values_mut()
}
#[inline]
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
self.lane.indices_and_values_mut()
}
#[inline]
#[must_use]
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<'_, T>> {
self.lane.get_entry_mut(global_col_index)
}
}
pub struct CsrRowIter<'a, T> {
lane_iter: CsLaneIter<'a, T>,
}
impl<'a, T> Iterator for CsrRowIter<'a, T> {
type Item = CsrRow<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.lane_iter.next().map(|lane| CsrRow { lane })
}
}
pub struct CsrRowIterMut<'a, T> {
lane_iter: CsLaneIterMut<'a, T>,
}
impl<'a, T> Iterator for CsrRowIterMut<'a, T>
where
T: 'a,
{
type Item = CsrRowMut<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.lane_iter.next().map(|lane| CsrRowMut { lane })
}
}