#[cfg(feature = "serde-serialize")]
mod csc_serde;
use crate::cs;
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::csr::CsrMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use nalgebra::Scalar;
use num_traits::One;
use std::slice::{Iter, IterMut};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CscMatrix<T> {
pub(crate) cs: CsMatrix<T>,
}
impl<T> CscMatrix<T> {
#[inline]
pub fn identity(n: usize) -> Self
where
T: Scalar + One,
{
Self {
cs: CsMatrix::identity(n),
}
}
pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self {
cs: CsMatrix::new(ncols, nrows),
}
}
pub fn try_from_csc_data(
num_rows: usize,
num_cols: usize,
col_offsets: Vec<usize>,
row_indices: Vec<usize>,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices(
num_cols,
num_rows,
col_offsets,
row_indices,
)
.map_err(pattern_format_error_to_csc_error)?;
Self::try_from_pattern_and_values(pattern, values)
}
pub fn try_from_unsorted_csc_data(
num_rows: usize,
num_cols: usize,
col_offsets: Vec<usize>,
mut row_indices: Vec<usize>,
mut values: Vec<T>,
) -> Result<Self, SparseFormatError>
where
T: Scalar,
{
let result = cs::validate_and_optionally_sort_cs_data(
num_cols,
num_rows,
&col_offsets,
&mut row_indices,
Some(&mut values),
true,
);
match result {
Ok(()) => {
let pattern = unsafe {
SparsityPattern::from_offset_and_indices_unchecked(
num_cols,
num_rows,
col_offsets,
row_indices,
)
};
Self::try_from_pattern_and_values(pattern, values)
}
Err(err) => Err(err),
}
}
pub fn try_from_pattern_and_values(
pattern: SparsityPattern,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() {
Ok(Self {
cs: CsMatrix::from_pattern_and_values(pattern, values),
})
} else {
Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and row indices must be the same",
))
}
}
#[inline]
#[must_use]
pub fn nrows(&self) -> usize {
self.cs.pattern().minor_dim()
}
#[inline]
#[must_use]
pub fn ncols(&self) -> usize {
self.cs.pattern().major_dim()
}
#[inline]
#[must_use]
pub fn nnz(&self) -> usize {
self.pattern().nnz()
}
#[inline]
#[must_use]
pub fn col_offsets(&self) -> &[usize] {
self.pattern().major_offsets()
}
#[inline]
#[must_use]
pub fn row_indices(&self) -> &[usize] {
self.pattern().minor_indices()
}
#[inline]
#[must_use]
pub fn values(&self) -> &[T] {
self.cs.values()
}
#[inline]
pub fn values_mut(&mut self) -> &mut [T] {
self.cs.values_mut()
}
pub fn triplet_iter(&self) -> CscTripletIter<'_, T> {
CscTripletIter {
pattern_iter: self.pattern().entries(),
values_iter: self.values().iter(),
}
}
pub fn triplet_iter_mut(&mut self) -> CscTripletIterMut<'_, T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CscTripletIterMut {
pattern_iter: pattern.entries(),
values_mut_iter: values.iter_mut(),
}
}
#[inline]
#[must_use]
pub fn col(&self, index: usize) -> CscCol<'_, T> {
self.get_col(index).expect("Row index must be in bounds")
}
#[inline]
pub fn col_mut(&mut self, index: usize) -> CscColMut<'_, T> {
self.get_col_mut(index)
.expect("Row index must be in bounds")
}
#[inline]
#[must_use]
pub fn get_col(&self, index: usize) -> Option<CscCol<'_, T>> {
self.cs.get_lane(index).map(|lane| CscCol { lane })
}
#[inline]
#[must_use]
pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<'_, T>> {
self.cs.get_lane_mut(index).map(|lane| CscColMut { lane })
}
pub fn col_iter(&self) -> CscColIter<'_, T> {
CscColIter {
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
}
}
pub fn col_iter_mut(&mut self) -> CscColIterMut<'_, T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CscColIterMut {
lane_iter: CsLaneIterMut::new(pattern, values),
}
}
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
self.cs.disassemble()
}
pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
self.cs.into_pattern_and_values()
}
#[inline]
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
self.cs.pattern_and_values_mut()
}
#[must_use]
pub fn pattern(&self) -> &SparsityPattern {
self.cs.pattern()
}
pub fn transpose_as_csr(self) -> CsrMatrix<T> {
let (pattern, values) = self.cs.take_pattern_and_values();
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
}
#[must_use]
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<'_, T>> {
self.cs.get_entry(col_index, row_index)
}
pub fn get_entry_mut(
&mut self,
row_index: usize,
col_index: usize,
) -> Option<SparseEntryMut<'_, T>> {
self.cs.get_entry_mut(col_index, row_index)
}
#[must_use]
pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<'_, T> {
self.get_entry(row_index, col_index)
.expect("Out of bounds matrix indices encountered")
}
pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<'_, T> {
self.get_entry_mut(row_index, col_index)
.expect("Out of bounds matrix indices encountered")
}
#[must_use]
pub fn csc_data(&self) -> (&[usize], &[usize], &[T]) {
self.cs.cs_data()
}
pub fn csc_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
self.cs.cs_data_mut()
}
#[must_use]
pub fn filter<P>(&self, predicate: P) -> Self
where
T: Clone,
P: Fn(usize, usize, &T) -> bool,
{
Self {
cs: self
.cs
.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)),
}
}
#[must_use]
pub fn upper_triangle(&self) -> Self
where
T: Clone,
{
self.filter(|i, j, _| i <= j)
}
#[must_use]
pub fn lower_triangle(&self) -> Self
where
T: Clone,
{
self.filter(|i, j, _| i >= j)
}
#[must_use]
pub fn diagonal_as_csc(&self) -> Self
where
T: Clone,
{
Self {
cs: self.cs.diagonal_as_matrix(),
}
}
#[must_use]
pub fn transpose(&self) -> CscMatrix<T>
where
T: Scalar,
{
CsrMatrix::from(self).transpose_as_csc()
}
}
impl<T> Default for CscMatrix<T> {
fn default() -> Self {
Self {
cs: Default::default(),
}
}
}
fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseFormatError {
use SparseFormatError as E;
use SparseFormatErrorKind as K;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*;
match err {
InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure,
"Length of col offset array is not equal to ncols + 1.",
),
InvalidOffsetFirstLast => E::from_kind_and_msg(
K::InvalidStructure,
"First or last col offset is inconsistent with format specification.",
),
NonmonotonicOffsets => E::from_kind_and_msg(
K::InvalidStructure,
"Col offsets are not monotonically increasing.",
),
NonmonotonicMinorIndices => E::from_kind_and_msg(
K::InvalidStructure,
"Row indices are not monotonically increasing (sorted) within each column.",
),
MinorIndexOutOfBounds => {
E::from_kind_and_msg(K::IndexOutOfBounds, "Row indices are out of bounds.")
}
PatternDuplicateEntry => {
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
}
}
}
#[derive(Debug)]
pub struct CscTripletIter<'a, T> {
pattern_iter: SparsityPatternIter<'a>,
values_iter: Iter<'a, T>,
}
impl<'a, T> Clone for CscTripletIter<'a, T> {
fn clone(&self) -> Self {
CscTripletIter {
pattern_iter: self.pattern_iter.clone(),
values_iter: self.values_iter.clone(),
}
}
}
impl<'a, T: Clone> CscTripletIter<'a, T> {
#[inline]
pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
self.map(|(i, j, v)| (i, j, v.clone()))
}
}
impl<'a, T> Iterator for CscTripletIter<'a, T> {
type Item = (usize, usize, &'a T);
fn next(&mut self) -> Option<Self::Item> {
let next_entry = self.pattern_iter.next();
let next_value = self.values_iter.next();
match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((j, i, v)),
_ => None,
}
}
}
#[derive(Debug)]
pub struct CscTripletIterMut<'a, T> {
pattern_iter: SparsityPatternIter<'a>,
values_mut_iter: IterMut<'a, T>,
}
impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
type Item = (usize, usize, &'a mut T);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let next_entry = self.pattern_iter.next();
let next_value = self.values_mut_iter.next();
match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((j, i, v)),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CscCol<'a, T> {
lane: CsLane<'a, T>,
}
#[derive(Debug, PartialEq, Eq)]
pub struct CscColMut<'a, T> {
lane: CsLaneMut<'a, T>,
}
macro_rules! impl_csc_col_common_methods {
($name:ty) => {
impl<'a, T> $name {
#[inline]
#[must_use]
pub fn nrows(&self) -> usize {
self.lane.minor_dim()
}
#[inline]
#[must_use]
pub fn nnz(&self) -> usize {
self.lane.nnz()
}
#[inline]
#[must_use]
pub fn row_indices(&self) -> &[usize] {
self.lane.minor_indices()
}
#[inline]
#[must_use]
pub fn values(&self) -> &[T] {
self.lane.values()
}
#[must_use]
pub fn get_entry(&self, global_row_index: usize) -> Option<SparseEntry<'_, T>> {
self.lane.get_entry(global_row_index)
}
}
};
}
impl_csc_col_common_methods!(CscCol<'a, T>);
impl_csc_col_common_methods!(CscColMut<'a, T>);
impl<'a, T> CscColMut<'a, T> {
pub fn values_mut(&mut self) -> &mut [T] {
self.lane.values_mut()
}
pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
self.lane.indices_and_values_mut()
}
#[must_use]
pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option<SparseEntryMut<'_, T>> {
self.lane.get_entry_mut(global_row_index)
}
}
pub struct CscColIter<'a, T> {
lane_iter: CsLaneIter<'a, T>,
}
impl<'a, T> Iterator for CscColIter<'a, T> {
type Item = CscCol<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.lane_iter.next().map(|lane| CscCol { lane })
}
}
pub struct CscColIterMut<'a, T> {
lane_iter: CsLaneIterMut<'a, T>,
}
impl<'a, T> Iterator for CscColIterMut<'a, T>
where
T: 'a,
{
type Item = CscColMut<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.lane_iter.next().map(|lane| CscColMut { lane })
}
}