sprs/sparse/
csmat.rs

1//! A sparse matrix in the Compressed Sparse Row/Column format
2//!
3//! In the CSR format, a matrix is a structure containing three vectors:
4//! indptr, indices, and data
5//! These vectors satisfy the relation
6//! for i in [0, nrows],
7//! A(i, indices[indptr[i]..indptr[i+1]]) = data[indptr[i]..indptr[i+1]]
8//! In the CSC format, the relation is
9//! A(indices[indptr[i]..indptr[i+1]], i) = data[indptr[i]..indptr[i+1]]
10use ndarray::ArrayView;
11use num_traits::{Float, Num, Signed, Zero};
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14use std::cmp;
15use std::default::Default;
16use std::iter::{Enumerate, Zip};
17use std::mem;
18use std::ops::{Add, Deref, DerefMut, Index, IndexMut, Mul, MulAssign};
19use std::slice::Iter;
20
21use crate::{Ix1, Ix2, Shape};
22use ndarray::linalg::Dot;
23use ndarray::{self, Array, ArrayBase, ShapeBuilder};
24
25use crate::indexing::SpIndex;
26
27use crate::errors::StructureError;
28use crate::sparse::binop;
29use crate::sparse::permutation::PermViewI;
30use crate::sparse::prelude::*;
31use crate::sparse::prod;
32use crate::sparse::smmp;
33use crate::sparse::to_dense::assign_to_dense;
34use crate::sparse::utils;
35use crate::sparse::vec;
36
37/// Describe the storage of a `CsMat`
38#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
39#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
40#[allow(clippy::upper_case_acronyms)]
41pub enum CompressedStorage {
42    /// Compressed row storage
43    CSR,
44    /// Compressed column storage
45    CSC,
46}
47
48impl CompressedStorage {
49    /// Get the other storage, ie return CSC if we were CSR, and vice versa
50    pub fn other_storage(self) -> Self {
51        match self {
52            CSR => CSC,
53            CSC => CSR,
54        }
55    }
56}
57
58pub fn outer_dimension(
59    storage: CompressedStorage,
60    rows: usize,
61    cols: usize,
62) -> usize {
63    match storage {
64        CSR => rows,
65        CSC => cols,
66    }
67}
68
69pub fn inner_dimension(
70    storage: CompressedStorage,
71    rows: usize,
72    cols: usize,
73) -> usize {
74    match storage {
75        CSR => cols,
76        CSC => rows,
77    }
78}
79
80pub use self::CompressedStorage::{CSC, CSR};
81
82#[derive(Clone, Copy, PartialEq, Eq, Debug)]
83/// Hold the index of a non-zero element in the compressed storage
84///
85/// An `NnzIndex` can be used to later access the non-zero element in constant
86/// time.
87pub struct NnzIndex(pub usize);
88
89pub struct CsIter<'a, N: 'a, I: 'a, Iptr: 'a = I>
90where
91    I: SpIndex,
92    Iptr: SpIndex,
93{
94    storage: CompressedStorage,
95    cur_outer: I,
96    indptr: crate::IndPtrView<'a, Iptr>,
97    inner_iter: Enumerate<Zip<Iter<'a, I>, Iter<'a, N>>>,
98}
99
100impl<'a, N, I, Iptr> Iterator for CsIter<'a, N, I, Iptr>
101where
102    I: SpIndex,
103    Iptr: SpIndex,
104    N: 'a,
105{
106    type Item = (&'a N, (I, I));
107    fn next(&mut self) -> Option<<Self as Iterator>::Item> {
108        match self.inner_iter.next() {
109            None => None,
110            Some((nnz_index, (&inner_ind, val))) => {
111                // loop to find the correct outer dimension. Looping
112                // is necessary because there can be several adjacent
113                // empty outer dimensions.
114                loop {
115                    let nnz_end = self
116                        .indptr
117                        .outer_inds_sz(self.cur_outer.index_unchecked())
118                        .end;
119                    if nnz_index == nnz_end.index_unchecked() {
120                        self.cur_outer += I::one();
121                    } else {
122                        break;
123                    }
124                }
125                let (row, col) = match self.storage {
126                    CSR => (self.cur_outer, inner_ind),
127                    CSC => (inner_ind, self.cur_outer),
128                };
129                Some((val, (row, col)))
130            }
131        }
132    }
133
134    fn size_hint(&self) -> (usize, Option<usize>) {
135        self.inner_iter.size_hint()
136    }
137}
138
139impl<N, I: SpIndex, Iptr: SpIndex, IptrStorage, IStorage, DStorage>
140    CsMatBase<N, I, IptrStorage, IStorage, DStorage, Iptr>
141where
142    IptrStorage: Deref<Target = [Iptr]>,
143    IStorage: Deref<Target = [I]>,
144    DStorage: Deref<Target = [N]>,
145{
146    pub(crate) fn new_checked(
147        storage: CompressedStorage,
148        shape: (usize, usize),
149        indptr: IptrStorage,
150        indices: IStorage,
151        data: DStorage,
152    ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)> {
153        let (nrows, ncols) = shape;
154        let (inner, outer) = match storage {
155            CSR => (ncols, nrows),
156            CSC => (nrows, ncols),
157        };
158        if data.len() != indices.len() {
159            return Err((
160                indptr,
161                indices,
162                data,
163                StructureError::SizeMismatch(
164                    "data and indices have different sizes",
165                ),
166            ));
167        }
168        match crate::sparse::utils::check_compressed_structure(
169            inner,
170            outer,
171            indptr.as_ref(),
172            indices.as_ref(),
173        ) {
174            Err(e) => Err((indptr, indices, data, e)),
175            Ok(_) => Ok(Self {
176                storage,
177                nrows,
178                ncols,
179                indptr: crate::IndPtrBase::new_trusted(indptr),
180                indices,
181                data,
182            }),
183        }
184    }
185
186    /// Create a new `CSR` sparse matrix
187    ///
188    /// See `new_csc` for the `CSC` equivalent
189    ///
190    /// This constructor can be used to construct all
191    /// sparse matrix types.
192    /// By using the type aliases one helps constrain the resulting type,
193    /// as shown below
194    ///
195    /// # Example
196    ///
197    /// ```rust
198    /// # use sprs::*;
199    /// // This creates an owned matrix
200    /// let owned_matrix = CsMat::new((2, 2), vec![0, 1, 1], vec![1], vec![4_u8]);
201    /// // This creates a matrix which only borrows the elements
202    /// let borrow_matrix = CsMatView::new((2, 2), &[0, 1, 1], &[1], &[4_u8]);
203    /// // A combination of storage types may also be used for a
204    /// // general sparse matrix
205    /// let mixed_matrix = CsMatBase::new((2, 2), &[0, 1, 1] as &[_], vec![1_i64].into_boxed_slice(), vec![4_u8]);
206    /// ```
207    pub fn new(
208        shape: (usize, usize),
209        indptr: IptrStorage,
210        indices: IStorage,
211        data: DStorage,
212    ) -> Self {
213        Self::new_checked(CompressedStorage::CSR, shape, indptr, indices, data)
214            .map_err(|(_, _, _, e)| e)
215            .unwrap()
216    }
217
218    /// Create a new `CSC` sparse matrix
219    ///
220    /// See `new` for the `CSR` equivalent
221    pub fn new_csc(
222        shape: (usize, usize),
223        indptr: IptrStorage,
224        indices: IStorage,
225        data: DStorage,
226    ) -> Self {
227        Self::new_checked(CompressedStorage::CSC, shape, indptr, indices, data)
228            .map_err(|(_, _, _, e)| e)
229            .unwrap()
230    }
231
232    /// Try to create a new `CSR` sparse matrix
233    ///
234    /// See `try_new_csc` for the `CSC` equivalent
235    pub fn try_new(
236        shape: (usize, usize),
237        indptr: IptrStorage,
238        indices: IStorage,
239        data: DStorage,
240    ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)> {
241        Self::new_checked(CompressedStorage::CSR, shape, indptr, indices, data)
242    }
243
244    /// Try to create a new `CSC` sparse matrix
245    ///
246    /// See `new` for the `CSR` equivalent
247    pub fn try_new_csc(
248        shape: (usize, usize),
249        indptr: IptrStorage,
250        indices: IStorage,
251        data: DStorage,
252    ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)> {
253        Self::new_checked(CompressedStorage::CSC, shape, indptr, indices, data)
254    }
255
256    /// Create a `CsMat` matrix from raw data,
257    /// without checking their validity
258    ///
259    /// # Safety
260    /// This is unsafe because algorithms are free to assume
261    /// that properties guaranteed by
262    /// [`check_compressed_structure`](Self::check_compressed_structure) are enforced.
263    /// For instance, non out-of-bounds indices can be relied upon to
264    /// perform unchecked slice access.
265    pub unsafe fn new_unchecked(
266        storage: CompressedStorage,
267        shape: Shape,
268        indptr: IptrStorage,
269        indices: IStorage,
270        data: DStorage,
271    ) -> Self {
272        let (nrows, ncols) = shape;
273        Self {
274            storage,
275            nrows,
276            ncols,
277            indptr: crate::IndPtrBase::new_trusted(indptr),
278            indices,
279            data,
280        }
281    }
282
283    /// Internal analog to `new_unchecked` which is not marked as `unsafe` as
284    /// we should always construct valid matrices internally
285    pub(crate) fn new_trusted(
286        storage: CompressedStorage,
287        shape: Shape,
288        indptr: IptrStorage,
289        indices: IStorage,
290        data: DStorage,
291    ) -> Self {
292        let (nrows, ncols) = shape;
293        Self {
294            storage,
295            nrows,
296            ncols,
297            indptr: crate::IndPtrBase::new_trusted(indptr),
298            indices,
299            data,
300        }
301    }
302}
303
304impl<N, I: SpIndex, Iptr: SpIndex, IptrStorage, IStorage, DStorage>
305    CsMatBase<N, I, IptrStorage, IStorage, DStorage, Iptr>
306where
307    IptrStorage: Deref<Target = [Iptr]>,
308    IStorage: DerefMut<Target = [I]>,
309    DStorage: DerefMut<Target = [N]>,
310{
311    fn new_from_unsorted_checked(
312        storage: CompressedStorage,
313        shape: (usize, usize),
314        indptr: IptrStorage,
315        mut indices: IStorage,
316        mut data: DStorage,
317    ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)>
318    where
319        N: Clone,
320    {
321        let (nrows, ncols) = shape;
322        let (inner, outer) = match storage {
323            CSR => (ncols, nrows),
324            CSC => (nrows, ncols),
325        };
326        if data.len() != indices.len() {
327            return Err((
328                indptr,
329                indices,
330                data,
331                StructureError::SizeMismatch(
332                    "data and indices have different sizes",
333                ),
334            ));
335        }
336        let mut buf = Vec::new();
337        for start_stop in indptr.windows(2) {
338            let start = start_stop[0].to_usize().unwrap();
339            let stop = start_stop[1].to_usize().unwrap();
340            let indices = &mut indices[start..stop];
341            if utils::sorted_indices(indices) {
342                continue;
343            }
344            let data = &mut data[start..stop];
345            let len = stop - start;
346            let indices = &mut indices[..len];
347            let data = &mut data[..len];
348            utils::sort_indices_data_slices(indices, data, &mut buf);
349        }
350
351        match crate::sparse::utils::check_compressed_structure(
352            inner,
353            outer,
354            indptr.as_ref(),
355            indices.as_ref(),
356        ) {
357            Err(e) => Err((indptr, indices, data, e)),
358            Ok(_) => Ok(Self {
359                storage,
360                nrows,
361                ncols,
362                indptr: crate::IndPtrBase::new_trusted(indptr),
363                indices,
364                data,
365            }),
366        }
367    }
368
369    /// Try create a `CSR` matrix which acts as an owner of its data.
370    ///
371    /// A `CSC` matrix can be created with `new_from_unsorted_csc()`.
372    ///
373    /// If necessary, the indices will be sorted in place.
374    pub fn new_from_unsorted(
375        shape: Shape,
376        indptr: IptrStorage,
377        indices: IStorage,
378        data: DStorage,
379    ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)>
380    where
381        N: Clone,
382    {
383        Self::new_from_unsorted_checked(CSR, shape, indptr, indices, data)
384    }
385
386    /// Try create a `CSC` matrix which acts as an owner of its data.
387    ///
388    /// A `CSR` matrix can be created with `new_from_unsorted_csr()`.
389    ///
390    /// If necessary, the indices will be sorted in place.
391    pub fn new_from_unsorted_csc(
392        shape: Shape,
393        indptr: IptrStorage,
394        indices: IStorage,
395        data: DStorage,
396    ) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)>
397    where
398        N: Clone,
399    {
400        Self::new_from_unsorted_checked(CSC, shape, indptr, indices, data)
401    }
402}
403
404/// # Constructor methods for owned sparse matrices
405impl<N, I: SpIndex, Iptr: SpIndex> CsMatI<N, I, Iptr> {
406    /// Identity matrix, stored as a CSR matrix.
407    ///
408    /// ```rust
409    /// use sprs::{CsMat, CsVec};
410    /// let eye = CsMat::eye(5);
411    /// assert!(eye.is_csr());
412    /// let x = CsVec::new(5, vec![0, 2, 4], vec![1., 2., 3.]);
413    /// let y = &eye * &x;
414    /// assert_eq!(x, y);
415    /// ```
416    pub fn eye(dim: usize) -> Self
417    where
418        N: Num + Clone,
419    {
420        let _ = (I::from_usize(dim), Iptr::from_usize(dim)); // Make sure dim fits in type I & Iptr
421        let n = dim;
422        let indptr = (0..=n).map(Iptr::from_usize_unchecked).collect();
423        let indices = (0..n).map(I::from_usize_unchecked).collect();
424        let data = vec![N::one(); n];
425        Self::new_trusted(CSR, (n, n), indptr, indices, data)
426    }
427
428    /// Identity matrix, stored as a CSC matrix.
429    ///
430    /// ```rust
431    /// use sprs::{CsMat, CsVec};
432    /// let eye = CsMat::eye_csc(5);
433    /// assert!(eye.is_csc());
434    /// let x = CsVec::new(5, vec![0, 2, 4], vec![1., 2., 3.]);
435    /// let y = &eye * &x;
436    /// assert_eq!(x, y);
437    /// ```
438    pub fn eye_csc(dim: usize) -> Self
439    where
440        N: Num + Clone,
441    {
442        let _ = (I::from_usize(dim), Iptr::from_usize(dim)); // Make sure dim fits in type I & Iptr
443        let n = dim;
444        let indptr = (0..=n).map(Iptr::from_usize_unchecked).collect();
445        let indices = (0..n).map(I::from_usize_unchecked).collect();
446        let data = vec![N::one(); n];
447        Self::new_trusted(CSC, (n, n), indptr, indices, data)
448    }
449    /// Create an empty `CsMat` for building purposes
450    pub fn empty(storage: CompressedStorage, inner_size: usize) -> Self {
451        let shape = match storage {
452            CSR => (0, inner_size),
453            CSC => (inner_size, 0),
454        };
455        Self::new_trusted(
456            storage,
457            shape,
458            vec![Iptr::zero(); 1],
459            Vec::new(),
460            Vec::new(),
461        )
462    }
463
464    /// Create a new `CsMat` representing the zero matrix.
465    /// Hence it has no non-zero elements.
466    pub fn zero(shape: Shape) -> Self {
467        let (nrows, _ncols) = shape;
468        Self::new_trusted(
469            CSR,
470            shape,
471            vec![Iptr::zero(); nrows + 1],
472            Vec::new(),
473            Vec::new(),
474        )
475    }
476
477    /// Reserve the storage for the given additional number of nonzero data
478    pub fn reserve_outer_dim(&mut self, outer_dim_additional: usize) {
479        self.indptr.reserve(outer_dim_additional);
480    }
481
482    /// Reserve the storage for the given additional number of nonzero data
483    pub fn reserve_nnz(&mut self, nnz_additional: usize) {
484        self.indices.reserve(nnz_additional);
485        self.data.reserve(nnz_additional);
486    }
487
488    /// Reserve the storage for the given number of nonzero data
489    pub fn reserve_outer_dim_exact(&mut self, outer_dim_lim: usize) {
490        self.indptr.reserve_exact(outer_dim_lim + 1);
491    }
492
493    /// Reserve the storage for the given number of nonzero data
494    pub fn reserve_nnz_exact(&mut self, nnz_lim: usize) {
495        self.indices.reserve_exact(nnz_lim);
496        self.data.reserve_exact(nnz_lim);
497    }
498
499    /// Create a CSR matrix from a dense matrix, ignoring elements lower than `epsilon`.
500    ///
501    /// If epsilon is negative, it will be clamped to zero.
502    pub fn csr_from_dense(m: ArrayView<N, Ix2>, epsilon: N) -> Self
503    where
504        N: Num + Clone + cmp::PartialOrd + Signed,
505    {
506        let epsilon = if epsilon > N::zero() {
507            epsilon
508        } else {
509            N::zero()
510        };
511        let nrows = m.shape()[0];
512        let ncols = m.shape()[1];
513
514        let mut indptr = vec![Iptr::zero(); nrows + 1];
515        let mut nnz = 0;
516        for (row, row_count) in m.outer_iter().zip(&mut indptr[1..]) {
517            nnz += row.iter().filter(|&x| x.abs() > epsilon).count();
518            *row_count = Iptr::from_usize(nnz);
519        }
520
521        let mut indices = Vec::with_capacity(nnz);
522        let mut data = Vec::with_capacity(nnz);
523        for row in m.outer_iter() {
524            for (col_ind, x) in row.iter().enumerate() {
525                if x.abs() > epsilon {
526                    indices.push(I::from_usize(col_ind));
527                    data.push(x.clone());
528                }
529            }
530        }
531        Self {
532            storage: CompressedStorage::CSR,
533            nrows,
534            ncols,
535            indptr: crate::IndPtr::new_trusted(indptr),
536            indices,
537            data,
538        }
539    }
540
541    /// Create a CSC matrix from a dense matrix, ignoring elements lower than `epsilon`.
542    ///
543    /// If epsilon is negative, it will be clamped to zero.
544    pub fn csc_from_dense(m: ArrayView<N, Ix2>, epsilon: N) -> Self
545    where
546        N: Num + Clone + cmp::PartialOrd + Signed,
547    {
548        Self::csr_from_dense(m.reversed_axes(), epsilon).transpose_into()
549    }
550
551    /// Append an outer dim to an existing matrix, compressing it in the process
552    pub fn append_outer(self, data: &[N]) -> Self
553    where
554        N: Clone + Zero,
555    {
556        // Safety: enumerate is monotonically increasing
557        unsafe {
558            self.append_outer_iter_unchecked(
559                data.iter()
560                    .cloned()
561                    .enumerate()
562                    .filter(|(_, val)| !val.is_zero()),
563            )
564        }
565    }
566
567    /// Append an outer dim to an existing matrix, increasing the size along the outer
568    /// dimension by one.
569    ///
570    /// # Panics
571    ///
572    /// if the iterator index is **not** monotonically increasing
573    pub fn append_outer_iter<Iter>(self, iter: Iter) -> Self
574    where
575        N: Zero,
576        Iter: IntoIterator<Item = (usize, N)>,
577    {
578        let iter = iter.into_iter();
579        unsafe {
580            self.append_outer_iter_unchecked(AssertOrderedIterator {
581                prev: None,
582                iter: iter.filter(|(_, val)| !val.is_zero()),
583            })
584        }
585    }
586
587    /// Append an outer dim to an existing matrix, increasing the size along the outer
588    /// dimension by one.
589    ///
590    /// # Safety
591    ///
592    /// This is unsafe since indices for each inner dim should be monotonically increasing
593    /// which is not checked. The data values are additionally not checked for zero.
594    /// See `append_outer_iter` for the checked version
595    pub unsafe fn append_outer_iter_unchecked<Iter>(
596        mut self,
597        iter: Iter,
598    ) -> Self
599    where
600        Iter: IntoIterator<Item = (usize, N)>,
601    {
602        let iter = iter.into_iter();
603        if let (_, Some(nnz)) = iter.size_hint() {
604            self.reserve_nnz(nnz)
605        }
606        let mut nnz = self.nnz();
607        for (inner_ind, val) in iter {
608            self.indices.push(I::from_usize(inner_ind));
609            self.data.push(val);
610            nnz += 1;
611        }
612        if let Some(last_inner_ind) = self.indices.last() {
613            assert!(
614                last_inner_ind.index_unchecked() < self.inner_dims(),
615                "inner index out of range"
616            );
617        }
618        match self.storage {
619            CSR => self.nrows += 1,
620            CSC => self.ncols += 1,
621        }
622        self.indptr.push(Iptr::from_usize(nnz));
623        self
624    }
625
626    /// Append an outer dim to an existing matrix, provided by a sparse vector
627    pub fn append_outer_csvec(self, vec: CsVecViewI<N, I>) -> Self
628    where
629        N: Clone,
630    {
631        assert_eq!(self.inner_dims(), vec.dim());
632        // Safety: CsVec has monotonically increasing indices
633        unsafe {
634            self.append_outer_iter_unchecked(
635                vec.iter().map(|(i, val)| (i, val.clone())),
636            )
637        }
638    }
639
640    /// Insert an element in the matrix. If the element is already present,
641    /// its value is overwritten.
642    ///
643    /// Warning: this is not an efficient operation, as it requires
644    /// a non-constant lookup followed by two `Vec` insertions.
645    ///
646    /// The insertion will be efficient, however, if the elements are inserted
647    /// according to the matrix's order, eg following the row order for a CSR
648    /// matrix.
649    pub fn insert(&mut self, row: usize, col: usize, val: N) {
650        match self.storage() {
651            CSR => self.insert_outer_inner(row, col, val),
652            CSC => self.insert_outer_inner(col, row, val),
653        }
654    }
655
656    fn insert_outer_inner(
657        &mut self,
658        outer_ind: usize,
659        inner_ind: usize,
660        val: N,
661    ) {
662        let outer_dims = self.outer_dims();
663        let inner_ind_idx = I::from_usize(inner_ind);
664        if outer_ind >= outer_dims {
665            // we need to add a new outer dimension
666            let last_nnz = self.indptr.nnz_i();
667            self.indptr.resize(outer_ind + 1, last_nnz);
668            self.set_outer_dims(outer_ind + 1);
669            self.indptr.push(last_nnz + Iptr::one());
670            self.indices.push(inner_ind_idx);
671            self.data.push(val);
672        } else {
673            // we need to search for an insertion spot
674            let range = self.indptr.outer_inds_sz(outer_ind);
675            let location =
676                self.indices[range.clone()].binary_search(&inner_ind_idx);
677            match location {
678                Ok(ind) => {
679                    let ind = range.start + ind.index_unchecked();
680                    self.data[ind] = val;
681                    return;
682                }
683                Err(ind) => {
684                    let ind = range.start + ind.index_unchecked();
685                    self.indices.insert(ind, inner_ind_idx);
686                    self.data.insert(ind, val);
687                    self.indptr.record_new_element(outer_ind);
688                }
689            }
690        }
691
692        if inner_ind >= self.inner_dims() {
693            self.set_inner_dims(inner_ind + 1);
694        }
695    }
696
697    fn set_outer_dims(&mut self, outer_dims: usize) {
698        match self.storage() {
699            CSR => self.nrows = outer_dims,
700            CSC => self.ncols = outer_dims,
701        }
702    }
703
704    fn set_inner_dims(&mut self, inner_dims: usize) {
705        match self.storage() {
706            CSR => self.ncols = inner_dims,
707            CSC => self.nrows = inner_dims,
708        }
709    }
710}
711
712pub(crate) struct AssertOrderedIterator<Iter> {
713    prev: Option<usize>,
714    iter: Iter,
715}
716
717impl<N, Iter: Iterator<Item = (usize, N)>> Iterator
718    for AssertOrderedIterator<Iter>
719{
720    type Item = (usize, N);
721
722    fn next(&mut self) -> Option<Self::Item> {
723        let (idx, n) = self.iter.next()?;
724
725        if let Some(prev_idx) = self.prev {
726            assert!(
727                prev_idx < idx,
728                "index out of order. {} followed {}",
729                idx,
730                prev_idx
731            );
732        }
733        self.prev = Some(idx);
734        Some((idx, n))
735    }
736
737    fn size_hint(&self) -> (usize, Option<usize>) {
738        self.iter.size_hint()
739    }
740}
741
742/// # Constructor methods for sparse matrix views
743///
744/// These constructors can be used to create views over non-matrix data
745/// such as slices.
746impl<'a, N: 'a, I: 'a + SpIndex, Iptr: 'a + SpIndex>
747    CsMatViewI<'a, N, I, Iptr>
748{
749    /// Get a view into count contiguous outer dimensions, starting from i.
750    ///
751    /// eg this gets the rows from i to i + count in a CSR matrix
752    ///
753    /// This function is now deprecated, as using an index and a count is not
754    /// ergonomic. The replacement, `slice_outer`, leverages the
755    /// `std::ops::Range` family of types, which is better integrated into the
756    /// ecosystem.
757    #[deprecated(
758        since = "0.10.0",
759        note = "Please use the `slice_outer` method instead"
760    )]
761    pub fn middle_outer_views(
762        &self,
763        i: usize,
764        count: usize,
765    ) -> CsMatViewI<'a, N, I, Iptr> {
766        let iend = i.checked_add(count).unwrap();
767        let (nrows, ncols) = match self.storage {
768            CSR => (count, self.cols()),
769            CSC => (self.rows(), count),
770        };
771        let data_range = self.indptr.outer_inds_slice(i, iend);
772        CsMatViewI {
773            storage: self.storage,
774            nrows,
775            ncols,
776            indptr: self.indptr.middle_slice_rbr(i..iend),
777            indices: &self.indices[data_range.clone()],
778            data: &self.data[data_range],
779        }
780    }
781
782    /// Get an iterator that yields the non-zero locations and values stored in
783    /// this matrix, in the fastest iteration order.
784    ///
785    /// This method will yield the correct lifetime for iterating over a sparse
786    /// matrix view.
787    pub fn iter_rbr(&self) -> CsIter<'a, N, I, Iptr> {
788        CsIter {
789            storage: self.storage,
790            cur_outer: I::zero(),
791            indptr: self.indptr.reborrow(),
792            inner_iter: self.indices.iter().zip(self.data.iter()).enumerate(),
793        }
794    }
795}
796
797/// # Common methods for all variants of compressed sparse matrices.
798impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
799    CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
800where
801    I: SpIndex,
802    Iptr: SpIndex,
803    IptrStorage: Deref<Target = [Iptr]>,
804    IndStorage: Deref<Target = [I]>,
805    DataStorage: Deref<Target = [N]>,
806{
807    /// The underlying storage of this matrix
808    pub fn storage(&self) -> CompressedStorage {
809        self.storage
810    }
811
812    /// The number of rows of this matrix
813    pub fn rows(&self) -> usize {
814        self.nrows
815    }
816
817    /// The number of cols of this matrix
818    pub fn cols(&self) -> usize {
819        self.ncols
820    }
821
822    /// The shape of the matrix.
823    /// Equivalent to `let shape = (a.rows(), a.cols())`.
824    pub fn shape(&self) -> Shape {
825        (self.nrows, self.ncols)
826    }
827
828    /// The number of non-zero elements this matrix stores.
829    /// This is often relevant for the complexity of most sparse matrix
830    /// algorithms, which are often linear in the number of non-zeros.
831    pub fn nnz(&self) -> usize {
832        self.indptr.nnz()
833    }
834
835    /// The density of the sparse matrix, defined as the number of non-zero
836    /// elements divided by the maximum number of elements
837    pub fn density(&self) -> f64 {
838        let rows = self.nrows as f64;
839        let cols = self.ncols as f64;
840        let nnz = self.nnz() as f64;
841        nnz / (rows * cols)
842    }
843
844    /// Number of outer dimensions, that ie equal to `self.rows()` for a CSR
845    /// matrix, and equal to `self.cols()` for a CSC matrix
846    pub fn outer_dims(&self) -> usize {
847        outer_dimension(self.storage, self.nrows, self.ncols)
848    }
849
850    /// Number of inner dimensions, that ie equal to `self.cols()` for a CSR
851    /// matrix, and equal to `self.rows()` for a CSC matrix
852    pub fn inner_dims(&self) -> usize {
853        match self.storage {
854            CSC => self.nrows,
855            CSR => self.ncols,
856        }
857    }
858
859    /// Access the element located at row i and column j.
860    /// Will return None if there is no non-zero element at this location.
861    ///
862    /// This access is logarithmic in the number of non-zeros
863    /// in the corresponding outer slice. It is therefore advisable not to rely
864    /// on this for algorithms, and prefer [`outer_iterator`](Self::outer_iterator)
865    /// which accesses elements in storage order.
866    pub fn get(&self, i: usize, j: usize) -> Option<&N> {
867        match self.storage {
868            CSR => self.get_outer_inner(i, j),
869            CSC => self.get_outer_inner(j, i),
870        }
871    }
872
873    /// The array of offsets in the `indices()` `and data()` slices.
874    /// The elements of the slice at outer dimension i
875    /// are available between the elements `indptr\[i\]` and `indptr\[i+1\]`
876    /// in the `indices()` and `data()` slices.
877    ///
878    /// # Example
879    ///
880    /// ```rust
881    /// use sprs::{CsMat};
882    /// let eye : CsMat<f64> = CsMat::eye(5);
883    /// // get the element of row 3
884    /// // there is only one element in this row, with a column index of 3
885    /// // and a value of 1.
886    /// let range = eye.indptr().outer_inds_sz(3);
887    /// assert_eq!(range.start, 3);
888    /// assert_eq!(range.end, 4);
889    /// assert_eq!(eye.indices()[range.start], 3);
890    /// assert_eq!(eye.data()[range.start], 1.);
891    /// ```
892    pub fn indptr(&self) -> crate::IndPtrView<Iptr> {
893        crate::IndPtrView::new_trusted(self.indptr.raw_storage())
894    }
895
896    /// Get an indptr representation suitable for ffi, cloning if necessary to
897    /// get a compatible representation.
898    ///
899    /// # Warning
900    ///
901    /// For ffi usage, one needs to call `Cow::as_ptr`, but it's important
902    /// to keep the `Cow` alive during the lifetime of the pointer. Example
903    /// of a correct and incorrect ffi usage:
904    ///
905    /// ```rust
906    /// let mat: sprs::CsMat<f64> = sprs::CsMat::eye(5);
907    /// let mid = mat.view().middle_outer_views(1, 2);
908    /// let ptr = {
909    ///     let indptr_proper = mid.proper_indptr();
910    ///     println!(
911    ///         "ptr {:?} is valid as long as _indptr_proper_owned is in scope",
912    ///         indptr_proper.as_ptr()
913    ///     );
914    ///     indptr_proper.as_ptr()
915    /// };
916    /// // This line is UB.
917    /// // println!("ptr deref: {}", *ptr);
918    /// ```
919    pub fn proper_indptr(&self) -> std::borrow::Cow<[Iptr]> {
920        self.indptr.to_proper()
921    }
922
923    /// The inner dimension location for each non-zero value. See
924    /// the documentation of `indptr()` for more explanations.
925    pub fn indices(&self) -> &[I] {
926        &self.indices[..]
927    }
928
929    /// The non-zero values. See the documentation of `indptr()`
930    /// for more explanations.
931    pub fn data(&self) -> &[N] {
932        &self.data[..]
933    }
934
935    /// Destruct the matrix object and recycle its storage containers.
936    ///
937    /// # Example
938    ///
939    /// ```rust
940    /// use sprs::{CsMat};
941    /// let (indptr, indices, data) = CsMat::<i32>::eye(3).into_raw_storage();
942    /// assert_eq!(indptr, vec![0, 1, 2, 3]);
943    /// assert_eq!(indices, vec![0, 1, 2]);
944    /// assert_eq!(data, vec![1, 1, 1]);
945    /// ```
946    pub fn into_raw_storage(self) -> (IptrStorage, IndStorage, DataStorage) {
947        let Self {
948            indptr,
949            indices,
950            data,
951            ..
952        } = self;
953        (indptr.into_raw_storage(), indices, data)
954    }
955
956    /// Test whether the matrix is in CSC storage
957    pub fn is_csc(&self) -> bool {
958        self.storage == CSC
959    }
960
961    /// Test whether the matrix is in CSR storage
962    pub fn is_csr(&self) -> bool {
963        self.storage == CSR
964    }
965
966    /// Transpose a matrix in place
967    /// No allocation required (this is simply a storage order change)
968    pub fn transpose_mut(&mut self) {
969        mem::swap(&mut self.nrows, &mut self.ncols);
970        self.storage = self.storage.other_storage();
971    }
972
973    /// Transpose a matrix in place
974    /// No allocation required (this is simply a storage order change)
975    pub fn transpose_into(mut self) -> Self {
976        self.transpose_mut();
977        self
978    }
979
980    /// Transposed view of this matrix
981    /// No allocation required (this is simply a storage order change)
982    pub fn transpose_view(&self) -> CsMatViewI<N, I, Iptr> {
983        CsMatViewI {
984            storage: self.storage.other_storage(),
985            nrows: self.ncols,
986            ncols: self.nrows,
987            indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
988            indices: &self.indices[..],
989            data: &self.data[..],
990        }
991    }
992
993    /// Get an owned version of this matrix. If the matrix was already
994    /// owned, this will make a deep copy.
995    pub fn to_owned(&self) -> CsMatI<N, I, Iptr>
996    where
997        N: Clone,
998    {
999        CsMatI {
1000            storage: self.storage,
1001            nrows: self.nrows,
1002            ncols: self.ncols,
1003            indptr: self.indptr.to_owned(),
1004            indices: self.indices.to_vec(),
1005            data: self.data.to_vec(),
1006        }
1007    }
1008
1009    /// Generate a one-hot matrix, compressing the inner dimension.
1010    ///
1011    /// Returns a matrix with the same size, the same CSR/CSC type,
1012    /// and a single value of 1.0 within each populated inner vector.
1013    ///
1014    /// See [`into_csc`](CsMatBase::into_csc) and [`into_csr`](CsMatBase::into_csr)
1015    /// if you need to prepare a matrix
1016    /// for one-hot compression.
1017    pub fn to_inner_onehot(&self) -> CsMatI<N, I, Iptr>
1018    where
1019        N: Clone + Float + PartialOrd,
1020    {
1021        let mut indptr_counter = 0_usize;
1022        let mut indptr: Vec<Iptr> = Vec::with_capacity(self.indptr.len());
1023
1024        let max_data_len = self.indptr.len().min(self.data.len());
1025        let mut indices: Vec<I> = Vec::with_capacity(max_data_len);
1026        let mut data = Vec::with_capacity(max_data_len);
1027
1028        for inner_vec in self.outer_iterator() {
1029            let hot_element = inner_vec
1030                .iter()
1031                .filter(|e| !e.1.is_nan())
1032                .max_by(|a, b| {
1033                    a.1.partial_cmp(b.1)
1034                        .expect("Unexpected NaN value was found")
1035                })
1036                .map(|a| a.0);
1037
1038            indptr.push(Iptr::from_usize(indptr_counter));
1039
1040            if let Some(inner_id) = hot_element {
1041                indices.push(I::from_usize(inner_id));
1042                data.push(N::one());
1043                indptr_counter += 1;
1044            }
1045        }
1046
1047        indptr.push(Iptr::from_usize(indptr_counter));
1048        CsMatI {
1049            storage: self.storage,
1050            nrows: self.rows(),
1051            ncols: self.cols(),
1052            indptr: crate::IndPtr::new_trusted(indptr),
1053            indices,
1054            data,
1055        }
1056    }
1057
1058    /// Clone the matrix with another integer type for indptr and indices
1059    ///
1060    /// # Panics
1061    ///
1062    /// If the indices or indptr values cannot be represented by the requested
1063    /// integer type.
1064    pub fn to_other_types<I2, N2, Iptr2>(&self) -> CsMatI<N2, I2, Iptr2>
1065    where
1066        N: Clone + Into<N2>,
1067        I2: SpIndex,
1068        Iptr2: SpIndex,
1069    {
1070        let indptr = crate::IndPtr::new_trusted(
1071            self.indptr
1072                .raw_storage()
1073                .iter()
1074                .map(|i| Iptr2::from_usize(i.index_unchecked()))
1075                .collect(),
1076        );
1077        let indices = self
1078            .indices
1079            .iter()
1080            .map(|i| I2::from_usize(i.index_unchecked()))
1081            .collect();
1082        let data = self.data.iter().map(|x| x.clone().into()).collect();
1083        CsMatI {
1084            storage: self.storage,
1085            nrows: self.nrows,
1086            ncols: self.ncols,
1087            indptr,
1088            indices,
1089            data,
1090        }
1091    }
1092
1093    /// Return a view into the current matrix
1094    pub fn view(&self) -> CsMatViewI<N, I, Iptr> {
1095        CsMatViewI {
1096            storage: self.storage,
1097            nrows: self.nrows,
1098            ncols: self.ncols,
1099            indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
1100            indices: &self.indices[..],
1101            data: &self.data[..],
1102        }
1103    }
1104
1105    pub fn structure_view(&self) -> CsStructureViewI<I, Iptr> {
1106        // Safety: std::slice::from_raw_parts requires its passed
1107        // pointer to be valid for the whole length of the slice. We have a
1108        // zero-sized type, so the length is zero, and since we cast
1109        // a non-null pointer, the pointer is valid as all pointers to zero-sized
1110        // types are valid if they are not null.
1111        let zst_data = unsafe {
1112            std::slice::from_raw_parts(
1113                self.data.as_ptr().cast::<()>(),
1114                self.data.len(),
1115            )
1116        };
1117        CsStructureViewI {
1118            storage: self.storage,
1119            nrows: self.nrows,
1120            ncols: self.ncols,
1121            indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
1122            indices: &self.indices[..],
1123            data: zst_data,
1124        }
1125    }
1126
1127    pub fn to_dense(&self) -> Array<N, Ix2>
1128    where
1129        N: Clone + Zero,
1130    {
1131        let mut res = Array::zeros((self.rows(), self.cols()));
1132        assign_to_dense(res.view_mut(), self.view());
1133        res
1134    }
1135
1136    /// Return an outer iterator for the matrix
1137    ///
1138    /// This can be used for iterating over the rows (resp. cols) of
1139    /// a CSR (resp. CSC) matrix.
1140    ///
1141    /// ```rust
1142    /// use sprs::{CsMat};
1143    /// let eye = CsMat::eye(5);
1144    /// for (row_ind, row_vec) in eye.outer_iterator().enumerate() {
1145    ///     let (col_ind, &val): (_, &f64) = row_vec.iter().next().unwrap();
1146    ///     assert_eq!(row_ind, col_ind);
1147    ///     assert_eq!(val, 1.);
1148    /// }
1149    /// ```
1150    pub fn outer_iterator(
1151        &self,
1152    ) -> impl std::iter::DoubleEndedIterator<Item = CsVecViewI<N, I>>
1153           + std::iter::ExactSizeIterator<Item = CsVecViewI<N, I>>
1154           + '_ {
1155        self.indptr.iter_outer_sz().map(move |range| {
1156            CsVecViewI::new_trusted(
1157                self.inner_dims(),
1158                // TODO: unsafe slice indexing
1159                &self.indices[range.clone()],
1160                &self.data[range],
1161            )
1162        })
1163    }
1164
1165    /// Return an outer iterator over P*A*P^T, where it is necessary to use
1166    /// `CsVec::iter_perm(perm.inv())` to iterate over the inner dimension.
1167    /// Unstable, this is a convenience function for the crate `sprs-ldl`
1168    /// for now.
1169    #[doc(hidden)]
1170    pub fn outer_iterator_papt<'a, 'perm: 'a>(
1171        &'a self,
1172        perm: PermViewI<'perm, I>,
1173    ) -> impl std::iter::DoubleEndedIterator<Item = (usize, CsVecViewI<'a, N, I>)>
1174           + std::iter::ExactSizeIterator<Item = (usize, CsVecViewI<'a, N, I>)>
1175           + 'a {
1176        (0..self.outer_dims()).map(move |outer_ind| {
1177            let outer_ind_perm = perm.at(outer_ind);
1178            let range = self.indptr.outer_inds_sz(outer_ind_perm);
1179            let indices = &self.indices[range.clone()];
1180            let data = &self.data[range];
1181            // CsMat invariants imply CsVec invariants
1182            let vec = CsVecBase::new_trusted(self.inner_dims(), indices, data);
1183            (outer_ind_perm, vec)
1184        })
1185    }
1186
1187    /// Get the max number of nnz for each outer dim
1188    pub fn max_outer_nnz(&self) -> usize {
1189        self.outer_iterator()
1190            .map(|outer| outer.indices().len())
1191            .max()
1192            .unwrap_or(0)
1193    }
1194
1195    /// Get the degrees of each vertex on a symmetric matrix
1196    ///
1197    /// The nonzero pattern of a symmetric matrix can be interpreted as
1198    /// an undirected graph. In such a graph, a vertex i is connected to another
1199    /// vertex j if there is a corresponding nonzero entry in the matrix at
1200    /// location (i, j).
1201    ///
1202    /// This function returns a vector containing the degree of each vertex,
1203    /// that is to say the number of neighbor of each vertex. We do not
1204    /// count diagonal entries as a neighbor.
1205    pub fn degrees(&self) -> Vec<usize> {
1206        self.outer_iterator()
1207            .enumerate()
1208            .map(|(outer_dim, outer)| {
1209                outer
1210                    .indices()
1211                    .iter()
1212                    .filter(|ind| ind.index() != outer_dim)
1213                    .count()
1214            })
1215            .collect()
1216    }
1217
1218    /// Get a view into the i-th outer dimension (eg i-th row for a CSR matrix)
1219    pub fn outer_view(&self, i: usize) -> Option<CsVecViewI<N, I>> {
1220        if i >= self.outer_dims() {
1221            return None;
1222        }
1223        let range = self.indptr.outer_inds_sz(i);
1224        // CsMat invariants imply CsVec invariants
1225        Some(CsVecViewI::new_trusted(
1226            self.inner_dims(),
1227            // TODO: unsafe slice indexing
1228            &self.indices[range.clone()],
1229            &self.data[range],
1230        ))
1231    }
1232
1233    /// Get the diagonal of a sparse matrix
1234    pub fn diag(&self) -> CsVecI<N, I>
1235    where
1236        N: Clone,
1237    {
1238        let shape = self.shape();
1239        let smallest_dim: usize = cmp::min(shape.0, shape.1);
1240        // Assuming most matrices have dense diagonals, it seems prudent
1241        // to allocate a bit of memory up front
1242        let heuristic = smallest_dim / 2;
1243        let mut index_vec = Vec::with_capacity(heuristic);
1244        let mut data_vec = Vec::with_capacity(heuristic);
1245
1246        for i in 0..smallest_dim {
1247            let optional_index = self.nnz_index(i, i);
1248            if let Some(idx) = optional_index {
1249                data_vec.push(self[idx].clone());
1250                index_vec.push(I::from_usize(i));
1251            }
1252        }
1253        data_vec.shrink_to_fit();
1254        index_vec.shrink_to_fit();
1255        CsVecI::new_trusted(smallest_dim, index_vec, data_vec)
1256    }
1257
1258    /// Iteration over all entries on the diagonal
1259    pub fn diag_iter(
1260        &self,
1261    ) -> impl ExactSizeIterator<Item = Option<&N>>
1262           + DoubleEndedIterator<Item = Option<&N>> {
1263        let smallest_dim = cmp::min(self.ncols, self.nrows);
1264        (0..smallest_dim).map(move |i| self.get_outer_inner(i, i))
1265    }
1266
1267    /// Iteration on outer blocks of size `block_size`
1268    ///
1269    /// # Panics
1270    ///
1271    /// If the block size is 0.
1272    pub fn outer_block_iter(
1273        &self,
1274        block_size: usize,
1275    ) -> impl std::iter::DoubleEndedIterator<Item = CsMatViewI<N, I, Iptr>>
1276           + std::iter::ExactSizeIterator<Item = CsMatViewI<N, I, Iptr>>
1277           + '_ {
1278        (0..self.outer_dims()).step_by(block_size).map(move |i| {
1279            let count = if i + block_size > self.outer_dims() {
1280                self.outer_dims() - i
1281            } else {
1282                block_size
1283            };
1284            self.view().slice_outer_rbr(i..i + count)
1285        })
1286    }
1287
1288    /// Return a new sparse matrix with the same sparsity pattern, with all non-zero values mapped by the function `f`.
1289    pub fn map<F, N2>(&self, f: F) -> CsMatI<N2, I, Iptr>
1290    where
1291        F: FnMut(&N) -> N2,
1292    {
1293        let data: Vec<N2> = self.data.iter().map(f).collect();
1294
1295        CsMatI {
1296            storage: self.storage,
1297            nrows: self.nrows,
1298            ncols: self.ncols,
1299            indptr: self.indptr.to_owned(),
1300            indices: self.indices.to_vec(),
1301            data,
1302        }
1303    }
1304
1305    /// Access an element given its `outer_ind` and `inner_ind`.
1306    /// Will return None if there is no non-zero element at this location.
1307    ///
1308    /// This access is logarithmic in the number of non-zeros
1309    /// in the corresponding outer slice. It is therefore advisable not to rely
1310    /// on this for algorithms, and prefer [`outer_iterator`](Self::outer_iterator)
1311    /// which accesses elements in storage order.
1312    pub fn get_outer_inner(
1313        &self,
1314        outer_ind: usize,
1315        inner_ind: usize,
1316    ) -> Option<&N> {
1317        self.outer_view(outer_ind)
1318            .and_then(|vec| vec.get_rbr(inner_ind))
1319    }
1320
1321    /// Find the non-zero index of the element specified by row and col
1322    ///
1323    /// Searching this index is logarithmic in the number of non-zeros
1324    /// in the corresponding outer slice.
1325    /// Once it is available, the `NnzIndex` enables retrieving the data with
1326    /// O(1) complexity.
1327    pub fn nnz_index(&self, row: usize, col: usize) -> Option<NnzIndex> {
1328        match self.storage() {
1329            CSR => self.nnz_index_outer_inner(row, col),
1330            CSC => self.nnz_index_outer_inner(col, row),
1331        }
1332    }
1333
1334    /// Find the non-zero index of the element specified by `outer_ind` and
1335    /// `inner_ind`.
1336    ///
1337    /// Searching this index is logarithmic in the number of non-zeros
1338    /// in the corresponding outer slice.
1339    pub fn nnz_index_outer_inner(
1340        &self,
1341        outer_ind: usize,
1342        inner_ind: usize,
1343    ) -> Option<NnzIndex> {
1344        if outer_ind >= self.outer_dims() {
1345            return None;
1346        }
1347        let offset = self.indptr.outer_inds_sz(outer_ind).start;
1348        self.outer_view(outer_ind)
1349            .and_then(|vec| vec.nnz_index(inner_ind))
1350            .map(|vec::NnzIndex(ind)| NnzIndex(ind + offset))
1351    }
1352
1353    /// Check the structure of `CsMat` components
1354    /// This will ensure that:
1355    /// * indptr is of length `outer_dim() + 1`
1356    /// * indices and data have the same length, `nnz == indptr[outer_dims()]`
1357    /// * indptr is sorted
1358    /// * indptr values do not exceed [`usize::MAX`](usize::MAX)`/ 2`, as that would mean
1359    ///   indices and indptr would take more space than the addressable memory
1360    /// * indices is sorted for each outer slice
1361    /// * indices are lower than `inner_dims()`
1362    pub fn check_compressed_structure(&self) -> Result<(), StructureError> {
1363        let inner = self.inner_dims();
1364        let outer = self.outer_dims();
1365
1366        if self.indices.len() != self.data.len() {
1367            return Err(StructureError::SizeMismatch(
1368                "Indices and data lengths do not match",
1369            ));
1370        }
1371
1372        utils::check_compressed_structure(
1373            inner,
1374            outer,
1375            self.indptr.raw_storage(),
1376            &self.indices,
1377        )
1378    }
1379
1380    /// Get an iterator that yields the non-zero locations and values stored in
1381    /// this matrix, in the fastest iteration order.
1382    pub fn iter(&self) -> CsIter<N, I, Iptr> {
1383        CsIter {
1384            storage: self.storage,
1385            cur_outer: I::zero(),
1386            indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
1387            inner_iter: self.indices.iter().zip(self.data.iter()).enumerate(),
1388        }
1389    }
1390}
1391
1392/// # Methods to convert between storage orders
1393impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
1394    CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
1395where
1396    N: Default,
1397    I: SpIndex,
1398    Iptr: SpIndex,
1399    IptrStorage: Deref<Target = [Iptr]>,
1400    IndStorage: Deref<Target = [I]>,
1401    DataStorage: Deref<Target = [N]>,
1402{
1403    /// Create a matrix mathematically equal to this one, but with the
1404    /// opposed storage (a CSC matrix will be converted to CSR, and vice versa)
1405    pub fn to_other_storage(&self) -> CsMatI<N, I, Iptr>
1406    where
1407        N: Clone,
1408    {
1409        let mut indptr = vec![Iptr::zero(); self.inner_dims() + 1];
1410        let mut indices = vec![I::zero(); self.nnz()];
1411        let mut data = vec![N::default(); self.nnz()];
1412        raw::convert_mat_storage(
1413            self.view(),
1414            &mut indptr,
1415            &mut indices,
1416            &mut data,
1417        );
1418        CsMatI {
1419            storage: self.storage().other_storage(),
1420            nrows: self.nrows,
1421            ncols: self.ncols,
1422            indptr: crate::IndPtr::new_trusted(indptr),
1423            indices,
1424            data,
1425        }
1426    }
1427
1428    /// Create a new CSC matrix equivalent to this one.
1429    /// A new matrix will be created even if this matrix was already CSC.
1430    pub fn to_csc(&self) -> CsMatI<N, I, Iptr>
1431    where
1432        N: Clone,
1433    {
1434        match self.storage {
1435            CSR => self.to_other_storage(),
1436            CSC => self.to_owned(),
1437        }
1438    }
1439
1440    /// Create a new CSR matrix equivalent to this one.
1441    /// A new matrix will be created even if this matrix was already CSR.
1442    pub fn to_csr(&self) -> CsMatI<N, I, Iptr>
1443    where
1444        N: Clone,
1445    {
1446        match self.storage {
1447            CSR => self.to_owned(),
1448            CSC => self.to_other_storage(),
1449        }
1450    }
1451}
1452
1453impl<N, I, Iptr> CsMatI<N, I, Iptr>
1454where
1455    N: Default,
1456
1457    I: SpIndex,
1458    Iptr: SpIndex,
1459{
1460    /// Create a new CSC matrix equivalent to this one.
1461    /// If this matrix is CSR, it is converted to CSC
1462    /// If this matrix is CSC, it is returned by value
1463    pub fn into_csc(self) -> Self
1464    where
1465        N: Clone,
1466    {
1467        match self.storage {
1468            CSR => self.to_other_storage(),
1469            CSC => self,
1470        }
1471    }
1472
1473    /// Create a new CSR matrix equivalent to this one.
1474    /// If this matrix is CSC, it is converted to CSR
1475    /// If this matrix is CSR, it is returned by value
1476    pub fn into_csr(self) -> Self
1477    where
1478        N: Clone,
1479    {
1480        match self.storage {
1481            CSR => self,
1482            CSC => self.to_other_storage(),
1483        }
1484    }
1485}
1486
1487/// # Methods for sparse matrices holding mutable access to their values.
1488impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
1489    CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
1490where
1491    I: SpIndex,
1492    Iptr: SpIndex,
1493    IptrStorage: Deref<Target = [Iptr]>,
1494    IndStorage: Deref<Target = [I]>,
1495    DataStorage: DerefMut<Target = [N]>,
1496{
1497    /// Mutable access to the non zero values
1498    ///
1499    /// This enables changing the values without changing the matrix's
1500    /// structure. To also change the matrix's structure,
1501    /// see [modify](fn.modify.html)
1502    pub fn data_mut(&mut self) -> &mut [N] {
1503        &mut self.data[..]
1504    }
1505
1506    /// Sparse matrix self-multiplication by a scalar
1507    pub fn scale(&mut self, val: N)
1508    where
1509        for<'r> N: MulAssign<&'r N>,
1510    {
1511        for data in self.data_mut() {
1512            *data *= &val;
1513        }
1514    }
1515
1516    /// Get a mutable view into the i-th outer dimension
1517    /// (eg i-th row for a CSR matrix)
1518    pub fn outer_view_mut(&mut self, i: usize) -> Option<CsVecViewMutI<N, I>> {
1519        if i >= self.outer_dims() {
1520            return None;
1521        }
1522        let range = self.indptr.outer_inds_sz(i);
1523        // CsMat invariants imply CsVec invariants
1524        Some(CsVecBase::new_trusted(
1525            self.inner_dims(),
1526            &self.indices[range.clone()],
1527            &mut self.data[range],
1528        ))
1529    }
1530
1531    /// Get a mutable reference to the element located at row i and column j.
1532    /// Will return None if there is no non-zero element at this location.
1533    ///
1534    /// This access is logarithmic in the number of non-zeros
1535    /// in the corresponding outer slice. It is therefore advisable not to rely
1536    /// on this for algorithms, and prefer [`outer_iterator_mut`](Self::outer_iterator_mut)
1537    /// which accesses elements in storage order.
1538    /// TODO: `outer_iterator_mut` is not yet implemented
1539    pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut N> {
1540        match self.storage {
1541            CSR => self.get_outer_inner_mut(i, j),
1542            CSC => self.get_outer_inner_mut(j, i),
1543        }
1544    }
1545
1546    /// Get a mutable reference to an element given its `outer_ind` and `inner_ind`.
1547    /// Will return None if there is no non-zero element at this location.
1548    ///
1549    /// This access is logarithmic in the number of non-zeros
1550    /// in the corresponding outer slice. It is therefore advisable not to rely
1551    /// on this for algorithms, and prefer [`outer_iterator_mut`](Self::outer_iterator_mut)
1552    /// which accesses elements in storage order.
1553    pub fn get_outer_inner_mut(
1554        &mut self,
1555        outer_ind: usize,
1556        inner_ind: usize,
1557    ) -> Option<&mut N> {
1558        if let Some(NnzIndex(index)) =
1559            self.nnz_index_outer_inner(outer_ind, inner_ind)
1560        {
1561            Some(&mut self.data[index])
1562        } else {
1563            None
1564        }
1565    }
1566
1567    /// Set the value of the non-zero element located at (row, col)
1568    ///
1569    /// # Panics
1570    ///
1571    /// - on out-of-bounds access
1572    /// - if no non-zero element exists at the given location
1573    pub fn set(&mut self, row: usize, col: usize, val: N) {
1574        let outer = outer_dimension(self.storage(), row, col);
1575        let inner = inner_dimension(self.storage(), row, col);
1576        let vec::NnzIndex(index) = self
1577            .outer_view(outer)
1578            .and_then(|vec| vec.nnz_index(inner))
1579            .unwrap();
1580        self.data[index] = val;
1581    }
1582
1583    /// Apply a function to every non-zero element
1584    pub fn map_inplace<F>(&mut self, mut f: F)
1585    where
1586        F: FnMut(&N) -> N,
1587    {
1588        for val in &mut self.data[..] {
1589            *val = f(val);
1590        }
1591    }
1592
1593    /// Return a mutable outer iterator for the matrix
1594    ///
1595    /// This iterator yields mutable sparse vector views for each outer
1596    /// dimension. Only the non-zero values can be modified, the
1597    /// structure is kept immutable.
1598    pub fn outer_iterator_mut(
1599        &mut self,
1600    ) -> impl std::iter::DoubleEndedIterator<Item = CsVecViewMutI<N, I>>
1601           + std::iter::ExactSizeIterator<Item = CsVecViewMutI<N, I>>
1602           + '_ {
1603        let inner_dim = self.inner_dims();
1604        let indices = &self.indices[..];
1605        let data_ptr: *mut N = self.data.as_mut_ptr();
1606        self.indptr.iter_outer_sz().map(move |range| {
1607            // # Safety
1608            // * ranges always point to exclusive parts of data
1609            // * lifetime bound to &mut self
1610            let data: &mut [N] = unsafe {
1611                std::slice::from_raw_parts_mut(
1612                    data_ptr.add(range.start),
1613                    range.end - range.start,
1614                )
1615            };
1616
1617            CsVecViewMutI::new_trusted(inner_dim, &indices[range], data)
1618        })
1619    }
1620
1621    /// Return a mutable view into the current matrix
1622    pub fn view_mut(&mut self) -> CsMatViewMutI<N, I, Iptr> {
1623        CsMatViewMutI {
1624            storage: self.storage,
1625            nrows: self.nrows,
1626            ncols: self.ncols,
1627            indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
1628            indices: &self.indices[..],
1629            data: &mut self.data[..],
1630        }
1631    }
1632
1633    /// Iteration over all entries on the diagonal
1634    pub fn diag_iter_mut(
1635        &mut self,
1636    ) -> impl ExactSizeIterator<Item = Option<&mut N>>
1637           + DoubleEndedIterator<Item = Option<&mut N>>
1638           + '_ {
1639        let data_ptr: *mut N = self.data[..].as_mut_ptr();
1640        let smallest_dim = cmp::min(self.ncols, self.nrows);
1641        (0..smallest_dim).map(move |i| {
1642            let idx = self.nnz_index_outer_inner(i, i);
1643            if let Some(NnzIndex(idx)) = idx {
1644                // To obtain multiple mutable references to different
1645                // locations in data we must use a pointer and some unsafe.
1646                // # Safety
1647                // This is safe as
1648                // * NnzIndex provides bounds checking
1649                // * diagonal entries are never overlapping in memory
1650                // * no entries are requested more than once
1651                // * nnz_index_outer_inner does not modify or read from entries in self.data
1652                Some(unsafe { &mut *data_ptr.add(idx) })
1653            } else {
1654                None
1655            }
1656        })
1657    }
1658}
1659
1660impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
1661    CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
1662where
1663    I: SpIndex,
1664    Iptr: SpIndex,
1665    IptrStorage: DerefMut<Target = [Iptr]>,
1666    IndStorage: DerefMut<Target = [I]>,
1667    DataStorage: DerefMut<Target = [N]>,
1668{
1669    /// Modify the matrix's structure without changing its nonzero count.
1670    ///
1671    /// The coherence of the structure will be checked afterwards.
1672    ///
1673    /// # Panics
1674    ///
1675    /// If the resulting matrix breaks the `CsMat` invariants
1676    /// (sorted indices, no out of bounds indices).
1677    ///
1678    /// # Example
1679    ///
1680    /// ```rust
1681    /// use sprs::CsMat;
1682    /// // |   1   |
1683    /// // | 1     |
1684    /// // |   1 1 |
1685    /// let mut mat = CsMat::new_csc((3, 3),
1686    ///                                   vec![0, 1, 3, 4],
1687    ///                                   vec![1, 0, 2, 2],
1688    ///                                   vec![1.; 4]);
1689    ///
1690    /// // | 1 2   |
1691    /// // | 1     |
1692    /// // |   1   |
1693    /// mat.modify(|indptr, indices, data| {
1694    ///     indptr[1] = 2;
1695    ///     indptr[2] = 4;
1696    ///     indices[0] = 0;
1697    ///     indices[1] = 1;
1698    ///     indices[2] = 0;
1699    ///     data[2] = 2.;
1700    /// });
1701    /// ```
1702    pub fn modify<F>(&mut self, mut f: F)
1703    where
1704        F: FnMut(&mut [Iptr], &mut [I], &mut [N]),
1705    {
1706        f(
1707            self.indptr.raw_storage_mut(),
1708            &mut self.indices[..],
1709            &mut self.data[..],
1710        );
1711        // This is safe as long as we do the check, if we panic
1712        // the structure can not be retrieved, as &mut self can not pass
1713        // safely across an unwind boundary
1714        self.check_compressed_structure().unwrap();
1715    }
1716}
1717
1718/// Raw functions acting directly on the compressed structure.
1719pub mod raw {
1720    use crate::indexing::SpIndex;
1721    use crate::sparse::prelude::*;
1722    use std::mem::swap;
1723
1724    /*
1725        /// Copy-convert a compressed matrix into the oppposite storage.
1726        ///
1727        /// The input compressed matrix does not need to have its indices sorted,
1728        /// but the output compressed matrix will have its indices sorted.
1729        ///
1730        /// Can be used to implement CSC <-> CSR conversions, or to implement
1731        /// same-storage (copy) transposition.
1732        ///
1733        /// # Panics
1734        ///
1735        /// Panics if indptr contains non-zero values
1736        ///
1737        /// Panics if the output slices don't match the input matrices'
1738        /// corresponding slices.
1739        pub fn convert_storage<N, I>(
1740            in_storage: super::CompressedStorage,
1741            shape: Shape,
1742            in_indtpr: &[I],
1743            in_indices: &[I],
1744            in_data: &[N],
1745            indptr: &mut [I],
1746            indices: &mut [I],
1747            data: &mut [N],
1748        ) where
1749            N: Clone,
1750            I: SpIndex,
1751        {
1752            // we're building a csmat even though the indices are not sorted,
1753            // but it's not a problem since we don't rely on this property.
1754            // FIXME: this would be better with an explicit unsorted matrix type
1755            let mat = CsMatBase {
1756                storage: in_storage,
1757                nrows: shape.0,
1758                ncols: shape.1,
1759                indptr: in_indtpr,
1760                indices: in_indices,
1761                data: in_data,
1762            };
1763
1764            convert_mat_storage(mat, indptr, indices, data);
1765        }
1766    */
1767
1768    /// Copy-convert a csmat into the oppposite storage.
1769    ///
1770    /// Can be used to implement CSC <-> CSR conversions, or to implement
1771    /// same-storage (copy) transposition.
1772    ///
1773    /// # Panics
1774    ///
1775    /// Panics if indptr contains non-zero values
1776    ///
1777    /// Panics if the output slices don't match the input matrices'
1778    /// corresponding slices.
1779    pub fn convert_mat_storage<N: Clone, I: SpIndex, Iptr: SpIndex>(
1780        mat: CsMatViewI<N, I, Iptr>,
1781        indptr: &mut [Iptr],
1782        indices: &mut [I],
1783        data: &mut [N],
1784    ) {
1785        assert_eq!(indptr.len(), mat.inner_dims() + 1);
1786        assert_eq!(indices.len(), mat.indices().len());
1787        assert_eq!(data.len(), mat.data().len());
1788
1789        assert!(indptr.iter().all(num_traits::Zero::is_zero));
1790
1791        assert!(
1792            I::try_from_usize(mat.rows()).is_some(),
1793            "Index type is not large enough to hold the number of rows requested (I::max_value={:?} vs. required {})", I::max_value(), mat.rows(),
1794        );
1795
1796        for vec in mat.outer_iterator() {
1797            for (inner_dim, _) in vec.iter() {
1798                indptr[inner_dim] += Iptr::one();
1799            }
1800        }
1801
1802        let mut cumsum = Iptr::zero();
1803        for iptr in indptr.iter_mut() {
1804            let tmp = *iptr;
1805            *iptr = cumsum;
1806            cumsum += tmp;
1807        }
1808        if let Some(last_iptr) = indptr.last() {
1809            assert_eq!(last_iptr.index(), mat.nnz());
1810        }
1811
1812        for (outer_dim, vec) in mat.outer_iterator().enumerate() {
1813            let outer_dim = I::from_usize_unchecked(outer_dim);
1814            for (inner_dim, val) in vec.iter() {
1815                let dest = indptr[inner_dim].index();
1816                data[dest] = val.clone();
1817                indices[dest] = outer_dim;
1818                indptr[inner_dim] += Iptr::one();
1819            }
1820        }
1821
1822        let mut last = Iptr::zero();
1823        for iptr in indptr.iter_mut() {
1824            swap(iptr, &mut last);
1825        }
1826    }
1827}
1828
1829impl<I, Iptr, IpStorage, IStorage, DStorage, T> std::ops::MulAssign<T>
1830    for CsMatBase<T, I, IpStorage, IStorage, DStorage, Iptr>
1831where
1832    I: SpIndex,
1833    Iptr: SpIndex,
1834    IpStorage: Deref<Target = [Iptr]>,
1835    IStorage: Deref<Target = [I]>,
1836    DStorage: DerefMut<Target = [T]>,
1837    T: std::ops::MulAssign<T> + Clone,
1838{
1839    fn mul_assign(&mut self, rhs: T) {
1840        self.data_mut()
1841            .iter_mut()
1842            .for_each(|v| v.mul_assign(rhs.clone()));
1843    }
1844}
1845
1846impl<I, Iptr, IpStorage, IStorage, DStorage, T> std::ops::DivAssign<T>
1847    for CsMatBase<T, I, IpStorage, IStorage, DStorage, Iptr>
1848where
1849    I: SpIndex,
1850    Iptr: SpIndex,
1851    IpStorage: Deref<Target = [Iptr]>,
1852    IStorage: Deref<Target = [I]>,
1853    DStorage: DerefMut<Target = [T]>,
1854    T: std::ops::DivAssign<T> + Clone,
1855{
1856    fn div_assign(&mut self, rhs: T) {
1857        self.data_mut()
1858            .iter_mut()
1859            .for_each(|v| v.div_assign(rhs.clone()));
1860    }
1861}
1862
1863impl<'a, 'b, N, I, Iptr, IpS1, IS1, DS1, IpS2, IS2, DS2>
1864    Mul<&'b CsMatBase<N, I, IpS2, IS2, DS2, Iptr>>
1865    for &'a CsMatBase<N, I, IpS1, IS1, DS1, Iptr>
1866where
1867    N: 'a + Clone + crate::MulAcc + num_traits::Zero + Default + Send + Sync,
1868    I: 'a + SpIndex,
1869    Iptr: 'a + SpIndex,
1870    IpS1: 'a + Deref<Target = [Iptr]>,
1871    IS1: 'a + Deref<Target = [I]>,
1872    DS1: 'a + Deref<Target = [N]>,
1873    IpS2: 'b + Deref<Target = [Iptr]>,
1874    IS2: 'b + Deref<Target = [I]>,
1875    DS2: 'b + Deref<Target = [N]>,
1876{
1877    type Output = CsMatI<N, I, Iptr>;
1878
1879    fn mul(
1880        self,
1881        rhs: &'b CsMatBase<N, I, IpS2, IS2, DS2, Iptr>,
1882    ) -> CsMatI<N, I, Iptr> {
1883        csmat_mul_csmat(self, rhs)
1884    }
1885}
1886
1887/// Multiply two sparse matrices.
1888///
1889/// This function is generic over `MulAcc`, and supports accumulating
1890/// into a different output type. This is not the default for `Mul`,
1891/// as type inference fails for intermediaries
1892pub fn csmat_mul_csmat<
1893    'a,
1894    'b,
1895    N,
1896    A,
1897    B,
1898    I,
1899    Iptr,
1900    IpS1,
1901    IS1,
1902    DS1,
1903    IpS2,
1904    IS2,
1905    DS2,
1906>(
1907    lhs: &'a CsMatBase<A, I, IpS1, IS1, DS1, Iptr>,
1908    rhs: &'b CsMatBase<B, I, IpS2, IS2, DS2, Iptr>,
1909) -> CsMatI<N, I, Iptr>
1910where
1911    N: 'a
1912        + Clone
1913        + crate::MulAcc<A, B>
1914        + crate::MulAcc<B, A>
1915        + num_traits::Zero
1916        + Default
1917        + Send
1918        + Sync,
1919    A: 'a + Clone + num_traits::Zero + Default + Send + Sync,
1920    B: 'a + Clone + num_traits::Zero + Default + Send + Sync,
1921    I: 'a + SpIndex,
1922    Iptr: 'a + SpIndex,
1923    IpS1: 'a + Deref<Target = [Iptr]>,
1924    IS1: 'a + Deref<Target = [I]>,
1925    DS1: 'a + Deref<Target = [A]>,
1926    IpS2: 'b + Deref<Target = [Iptr]>,
1927    IS2: 'b + Deref<Target = [I]>,
1928    DS2: 'b + Deref<Target = [B]>,
1929{
1930    match (lhs.storage(), rhs.storage()) {
1931        (CSR, CSR) => smmp::mul_csr_csr(lhs.view(), rhs.view()),
1932        (CSR, CSC) => {
1933            let rhs_csr = rhs.to_other_storage();
1934            smmp::mul_csr_csr(lhs.view(), rhs_csr.view())
1935        }
1936        (CSC, CSR) => {
1937            let rhs_csc = rhs.to_other_storage();
1938            smmp::mul_csr_csr(rhs_csc.transpose_view(), lhs.transpose_view())
1939                .transpose_into()
1940        }
1941        (CSC, CSC) => {
1942            smmp::mul_csr_csr(rhs.transpose_view(), lhs.transpose_view())
1943                .transpose_into()
1944        }
1945    }
1946}
1947
1948impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Add<&'b ArrayBase<DS2, Ix2>>
1949    for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
1950where
1951    N: 'a + Copy + Num + Default,
1952    for<'r> &'r N: Mul<Output = N>,
1953    I: 'a + SpIndex,
1954    Iptr: 'a + SpIndex,
1955    IpS: 'a + Deref<Target = [Iptr]>,
1956    IS: 'a + Deref<Target = [I]>,
1957    DS: 'a + Deref<Target = [N]>,
1958    DS2: 'b + ndarray::Data<Elem = N>,
1959{
1960    type Output = Array<N, Ix2>;
1961
1962    fn add(self, rhs: &'b ArrayBase<DS2, Ix2>) -> Array<N, Ix2> {
1963        let is_standard_layout =
1964            utils::fastest_axis(rhs.view()) == ndarray::Axis(1);
1965        let neuter_element = N::one();
1966        match (self.storage(), is_standard_layout) {
1967            (CSR, true) | (CSC, false) => binop::add_dense_mat_same_ordering(
1968                self,
1969                rhs,
1970                neuter_element,
1971                neuter_element,
1972            ),
1973            (CSR, false) | (CSC, true) => {
1974                let lhs = self.to_other_storage();
1975                binop::add_dense_mat_same_ordering(
1976                    &lhs,
1977                    rhs,
1978                    neuter_element,
1979                    neuter_element,
1980                )
1981            }
1982        }
1983    }
1984}
1985
1986impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Mul<&'b ArrayBase<DS2, Ix2>>
1987    for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
1988where
1989    N: 'a + crate::MulAcc + num_traits::Zero + Clone,
1990    I: 'a + SpIndex,
1991    Iptr: 'a + SpIndex,
1992    IpS: 'a + Deref<Target = [Iptr]>,
1993    IS: 'a + Deref<Target = [I]>,
1994    DS: 'a + Deref<Target = [N]>,
1995    DS2: 'b + ndarray::Data<Elem = N>,
1996{
1997    type Output = Array<N, Ix2>;
1998
1999    fn mul(self, rhs: &'b ArrayBase<DS2, Ix2>) -> Array<N, Ix2> {
2000        let rows = self.rows();
2001        let cols = rhs.shape()[1];
2002        // when the number of colums is small, it is more efficient
2003        // to perform the product by iterating over the columns of
2004        // the rhs, otherwise iterating by rows can take advantage of
2005        // vectorized axpy.
2006        match (self.storage(), cols >= 8) {
2007            (CSR, true) => {
2008                let mut res = Array::zeros((rows, cols));
2009                prod::csr_mulacc_dense_rowmaj(
2010                    self.view(),
2011                    rhs.view(),
2012                    res.view_mut(),
2013                );
2014                res
2015            }
2016            (CSR, false) => {
2017                let mut res = Array::zeros((rows, cols).f());
2018                prod::csr_mulacc_dense_colmaj(
2019                    self.view(),
2020                    rhs.view(),
2021                    res.view_mut(),
2022                );
2023                res
2024            }
2025            (CSC, true) => {
2026                let mut res = Array::zeros((rows, cols));
2027                prod::csc_mulacc_dense_rowmaj(
2028                    self.view(),
2029                    rhs.view(),
2030                    res.view_mut(),
2031                );
2032                res
2033            }
2034            (CSC, false) => {
2035                let mut res = Array::zeros((rows, cols).f());
2036                prod::csc_mulacc_dense_colmaj(
2037                    self.view(),
2038                    rhs.view(),
2039                    res.view_mut(),
2040                );
2041                res
2042            }
2043        }
2044    }
2045}
2046
2047impl<N, I, IpS, IS, DS, DS2> Dot<CsMatBase<N, I, IpS, IS, DS>>
2048    for ArrayBase<DS2, Ix2>
2049where
2050    N: Clone + crate::MulAcc + num_traits::Zero + std::fmt::Debug,
2051    I: SpIndex,
2052    IpS: Deref<Target = [I]>,
2053    IS: Deref<Target = [I]>,
2054    DS: Deref<Target = [N]>,
2055    DS2: ndarray::Data<Elem = N>,
2056{
2057    type Output = Array<N, Ix2>;
2058
2059    fn dot(&self, rhs: &CsMatBase<N, I, IpS, IS, DS>) -> Array<N, Ix2> {
2060        let rhs_t = rhs.transpose_view();
2061        let lhs_t = self.t();
2062
2063        let rows = rhs_t.rows();
2064        let cols = lhs_t.ncols();
2065        // when the number of colums is small, it is more efficient
2066        // to perform the product by iterating over the columns of
2067        // the rhs, otherwise iterating by rows can take advantage of
2068        // vectorized axpy.
2069        let rres = match (rhs_t.storage(), cols >= 8) {
2070            (CSR, true) => {
2071                let mut res = Array::zeros((rows, cols));
2072                prod::csr_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
2073                res.reversed_axes()
2074            }
2075            (CSR, false) => {
2076                let mut res = Array::zeros((rows, cols).f());
2077                prod::csr_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
2078                res.reversed_axes()
2079            }
2080            (CSC, true) => {
2081                let mut res = Array::zeros((rows, cols));
2082                prod::csc_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
2083                res.reversed_axes()
2084            }
2085            (CSC, false) => {
2086                let mut res = Array::zeros((rows, cols).f());
2087                prod::csc_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
2088                res.reversed_axes()
2089            }
2090        };
2091
2092        assert_eq!(self.shape()[0], rres.shape()[0]);
2093        assert_eq!(rhs.cols(), rres.shape()[1]);
2094        rres
2095    }
2096}
2097
2098impl<N, I, Iptr, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix2>>
2099    for CsMatBase<N, I, IpS, IS, DS, Iptr>
2100where
2101    N: Clone + crate::MulAcc + num_traits::Zero,
2102    I: SpIndex,
2103    Iptr: SpIndex,
2104    IpS: Deref<Target = [Iptr]>,
2105    IS: Deref<Target = [I]>,
2106    DS: Deref<Target = [N]>,
2107    DS2: ndarray::Data<Elem = N>,
2108{
2109    type Output = Array<N, Ix2>;
2110
2111    fn dot(&self, rhs: &ArrayBase<DS2, Ix2>) -> Array<N, Ix2> {
2112        Mul::mul(self, rhs)
2113    }
2114}
2115
2116impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Mul<&'b ArrayBase<DS2, Ix1>>
2117    for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
2118where
2119    N: 'a + Clone + crate::MulAcc + num_traits::Zero,
2120    I: 'a + SpIndex,
2121    Iptr: 'a + SpIndex,
2122    IpS: 'a + Deref<Target = [Iptr]>,
2123    IS: 'a + Deref<Target = [I]>,
2124    DS: 'a + Deref<Target = [N]>,
2125    DS2: 'b + ndarray::Data<Elem = N>,
2126{
2127    type Output = Array<N, Ix1>;
2128
2129    fn mul(self, rhs: &'b ArrayBase<DS2, Ix1>) -> Array<N, Ix1> {
2130        let rows = self.rows();
2131        let cols = rhs.shape()[0];
2132        #[allow(deprecated)]
2133        let rhs_reshape = rhs.view().into_shape((cols, 1)).unwrap();
2134        let mut res = Array::zeros(rows);
2135        {
2136            #[allow(deprecated)]
2137            let res_reshape = res.view_mut().into_shape((rows, 1)).unwrap();
2138            match self.storage() {
2139                CSR => {
2140                    prod::csr_mulacc_dense_colmaj(
2141                        self.view(),
2142                        rhs_reshape,
2143                        res_reshape,
2144                    );
2145                }
2146                CSC => {
2147                    prod::csc_mulacc_dense_colmaj(
2148                        self.view(),
2149                        rhs_reshape,
2150                        res_reshape,
2151                    );
2152                }
2153            }
2154        }
2155        res
2156    }
2157}
2158
2159impl<N, I, Iptr, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix1>>
2160    for CsMatBase<N, I, IpS, IS, DS, Iptr>
2161where
2162    N: Clone + crate::MulAcc + num_traits::Zero,
2163    I: SpIndex,
2164    Iptr: SpIndex,
2165    IpS: Deref<Target = [Iptr]>,
2166    IS: Deref<Target = [I]>,
2167    DS: Deref<Target = [N]>,
2168    DS2: ndarray::Data<Elem = N>,
2169{
2170    type Output = Array<N, Ix1>;
2171
2172    fn dot(&self, rhs: &ArrayBase<DS2, Ix1>) -> Array<N, Ix1> {
2173        Mul::mul(self, rhs)
2174    }
2175}
2176
2177impl<N, I, Iptr, IpS, IS, DS> Index<[usize; 2]>
2178    for CsMatBase<N, I, IpS, IS, DS, Iptr>
2179where
2180    I: SpIndex,
2181    Iptr: SpIndex,
2182    IpS: Deref<Target = [Iptr]>,
2183    IS: Deref<Target = [I]>,
2184    DS: Deref<Target = [N]>,
2185{
2186    type Output = N;
2187
2188    fn index(&self, index: [usize; 2]) -> &N {
2189        let i = index[0];
2190        let j = index[1];
2191        self.get(i, j).unwrap()
2192    }
2193}
2194
2195impl<N, I, Iptr, IpS, IS, DS> IndexMut<[usize; 2]>
2196    for CsMatBase<N, I, IpS, IS, DS, Iptr>
2197where
2198    I: SpIndex,
2199    Iptr: SpIndex,
2200    IpS: Deref<Target = [Iptr]>,
2201    IS: Deref<Target = [I]>,
2202    DS: DerefMut<Target = [N]>,
2203{
2204    fn index_mut(&mut self, index: [usize; 2]) -> &mut N {
2205        let i = index[0];
2206        let j = index[1];
2207        self.get_mut(i, j).unwrap()
2208    }
2209}
2210
2211impl<N, I, Iptr, IpS, IS, DS> Index<NnzIndex>
2212    for CsMatBase<N, I, IpS, IS, DS, Iptr>
2213where
2214    I: SpIndex,
2215    Iptr: SpIndex,
2216    IpS: Deref<Target = [Iptr]>,
2217    IS: Deref<Target = [I]>,
2218    DS: Deref<Target = [N]>,
2219{
2220    type Output = N;
2221
2222    fn index(&self, index: NnzIndex) -> &N {
2223        let NnzIndex(i) = index;
2224        self.data().get(i).unwrap()
2225    }
2226}
2227
2228impl<N, I, Iptr, IpS, IS, DS> IndexMut<NnzIndex>
2229    for CsMatBase<N, I, IpS, IS, DS, Iptr>
2230where
2231    I: SpIndex,
2232    Iptr: SpIndex,
2233    IpS: Deref<Target = [Iptr]>,
2234    IS: Deref<Target = [I]>,
2235    DS: DerefMut<Target = [N]>,
2236{
2237    fn index_mut(&mut self, index: NnzIndex) -> &mut N {
2238        let NnzIndex(i) = index;
2239        self.data_mut().get_mut(i).unwrap()
2240    }
2241}
2242
2243impl<N, I, Iptr, IpS, IS, DS> SparseMat for CsMatBase<N, I, IpS, IS, DS, Iptr>
2244where
2245    I: SpIndex,
2246    Iptr: SpIndex,
2247    IpS: Deref<Target = [Iptr]>,
2248    IS: Deref<Target = [I]>,
2249    DS: Deref<Target = [N]>,
2250{
2251    fn rows(&self) -> usize {
2252        self.rows()
2253    }
2254
2255    fn cols(&self) -> usize {
2256        self.cols()
2257    }
2258
2259    fn nnz(&self) -> usize {
2260        self.nnz()
2261    }
2262}
2263
2264impl<'a, N, I, Iptr, IpS, IS, DS> SparseMat
2265    for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
2266where
2267    I: 'a + SpIndex,
2268    Iptr: 'a + SpIndex,
2269    N: 'a,
2270    IpS: Deref<Target = [Iptr]>,
2271    IS: Deref<Target = [I]>,
2272    DS: Deref<Target = [N]>,
2273{
2274    fn rows(&self) -> usize {
2275        (*self).rows()
2276    }
2277
2278    fn cols(&self) -> usize {
2279        (*self).cols()
2280    }
2281
2282    fn nnz(&self) -> usize {
2283        (*self).nnz()
2284    }
2285}
2286
2287impl<'a, N, I, IpS, IS, DS, Iptr> IntoIterator
2288    for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
2289where
2290    I: 'a + SpIndex,
2291    Iptr: 'a + SpIndex,
2292    N: 'a,
2293    IpS: Deref<Target = [Iptr]>,
2294    IS: Deref<Target = [I]>,
2295    DS: Deref<Target = [N]>,
2296{
2297    type Item = (&'a N, (I, I));
2298    type IntoIter = CsIter<'a, N, I, Iptr>;
2299    fn into_iter(self) -> Self::IntoIter {
2300        self.iter()
2301    }
2302}
2303
2304impl<'a, N, I, Iptr> IntoIterator for CsMatViewI<'a, N, I, Iptr>
2305where
2306    I: 'a + SpIndex,
2307    Iptr: 'a + SpIndex,
2308    N: 'a,
2309{
2310    type Item = (&'a N, (I, I));
2311    type IntoIter = CsIter<'a, N, I, Iptr>;
2312    fn into_iter(self) -> Self::IntoIter {
2313        self.iter_rbr()
2314    }
2315}
2316
2317#[cfg(test)]
2318mod test {
2319    use super::CompressedStorage::CSR;
2320    use crate::errors::StructureErrorKind;
2321    use crate::sparse::{CsMat, CsMatI, CsMatView, CsVec};
2322    use crate::test_data::{mat1, mat1_csc, mat1_times_2};
2323    use ndarray::{arr2, Array};
2324
2325    #[test]
2326    fn test_copy() {
2327        let m = mat1();
2328        let view1 = m.view();
2329        let view2 = view1; // this shouldn't move
2330        assert_eq!(view1, view2);
2331    }
2332
2333    #[test]
2334    fn test_new_csr_success() {
2335        let indptr_ok: &[usize] = &[0, 1, 2, 3];
2336        let indices_ok: &[usize] = &[0, 1, 2];
2337        let data_ok: &[f64] = &[1., 1., 1.];
2338        let m = CsMatView::try_new((3, 3), indptr_ok, indices_ok, data_ok);
2339        assert!(m.is_ok());
2340    }
2341
2342    #[test]
2343    #[should_panic]
2344    fn test_new_csr_bad_indptr_length() {
2345        let indptr_fail1: &[usize] = &[0, 1, 2];
2346        let indices_ok: &[usize] = &[0, 1, 2];
2347        let data_ok: &[f64] = &[1., 1., 1.];
2348        let res = CsMatView::try_new((3, 3), indptr_fail1, indices_ok, data_ok);
2349        res.unwrap(); // unreachable
2350    }
2351
2352    #[test]
2353    #[should_panic]
2354    fn test_new_csr_out_of_bounds_index() {
2355        let indptr_ok: &[usize] = &[0, 1, 2, 3];
2356        let data_ok: &[f64] = &[1., 1., 1.];
2357        let indices_fail2: &[usize] = &[0, 1, 4];
2358        let res = CsMatView::try_new((3, 3), indptr_ok, indices_fail2, data_ok);
2359        res.unwrap(); //unreachable
2360    }
2361
2362    #[test]
2363    #[should_panic]
2364    fn test_new_csr_bad_nnz_count() {
2365        let indices_ok: &[usize] = &[0, 1, 2];
2366        let data_ok: &[f64] = &[1., 1., 1.];
2367        let indptr_fail2: &[usize] = &[0, 1, 2, 4];
2368        let res = CsMatView::try_new((3, 3), indptr_fail2, indices_ok, data_ok);
2369        res.unwrap(); //unreachable
2370    }
2371
2372    #[test]
2373    #[should_panic]
2374    fn test_new_csr_data_indices_mismatch1() {
2375        let indptr_ok: &[usize] = &[0, 1, 2, 3];
2376        let data_ok: &[f64] = &[1., 1., 1.];
2377        let indices_fail1: &[usize] = &[0, 1];
2378        let res = CsMatView::try_new((3, 3), indptr_ok, indices_fail1, data_ok);
2379        res.unwrap(); //unreachable
2380    }
2381
2382    #[test]
2383    #[should_panic]
2384    fn test_new_csr_data_indices_mismatch2() {
2385        let indptr_ok: &[usize] = &[0, 1, 2, 3];
2386        let indices_ok: &[usize] = &[0, 1, 2];
2387        let data_fail1: &[f64] = &[1., 1., 1., 1.];
2388        let res = CsMatView::try_new((3, 3), indptr_ok, indices_ok, data_fail1);
2389        res.unwrap(); //unreachable
2390    }
2391
2392    #[test]
2393    #[should_panic]
2394    fn test_new_csr_data_indices_mismatch3() {
2395        let indptr_ok: &[usize] = &[0, 1, 2, 3];
2396        let indices_ok: &[usize] = &[0, 1, 2];
2397        let data_fail2: &[f64] = &[1., 1.];
2398        let res = CsMatView::try_new((3, 3), indptr_ok, indices_ok, data_fail2);
2399        res.unwrap(); //unreachable
2400    }
2401
2402    #[test]
2403    fn test_new_csr_fails() {
2404        let indices_ok: &[usize] = &[0, 1, 2];
2405        let data_ok: &[f64] = &[1., 1., 1.];
2406        let indptr_fail3: &[usize] = &[0, 2, 1, 3];
2407        assert_eq!(
2408            CsMatView::try_new((3, 3), indptr_fail3, indices_ok, data_ok)
2409                .unwrap_err()
2410                .3
2411                .kind(),
2412            StructureErrorKind::Unsorted
2413        );
2414    }
2415
2416    #[test]
2417    fn test_new_csr_fail_indices_ordering() {
2418        let indptr: &[usize] = &[0, 2, 4, 5, 6, 7];
2419        // good indices would be [2, 3, 3, 4, 2, 1, 3];
2420        let indices: &[usize] = &[3, 2, 3, 4, 2, 1, 3];
2421        let data: &[f64] = &[
2422            0.35310881, 0.42380633, 0.28035896, 0.58082095, 0.53350123,
2423            0.88132896, 0.72527863,
2424        ];
2425        assert_eq!(
2426            CsMatView::try_new((5, 5), indptr, indices, data)
2427                .unwrap_err()
2428                .3
2429                .kind(),
2430            StructureErrorKind::Unsorted
2431        );
2432    }
2433
2434    #[test]
2435    fn test_new_csr_csc_success() {
2436        let indptr_ok: &[usize] = &[0, 2, 5, 6];
2437        let indices_ok: &[usize] = &[2, 3, 1, 2, 3, 3];
2438        let data_ok: &[f64] = &[
2439            0.05734571, 0.15543348, 0.75628258, 0.83054515, 0.71851547,
2440            0.46202352,
2441        ];
2442        assert!(
2443            CsMatView::try_new((3, 4), indptr_ok, indices_ok, data_ok).is_ok()
2444        );
2445        assert!(
2446            CsMatView::try_new_csc((4, 3), indptr_ok, indices_ok, data_ok)
2447                .is_ok()
2448        );
2449    }
2450
2451    #[test]
2452    #[should_panic]
2453    fn test_new_csc_bad_indptr_length() {
2454        let indptr_ok: &[usize] = &[0, 2, 5, 6];
2455        let indices_ok: &[usize] = &[2, 3, 1, 2, 3, 3];
2456        let data_ok: &[f64] = &[
2457            0.05734571, 0.15543348, 0.75628258, 0.83054515, 0.71851547,
2458            0.46202352,
2459        ];
2460        let res =
2461            CsMatView::try_new_csc((3, 4), indptr_ok, indices_ok, data_ok);
2462        res.unwrap(); //unreachable
2463    }
2464
2465    #[test]
2466    fn test_new_csr_vec_borrowed() {
2467        let indptr_ok = vec![0, 1, 2, 3];
2468        let indices_ok = vec![0, 1, 2];
2469        let data_ok: Vec<f64> = vec![1., 1., 1.];
2470        assert!(
2471            CsMatView::try_new((3, 3), &indptr_ok, &indices_ok, &data_ok)
2472                .is_ok()
2473        );
2474    }
2475
2476    #[test]
2477    fn test_new_csr_vec_owned() {
2478        let indptr_ok = vec![0, 1, 2, 3];
2479        let indices_ok = vec![0, 1, 2];
2480        let data_ok: Vec<f64> = vec![1., 1., 1.];
2481        assert!(CsMat::new_from_unsorted(
2482            (3, 3),
2483            indptr_ok,
2484            indices_ok,
2485            data_ok
2486        )
2487        .is_ok());
2488    }
2489
2490    #[test]
2491    fn test_csr_from_dense() {
2492        let m = Array::eye(3);
2493        let m_sparse = CsMat::csr_from_dense(m.view(), 0.);
2494
2495        assert_eq!(m_sparse, CsMat::eye(3));
2496
2497        let m = arr2(&[
2498            [1., 0., 2., 1e-7, 1.],
2499            [0., 0., 0., 1., 0.],
2500            [3., 0., 1., 0., 0.],
2501        ]);
2502        let m_sparse = CsMat::csr_from_dense(m.view(), 1e-5);
2503
2504        let expected_output = CsMat::new(
2505            (3, 5),
2506            vec![0, 3, 4, 6],
2507            vec![0, 2, 4, 3, 0, 2],
2508            vec![1., 2., 1., 1., 3., 1.],
2509        );
2510
2511        assert_eq!(m_sparse, expected_output);
2512    }
2513
2514    #[test]
2515    fn test_csc_from_dense() {
2516        let m = Array::eye(3);
2517        let m_sparse = CsMat::csc_from_dense(m.view(), 0.);
2518
2519        assert_eq!(m_sparse, CsMat::eye_csc(3));
2520
2521        let m = arr2(&[
2522            [1., 0., 2., 1e-7, 1.],
2523            [0., 0., 0., 1., 0.],
2524            [3., 0., 1., 0., 0.],
2525        ]);
2526        let m_sparse = CsMat::csc_from_dense(m.view(), 1e-5);
2527
2528        let expected_output = CsMat::new_csc(
2529            (3, 5),
2530            vec![0, 2, 2, 4, 5, 6],
2531            vec![0, 2, 0, 2, 1, 0],
2532            vec![1., 3., 2., 1., 1., 1.],
2533        );
2534
2535        assert_eq!(m_sparse, expected_output);
2536    }
2537
2538    #[test]
2539    fn owned_csr_unsorted_indices() {
2540        let indptr = vec![0, 3, 3, 5, 6, 7];
2541        let indices_sorted = &[1, 2, 3, 2, 3, 4, 4];
2542        let indices_shuffled = vec![1, 3, 2, 2, 3, 4, 4];
2543        let mut data: Vec<i32> = (0..7).collect();
2544        let m = CsMat::new_from_unsorted(
2545            (5, 5),
2546            indptr,
2547            indices_shuffled,
2548            data.clone(),
2549        )
2550        .unwrap();
2551        assert_eq!(m.indices(), indices_sorted);
2552        data.swap(1, 2);
2553        assert_eq!(m.data(), &data[..]);
2554    }
2555
2556    #[test]
2557    fn new_csr_with_empty_row() {
2558        let indptr: &[usize] = &[0, 3, 3, 5, 6, 7];
2559        let indices: &[usize] = &[1, 2, 3, 2, 3, 4, 4];
2560        let data: &[f64] = &[
2561            0.75672424, 0.1649078, 0.30140296, 0.10358244, 0.6283315,
2562            0.39244208, 0.57202407,
2563        ];
2564        assert!(CsMatView::try_new((5, 5), indptr, indices, data).is_ok());
2565    }
2566
2567    #[test]
2568    fn csr_to_csc() {
2569        let a = mat1();
2570        let a_csc_ground_truth = mat1_csc();
2571        let a_csc = a.to_other_storage();
2572        assert_eq!(a_csc, a_csc_ground_truth);
2573    }
2574
2575    #[test]
2576    fn test_self_smul() {
2577        let mut a = mat1();
2578        a.scale(2.);
2579        let c_true = mat1_times_2();
2580        assert_eq!(a.indptr(), c_true.indptr());
2581        assert_eq!(a.indices(), c_true.indices());
2582        assert_eq!(a.data(), c_true.data());
2583    }
2584
2585    #[test]
2586    fn outer_block_iter() {
2587        let mat: CsMat<f64> = CsMat::eye(11);
2588        let mut block_iter = mat.outer_block_iter(3);
2589        assert_eq!(block_iter.next().unwrap().rows(), 3);
2590        assert_eq!(block_iter.next().unwrap().rows(), 3);
2591        assert_eq!(block_iter.next().unwrap().rows(), 3);
2592        assert_eq!(block_iter.next().unwrap().rows(), 2);
2593        assert_eq!(block_iter.next(), None);
2594
2595        let mut block_iter = mat.outer_block_iter(4);
2596        assert_eq!(block_iter.next().unwrap().cols(), 11);
2597        block_iter.next().unwrap();
2598        block_iter.next().unwrap();
2599        assert_eq!(block_iter.next(), None);
2600    }
2601
2602    #[test]
2603    fn middle_outer_views() {
2604        let size = 11;
2605        let csr: CsMat<f64> = CsMat::eye(size);
2606        #[allow(deprecated)]
2607        let v = csr.view().middle_outer_views(1, 3);
2608        assert_eq!(v.shape(), (3, size));
2609        assert_eq!(v.nnz(), 3);
2610
2611        let csc = csr.to_other_storage();
2612        #[allow(deprecated)]
2613        let v = csc.view().middle_outer_views(1, 3);
2614        assert_eq!(v.shape(), (size, 3));
2615        assert_eq!(v.nnz(), 3);
2616    }
2617
2618    #[test]
2619    fn nnz_index() {
2620        let mat: CsMat<f64> = CsMat::eye(11);
2621
2622        assert_eq!(mat.nnz_index(2, 3), None);
2623        assert_eq!(mat.nnz_index(5, 7), None);
2624        assert_eq!(mat.nnz_index(0, 11), None);
2625        assert_eq!(mat.nnz_index(0, 0), Some(super::NnzIndex(0)));
2626        assert_eq!(mat.nnz_index(7, 7), Some(super::NnzIndex(7)));
2627        assert_eq!(mat.nnz_index(10, 10), Some(super::NnzIndex(10)));
2628
2629        let index = mat.nnz_index(8, 8).unwrap();
2630        assert_eq!(mat[index], 1.);
2631        let mut mat = mat;
2632        mat[index] = 2.;
2633        assert_eq!(mat[index], 2.);
2634    }
2635
2636    #[test]
2637    fn index() {
2638        // | 0 2 0 |
2639        // | 1 0 0 |
2640        // | 0 3 4 |
2641        let mat = CsMat::new_csc(
2642            (3, 3),
2643            vec![0, 1, 3, 4],
2644            vec![1, 0, 2, 2],
2645            vec![1., 2., 3., 4.],
2646        );
2647        assert_eq!(mat[[1, 0]], 1.);
2648        assert_eq!(mat[[0, 1]], 2.);
2649        assert_eq!(mat[[2, 1]], 3.);
2650        assert_eq!(mat[[2, 2]], 4.);
2651        assert_eq!(mat.get(0, 0), None);
2652        assert_eq!(mat.get(4, 4), None);
2653    }
2654
2655    #[test]
2656    fn get_mut() {
2657        // | 0 1 0 |
2658        // | 1 0 0 |
2659        // | 0 1 1 |
2660        let mut mat = CsMat::new_csc(
2661            (3, 3),
2662            vec![0, 1, 3, 4],
2663            vec![1, 0, 2, 2],
2664            vec![1.; 4],
2665        );
2666
2667        *mat.get_mut(2, 1).unwrap() = 3.;
2668
2669        let exp = CsMat::new_csc(
2670            (3, 3),
2671            vec![0, 1, 3, 4],
2672            vec![1, 0, 2, 2],
2673            vec![1., 1., 3., 1.],
2674        );
2675
2676        assert_eq!(mat, exp);
2677
2678        mat[[2, 2]] = 5.;
2679        let exp = CsMat::new_csc(
2680            (3, 3),
2681            vec![0, 1, 3, 4],
2682            vec![1, 0, 2, 2],
2683            vec![1., 1., 3., 5.],
2684        );
2685
2686        assert_eq!(mat, exp);
2687    }
2688
2689    #[test]
2690    fn map() {
2691        // | 0 1 0 |
2692        // | 1 0 0 |
2693        // | 0 1 1 |
2694        let mat = CsMat::new_csc(
2695            (3, 3),
2696            vec![0, 1, 3, 4],
2697            vec![1, 0, 2, 2],
2698            vec![1.; 4],
2699        );
2700
2701        let mut res = mat.map(|&x| x + 2.);
2702        let expected = CsMat::new_csc(
2703            (3, 3),
2704            vec![0, 1, 3, 4],
2705            vec![1, 0, 2, 2],
2706            vec![3.; 4],
2707        );
2708        assert_eq!(res, expected);
2709
2710        res.map_inplace(|&x| x / 3.);
2711        assert_eq!(res, mat);
2712    }
2713
2714    #[test]
2715    fn insert() {
2716        // | 0 1 0 |
2717        // | 1 0 0 |
2718        // | 0 1 1 |
2719        let mut mat = CsMat::empty(CSR, 0);
2720        mat.reserve_outer_dim(3);
2721        mat.reserve_nnz(4);
2722        // exercise the fast and easy path where the elements are added
2723        // in row order for a CSR matrix
2724        mat.insert(0, 1, 1.);
2725        mat.insert(1, 0, 1.);
2726        mat.insert(2, 1, 1.);
2727        mat.insert(2, 2, 1.);
2728
2729        let expected =
2730            CsMat::new((3, 3), vec![0, 1, 2, 4], vec![1, 0, 1, 2], vec![1.; 4]);
2731        assert_eq!(mat, expected);
2732
2733        // | 2 1 0 |
2734        // | 1 0 0 |
2735        // | 0 1 1 |
2736        // exercise adding inside an already formed row (ie a search needs
2737        // to be performed)
2738        mat.insert(0, 0, 2.);
2739        let expected = CsMat::new(
2740            (3, 3),
2741            vec![0, 2, 3, 5],
2742            vec![0, 1, 0, 1, 2],
2743            vec![2., 1., 1., 1., 1.],
2744        );
2745        assert_eq!(mat, expected);
2746
2747        // | 2 1 0 |
2748        // | 3 0 0 |
2749        // | 0 1 1 |
2750        // exercise the fact that inserting in an existing element
2751        // should change this element's value
2752        mat.insert(1, 0, 3.);
2753        let expected = CsMat::new(
2754            (3, 3),
2755            vec![0, 2, 3, 5],
2756            vec![0, 1, 0, 1, 2],
2757            vec![2., 1., 3., 1., 1.],
2758        );
2759        assert_eq!(mat, expected);
2760    }
2761
2762    #[test]
2763    /// Non-regression test for https://github.com/vbarrielle/sprs/issues/129
2764    fn bug_129() {
2765        let mut mat = CsMat::zero((3, 100));
2766        mat.insert(2, 3, 42);
2767        let mut iter = mat.iter();
2768        assert_eq!(iter.next(), Some((&42, (2, 3))));
2769        assert_eq!(iter.next(), None);
2770    }
2771
2772    #[test]
2773    fn iter_mut() {
2774        // | 0 1 0 |
2775        // | 1 0 0 |
2776        // | 0 1 1 |
2777        let mut mat = CsMat::new_csc(
2778            (3, 3),
2779            vec![0, 1, 3, 4],
2780            vec![1, 0, 2, 2],
2781            vec![1.; 4],
2782        );
2783
2784        for mut col_vec in mat.outer_iterator_mut() {
2785            for (row_ind, val) in col_vec.iter_mut() {
2786                *val = row_ind as f64 + 1.;
2787            }
2788        }
2789
2790        let expected = CsMat::new_csc(
2791            (3, 3),
2792            vec![0, 1, 3, 4],
2793            vec![1, 0, 2, 2],
2794            vec![2., 1., 3., 3.],
2795        );
2796        assert_eq!(mat, expected);
2797    }
2798
2799    #[test]
2800    #[should_panic]
2801    fn modify_fail() {
2802        let mut mat = CsMat::new_csc(
2803            (3, 3),
2804            vec![0, 1, 3, 4],
2805            vec![1, 0, 2, 2],
2806            vec![1.; 4],
2807        );
2808
2809        // we panic because we forget to modify the last index, which gets
2810        // pushed in the same col as its predecessor, yet has the same value
2811        mat.modify(|indptr, indices, data| {
2812            indptr[1] = 2;
2813            indptr[2] = 4;
2814            indices[0] = 0;
2815            indices[1] = 1;
2816            data[2] = 2.;
2817        });
2818    }
2819
2820    #[test]
2821    fn convert_types() {
2822        let mat: CsMat<f32> = CsMat::eye(3);
2823        let mat_: CsMatI<f64, u32> = mat.to_other_types();
2824        assert_eq!(mat_.indptr(), &[0, 1, 2, 3][..]);
2825
2826        let mat = CsMatI::new_csc(
2827            (3, 3),
2828            vec![0u32, 1, 3, 4],
2829            vec![1, 0, 2, 2],
2830            vec![1.; 4],
2831        );
2832        let mat_: CsMatI<f32, usize, u32> = mat.to_other_types();
2833        assert_eq!(mat_.indptr(), &[0, 1, 3, 4][..]);
2834        assert_eq!(mat_.data(), &[1.0f32, 1., 1., 1.]);
2835    }
2836
2837    #[test]
2838    fn iter() {
2839        let mat = CsMat::new_csc(
2840            (3, 3),
2841            vec![0, 1, 3, 4],
2842            vec![1, 0, 2, 2],
2843            vec![1.; 4],
2844        );
2845        let mut iter = mat.iter();
2846        assert_eq!(iter.next(), Some((&1., (1, 0))));
2847        assert_eq!(iter.next(), Some((&1., (0, 1))));
2848        assert_eq!(iter.next(), Some((&1., (2, 1))));
2849        assert_eq!(iter.next(), Some((&1., (2, 2))));
2850        assert_eq!(iter.next(), None);
2851    }
2852
2853    #[test]
2854    fn degrees() {
2855        // | 1 0 0 3 1 |
2856        // | 0 2 0 0 0 |
2857        // | 0 0 0 1 0 |
2858        // | 3 0 1 1 0 |
2859        // | 1 0 0 0 1 |
2860        let mat = CsMat::new_csc(
2861            (5, 5),
2862            vec![0, 3, 4, 5, 8, 10],
2863            vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
2864            vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
2865        );
2866
2867        let degrees = mat.degrees();
2868        assert_eq!(&degrees, &[2, 0, 1, 2, 1],);
2869    }
2870
2871    #[test]
2872    fn diag() {
2873        // | 1 0 0 3 1 |
2874        // | 0 2 0 0 0 |
2875        // | 0 0 0 1 0 |
2876        // | 3 0 1 1 0 |
2877        // | 1 0 0 0 1 |
2878        let mat = CsMat::new_csc(
2879            (5, 5),
2880            vec![0, 3, 4, 5, 8, 10],
2881            vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
2882            vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
2883        );
2884
2885        let diag = mat.diag();
2886        let expected = CsVec::new(5, vec![0, 1, 3, 4], vec![1, 2, 1, 1]);
2887        assert_eq!(diag, expected);
2888
2889        let mut iter = mat.diag_iter();
2890        assert_eq!(iter.next().unwrap(), Some(&1));
2891        assert_eq!(iter.next().unwrap(), Some(&2));
2892        assert_eq!(iter.next().unwrap(), None);
2893        assert_eq!(iter.next().unwrap(), Some(&1));
2894        assert_eq!(iter.next().unwrap(), Some(&1));
2895        assert_eq!(iter.next(), None);
2896    }
2897
2898    #[test]
2899    #[cfg_attr(miri, ignore)]
2900    fn diag_mut() {
2901        // | 1 0 0 3 1 |
2902        // | 0 2 0 0 0 |
2903        // | 0 0 0 1 0 |
2904        // | 3 0 1 1 0 |
2905        // | 1 0 0 0 1 |
2906        let mut mat = CsMat::new_csc(
2907            (5, 5),
2908            vec![0, 3, 4, 5, 8, 10],
2909            vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
2910            vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
2911        );
2912
2913        let mut diags = mat.diag_iter_mut().collect::<Vec<_>>();
2914        diags[4].as_mut().map(|x| **x *= 3);
2915        diags[3].as_mut().map(|x| **x -= 4);
2916        let expected = CsVec::new(5, vec![0, 1, 3, 4], vec![1, 2, -3, 3]);
2917        assert_eq!(mat.diag(), expected);
2918    }
2919
2920    #[test]
2921    fn diag_rectangular() {
2922        // | 1 0 0 3 1 3|
2923        // | 0 2 0 0 0 0|
2924        // | 0 0 0 1 0 1|
2925        // | 3 0 1 1 0 0|
2926        // | 1 0 0 0 1 0|
2927        let mat = CsMat::new_csc(
2928            (5, 6),
2929            vec![0, 3, 4, 5, 8, 10, 12],
2930            vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4, 0, 2],
2931            vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1, 3, 1],
2932        );
2933
2934        let diag = mat.diag();
2935        let expected = CsVec::new(5, vec![0, 1, 3, 4], vec![1, 2, 1, 1]);
2936        assert_eq!(diag, expected);
2937
2938        let mut iter = mat.diag_iter();
2939        assert_eq!(iter.next().unwrap(), Some(&1));
2940        assert_eq!(iter.next().unwrap(), Some(&2));
2941        assert_eq!(iter.next().unwrap(), None);
2942        assert_eq!(iter.next().unwrap(), Some(&1));
2943        assert_eq!(iter.next().unwrap(), Some(&1));
2944        assert_eq!(iter.next(), None);
2945    }
2946
2947    #[test]
2948    fn onehot_zero() {
2949        let onehot: CsMat<f32> = CsMat::zero((3, 3)).to_inner_onehot();
2950
2951        assert!(onehot.is_csr());
2952        assert_eq!(CsMat::zero((3, 3)), onehot);
2953    }
2954
2955    #[test]
2956    fn onehot_eye() {
2957        let mat = CsMat::new(
2958            (2, 2),
2959            vec![0, 2, 4],
2960            vec![0, 1, 0, 1],
2961            vec![2.0, 0.0, 0.0, 2.0],
2962        );
2963
2964        let onehot = mat.to_inner_onehot();
2965
2966        assert!(onehot.is_csr());
2967        assert_eq!(CsMat::eye(2), onehot);
2968    }
2969
2970    #[test]
2971    fn onehot_sparse_csc() {
2972        let mat = CsMat::new_csc((2, 3), vec![0, 0, 1, 1], vec![1], vec![2.0]);
2973
2974        let onehot = mat.to_inner_onehot();
2975
2976        let expected =
2977            CsMat::new_csc((2, 3), vec![0, 0, 1, 1], vec![1], vec![1.0]);
2978
2979        assert!(onehot.is_csc());
2980        assert_eq!(expected, onehot);
2981    }
2982
2983    #[test]
2984    fn onehot_ignores_nan() {
2985        let mat = CsMat::new(
2986            (2, 2),
2987            vec![0, 2, 3],
2988            vec![0, 1, 1],
2989            vec![2.0, std::f64::NAN, 2.0],
2990        );
2991
2992        let onehot = mat.to_inner_onehot();
2993
2994        assert!(onehot.is_csr());
2995        assert_eq!(CsMat::eye(2), onehot);
2996    }
2997
2998    #[test]
2999    fn mul_assign() {
3000        let mut m1 = crate::TriMat::new((6, 9));
3001        m1.add_triplet(1, 1, 8_i32);
3002        m1.add_triplet(1, 2, 7);
3003        m1.add_triplet(0, 1, 6);
3004        m1.add_triplet(0, 8, 5);
3005        m1.add_triplet(4, 2, 4);
3006        let mut m1: CsMat<_> = m1.to_csr();
3007
3008        m1 *= 2;
3009        for (&v, (j, i)) in m1.iter() {
3010            match (j, i) {
3011                (1, 1) => assert_eq!(v, 16),
3012                (1, 2) => assert_eq!(v, 14),
3013                (0, 1) => assert_eq!(v, 12),
3014                (0, 8) => assert_eq!(v, 10),
3015                (4, 2) => assert_eq!(v, 8),
3016                _ => panic!(),
3017            }
3018        }
3019    }
3020
3021    #[test]
3022    fn div_assign() {
3023        let mut m1 = crate::TriMat::new((6, 9));
3024        m1.add_triplet(1, 1, 8_i32);
3025        m1.add_triplet(1, 2, 7);
3026        m1.add_triplet(0, 1, 6);
3027        m1.add_triplet(0, 8, 5);
3028        m1.add_triplet(4, 2, 4);
3029        let mut m1: CsMat<_> = m1.to_csr();
3030
3031        m1 /= 2;
3032        for (&v, (j, i)) in m1.iter() {
3033            match (j, i) {
3034                (1, 1) => assert_eq!(v, 4),
3035                (1, 2) => assert_eq!(v, 3),
3036                (0, 1) => assert_eq!(v, 3),
3037                (0, 8) => assert_eq!(v, 2),
3038                (4, 2) => assert_eq!(v, 2),
3039                _ => panic!(),
3040            }
3041        }
3042    }
3043
3044    #[test]
3045    fn issue_99() {
3046        let a = crate::TriMat::<i32>::new((10, 1)).to_csc::<usize>();
3047        let b = crate::TriMat::<i32>::new((1, 9)).to_csr();
3048        let _c = &a * &b;
3049    }
3050}
3051
3052#[cfg(feature = "approx")]
3053mod approx_impls {
3054    use super::*;
3055    use approx::*;
3056
3057    impl<N, I, Iptr, IS1, DS1, ISptr1, IS2, ISptr2, DS2>
3058        AbsDiffEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>
3059        for CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>
3060    where
3061        I: SpIndex,
3062        Iptr: SpIndex,
3063        CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>:
3064            std::cmp::PartialEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>,
3065        IS1: Deref<Target = [I]>,
3066        IS2: Deref<Target = [I]>,
3067        ISptr1: Deref<Target = [Iptr]>,
3068        ISptr2: Deref<Target = [Iptr]>,
3069        DS1: Deref<Target = [N]>,
3070        DS2: Deref<Target = [N]>,
3071        N: AbsDiffEq,
3072        N::Epsilon: Clone,
3073        N: num_traits::Zero,
3074    {
3075        type Epsilon = N::Epsilon;
3076        fn default_epsilon() -> N::Epsilon {
3077            N::default_epsilon()
3078        }
3079        fn abs_diff_eq(
3080            &self,
3081            other: &CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>,
3082            epsilon: N::Epsilon,
3083        ) -> bool {
3084            if self.shape() != other.shape() {
3085                return false;
3086            }
3087            if self.storage() == other.storage() {
3088                self.outer_iterator()
3089                    .zip(other.outer_iterator())
3090                    .all(|(r1, r2)| r1.abs_diff_eq(&r2, epsilon.clone()))
3091            } else {
3092                // Checks if all elements in self has a matching element
3093                // in other
3094                let all_matching = self.iter().all(|(n, (i, j))| {
3095                    n.abs_diff_eq(
3096                        other
3097                            .get(i.to_usize().unwrap(), j.to_usize().unwrap())
3098                            .unwrap_or(&N::zero()),
3099                        epsilon.clone(),
3100                    )
3101                });
3102                if !all_matching {
3103                    return false;
3104                }
3105
3106                // Must also check if all elements in other matches self
3107                other.iter().all(|(n, (i, j))| {
3108                    n.abs_diff_eq(
3109                        self.get(i.to_usize().unwrap(), j.to_usize().unwrap())
3110                            .unwrap_or(&N::zero()),
3111                        epsilon.clone(),
3112                    )
3113                })
3114            }
3115        }
3116    }
3117    impl<N, I, Iptr, IS1, DS1, ISptr1, IS2, ISptr2, DS2>
3118        UlpsEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>
3119        for CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>
3120    where
3121        I: SpIndex,
3122        Iptr: SpIndex,
3123        CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>:
3124            std::cmp::PartialEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>,
3125        IS1: Deref<Target = [I]>,
3126        IS2: Deref<Target = [I]>,
3127        ISptr1: Deref<Target = [Iptr]>,
3128        ISptr2: Deref<Target = [Iptr]>,
3129        DS1: Deref<Target = [N]>,
3130        DS2: Deref<Target = [N]>,
3131        N: UlpsEq,
3132        N::Epsilon: Clone,
3133        N: num_traits::Zero,
3134    {
3135        fn default_max_ulps() -> u32 {
3136            N::default_max_ulps()
3137        }
3138        fn ulps_eq(
3139            &self,
3140            other: &CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>,
3141            epsilon: N::Epsilon,
3142            max_ulps: u32,
3143        ) -> bool {
3144            if self.shape() != other.shape() {
3145                return false;
3146            }
3147            if self.storage() == other.storage() {
3148                self.outer_iterator()
3149                    .zip(other.outer_iterator())
3150                    .all(|(r1, r2)| r1.ulps_eq(&r2, epsilon.clone(), max_ulps))
3151            } else {
3152                // Checks if all elements in self has a matching element
3153                // in other
3154                let all_matches = self.iter().all(|(n, (i, j))| {
3155                    n.ulps_eq(
3156                        other
3157                            .get(i.to_usize().unwrap(), j.to_usize().unwrap())
3158                            .unwrap_or(&N::zero()),
3159                        epsilon.clone(),
3160                        max_ulps,
3161                    )
3162                });
3163                if !all_matches {
3164                    return false;
3165                }
3166
3167                // Must also check if all elements in other matches self
3168                other.iter().all(|(n, (i, j))| {
3169                    n.ulps_eq(
3170                        self.get(i.to_usize().unwrap(), j.to_usize().unwrap())
3171                            .unwrap_or(&N::zero()),
3172                        epsilon.clone(),
3173                        max_ulps,
3174                    )
3175                })
3176            }
3177        }
3178    }
3179    impl<N, I, Iptr, IS1, DS1, ISptr1, IS2, ISptr2, DS2>
3180        RelativeEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>
3181        for CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>
3182    where
3183        I: SpIndex,
3184        Iptr: SpIndex,
3185        CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>:
3186            std::cmp::PartialEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>,
3187        IS1: Deref<Target = [I]>,
3188        IS2: Deref<Target = [I]>,
3189        ISptr1: Deref<Target = [Iptr]>,
3190        ISptr2: Deref<Target = [Iptr]>,
3191        DS1: Deref<Target = [N]>,
3192        DS2: Deref<Target = [N]>,
3193        N: RelativeEq,
3194        N::Epsilon: Clone,
3195        N: num_traits::Zero,
3196    {
3197        fn default_max_relative() -> N::Epsilon {
3198            N::default_max_relative()
3199        }
3200        fn relative_eq(
3201            &self,
3202            other: &CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>,
3203            epsilon: N::Epsilon,
3204            max_relative: Self::Epsilon,
3205        ) -> bool {
3206            if self.shape() != other.shape() {
3207                return false;
3208            }
3209            if self.storage() == other.storage() {
3210                self.outer_iterator().zip(other.outer_iterator()).all(
3211                    |(r1, r2)| {
3212                        r1.relative_eq(
3213                            &r2,
3214                            epsilon.clone(),
3215                            max_relative.clone(),
3216                        )
3217                    },
3218                )
3219            } else {
3220                // Checks if all elements in self has a matching element
3221                // in other
3222                let all_matches = self.iter().all(|(n, (i, j))| {
3223                    n.relative_eq(
3224                        other
3225                            .get(i.to_usize().unwrap(), j.to_usize().unwrap())
3226                            .unwrap_or(&N::zero()),
3227                        epsilon.clone(),
3228                        max_relative.clone(),
3229                    )
3230                });
3231                if !all_matches {
3232                    return false;
3233                }
3234
3235                // Must also check if all elements in other matches self
3236                other.iter().all(|(n, (i, j))| {
3237                    n.relative_eq(
3238                        self.get(i.to_usize().unwrap(), j.to_usize().unwrap())
3239                            .unwrap_or(&N::zero()),
3240                        epsilon.clone(),
3241                        max_relative.clone(),
3242                    )
3243                })
3244            }
3245        }
3246    }
3247
3248    #[cfg(test)]
3249    mod tests {
3250        use crate::*;
3251
3252        #[test]
3253        fn different_shapes() {
3254            let mut m1 = TriMat::new((3, 2));
3255            m1.add_triplet(1, 1, 8_u8);
3256            let m1: CsMat<_> = m1.to_csr();
3257            let mut m2 = TriMat::new((2, 3));
3258            m2.add_triplet(1, 1, 8_u8);
3259            let m2 = m2.to_csr();
3260
3261            ::approx::assert_abs_diff_ne!(m1, m2);
3262            ::approx::assert_abs_diff_ne!(m1, m2.to_csc());
3263            ::approx::assert_abs_diff_ne!(m1.to_csc(), m2);
3264            ::approx::assert_abs_diff_ne!(m1.to_csc(), m2.to_csc());
3265        }
3266
3267        #[test]
3268        fn equal_elements() {
3269            let mut m1 = TriMat::new((6, 9));
3270            m1.add_triplet(1, 1, 8_u8);
3271            m1.add_triplet(1, 2, 7_u8);
3272            m1.add_triplet(0, 1, 6_u8);
3273            m1.add_triplet(0, 8, 5_u8);
3274            m1.add_triplet(4, 2, 4_u8);
3275
3276            let m1: CsMat<_> = m1.to_csr();
3277            let m2 = m1.clone();
3278
3279            ::approx::assert_abs_diff_eq!(m1, m2, epsilon = 0);
3280            ::approx::assert_abs_diff_eq!(m1.to_csc(), m2, epsilon = 0);
3281            ::approx::assert_abs_diff_eq!(m1, m2.to_csc(), epsilon = 0);
3282            ::approx::assert_abs_diff_eq!(
3283                m1.to_csc(),
3284                m2.to_csc(),
3285                epsilon = 0
3286            );
3287
3288            let mut m1 = TriMat::new((6, 9));
3289            m1.add_triplet(1, 1, 8.0_f32);
3290            m1.add_triplet(1, 2, 7.0);
3291            m1.add_triplet(0, 1, 6.0);
3292            m1.add_triplet(0, 8, 5.0);
3293            m1.add_triplet(4, 2, 4.0);
3294
3295            let m1: CsMat<_> = m1.to_csr();
3296            let m2 = m1.clone();
3297
3298            ::approx::assert_abs_diff_eq!(m1, m2);
3299            ::approx::assert_abs_diff_eq!(m1.to_csc(), m2);
3300            ::approx::assert_abs_diff_eq!(m1, m2.to_csc());
3301            ::approx::assert_abs_diff_eq!(m1.to_csc(), m2.to_csc());
3302
3303            ::approx::assert_relative_eq!(m1, m2);
3304            ::approx::assert_relative_eq!(m1.to_csc(), m2);
3305            ::approx::assert_relative_eq!(m1, m2.to_csc());
3306            ::approx::assert_relative_eq!(m1.to_csc(), m2.to_csc());
3307
3308            ::approx::assert_ulps_eq!(m1, m2);
3309            ::approx::assert_ulps_eq!(m1.to_csc(), m2);
3310            ::approx::assert_ulps_eq!(m1, m2.to_csc());
3311            ::approx::assert_ulps_eq!(m1.to_csc(), m2.to_csc());
3312        }
3313
3314        #[test]
3315        fn almost_equal_elements() {
3316            let mut m1 = TriMat::new((6, 9));
3317            m1.add_triplet(1, 1, 8.0_f32);
3318            m1.add_triplet(1, 2, 7.0);
3319            m1.add_triplet(0, 1, 6.0);
3320            m1.add_triplet(0, 8, 5.0);
3321            m1.add_triplet(4, 2, 4.0);
3322            let m1: CsMat<_> = m1.to_csr();
3323
3324            let mut m2 = TriMat::new((6, 9));
3325            m2.add_triplet(1, 1, 8.0_f32);
3326            m2.add_triplet(1, 2, 7.0 - 0.5); // 0.5 subtracted
3327            m2.add_triplet(0, 1, 6.0);
3328            m2.add_triplet(0, 8, 5.0);
3329            m2.add_triplet(4, 2, 4.0);
3330            m2.add_triplet(4, 3, 0.2); // extra element
3331            let m2 = m2.to_csr();
3332
3333            ::approx::assert_abs_diff_eq!(m1, m2, epsilon = 0.6);
3334            ::approx::assert_abs_diff_eq!(m1.to_csc(), m2, epsilon = 0.6);
3335            ::approx::assert_abs_diff_eq!(m1, m2.to_csc(), epsilon = 0.6);
3336            ::approx::assert_abs_diff_eq!(
3337                m1.to_csc(),
3338                m2.to_csc(),
3339                epsilon = 0.6
3340            );
3341
3342            ::approx::assert_abs_diff_ne!(m1, m2, epsilon = 0.4);
3343            ::approx::assert_abs_diff_ne!(m1.to_csc(), m2, epsilon = 0.4);
3344            ::approx::assert_abs_diff_ne!(m1, m2.to_csc(), epsilon = 0.4);
3345            ::approx::assert_abs_diff_ne!(
3346                m1.to_csc(),
3347                m2.to_csc(),
3348                epsilon = 0.4
3349            );
3350        }
3351    }
3352}