oxiblas_sparse/
csc.rs

1//! Compressed Sparse Column (CSC) matrix format.
2//!
3//! CSC stores matrix data using three arrays:
4//! - `values`: Non-zero values (column-major order)
5//! - `row_indices`: Row index for each value
6//! - `col_ptrs`: Index into values/row_indices for start of each column
7//!
8//! For an m×n matrix with nnz non-zeros:
9//! - `values` has length nnz
10//! - `row_indices` has length nnz
11//! - `col_ptrs` has length n+1
12
13use oxiblas_core::scalar::{Field, Scalar};
14use std::ops::Index;
15
16/// Error type for CSC matrix operations.
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum CscError {
19    /// Invalid column pointer array length.
20    InvalidColPtrs {
21        /// Expected length.
22        expected: usize,
23        /// Actual length.
24        actual: usize,
25    },
26    /// Mismatched array lengths.
27    LengthMismatch {
28        /// Number of values.
29        values_len: usize,
30        /// Number of row indices.
31        row_indices_len: usize,
32    },
33    /// Row index out of bounds.
34    InvalidRowIndex {
35        /// The invalid index.
36        index: usize,
37        /// Number of rows.
38        nrows: usize,
39    },
40    /// Column pointers not monotonically increasing.
41    InvalidColPtrOrder,
42    /// Duplicate entry at same position.
43    DuplicateEntry {
44        /// Row of duplicate.
45        row: usize,
46        /// Column of duplicate.
47        col: usize,
48    },
49}
50
51impl core::fmt::Display for CscError {
52    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
53        match self {
54            Self::InvalidColPtrs { expected, actual } => {
55                write!(
56                    f,
57                    "Invalid col_ptrs length: expected {expected}, got {actual}"
58                )
59            }
60            Self::LengthMismatch {
61                values_len,
62                row_indices_len,
63            } => {
64                write!(
65                    f,
66                    "Length mismatch: values={values_len}, row_indices={row_indices_len}"
67                )
68            }
69            Self::InvalidRowIndex { index, nrows } => {
70                write!(f, "Row index {index} out of bounds for {nrows} rows")
71            }
72            Self::InvalidColPtrOrder => {
73                write!(f, "Column pointers must be monotonically increasing")
74            }
75            Self::DuplicateEntry { row, col } => {
76                write!(f, "Duplicate entry at ({row}, {col})")
77            }
78        }
79    }
80}
81
82impl std::error::Error for CscError {}
83
84/// Compressed Sparse Column matrix.
85///
86/// Efficient for:
87/// - Column slicing
88/// - Matrix-vector products with transpose (y = A^T * x)
89/// - Column-wise traversal
90/// - Direct solvers (LU, Cholesky)
91#[derive(Debug, Clone)]
92pub struct CscMatrix<T: Scalar> {
93    /// Number of rows.
94    nrows: usize,
95    /// Number of columns.
96    ncols: usize,
97    /// Column pointers (length ncols + 1).
98    col_ptrs: Vec<usize>,
99    /// Row indices for each non-zero.
100    row_indices: Vec<usize>,
101    /// Non-zero values.
102    values: Vec<T>,
103}
104
105impl<T: Scalar + Clone> CscMatrix<T> {
106    /// Creates a new CSC matrix from raw components.
107    ///
108    /// # Arguments
109    ///
110    /// * `nrows` - Number of rows
111    /// * `ncols` - Number of columns
112    /// * `col_ptrs` - Column pointers (length ncols + 1)
113    /// * `row_indices` - Row indices for each non-zero
114    /// * `values` - Non-zero values
115    ///
116    /// # Errors
117    ///
118    /// Returns an error if the input is invalid.
119    pub fn new(
120        nrows: usize,
121        ncols: usize,
122        col_ptrs: Vec<usize>,
123        row_indices: Vec<usize>,
124        values: Vec<T>,
125    ) -> Result<Self, CscError> {
126        // Validate col_ptrs length
127        if col_ptrs.len() != ncols + 1 {
128            return Err(CscError::InvalidColPtrs {
129                expected: ncols + 1,
130                actual: col_ptrs.len(),
131            });
132        }
133
134        // Validate values and row_indices have same length
135        if values.len() != row_indices.len() {
136            return Err(CscError::LengthMismatch {
137                values_len: values.len(),
138                row_indices_len: row_indices.len(),
139            });
140        }
141
142        // Validate col_ptrs are monotonically increasing
143        for i in 1..col_ptrs.len() {
144            if col_ptrs[i] < col_ptrs[i - 1] {
145                return Err(CscError::InvalidColPtrOrder);
146            }
147        }
148
149        // Validate col_ptrs[ncols] equals nnz
150        let nnz = values.len();
151        if col_ptrs[ncols] != nnz {
152            return Err(CscError::InvalidColPtrs {
153                expected: nnz,
154                actual: col_ptrs[ncols],
155            });
156        }
157
158        // Validate row indices
159        for &row in &row_indices {
160            if row >= nrows {
161                return Err(CscError::InvalidRowIndex { index: row, nrows });
162            }
163        }
164
165        Ok(Self {
166            nrows,
167            ncols,
168            col_ptrs,
169            row_indices,
170            values,
171        })
172    }
173
174    /// Creates a CSC matrix without validation (unsafe but faster).
175    ///
176    /// # Safety
177    ///
178    /// The caller must ensure:
179    /// - `col_ptrs.len() == ncols + 1`
180    /// - `values.len() == row_indices.len()`
181    /// - `col_ptrs` is monotonically increasing
182    /// - All row indices are < nrows
183    #[inline]
184    pub unsafe fn new_unchecked(
185        nrows: usize,
186        ncols: usize,
187        col_ptrs: Vec<usize>,
188        row_indices: Vec<usize>,
189        values: Vec<T>,
190    ) -> Self {
191        Self {
192            nrows,
193            ncols,
194            col_ptrs,
195            row_indices,
196            values,
197        }
198    }
199
200    /// Creates an empty CSC matrix with given dimensions.
201    pub fn zeros(nrows: usize, ncols: usize) -> Self {
202        Self {
203            nrows,
204            ncols,
205            col_ptrs: vec![0; ncols + 1],
206            row_indices: Vec::new(),
207            values: Vec::new(),
208        }
209    }
210
211    /// Creates an identity matrix in CSC format.
212    pub fn eye(n: usize) -> Self
213    where
214        T: Field,
215    {
216        let mut col_ptrs = Vec::with_capacity(n + 1);
217        let mut row_indices = Vec::with_capacity(n);
218        let mut values = Vec::with_capacity(n);
219
220        for i in 0..n {
221            col_ptrs.push(i);
222            row_indices.push(i);
223            values.push(T::one());
224        }
225        col_ptrs.push(n);
226
227        Self {
228            nrows: n,
229            ncols: n,
230            col_ptrs,
231            row_indices,
232            values,
233        }
234    }
235
236    /// Returns the number of rows.
237    #[inline]
238    pub fn nrows(&self) -> usize {
239        self.nrows
240    }
241
242    /// Returns the number of columns.
243    #[inline]
244    pub fn ncols(&self) -> usize {
245        self.ncols
246    }
247
248    /// Returns the shape (nrows, ncols).
249    #[inline]
250    pub fn shape(&self) -> (usize, usize) {
251        (self.nrows, self.ncols)
252    }
253
254    /// Returns the number of non-zero elements.
255    #[inline]
256    pub fn nnz(&self) -> usize {
257        self.values.len()
258    }
259
260    /// Returns the density (nnz / total_elements).
261    #[inline]
262    pub fn density(&self) -> f64 {
263        if self.nrows == 0 || self.ncols == 0 {
264            0.0
265        } else {
266            self.nnz() as f64 / (self.nrows * self.ncols) as f64
267        }
268    }
269
270    /// Returns a reference to the column pointers.
271    #[inline]
272    pub fn col_ptrs(&self) -> &[usize] {
273        &self.col_ptrs
274    }
275
276    /// Returns a reference to the row indices.
277    #[inline]
278    pub fn row_indices(&self) -> &[usize] {
279        &self.row_indices
280    }
281
282    /// Returns a reference to the values.
283    #[inline]
284    pub fn values(&self) -> &[T] {
285        &self.values
286    }
287
288    /// Returns a mutable reference to the values.
289    #[inline]
290    pub fn values_mut(&mut self) -> &mut [T] {
291        &mut self.values
292    }
293
294    /// Gets the value at (row, col), returning None if not present.
295    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
296        if row >= self.nrows || col >= self.ncols {
297            return None;
298        }
299
300        let start = self.col_ptrs[col];
301        let end = self.col_ptrs[col + 1];
302
303        for i in start..end {
304            if self.row_indices[i] == row {
305                return Some(&self.values[i]);
306            }
307        }
308
309        None
310    }
311
312    /// Gets the value at (row, col), returning zero if not present.
313    pub fn get_or_zero(&self, row: usize, col: usize) -> T
314    where
315        T: Field,
316    {
317        self.get(row, col).cloned().unwrap_or_else(T::zero)
318    }
319
320    /// Returns an iterator over the non-zeros in a column.
321    pub fn col_iter(&self, col: usize) -> impl Iterator<Item = (usize, &T)> {
322        let start = self.col_ptrs[col];
323        let end = self.col_ptrs[col + 1];
324
325        self.row_indices[start..end]
326            .iter()
327            .zip(self.values[start..end].iter())
328            .map(|(&row, val)| (row, val))
329    }
330
331    /// Returns an iterator over all non-zeros as (row, col, value).
332    pub fn iter(&self) -> impl Iterator<Item = (usize, usize, &T)> {
333        (0..self.ncols).flat_map(move |col| {
334            let start = self.col_ptrs[col];
335            let end = self.col_ptrs[col + 1];
336
337            self.row_indices[start..end]
338                .iter()
339                .zip(self.values[start..end].iter())
340                .map(move |(&row, val)| (row, col, val))
341        })
342    }
343
344    /// Converts to CSR format.
345    pub fn to_csr(&self) -> crate::csr::CsrMatrix<T> {
346        crate::convert::csc_to_csr(self)
347    }
348
349    /// Converts to dense matrix.
350    pub fn to_dense(&self) -> oxiblas_matrix::Mat<T>
351    where
352        T: Field + bytemuck::Zeroable,
353    {
354        let mut dense = oxiblas_matrix::Mat::zeros(self.nrows, self.ncols);
355
356        for col in 0..self.ncols {
357            let start = self.col_ptrs[col];
358            let end = self.col_ptrs[col + 1];
359
360            for i in start..end {
361                dense[(self.row_indices[i], col)] = self.values[i].clone();
362            }
363        }
364
365        dense
366    }
367
368    /// Creates a CSC matrix from a dense matrix.
369    pub fn from_dense(dense: &oxiblas_matrix::MatRef<'_, T>) -> Self
370    where
371        T: Field,
372    {
373        let (nrows, ncols) = dense.shape();
374        let mut col_ptrs = Vec::with_capacity(ncols + 1);
375        let mut row_indices = Vec::new();
376        let mut values = Vec::new();
377
378        let eps = <T as Scalar>::epsilon();
379
380        col_ptrs.push(0);
381
382        for j in 0..ncols {
383            for i in 0..nrows {
384                let val = dense[(i, j)].clone();
385                if Scalar::abs(val.clone()) > eps {
386                    row_indices.push(i);
387                    values.push(val);
388                }
389            }
390            col_ptrs.push(values.len());
391        }
392
393        Self {
394            nrows,
395            ncols,
396            col_ptrs,
397            row_indices,
398            values,
399        }
400    }
401
402    /// Returns the transpose of this matrix.
403    pub fn transpose(&self) -> Self {
404        // Transpose of CSC is equivalent to CSR with swapped interpretation
405        // Then convert back to CSC
406        let csr = self.to_csr();
407        csr.to_csc()
408    }
409
410    /// Scales all values by a scalar.
411    pub fn scale(&mut self, alpha: T) {
412        for val in &mut self.values {
413            *val = val.clone() * alpha.clone();
414        }
415    }
416
417    /// Returns a scaled copy of this matrix.
418    pub fn scaled(&self, alpha: T) -> Self {
419        let mut result = self.clone();
420        result.scale(alpha);
421        result
422    }
423
424    /// Returns the number of non-zeros in a column.
425    #[inline]
426    pub fn col_nnz(&self, col: usize) -> usize {
427        self.col_ptrs[col + 1] - self.col_ptrs[col]
428    }
429
430    /// Checks if the matrix is structurally symmetric.
431    ///
432    /// Returns true if A\[i,j\] != 0 implies A\[j,i\] != 0.
433    pub fn is_structurally_symmetric(&self) -> bool {
434        if self.nrows != self.ncols {
435            return false;
436        }
437
438        for col in 0..self.ncols {
439            let start = self.col_ptrs[col];
440            let end = self.col_ptrs[col + 1];
441
442            for i in start..end {
443                let row = self.row_indices[i];
444                if self.get(col, row).is_none() {
445                    return false;
446                }
447            }
448        }
449
450        true
451    }
452}
453
454impl<T: Scalar + Clone> Index<(usize, usize)> for CscMatrix<T>
455where
456    T: Field,
457{
458    type Output = T;
459
460    fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
461        self.get(row, col)
462            .expect("Index out of bounds or zero element")
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_csc_new() {
472        // [1 0 4]
473        // [0 3 0]
474        // [2 0 5]
475        // Column 0: 1, 2 at rows 0, 2
476        // Column 1: 3 at row 1
477        // Column 2: 4, 5 at rows 0, 2
478        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
479        let row_indices = vec![0, 2, 1, 0, 2];
480        let col_ptrs = vec![0, 2, 3, 5];
481
482        let csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
483
484        assert_eq!(csc.nrows(), 3);
485        assert_eq!(csc.ncols(), 3);
486        assert_eq!(csc.nnz(), 5);
487    }
488
489    #[test]
490    fn test_csc_get() {
491        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
492        let row_indices = vec![0, 2, 1, 0, 2];
493        let col_ptrs = vec![0, 2, 3, 5];
494
495        let csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
496
497        assert_eq!(csc.get(0, 0), Some(&1.0));
498        assert_eq!(csc.get(2, 0), Some(&2.0));
499        assert_eq!(csc.get(1, 1), Some(&3.0));
500        assert_eq!(csc.get(0, 2), Some(&4.0));
501        assert_eq!(csc.get(2, 2), Some(&5.0));
502
503        // Zero elements
504        assert_eq!(csc.get(1, 0), None);
505        assert_eq!(csc.get(0, 1), None);
506    }
507
508    #[test]
509    fn test_csc_zeros() {
510        let csc: CscMatrix<f64> = CscMatrix::zeros(5, 3);
511
512        assert_eq!(csc.nrows(), 5);
513        assert_eq!(csc.ncols(), 3);
514        assert_eq!(csc.nnz(), 0);
515    }
516
517    #[test]
518    fn test_csc_eye() {
519        let csc: CscMatrix<f64> = CscMatrix::eye(4);
520
521        assert_eq!(csc.nrows(), 4);
522        assert_eq!(csc.ncols(), 4);
523        assert_eq!(csc.nnz(), 4);
524
525        for i in 0..4 {
526            assert_eq!(csc.get(i, i), Some(&1.0));
527        }
528    }
529
530    #[test]
531    fn test_csc_density() {
532        let values = vec![1.0f64, 2.0, 3.0];
533        let row_indices = vec![0, 1, 2];
534        let col_ptrs = vec![0, 1, 2, 3];
535
536        let csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
537
538        let density = csc.density();
539        assert!((density - 3.0 / 9.0).abs() < 1e-10);
540    }
541
542    #[test]
543    fn test_csc_col_iter() {
544        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
545        let row_indices = vec![0, 2, 1, 0, 2];
546        let col_ptrs = vec![0, 2, 3, 5];
547
548        let csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
549
550        let col0: Vec<_> = csc.col_iter(0).collect();
551        assert_eq!(col0, vec![(0, &1.0), (2, &2.0)]);
552
553        let col1: Vec<_> = csc.col_iter(1).collect();
554        assert_eq!(col1, vec![(1, &3.0)]);
555    }
556
557    #[test]
558    fn test_csc_scale() {
559        let values = vec![1.0f64, 2.0, 3.0];
560        let row_indices = vec![0, 1, 2];
561        let col_ptrs = vec![0, 1, 2, 3];
562
563        let mut csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
564        csc.scale(2.0);
565
566        assert_eq!(csc.values(), &[2.0, 4.0, 6.0]);
567    }
568
569    #[test]
570    fn test_csc_invalid_col_ptrs() {
571        let values = vec![1.0f64, 2.0];
572        let row_indices = vec![0, 1];
573        let col_ptrs = vec![0, 1]; // Should have length 3 for 2 columns
574
575        let result = CscMatrix::new(2, 2, col_ptrs, row_indices, values);
576        assert!(matches!(result, Err(CscError::InvalidColPtrs { .. })));
577    }
578
579    #[test]
580    fn test_csc_invalid_row_index() {
581        let values = vec![1.0f64];
582        let row_indices = vec![5]; // Out of bounds
583        let col_ptrs = vec![0, 1];
584
585        let result = CscMatrix::new(3, 1, col_ptrs, row_indices, values);
586        assert!(matches!(result, Err(CscError::InvalidRowIndex { .. })));
587    }
588
589    #[test]
590    fn test_csc_col_nnz() {
591        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
592        let row_indices = vec![0, 2, 1, 0, 2];
593        let col_ptrs = vec![0, 2, 3, 5];
594
595        let csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
596
597        assert_eq!(csc.col_nnz(0), 2);
598        assert_eq!(csc.col_nnz(1), 1);
599        assert_eq!(csc.col_nnz(2), 2);
600    }
601
602    #[test]
603    fn test_csc_structurally_symmetric() {
604        // Symmetric pattern
605        let values = vec![1.0f64, 2.0, 2.0, 3.0];
606        let row_indices = vec![0, 1, 0, 1];
607        let col_ptrs = vec![0, 2, 4];
608
609        let csc = CscMatrix::new(2, 2, col_ptrs, row_indices, values).unwrap();
610        assert!(csc.is_structurally_symmetric());
611    }
612}