algebra_sparse/
set.rs

1// Copyright (C) 2020-2025 algebra-sparse authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::traits::IntoView;
16use crate::{CsVecBuilder, CsrMatrixView, Real};
17
18/// A collection of CSR matrices stored efficiently in a single data structure.
19///
20/// This structure is designed to store multiple CSR matrices with different dimensions
21/// in a compact format. It's particularly useful for applications that need to manage
22/// many sparse matrices with similar sparsity patterns or for systems that generate
23/// multiple matrices during computation.
24///
25/// # Format
26///
27/// The set stores all matrices in contiguous arrays:
28/// - `col_indices`: Column indices for all matrices concatenated
29/// - `values`: Non-zero values for all matrices concatenated
30/// - `row_offsets`: Row offsets for all matrices concatenated
31/// - `ncols`: Number of columns for each individual matrix
32/// - `partition`: Metadata to locate each matrix within the concatenated data
33///
34/// # Examples
35///
36/// ```rust
37/// use algebra_sparse::CsrMatrixSet;
38/// use algebra_sparse::traits::IntoView;
39/// use algebra_sparse::CsrMatrixSetMethods;
40///
41/// let mut set: CsrMatrixSet<f64> = CsrMatrixSet::default();
42///
43/// // Add first matrix to the set
44/// {
45///     let mut builder1 = set.new_matrix(3, 1e-10);
46///     builder1.new_row().push(0, 1.0);
47/// } // builder1 is dropped here and matrix is added to set
48///
49/// // Add second matrix to the set
50/// {
51///     let mut builder2 = set.new_matrix(2, 1e-10);
52///     builder2.new_row().push(1, 3.0);
53/// } // builder2 is dropped here and matrix is added to set
54///
55/// println!("Set contains {} matrices", (&set).len());
56/// ```
57#[derive(Clone)]
58pub struct CsrMatrixSet<T> {
59    /// The column indices of the non-zero entries for all matrices.
60    col_indices: Vec<usize>,
61    /// The non-zero values for all matrices.
62    values: Vec<T>,
63    /// The offsets of each row in the `col_indices` and `values` arrays for all matrices.
64    row_offsets: Vec<usize>,
65    /// The number of columns for each matrix in the set.
66    ncols: Vec<usize>,
67    /// Partition information to locate individual matrices within the concatenated data.
68    partition: Vec<Partition>,
69}
70
71impl<T> Default for CsrMatrixSet<T> {
72    /// Creates a new empty CSR matrix set.
73    #[inline]
74    fn default() -> Self {
75        Self {
76            col_indices: Vec::new(),
77            values: Vec::new(),
78            row_offsets: vec![0],
79            ncols: Vec::new(),
80            partition: Vec::new(),
81        }
82    }
83}
84
85#[derive(Clone)]
86struct Partition {
87    pub value_offset: usize,
88    pub value_len: usize,
89    pub row_offset: usize,
90    pub row_len: usize,
91}
92
93impl Partition {
94    #[inline]
95    pub fn value_range(&self) -> std::ops::Range<usize> {
96        self.value_offset..self.value_offset + self.value_len
97    }
98
99    #[inline]
100    pub fn row_offset_range(&self) -> std::ops::Range<usize> {
101        self.row_offset..self.row_offset + self.row_len
102    }
103}
104
105impl<T: Real> CsrMatrixSet<T> {
106    /// Clears all matrices from the set.
107    ///
108    /// This removes all data and allows reuse of the set.
109    pub fn clear(&mut self) {
110        self.col_indices.clear();
111        self.values.clear();
112        self.row_offsets.clear();
113        self.ncols.clear();
114        self.partition.clear();
115    }
116
117    /// Creates a new matrix builder for this set.
118    ///
119    /// The matrix will be automatically added to the set when the builder is dropped.
120    ///
121    /// # Arguments
122    /// * `ncol` - Number of columns for the new matrix
123    /// * `zero_threshold` - Values below this threshold are filtered out
124    ///
125    /// # Returns
126    /// A `CsrMatrixBuilder` for constructing the new matrix
127    pub fn new_matrix(&mut self, ncol: usize, zero_threshold: T) -> CsrMatrixBuilder<T> {
128        let value_start = self.values.len();
129        let row_start = self.row_offsets.len();
130        self.row_offsets.push(0);
131        CsrMatrixBuilder {
132            set: self,
133            zero_threshold,
134            value_start,
135            row_start,
136            ncol,
137        }
138    }
139}
140
141impl<T> CsrMatrixSet<T> {
142    /// Returns the view of the matrix at the given index.
143    ///
144    /// # Arguments
145    /// * `index` - Index of the matrix to retrieve
146    ///
147    /// # Returns
148    /// A `CsrMatrixView` representing the requested matrix
149    ///
150    /// # Panics
151    ///
152    /// Panics if the index is out of bounds
153    #[inline]
154    pub fn get(&self, index: usize) -> CsrMatrixView<T> {
155        let partition = &self.partition[index];
156        CsrMatrixView::from_parts_unchecked(
157            &self.row_offsets[partition.row_offset_range()],
158            &self.col_indices[partition.value_range()],
159            &self.values[partition.value_range()],
160            self.ncols[index],
161        )
162    }
163
164    /// Returns a view of the entire matrix set.
165    #[inline]
166    pub fn as_view(&self) -> CsrMatrixSetView<'_, T> {
167        CsrMatrixSetView {
168            col_indices: &self.col_indices,
169            values: &self.values,
170            row_offsets: &self.row_offsets,
171            ncols: &self.ncols,
172            partition: &self.partition,
173        }
174    }
175}
176
177/// An immutable view of a CSR matrix set for efficient read-only access.
178///
179/// This structure provides zero-cost abstraction access to multiple CSR matrices stored
180/// in a compact, consolidated format. It's designed for scenarios where you need to
181/// read or process multiple sparse matrices without the overhead of copying data
182/// or creating separate matrix objects.
183///
184/// # Use Cases
185///
186/// Views are particularly useful for:
187///
188/// ## Parallel Processing
189/// ```rust
190/// use algebra_sparse::CsrMatrixSet;
191///
192/// let mut set = CsrMatrixSet::default();
193/// // Add some matrices to the set
194/// {
195///     let mut builder = set.new_matrix(3, 1e-10);
196///     builder.new_row().push(0, 1.0);
197/// }
198/// {
199///     let mut builder = set.new_matrix(2, 1e-10);
200///     builder.new_row().push(1, 2.0);
201/// }
202///
203/// let view = set.as_view();
204/// assert_eq!(view.len(),2);
205/// let (left, right) = view.split_at(view.len() / 2);
206/// assert_eq!(left.len(),1);
207/// assert_eq!(right.len(),1);
208/// ```
209#[derive(Clone, Copy)]
210pub struct CsrMatrixSetView<'a, T> {
211    col_indices: &'a [usize],
212    values: &'a [T],
213    row_offsets: &'a [usize],
214    ncols: &'a [usize],
215    partition: &'a [Partition],
216}
217
218impl<'a, T> CsrMatrixSetView<'a, T> {
219    /// Returns the view of the matrix at the given index.
220    ///
221    /// # Arguments
222    /// * `index` - Index of the matrix to retrieve
223    ///
224    /// # Returns
225    /// A `CsrMatrixView` representing the requested matrix
226    ///
227    /// # Panics
228    ///
229    /// Panics if the index is out of bounds
230    #[inline]
231    pub fn get(self, index: usize) -> CsrMatrixView<'a, T> {
232        let partition = &self.partition[index];
233        CsrMatrixView::from_parts_unchecked(
234            &self.row_offsets[partition.row_offset_range()],
235            &self.col_indices[partition.value_range()],
236            &self.values[partition.value_range()],
237            self.ncols[index],
238        )
239    }
240
241    /// Returns the number of matrices in the set.
242    #[inline]
243    pub fn len(&self) -> usize {
244        self.partition.len()
245    }
246
247    /// Returns true if the set contains no matrices.
248    #[inline]
249    pub fn is_empty(&self) -> bool {
250        self.partition.is_empty()
251    }
252
253    /// Splits the matrix set view into two at the given index.
254    ///
255    /// This is a zero-cost operation that creates two independent views that reference
256    /// the same underlying data but represent disjoint subsets of the matrices.
257    /// The operation is O(1) and involves no copying or allocation of matrix data.
258    ///
259    /// # Arguments
260    /// * `index` - The split position. The left view will contain matrices at indices `[0, index)`,
261    ///   and the right view will contain matrices at indices `[index, len)`.
262    ///
263    /// # Returns
264    /// A tuple of two views: `(left, right)` where:
265    /// - `left` contains matrices `0..index`
266    /// - `right` contains matrices `index..len`
267    ///
268    /// # Panics
269    ///
270    /// Panics if `index > len()`. Splitting at `index = 0` or `index = len()` is allowed
271    /// and will return an empty view on one side.
272    ///
273    /// # Examples
274    ///
275    /// ```rust
276    /// use algebra_sparse::CsrMatrixSet;
277    ///
278    /// let mut set = CsrMatrixSet::default();
279    /// // Add 3 matrices to the set
280    /// {
281    ///     let mut builder = set.new_matrix(2, 1e-10);
282    ///     builder.new_row().push(0, 1.0);
283    /// }
284    /// {
285    ///     let mut builder = set.new_matrix(2, 1e-10);
286    ///     builder.new_row().push(1, 2.0);
287    /// }
288    /// {
289    ///     let mut builder = set.new_matrix(2, 1e-10);
290    ///     builder.new_row().push(0, 3.0);
291    /// }
292    ///
293    /// let view = set.as_view();
294    ///
295    /// // Split in the middle
296    /// let (left, right) = view.split_at(1);
297    /// assert_eq!(left.len(), 1);  // matrices 0
298    /// assert_eq!(right.len(), 2); // matrices 1, 2
299    ///
300    /// // Split at the beginning
301    /// let (empty, all) = view.split_at(0);
302    /// assert!(empty.is_empty());
303    /// assert_eq!(all.len(), 3);
304    ///
305    /// // Split at the end
306    /// let (all, empty) = view.split_at(3);
307    /// assert_eq!(all.len(), 3);
308    /// assert!(empty.is_empty());
309    /// ```
310    ///
311    /// # Parallel Processing
312    ///
313    /// This method is particularly useful for parallel processing:
314    ///
315    /// ```rust
316    /// # use algebra_sparse::CsrMatrixSet;
317    /// # let mut set: CsrMatrixSet<f64> = CsrMatrixSet::default();
318    /// let view = set.as_view();
319    /// let midpoint = view.len() / 2;
320    /// let (left, right) = view.split_at(midpoint);
321    ///
322    /// // Process halves independently (parallel processing example)
323    /// // This is conceptual - actual parallel processing would use rayon or similar
324    /// ```
325    ///
326    /// # Memory Efficiency
327    ///
328    /// Both resulting views share references to the same underlying data:
329    /// - No matrix data is copied during the split
330    /// - Both views have independent lifetimes
331    #[inline]
332    pub fn split_at(self, index: usize) -> (Self, Self) {
333        let (left_partition, right_partition) = self.partition.split_at(index);
334        let (left_ncols, right_ncols) = self.ncols.split_at(index);
335        let left = CsrMatrixSetView {
336            col_indices: self.col_indices,
337            values: self.values,
338            row_offsets: self.row_offsets,
339            ncols: left_ncols,
340            partition: left_partition,
341        };
342        let right = CsrMatrixSetView {
343            col_indices: self.col_indices,
344            values: self.values,
345            row_offsets: self.row_offsets,
346            ncols: right_ncols,
347            partition: right_partition,
348        };
349        (left, right)
350    }
351}
352
353impl<'a, T> IntoView for &'a CsrMatrixSet<T> {
354    type View = CsrMatrixSetView<'a, T>;
355
356    #[inline]
357    fn into_view(self) -> Self::View {
358        self.as_view()
359    }
360}
361
362pub trait CsrMatrixSetMethods<V> {
363    /// Returns the number of matrices in the set.
364    fn len(&self) -> usize;
365
366    /// Returns true if the set contains no matrices.
367    #[inline]
368    fn is_empty(&self) -> bool {
369        self.len() == 0
370    }
371}
372
373impl<'a, T, V> CsrMatrixSetMethods<V> for &'a T
374where
375    &'a T: IntoView<View = CsrMatrixSetView<'a, V>>,
376    V: Real,
377{
378    #[inline]
379    fn len(&self) -> usize {
380        self.into_view().len()
381    }
382}
383
384/// A builder for constructing CSR matrices within a `CsrMatrixSet`.
385///
386/// This builder allows efficient construction of CSR matrices that will be stored
387/// in a matrix set. When the builder is dropped, the matrix is automatically
388/// added to the parent set.
389///
390/// # Examples
391///
392/// ```rust
393/// use algebra_sparse::CsrMatrixSet;
394///
395/// let mut set = CsrMatrixSet::default();
396/// let mut builder = set.new_matrix(3, 1e-10);
397///
398/// let mut row_builder = builder.new_row();
399/// row_builder.push(0, 1.0);
400/// row_builder.push(2, 2.0);
401/// // Matrix is automatically added to set when builder is dropped
402/// ```
403pub struct CsrMatrixBuilder<'a, T> {
404    set: &'a mut CsrMatrixSet<T>,
405    /// The value index start for this matrix in the set's `values` array.
406    value_start: usize,
407    /// The row index start for this matrix in the set's `row_offsets` array.
408    row_start: usize,
409    /// The number of columns of this matrix.
410    ncol: usize,
411    /// Values below this threshold will be ignored during construction.
412    zero_threshold: T,
413}
414
415/// Automatically finalizes the CSR matrix and adds it to the set when the builder is dropped.
416///
417/// This implementation ensures that the matrix is properly added to the parent set
418/// when the builder goes out of scope, including updating all necessary metadata.
419impl<T> Drop for CsrMatrixBuilder<'_, T> {
420    fn drop(&mut self) {
421        self.set.ncols.push(self.ncol);
422        let partition = Partition {
423            value_offset: self.value_start,
424            value_len: self.set.values.len() - self.value_start,
425            row_offset: self.row_start,
426            row_len: self.set.row_offsets.len() - self.row_start,
427        };
428        self.set.partition.push(partition);
429    }
430}
431
432impl<T: Real> CsrMatrixBuilder<'_, T> {
433    /// Returns the number of columns for the matrix being built.
434    #[inline]
435    pub fn ncol(&self) -> usize {
436        self.ncol
437    }
438
439    /// Creates a new row builder for this matrix.
440    ///
441    /// The returned builder can be used to add non-zero elements to the next row.
442    ///
443    /// # Returns
444    /// A `CsVecBuilder` for constructing a sparse row
445    #[inline]
446    pub fn new_row(&mut self) -> CsVecBuilder<T> {
447        CsVecBuilder::from_parts_unchecked(
448            &mut self.set.col_indices,
449            &mut self.set.row_offsets,
450            &mut self.set.values,
451            self.value_start,
452            self.zero_threshold,
453        )
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use approx::assert_relative_eq;
460
461    use super::*;
462    use crate::csm::CsrMatrixViewMethods;
463    use crate::traits::IntoView;
464
465    /// Helper function to create a test matrix set with multiple matrices
466    fn create_test_matrix_set() -> CsrMatrixSet<f32> {
467        let mut set = CsrMatrixSet::default();
468
469        // First matrix: 3x3 sparse matrix
470        {
471            let mut builder = set.new_matrix(3, 1e-10);
472            {
473                let mut row = builder.new_row();
474                row.extend_with_nonzeros(vec![(0, 1.0), (2, 2.0)]);
475            }
476            {
477                let mut row = builder.new_row();
478                row.extend_with_nonzeros(vec![(1, 3.0)]);
479            }
480            {
481                let mut row = builder.new_row();
482                row.extend_with_nonzeros(vec![(0, 4.0), (1, 5.0)]);
483            }
484        }
485
486        // Second matrix: 2x2 sparse matrix
487        {
488            let mut builder = set.new_matrix(2, 1e-10);
489            {
490                let mut row = builder.new_row();
491                row.extend_with_nonzeros(vec![(0, 6.0), (1, 7.0)]);
492            }
493            {
494                let mut row = builder.new_row();
495                row.extend_with_nonzeros(vec![(1, 8.0)]);
496            }
497        }
498
499        // Third matrix: 4x1 sparse matrix
500        {
501            let mut builder = set.new_matrix(1, 1e-10);
502            {
503                let mut row = builder.new_row();
504                row.extend_with_nonzeros(vec![(0, 9.0)]);
505            }
506        }
507
508        // Fourth matrix: 2x3 sparse matrix
509        {
510            let mut builder = set.new_matrix(3, 1e-10);
511            {
512                let mut row = builder.new_row();
513                row.extend_with_nonzeros(vec![(1, 10.0)]);
514            }
515            {
516                let mut row = builder.new_row();
517                row.extend_with_nonzeros(vec![(0, 11.0), (2, 12.0)]);
518            }
519        }
520
521        set
522    }
523
524    #[test]
525    fn test_split_at_beginning() {
526        let set = create_test_matrix_set();
527        let view = set.as_view();
528
529        // Split at index 0: left empty, right contains all matrices
530        let (left, right) = view.split_at(0);
531
532        assert_eq!(left.len(), 0);
533        assert_eq!(right.len(), 4);
534
535        // Verify right view contains all matrices in original order
536        for i in 0..4 {
537            let original_matrix = set.get(i);
538            let split_matrix = right.get(i);
539            assert_relative_eq!(original_matrix.to_dense(), split_matrix.to_dense());
540        }
541    }
542
543    #[test]
544    fn test_split_at_end() {
545        let set = create_test_matrix_set();
546        let view = set.as_view();
547
548        // Split at index 4: left contains all matrices, right empty
549        let (left, right) = view.split_at(4);
550
551        assert_eq!(left.len(), 4);
552        assert_eq!(right.len(), 0);
553
554        // Verify left view contains all matrices in original order
555        for i in 0..4 {
556            let original_matrix = set.get(i);
557            let split_matrix = left.get(i);
558            assert_relative_eq!(original_matrix.to_dense(), split_matrix.to_dense());
559        }
560    }
561
562    #[test]
563    fn test_split_at_middle() {
564        let set = create_test_matrix_set();
565        let view = set.as_view();
566
567        // Split at index 2: left contains matrices 0,1; right contains matrices 2,3
568        let (left, right) = view.split_at(2);
569
570        assert_eq!(left.len(), 2);
571        assert_eq!(right.len(), 2);
572
573        // Verify left view contains first two matrices
574        for i in 0..2 {
575            let original_matrix = set.get(i);
576            let split_matrix = left.get(i);
577            assert_relative_eq!(original_matrix.to_dense(), split_matrix.to_dense());
578        }
579
580        // Verify right view contains last two matrices
581        for i in 0..2 {
582            let original_matrix = set.get(i + 2);
583            let split_matrix = right.get(i);
584            assert_relative_eq!(original_matrix.to_dense(), split_matrix.to_dense());
585        }
586    }
587
588    #[test]
589    fn test_split_at_various_positions() {
590        let set = create_test_matrix_set();
591        let view = set.as_view();
592
593        // Test splitting at each possible position
594        for split_index in 0..=4 {
595            let (left, right) = view.split_at(split_index);
596
597            assert_eq!(left.len(), split_index);
598            assert_eq!(right.len(), 4 - split_index);
599
600            // Verify all matrices are correctly distributed
601            for i in 0..4 {
602                let original_matrix = set.get(i);
603                let split_matrix = if i < split_index {
604                    left.get(i)
605                } else {
606                    right.get(i - split_index)
607                };
608                assert_relative_eq!(original_matrix.to_dense(), split_matrix.to_dense());
609            }
610        }
611    }
612
613    #[test]
614    fn test_split_multiple_times() {
615        let set = create_test_matrix_set();
616        let view = set.as_view();
617
618        // First split at index 2
619        let (left, right) = view.split_at(2);
620
621        // Split left part at index 1
622        let (left_left, left_right) = left.split_at(1);
623
624        // Split right part at index 1
625        let (right_left, right_right) = right.split_at(1);
626
627        // Verify all parts have correct lengths
628        assert_eq!(left_left.len(), 1);
629        assert_eq!(left_right.len(), 1);
630        assert_eq!(right_left.len(), 1);
631        assert_eq!(right_right.len(), 1);
632
633        // Verify matrices are correctly distributed
634        assert_relative_eq!(set.get(0).to_dense(), left_left.get(0).to_dense());
635        assert_relative_eq!(set.get(1).to_dense(), left_right.get(0).to_dense());
636        assert_relative_eq!(set.get(2).to_dense(), right_left.get(0).to_dense());
637        assert_relative_eq!(set.get(3).to_dense(), right_right.get(0).to_dense());
638    }
639
640    #[test]
641    fn test_split_single_matrix() {
642        let mut set = CsrMatrixSet::default();
643
644        // Create a set with just one matrix
645        {
646            let mut builder = set.new_matrix(2, 1e-10);
647            {
648                let mut row = builder.new_row();
649                row.extend_with_nonzeros(vec![(0, 1.0), (1, 2.0)]);
650            }
651        }
652
653        let view = (&set).into_view();
654
655        // Split at index 0
656        let (left, right) = view.split_at(0);
657        assert_eq!(left.len(), 0);
658        assert_eq!(right.len(), 1);
659
660        // Split at index 1
661        let (left, right) = view.split_at(1);
662        assert_eq!(left.len(), 1);
663        assert_eq!(right.len(), 0);
664
665        assert_relative_eq!(set.get(0).to_dense(), left.get(0).to_dense());
666    }
667
668    #[test]
669    fn test_split_empty_view() {
670        let set = CsrMatrixSet::<f32>::default();
671        let view = set.as_view();
672
673        // Split empty view at index 0
674        let (left, right) = view.split_at(0);
675
676        assert_eq!(left.len(), 0);
677        assert_eq!(right.len(), 0);
678        assert!(left.is_empty());
679        assert!(right.is_empty());
680    }
681
682    #[test]
683    fn test_split_view_data_integrity() {
684        let set = create_test_matrix_set();
685        let view = set.as_view();
686
687        let (left, right) = view.split_at(2);
688
689        // Verify that all views share the same underlying data
690        // by checking that modifications through one view are visible through others
691        // (Note: since we're working with views, we can't modify the data,
692        // but we can verify the data integrity by comparing matrices)
693
694        let original_matrices: Vec<_> = (0..4).map(|i| set.get(i).to_dense()).collect();
695        let left_matrices: Vec<_> = (0..2).map(|i| left.get(i).to_dense()).collect();
696        let right_matrices: Vec<_> = (0..2).map(|i| right.get(i).to_dense()).collect();
697
698        // Verify data integrity
699        for (i, original) in original_matrices.iter().enumerate() {
700            let split_matrix = if i < 2 {
701                &left_matrices[i]
702            } else {
703                &right_matrices[i - 2]
704            };
705            assert_relative_eq!(original, split_matrix);
706        }
707    }
708
709    #[test]
710    fn test_split_view_independence() {
711        let set = create_test_matrix_set();
712        let view = set.as_view();
713
714        let (left1, right1) = view.split_at(2);
715        let (left2, right2) = view.split_at(2);
716
717        // Both splits should produce identical results
718        assert_eq!(left1.len(), left2.len());
719        assert_eq!(right1.len(), right2.len());
720
721        for i in 0..left1.len() {
722            assert_relative_eq!(left1.get(i).to_dense(), left2.get(i).to_dense());
723        }
724
725        for i in 0..right1.len() {
726            assert_relative_eq!(right1.get(i).to_dense(), right2.get(i).to_dense());
727        }
728    }
729}