Skip to main content

scirs2_sparse/
csr_array.rs

1// CSR Array implementation
2//
3// This module provides the CSR (Compressed Sparse Row) array format,
4// which is efficient for row-wise operations and is one of the most common formats.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::{SparseArray, SparseSum};
13
14/// Insert a value into an `Array1` at position `idx`, shifting subsequent
15/// elements to the right.  ndarray's `Array1` does not provide an insert
16/// method, so we convert to `Vec`, insert, and convert back.
17fn array1_insert<T: Clone + Default>(arr: &Array1<T>, idx: usize, value: T) -> Array1<T> {
18    let mut v = arr.to_vec();
19    v.insert(idx, value);
20    Array1::from_vec(v)
21}
22
23/// CSR Array format - Compressed Sparse Row matrix representation
24///
25/// The CSR (Compressed Sparse Row) format is one of the most popular sparse matrix formats,
26/// storing a sparse array using three arrays:
27/// - `data`: array of non-zero values in row-major order
28/// - `indices`: column indices of the non-zero values
29/// - `indptr`: row pointers; `indptr[i]` is the index into `data`/`indices` where row `i` starts
30///
31/// The CSR format is particularly efficient for:
32/// - ✅ Matrix-vector multiplications (`A * x`)
33/// - ✅ Matrix-matrix multiplications with other sparse matrices
34/// - ✅ Row-wise operations and row slicing
35/// - ✅ Iterating over non-zero elements row by row
36/// - ✅ Adding and subtracting sparse matrices
37///
38/// But less efficient for:
39/// - ❌ Column-wise operations and column slicing
40/// - ❌ Inserting or modifying individual elements after construction
41/// - ❌ Operations that require column access patterns
42///
43/// # Memory Layout
44///
45/// For a matrix with `m` rows, `n` columns, and `nnz` non-zero elements:
46/// - `data`: length `nnz` - stores the actual non-zero values
47/// - `indices`: length `nnz` - stores column indices for each non-zero value
48/// - `indptr`: length `m+1` - stores cumulative count of non-zeros per row
49///
50/// # Examples
51///
52/// ## Basic Construction and Access
53/// ```
54/// use scirs2_sparse::csr_array::CsrArray;
55/// use scirs2_sparse::SparseArray;
56///
57/// // Create a 3x3 matrix:
58/// // [1.0, 0.0, 2.0]
59/// // [0.0, 3.0, 0.0]
60/// // [4.0, 0.0, 5.0]
61/// let rows = vec![0, 0, 1, 2, 2];
62/// let cols = vec![0, 2, 1, 0, 2];
63/// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
64/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
65///
66/// // Access elements
67/// assert_eq!(matrix.get(0, 0), 1.0);
68/// assert_eq!(matrix.get(0, 1), 0.0);  // Zero element
69/// assert_eq!(matrix.get(1, 1), 3.0);
70///
71/// // Get matrix properties
72/// assert_eq!(matrix.shape(), (3, 3));
73/// assert_eq!(matrix.nnz(), 5);
74/// ```
75///
76/// ## Matrix Operations
77/// ```
78/// use scirs2_sparse::csr_array::CsrArray;
79/// use scirs2_sparse::SparseArray;
80/// use scirs2_core::ndarray::Array1;
81///
82/// let rows = vec![0, 1, 2];
83/// let cols = vec![0, 1, 2];
84/// let data = vec![2.0, 3.0, 4.0];
85/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
86///
87/// // Matrix-vector multiplication
88/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
89/// let y = matrix.dot_vector(&x.view()).expect("Operation failed");
90/// assert_eq!(y[0], 2.0);  // 2.0 * 1.0
91/// assert_eq!(y[1], 6.0);  // 3.0 * 2.0
92/// assert_eq!(y[2], 12.0); // 4.0 * 3.0
93/// ```
94///
95/// ## Format Conversion
96/// ```
97/// use scirs2_sparse::csr_array::CsrArray;
98/// use scirs2_sparse::SparseArray;
99///
100/// let rows = vec![0, 1];
101/// let cols = vec![0, 1];
102/// let data = vec![1.0, 2.0];
103/// let csr = CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).expect("Operation failed");
104///
105/// // Convert to dense array
106/// let dense = csr.to_array();
107/// assert_eq!(dense[[0, 0]], 1.0);
108/// assert_eq!(dense[[1, 1]], 2.0);
109///
110/// // Convert to other sparse formats
111/// let coo = csr.to_coo();
112/// let csc = csr.to_csc();
113/// ```
114#[derive(Clone)]
115pub struct CsrArray<T>
116where
117    T: SparseElement + Div<Output = T> + 'static,
118{
119    /// Non-zero values
120    data: Array1<T>,
121    /// Column indices of non-zero values
122    indices: Array1<usize>,
123    /// Row pointers (indices into data/indices for the start of each row)
124    indptr: Array1<usize>,
125    /// Shape of the sparse array
126    shape: (usize, usize),
127    /// Whether indices are sorted for each row
128    has_sorted_indices: bool,
129}
130
131impl<T> CsrArray<T>
132where
133    T: SparseElement + Div<Output = T> + Zero + 'static,
134{
135    /// Creates a new CSR array from raw components
136    ///
137    /// # Arguments
138    /// * `data` - Array of non-zero values
139    /// * `indices` - Column indices of non-zero values
140    /// * `indptr` - Index pointers for the start of each row
141    /// * `shape` - Shape of the sparse array
142    ///
143    /// # Returns
144    /// A new `CsrArray`
145    ///
146    /// # Errors
147    /// Returns an error if the data is not consistent
148    pub fn new(
149        data: Array1<T>,
150        indices: Array1<usize>,
151        indptr: Array1<usize>,
152        shape: (usize, usize),
153    ) -> SparseResult<Self> {
154        // Validation
155        if data.len() != indices.len() {
156            return Err(SparseError::InconsistentData {
157                reason: "data and indices must have the same length".to_string(),
158            });
159        }
160
161        if indptr.len() != shape.0 + 1 {
162            return Err(SparseError::InconsistentData {
163                reason: format!(
164                    "indptr length ({}) must be one more than the number of rows ({})",
165                    indptr.len(),
166                    shape.0
167                ),
168            });
169        }
170
171        if let Some(&max_idx) = indices.iter().max() {
172            if max_idx >= shape.1 {
173                return Err(SparseError::IndexOutOfBounds {
174                    index: (0, max_idx),
175                    shape,
176                });
177            }
178        }
179
180        if let Some((&last, &first)) = indptr.iter().next_back().zip(indptr.iter().next()) {
181            if first != 0 {
182                return Err(SparseError::InconsistentData {
183                    reason: "first element of indptr must be 0".to_string(),
184                });
185            }
186
187            if last != data.len() {
188                return Err(SparseError::InconsistentData {
189                    reason: format!(
190                        "last element of indptr ({}) must equal data length ({})",
191                        last,
192                        data.len()
193                    ),
194                });
195            }
196        }
197
198        let has_sorted_indices = Self::check_sorted_indices(&indices, &indptr);
199
200        Ok(Self {
201            data,
202            indices,
203            indptr,
204            shape,
205            has_sorted_indices,
206        })
207    }
208
209    /// Create a CSR array from triplet format (COO-like)
210    ///
211    /// This function creates a CSR (Compressed Sparse Row) array from coordinate triplets.
212    /// The triplets represent non-zero elements as (row, column, value) tuples.
213    ///
214    /// # Arguments
215    /// * `rows` - Row indices of non-zero elements
216    /// * `cols` - Column indices of non-zero elements  
217    /// * `data` - Values of non-zero elements
218    /// * `shape` - Shape of the sparse array (nrows, ncols)
219    /// * `sorted` - Whether the triplets are already sorted by (row, col). If false, sorting will be performed.
220    ///
221    /// # Returns
222    /// A new `CsrArray` containing the sparse matrix
223    ///
224    /// # Errors
225    /// Returns an error if:
226    /// - `rows`, `cols`, and `data` have different lengths
227    /// - Any index is out of bounds for the given shape
228    /// - The resulting data structure is inconsistent
229    ///
230    /// # Examples
231    ///
232    /// Create a simple 3x3 sparse matrix:
233    /// ```
234    /// use scirs2_sparse::csr_array::CsrArray;
235    /// use scirs2_sparse::SparseArray;
236    ///
237    /// // Create a 3x3 matrix with the following structure:
238    /// // [1.0, 0.0, 2.0]
239    /// // [0.0, 3.0, 0.0]
240    /// // [4.0, 0.0, 5.0]
241    /// let rows = vec![0, 0, 1, 2, 2];
242    /// let cols = vec![0, 2, 1, 0, 2];
243    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
244    /// let shape = (3, 3);
245    ///
246    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
247    /// assert_eq!(matrix.get(0, 0), 1.0);
248    /// assert_eq!(matrix.get(0, 1), 0.0);
249    /// assert_eq!(matrix.get(1, 1), 3.0);
250    /// ```
251    ///
252    /// Create an empty sparse matrix:
253    /// ```
254    /// use scirs2_sparse::csr_array::CsrArray;
255    /// use scirs2_sparse::SparseArray;
256    ///
257    /// let rows: Vec<usize> = vec![];
258    /// let cols: Vec<usize> = vec![];
259    /// let data: Vec<f64> = vec![];
260    /// let shape = (5, 5);
261    ///
262    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
263    /// assert_eq!(matrix.nnz(), 0);
264    /// assert_eq!(matrix.shape(), (5, 5));
265    /// ```
266    ///
267    /// Handle duplicate entries (they will be preserved):
268    /// ```
269    /// use scirs2_sparse::csr_array::CsrArray;
270    /// use scirs2_sparse::SparseArray;
271    ///
272    /// // Multiple entries at the same position
273    /// let rows = vec![0, 0];
274    /// let cols = vec![0, 0];
275    /// let data = vec![1.0, 2.0];
276    /// let shape = (2, 2);
277    ///
278    /// let matrix = CsrArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
279    /// // Note: CSR format preserves duplicates; use sum_duplicates() to combine them
280    /// assert_eq!(matrix.nnz(), 2);
281    /// ```
282    pub fn from_triplets(
283        rows: &[usize],
284        cols: &[usize],
285        data: &[T],
286        shape: (usize, usize),
287        sorted: bool,
288    ) -> SparseResult<Self> {
289        if rows.len() != cols.len() || rows.len() != data.len() {
290            return Err(SparseError::InconsistentData {
291                reason: "rows, cols, and data must have the same length".to_string(),
292            });
293        }
294
295        if rows.is_empty() {
296            // Empty matrix
297            let indptr = Array1::zeros(shape.0 + 1);
298            return Self::new(Array1::zeros(0), Array1::zeros(0), indptr, shape);
299        }
300
301        let nnz = rows.len();
302        let mut all_data: Vec<(usize, usize, T)> = Vec::with_capacity(nnz);
303
304        for i in 0..nnz {
305            if rows[i] >= shape.0 || cols[i] >= shape.1 {
306                return Err(SparseError::IndexOutOfBounds {
307                    index: (rows[i], cols[i]),
308                    shape,
309                });
310            }
311            all_data.push((rows[i], cols[i], data[i]));
312        }
313
314        if !sorted {
315            all_data.sort_by_key(|&(row, col_, _)| (row, col_));
316        }
317
318        // Count elements per row
319        let mut row_counts = vec![0; shape.0];
320        for &(row_, _, _) in &all_data {
321            row_counts[row_] += 1;
322        }
323
324        // Create indptr
325        let mut indptr = Vec::with_capacity(shape.0 + 1);
326        indptr.push(0);
327        let mut cumsum = 0;
328        for &count in &row_counts {
329            cumsum += count;
330            indptr.push(cumsum);
331        }
332
333        // Create indices and data arrays
334        let mut indices = Vec::with_capacity(nnz);
335        let mut values = Vec::with_capacity(nnz);
336
337        for (_, col, val) in all_data {
338            indices.push(col);
339            values.push(val);
340        }
341
342        Self::new(
343            Array1::from_vec(values),
344            Array1::from_vec(indices),
345            Array1::from_vec(indptr),
346            shape,
347        )
348    }
349
350    /// Checks if column indices are sorted for each row
351    fn check_sorted_indices(indices: &Array1<usize>, indptr: &Array1<usize>) -> bool {
352        for row in 0..indptr.len() - 1 {
353            let start = indptr[row];
354            let end = indptr[row + 1];
355
356            for i in start..end.saturating_sub(1) {
357                if i + 1 < indices.len() && indices[i] > indices[i + 1] {
358                    return false;
359                }
360            }
361        }
362        true
363    }
364
365    /// Get the raw data array
366    pub fn get_data(&self) -> &Array1<T> {
367        &self.data
368    }
369
370    /// Get the raw indices array
371    pub fn get_indices(&self) -> &Array1<usize> {
372        &self.indices
373    }
374
375    /// Get the raw indptr array
376    pub fn get_indptr(&self) -> &Array1<usize> {
377        &self.indptr
378    }
379
380    /// Get the number of rows
381    pub fn nrows(&self) -> usize {
382        self.shape.0
383    }
384
385    /// Get the number of columns  
386    pub fn ncols(&self) -> usize {
387        self.shape.1
388    }
389
390    /// Get the shape (rows, cols)
391    pub fn shape(&self) -> (usize, usize) {
392        self.shape
393    }
394}
395
396impl<T> SparseArray<T> for CsrArray<T>
397where
398    T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
399{
400    fn shape(&self) -> (usize, usize) {
401        self.shape
402    }
403
404    fn nnz(&self) -> usize {
405        self.data.len()
406    }
407
408    fn dtype(&self) -> &str {
409        "float" // Placeholder, ideally we would return the actual type
410    }
411
412    fn to_array(&self) -> Array2<T> {
413        let (rows, cols) = self.shape;
414        let mut result = Array2::zeros((rows, cols));
415
416        for row in 0..rows {
417            let start = self.indptr[row];
418            let end = self.indptr[row + 1];
419
420            for i in start..end {
421                let col = self.indices[i];
422                result[[row, col]] = self.data[i];
423            }
424        }
425
426        result
427    }
428
429    fn toarray(&self) -> Array2<T> {
430        self.to_array()
431    }
432
433    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
434        // This would convert to COO format
435        // For now we just return self
436        Ok(Box::new(self.clone()))
437    }
438
439    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
440        Ok(Box::new(self.clone()))
441    }
442
443    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
444        // This would convert to CSC format
445        // For now we just return self
446        Ok(Box::new(self.clone()))
447    }
448
449    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
450        // This would convert to DOK format
451        // For now we just return self
452        Ok(Box::new(self.clone()))
453    }
454
455    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
456        // This would convert to LIL format
457        // For now we just return self
458        Ok(Box::new(self.clone()))
459    }
460
461    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
462        // This would convert to DIA format
463        // For now we just return self
464        Ok(Box::new(self.clone()))
465    }
466
467    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
468        // This would convert to BSR format
469        // For now we just return self
470        Ok(Box::new(self.clone()))
471    }
472
473    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
474        if self.shape() != other.shape() {
475            return Err(SparseError::DimensionMismatch {
476                expected: self.shape().0,
477                found: other.shape().0,
478            });
479        }
480
481        // Fast path: if `other` is also a CsrArray with sorted indices,
482        // perform a sorted row merge in O(nnz(A) + nnz(B)) time.
483        if let Some(other_csr) = other.as_any().downcast_ref::<CsrArray<T>>() {
484            if self.has_sorted_indices && other_csr.has_sorted_indices {
485                let (nrows, _) = self.shape();
486                let mut data = Vec::new();
487                let mut indices = Vec::new();
488                let mut indptr = vec![0usize];
489
490                for row in 0..nrows {
491                    let a_start = self.indptr[row];
492                    let a_end = self.indptr[row + 1];
493                    let b_start = other_csr.indptr[row];
494                    let b_end = other_csr.indptr[row + 1];
495
496                    let a_cols = &self.indices.as_slice().unwrap_or(&[])[a_start..a_end];
497                    let a_data = &self.data.as_slice().unwrap_or(&[])[a_start..a_end];
498                    let b_cols = &other_csr.indices.as_slice().unwrap_or(&[])[b_start..b_end];
499                    let b_data = &other_csr.data.as_slice().unwrap_or(&[])[b_start..b_end];
500
501                    let mut ai = 0;
502                    let mut bi = 0;
503                    while ai < a_cols.len() && bi < b_cols.len() {
504                        if a_cols[ai] < b_cols[bi] {
505                            let val = a_data[ai];
506                            if val != T::sparse_zero() {
507                                data.push(val);
508                                indices.push(a_cols[ai]);
509                            }
510                            ai += 1;
511                        } else if a_cols[ai] > b_cols[bi] {
512                            let val = b_data[bi];
513                            if val != T::sparse_zero() {
514                                data.push(val);
515                                indices.push(b_cols[bi]);
516                            }
517                            bi += 1;
518                        } else {
519                            let val = a_data[ai] + b_data[bi];
520                            if val != T::sparse_zero() {
521                                data.push(val);
522                                indices.push(a_cols[ai]);
523                            }
524                            ai += 1;
525                            bi += 1;
526                        }
527                    }
528                    while ai < a_cols.len() {
529                        let val = a_data[ai];
530                        if val != T::sparse_zero() {
531                            data.push(val);
532                            indices.push(a_cols[ai]);
533                        }
534                        ai += 1;
535                    }
536                    while bi < b_cols.len() {
537                        let val = b_data[bi];
538                        if val != T::sparse_zero() {
539                            data.push(val);
540                            indices.push(b_cols[bi]);
541                        }
542                        bi += 1;
543                    }
544                    indptr.push(data.len());
545                }
546
547                return CsrArray::new(
548                    Array1::from_vec(data),
549                    Array1::from_vec(indices),
550                    Array1::from_vec(indptr),
551                    self.shape(),
552                )
553                .map(|array| Box::new(array) as Box<dyn SparseArray<T>>);
554            }
555        }
556
557        // Fallback: dense conversion
558        let self_array = self.to_array();
559        let other_array = other.to_array();
560        let result = &self_array + &other_array;
561
562        let (rows, cols) = self.shape();
563        let mut data = Vec::new();
564        let mut indices = Vec::new();
565        let mut indptr = vec![0];
566
567        for row in 0..rows {
568            for col in 0..cols {
569                let val = result[[row, col]];
570                if val != T::sparse_zero() {
571                    data.push(val);
572                    indices.push(col);
573                }
574            }
575            indptr.push(data.len());
576        }
577
578        CsrArray::new(
579            Array1::from_vec(data),
580            Array1::from_vec(indices),
581            Array1::from_vec(indptr),
582            self.shape(),
583        )
584        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
585    }
586
587    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
588        // Similar to add, this is a placeholder
589        let self_array = self.to_array();
590        let other_array = other.to_array();
591
592        if self.shape() != other.shape() {
593            return Err(SparseError::DimensionMismatch {
594                expected: self.shape().0,
595                found: other.shape().0,
596            });
597        }
598
599        let result = &self_array - &other_array;
600
601        // Convert back to CSR
602        let (rows, cols) = self.shape();
603        let mut data = Vec::new();
604        let mut indices = Vec::new();
605        let mut indptr = vec![0];
606
607        for row in 0..rows {
608            for col in 0..cols {
609                let val = result[[row, col]];
610                if val != T::sparse_zero() {
611                    data.push(val);
612                    indices.push(col);
613                }
614            }
615            indptr.push(data.len());
616        }
617
618        CsrArray::new(
619            Array1::from_vec(data),
620            Array1::from_vec(indices),
621            Array1::from_vec(indptr),
622            self.shape(),
623        )
624        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
625    }
626
627    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
628        // This is element-wise multiplication (Hadamard product)
629        // In the sparse array API, * is element-wise, not matrix multiplication
630        let self_array = self.to_array();
631        let other_array = other.to_array();
632
633        if self.shape() != other.shape() {
634            return Err(SparseError::DimensionMismatch {
635                expected: self.shape().0,
636                found: other.shape().0,
637            });
638        }
639
640        let result = &self_array * &other_array;
641
642        // Convert back to CSR
643        let (rows, cols) = self.shape();
644        let mut data = Vec::new();
645        let mut indices = Vec::new();
646        let mut indptr = vec![0];
647
648        for row in 0..rows {
649            for col in 0..cols {
650                let val = result[[row, col]];
651                if val != T::sparse_zero() {
652                    data.push(val);
653                    indices.push(col);
654                }
655            }
656            indptr.push(data.len());
657        }
658
659        CsrArray::new(
660            Array1::from_vec(data),
661            Array1::from_vec(indices),
662            Array1::from_vec(indptr),
663            self.shape(),
664        )
665        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
666    }
667
668    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
669        // Element-wise division
670        let self_array = self.to_array();
671        let other_array = other.to_array();
672
673        if self.shape() != other.shape() {
674            return Err(SparseError::DimensionMismatch {
675                expected: self.shape().0,
676                found: other.shape().0,
677            });
678        }
679
680        let result = &self_array / &other_array;
681
682        // Convert back to CSR
683        let (rows, cols) = self.shape();
684        let mut data = Vec::new();
685        let mut indices = Vec::new();
686        let mut indptr = vec![0];
687
688        for row in 0..rows {
689            for col in 0..cols {
690                let val = result[[row, col]];
691                if val != T::sparse_zero() {
692                    data.push(val);
693                    indices.push(col);
694                }
695            }
696            indptr.push(data.len());
697        }
698
699        CsrArray::new(
700            Array1::from_vec(data),
701            Array1::from_vec(indices),
702            Array1::from_vec(indptr),
703            self.shape(),
704        )
705        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
706    }
707
708    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
709        let (m, n) = self.shape();
710        let (p, q) = other.shape();
711
712        if n != p {
713            return Err(SparseError::DimensionMismatch {
714                expected: n,
715                found: p,
716            });
717        }
718
719        // Fast path: if `other` is also a CsrArray, use scatter-gather
720        // row-by-row multiplication in O(nnz(A) * avg_nnz_per_row(B)) time.
721        if let Some(other_csr) = other.as_any().downcast_ref::<CsrArray<T>>() {
722            let mut data = Vec::new();
723            let mut col_indices = Vec::new();
724            let mut indptr = vec![0usize];
725
726            let mut workspace = vec![T::sparse_zero(); q];
727            let mut marker = vec![false; q];
728
729            for i in 0..m {
730                let a_start = self.indptr[i];
731                let a_end = self.indptr[i + 1];
732                let mut touched: Vec<usize> = Vec::new();
733
734                for a_idx in a_start..a_end {
735                    let k = self.indices[a_idx];
736                    let a_ik = self.data[a_idx];
737                    if a_ik == T::sparse_zero() {
738                        continue;
739                    }
740                    let b_start = other_csr.indptr[k];
741                    let b_end = other_csr.indptr[k + 1];
742                    for b_idx in b_start..b_end {
743                        let j = other_csr.indices[b_idx];
744                        workspace[j] = workspace[j] + a_ik * other_csr.data[b_idx];
745                        if !marker[j] {
746                            marker[j] = true;
747                            touched.push(j);
748                        }
749                    }
750                }
751
752                touched.sort_unstable();
753                for &j in &touched {
754                    let val = workspace[j];
755                    if val != T::sparse_zero() {
756                        data.push(val);
757                        col_indices.push(j);
758                    }
759                    workspace[j] = T::sparse_zero();
760                    marker[j] = false;
761                }
762                indptr.push(data.len());
763            }
764
765            return CsrArray::new(
766                Array1::from_vec(data),
767                Array1::from_vec(col_indices),
768                Array1::from_vec(indptr),
769                (m, q),
770            )
771            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>);
772        }
773
774        // Fallback: use dense `other` matrix
775        let other_array = other.to_array();
776        let mut data = Vec::new();
777        let mut col_indices = Vec::new();
778        let mut indptr = vec![0];
779
780        for row in 0..m {
781            let start = self.indptr[row];
782            let end = self.indptr[row + 1];
783
784            for j in 0..q {
785                let mut sum = T::sparse_zero();
786                for idx in start..end {
787                    let col = self.indices[idx];
788                    sum = sum + self.data[idx] * other_array[[col, j]];
789                }
790                if sum != T::sparse_zero() {
791                    data.push(sum);
792                    col_indices.push(j);
793                }
794            }
795            indptr.push(data.len());
796        }
797
798        CsrArray::new(
799            Array1::from_vec(data),
800            Array1::from_vec(col_indices),
801            Array1::from_vec(indptr),
802            (m, q),
803        )
804        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
805    }
806
807    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
808        let (m, n) = self.shape();
809        if n != other.len() {
810            return Err(SparseError::DimensionMismatch {
811                expected: n,
812                found: other.len(),
813            });
814        }
815
816        let mut result = Array1::zeros(m);
817
818        for row in 0..m {
819            let start = self.indptr[row];
820            let end = self.indptr[row + 1];
821
822            let mut sum = T::sparse_zero();
823            for idx in start..end {
824                let col = self.indices[idx];
825                sum = sum + self.data[idx] * other[col];
826            }
827            result[row] = sum;
828        }
829
830        Ok(result)
831    }
832
833    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
834        // Transpose is non-trivial for CSR format
835        // A full implementation would convert to CSC format or implement
836        // an efficient algorithm
837        let (rows, cols) = self.shape();
838        let mut row_indices = Vec::with_capacity(self.nnz());
839        let mut col_indices = Vec::with_capacity(self.nnz());
840        let mut values = Vec::with_capacity(self.nnz());
841
842        for row in 0..rows {
843            let start = self.indptr[row];
844            let end = self.indptr[row + 1];
845
846            for idx in start..end {
847                let col = self.indices[idx];
848                row_indices.push(col); // Note: rows and cols are swapped for transposition
849                col_indices.push(row);
850                values.push(self.data[idx]);
851            }
852        }
853
854        // We need to create CSR from this "COO" representation
855        CsrArray::from_triplets(
856            &row_indices,
857            &col_indices,
858            &values,
859            (cols, rows), // Swapped dimensions
860            false,
861        )
862        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
863    }
864
865    fn copy(&self) -> Box<dyn SparseArray<T>> {
866        Box::new(self.clone())
867    }
868
869    fn get(&self, i: usize, j: usize) -> T {
870        if i >= self.shape.0 || j >= self.shape.1 {
871            return T::sparse_zero();
872        }
873
874        let start = self.indptr[i];
875        let end = self.indptr[i + 1];
876
877        for idx in start..end {
878            if self.indices[idx] == j {
879                return self.data[idx];
880            }
881            // If indices are sorted, we can break early
882            if self.has_sorted_indices && self.indices[idx] > j {
883                break;
884            }
885        }
886
887        T::sparse_zero()
888    }
889
890    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
891        if i >= self.shape.0 || j >= self.shape.1 {
892            return Err(SparseError::IndexOutOfBounds {
893                index: (i, j),
894                shape: self.shape,
895            });
896        }
897
898        let start = self.indptr[i];
899        let end = self.indptr[i + 1];
900
901        // Try to find existing element
902        for idx in start..end {
903            if self.indices[idx] == j {
904                self.data[idx] = value;
905                return Ok(());
906            }
907            if self.has_sorted_indices && self.indices[idx] > j {
908                // Insert at position `idx` to maintain sorted order
909                self.data = array1_insert(&self.data, idx, value);
910                self.indices = array1_insert(&self.indices, idx, j);
911                // Increment indptr for all subsequent rows
912                for row_ptr in self.indptr.iter_mut().skip(i + 1) {
913                    *row_ptr += 1;
914                }
915                return Ok(());
916            }
917        }
918
919        // Element not found - insert at end of this row's range
920        self.data = array1_insert(&self.data, end, value);
921        self.indices = array1_insert(&self.indices, end, j);
922        // Increment indptr for all subsequent rows
923        for row_ptr in self.indptr.iter_mut().skip(i + 1) {
924            *row_ptr += 1;
925        }
926        // If we inserted at the end, indices may no longer be sorted
927        // (only if there are elements after this row that come before j).
928        // Re-check sorted state for this row.
929        if self.has_sorted_indices {
930            let new_end = self.indptr[i + 1];
931            let new_start = self.indptr[i];
932            for k in new_start..new_end.saturating_sub(1) {
933                if self.indices[k] > self.indices[k + 1] {
934                    self.has_sorted_indices = false;
935                    break;
936                }
937            }
938        }
939        Ok(())
940    }
941
942    fn eliminate_zeros(&mut self) {
943        // Find all non-zero entries
944        let mut new_data = Vec::new();
945        let mut new_indices = Vec::new();
946        let mut new_indptr = vec![0];
947
948        let (rows, _) = self.shape();
949
950        for row in 0..rows {
951            let start = self.indptr[row];
952            let end = self.indptr[row + 1];
953
954            for idx in start..end {
955                if !SparseElement::is_zero(&self.data[idx]) {
956                    new_data.push(self.data[idx]);
957                    new_indices.push(self.indices[idx]);
958                }
959            }
960            new_indptr.push(new_data.len());
961        }
962
963        // Replace data with filtered data
964        self.data = Array1::from_vec(new_data);
965        self.indices = Array1::from_vec(new_indices);
966        self.indptr = Array1::from_vec(new_indptr);
967    }
968
969    fn sort_indices(&mut self) {
970        if self.has_sorted_indices {
971            return;
972        }
973
974        let (rows, _) = self.shape();
975
976        for row in 0..rows {
977            let start = self.indptr[row];
978            let end = self.indptr[row + 1];
979
980            if start == end {
981                continue;
982            }
983
984            // Extract the non-zero elements for this row
985            let mut row_data = Vec::with_capacity(end - start);
986            for idx in start..end {
987                row_data.push((self.indices[idx], self.data[idx]));
988            }
989
990            // Sort by column index
991            row_data.sort_by_key(|&(col_, _)| col_);
992
993            // Put the sorted data back
994            for (i, (col, val)) in row_data.into_iter().enumerate() {
995                self.indices[start + i] = col;
996                self.data[start + i] = val;
997            }
998        }
999
1000        self.has_sorted_indices = true;
1001    }
1002
1003    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
1004        if self.has_sorted_indices {
1005            return Box::new(self.clone());
1006        }
1007
1008        let mut sorted = self.clone();
1009        sorted.sort_indices();
1010        Box::new(sorted)
1011    }
1012
1013    fn has_sorted_indices(&self) -> bool {
1014        self.has_sorted_indices
1015    }
1016
1017    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
1018        match axis {
1019            None => {
1020                // Sum all elements
1021                let mut sum = T::sparse_zero();
1022                for &val in self.data.iter() {
1023                    sum = sum + val;
1024                }
1025                Ok(SparseSum::Scalar(sum))
1026            }
1027            Some(0) => {
1028                // Sum along rows (result is a row vector)
1029                let (_, cols) = self.shape();
1030                let mut result = vec![T::sparse_zero(); cols];
1031
1032                for row in 0..self.shape.0 {
1033                    let start = self.indptr[row];
1034                    let end = self.indptr[row + 1];
1035
1036                    for idx in start..end {
1037                        let col = self.indices[idx];
1038                        result[col] = result[col] + self.data[idx];
1039                    }
1040                }
1041
1042                // Convert to CSR format
1043                let mut data = Vec::new();
1044                let mut indices = Vec::new();
1045                let mut indptr = vec![0];
1046
1047                for (col, &val) in result.iter().enumerate() {
1048                    if val != T::sparse_zero() {
1049                        data.push(val);
1050                        indices.push(col);
1051                    }
1052                }
1053                indptr.push(data.len());
1054
1055                let result_array = CsrArray::new(
1056                    Array1::from_vec(data),
1057                    Array1::from_vec(indices),
1058                    Array1::from_vec(indptr),
1059                    (1, cols),
1060                )?;
1061
1062                Ok(SparseSum::SparseArray(Box::new(result_array)))
1063            }
1064            Some(1) => {
1065                // Sum along columns (result is a column vector)
1066                let mut result = Vec::with_capacity(self.shape.0);
1067
1068                for row in 0..self.shape.0 {
1069                    let start = self.indptr[row];
1070                    let end = self.indptr[row + 1];
1071
1072                    let mut row_sum = T::sparse_zero();
1073                    for idx in start..end {
1074                        row_sum = row_sum + self.data[idx];
1075                    }
1076                    result.push(row_sum);
1077                }
1078
1079                // Convert to CSR format
1080                let mut data = Vec::new();
1081                let mut indices = Vec::new();
1082                let mut indptr = vec![0];
1083
1084                for &val in result.iter() {
1085                    if val != T::sparse_zero() {
1086                        data.push(val);
1087                        indices.push(0);
1088                        indptr.push(data.len());
1089                    } else {
1090                        indptr.push(data.len());
1091                    }
1092                }
1093
1094                let result_array = CsrArray::new(
1095                    Array1::from_vec(data),
1096                    Array1::from_vec(indices),
1097                    Array1::from_vec(indptr),
1098                    (self.shape.0, 1),
1099                )?;
1100
1101                Ok(SparseSum::SparseArray(Box::new(result_array)))
1102            }
1103            _ => Err(SparseError::InvalidAxis),
1104        }
1105    }
1106
1107    fn max(&self) -> T {
1108        if self.data.is_empty() {
1109            // Empty sparse matrix - all elements are implicitly zero
1110            return T::sparse_zero();
1111        }
1112
1113        let mut max_val = self.data[0];
1114        for &val in self.data.iter().skip(1) {
1115            if val > max_val {
1116                max_val = val;
1117            }
1118        }
1119
1120        // Check if max_val is less than zero, as zeros aren't explicitly stored
1121        // If the matrix has implicit zeros and max_val < 0, then max is 0
1122        let zero = T::sparse_zero();
1123        if max_val < zero && self.nnz() < self.shape.0 * self.shape.1 {
1124            max_val = zero;
1125        }
1126
1127        max_val
1128    }
1129
1130    fn min(&self) -> T {
1131        if self.data.is_empty() {
1132            // Empty sparse matrix - all elements are implicitly zero
1133            return T::sparse_zero();
1134        }
1135
1136        let mut min_val = self.data[0];
1137        for &val in self.data.iter().skip(1) {
1138            if val < min_val {
1139                min_val = val;
1140            }
1141        }
1142
1143        // Check if min_val is greater than zero, as zeros aren't explicitly stored
1144        // If the matrix has implicit zeros and min_val > 0, then min is 0
1145        let zero = T::sparse_zero();
1146        if min_val > zero && self.nnz() < self.shape.0 * self.shape.1 {
1147            min_val = zero;
1148        }
1149
1150        min_val
1151    }
1152
1153    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
1154        let nnz = self.nnz();
1155        let mut rows = Vec::with_capacity(nnz);
1156        let mut cols = Vec::with_capacity(nnz);
1157        let mut values = Vec::with_capacity(nnz);
1158
1159        for row in 0..self.shape.0 {
1160            let start = self.indptr[row];
1161            let end = self.indptr[row + 1];
1162
1163            for idx in start..end {
1164                let col = self.indices[idx];
1165                rows.push(row);
1166                cols.push(col);
1167                values.push(self.data[idx]);
1168            }
1169        }
1170
1171        (
1172            Array1::from_vec(rows),
1173            Array1::from_vec(cols),
1174            Array1::from_vec(values),
1175        )
1176    }
1177
1178    fn slice(
1179        &self,
1180        row_range: (usize, usize),
1181        col_range: (usize, usize),
1182    ) -> SparseResult<Box<dyn SparseArray<T>>> {
1183        let (start_row, end_row) = row_range;
1184        let (start_col, end_col) = col_range;
1185
1186        if start_row >= self.shape.0
1187            || end_row > self.shape.0
1188            || start_col >= self.shape.1
1189            || end_col > self.shape.1
1190        {
1191            return Err(SparseError::InvalidSliceRange);
1192        }
1193
1194        if start_row >= end_row || start_col >= end_col {
1195            return Err(SparseError::InvalidSliceRange);
1196        }
1197
1198        let mut data = Vec::new();
1199        let mut indices = Vec::new();
1200        let mut indptr = vec![0];
1201
1202        for row in start_row..end_row {
1203            let start = self.indptr[row];
1204            let end = self.indptr[row + 1];
1205
1206            for idx in start..end {
1207                let col = self.indices[idx];
1208                if col >= start_col && col < end_col {
1209                    data.push(self.data[idx]);
1210                    indices.push(col - start_col);
1211                }
1212            }
1213            indptr.push(data.len());
1214        }
1215
1216        CsrArray::new(
1217            Array1::from_vec(data),
1218            Array1::from_vec(indices),
1219            Array1::from_vec(indptr),
1220            (end_row - start_row, end_col - start_col),
1221        )
1222        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
1223    }
1224
1225    fn as_any(&self) -> &dyn std::any::Any {
1226        self
1227    }
1228
1229    fn get_indptr(&self) -> Option<&Array1<usize>> {
1230        Some(&self.indptr)
1231    }
1232
1233    fn indptr(&self) -> Option<&Array1<usize>> {
1234        Some(&self.indptr)
1235    }
1236}
1237
1238impl<T> fmt::Debug for CsrArray<T>
1239where
1240    T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
1241{
1242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1243        write!(
1244            f,
1245            "CsrArray<{}x{}, nnz={}>",
1246            self.shape.0,
1247            self.shape.1,
1248            self.nnz()
1249        )
1250    }
1251}
1252
1253#[cfg(test)]
1254mod tests {
1255    use super::*;
1256    use approx::assert_relative_eq;
1257
1258    #[test]
1259    fn test_csr_array_construction() {
1260        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1261        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1262        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1263        let shape = (3, 3);
1264
1265        let csr = CsrArray::new(data, indices, indptr, shape).expect("Operation failed");
1266
1267        assert_eq!(csr.shape(), (3, 3));
1268        assert_eq!(csr.nnz(), 5);
1269        assert_eq!(csr.get(0, 0), 1.0);
1270        assert_eq!(csr.get(0, 2), 2.0);
1271        assert_eq!(csr.get(1, 1), 3.0);
1272        assert_eq!(csr.get(2, 0), 4.0);
1273        assert_eq!(csr.get(2, 2), 5.0);
1274        assert_eq!(csr.get(0, 1), 0.0);
1275    }
1276
1277    #[test]
1278    fn test_csr_from_triplets() {
1279        let rows = vec![0, 0, 1, 2, 2];
1280        let cols = vec![0, 2, 1, 0, 2];
1281        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1282        let shape = (3, 3);
1283
1284        let csr =
1285            CsrArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
1286
1287        assert_eq!(csr.shape(), (3, 3));
1288        assert_eq!(csr.nnz(), 5);
1289        assert_eq!(csr.get(0, 0), 1.0);
1290        assert_eq!(csr.get(0, 2), 2.0);
1291        assert_eq!(csr.get(1, 1), 3.0);
1292        assert_eq!(csr.get(2, 0), 4.0);
1293        assert_eq!(csr.get(2, 2), 5.0);
1294        assert_eq!(csr.get(0, 1), 0.0);
1295    }
1296
1297    #[test]
1298    fn test_csr_array_to_array() {
1299        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1300        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1301        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1302        let shape = (3, 3);
1303
1304        let csr = CsrArray::new(data, indices, indptr, shape).expect("Operation failed");
1305        let dense = csr.to_array();
1306
1307        assert_eq!(dense.shape(), &[3, 3]);
1308        assert_eq!(dense[[0, 0]], 1.0);
1309        assert_eq!(dense[[0, 1]], 0.0);
1310        assert_eq!(dense[[0, 2]], 2.0);
1311        assert_eq!(dense[[1, 0]], 0.0);
1312        assert_eq!(dense[[1, 1]], 3.0);
1313        assert_eq!(dense[[1, 2]], 0.0);
1314        assert_eq!(dense[[2, 0]], 4.0);
1315        assert_eq!(dense[[2, 1]], 0.0);
1316        assert_eq!(dense[[2, 2]], 5.0);
1317    }
1318
1319    #[test]
1320    fn test_csr_array_dot_vector() {
1321        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1322        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1323        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1324        let shape = (3, 3);
1325
1326        let csr = CsrArray::new(data, indices, indptr, shape).expect("Operation failed");
1327        let vec = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1328
1329        let result = csr.dot_vector(&vec.view()).expect("Operation failed");
1330
1331        // Expected: [1*1 + 0*2 + 2*3, 0*1 + 3*2 + 0*3, 4*1 + 0*2 + 5*3] = [7, 6, 19]
1332        assert_eq!(result.len(), 3);
1333        assert_relative_eq!(result[0], 7.0);
1334        assert_relative_eq!(result[1], 6.0);
1335        assert_relative_eq!(result[2], 19.0);
1336    }
1337
1338    #[test]
1339    fn test_csr_array_sum() {
1340        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1341        let indices = Array1::from_vec(vec![0, 2, 1, 0, 2]);
1342        let indptr = Array1::from_vec(vec![0, 2, 3, 5]);
1343        let shape = (3, 3);
1344
1345        let csr = CsrArray::new(data, indices, indptr, shape).expect("Operation failed");
1346
1347        // Sum all elements
1348        if let SparseSum::Scalar(sum) = csr.sum(None).expect("Operation failed") {
1349            assert_relative_eq!(sum, 15.0);
1350        } else {
1351            panic!("Expected scalar sum");
1352        }
1353
1354        // Sum along rows
1355        if let SparseSum::SparseArray(row_sum) = csr.sum(Some(0)).expect("Operation failed") {
1356            let row_sum_array = row_sum.to_array();
1357            assert_eq!(row_sum_array.shape(), &[1, 3]);
1358            assert_relative_eq!(row_sum_array[[0, 0]], 5.0);
1359            assert_relative_eq!(row_sum_array[[0, 1]], 3.0);
1360            assert_relative_eq!(row_sum_array[[0, 2]], 7.0);
1361        } else {
1362            panic!("Expected sparse array sum");
1363        }
1364
1365        // Sum along columns
1366        if let SparseSum::SparseArray(col_sum) = csr.sum(Some(1)).expect("Operation failed") {
1367            let col_sum_array = col_sum.to_array();
1368            assert_eq!(col_sum_array.shape(), &[3, 1]);
1369            assert_relative_eq!(col_sum_array[[0, 0]], 3.0);
1370            assert_relative_eq!(col_sum_array[[1, 0]], 3.0);
1371            assert_relative_eq!(col_sum_array[[2, 0]], 9.0);
1372        } else {
1373            panic!("Expected sparse array sum");
1374        }
1375    }
1376}