Skip to main content

otspot_core/sparse/
csc.rs

1use super::compress::build_compressed_format;
2use crate::error::SolverError;
3
4/// 列圧縮形式(CSC: Compressed Sparse Column)の疎行列
5///
6/// 非ゼロ要素を列単位で格納する疎行列フォーマット。
7/// 列ポインタ・行インデックス・値の3配列で表現される。
8///
9/// # フォーマット詳細
10///
11/// 列 `j` の非ゼロ要素は `values[col_ptr[j]..col_ptr[j+1]]` に格納され、
12/// 対応する行インデックスは `row_ind[col_ptr[j]..col_ptr[j+1]]` に入る。
13/// 各列の行インデックスは昇順にソートされている。
14#[derive(Debug, Clone)]
15pub struct CscMatrix {
16    pub(crate) col_ptr: Vec<usize>,
17    pub(crate) row_ind: Vec<usize>,
18    pub(crate) values: Vec<f64>,
19    pub(crate) nrows: usize,
20    pub(crate) ncols: usize,
21}
22
23impl CscMatrix {
24    /// 空の CSC 行列を生成する
25    ///
26    /// すべての要素がゼロの (nrows × ncols) 行列として初期化される。
27    ///
28    /// # 引数
29    /// - `nrows`: 行数
30    /// - `ncols`: 列数
31    pub fn new(nrows: usize, ncols: usize) -> Self {
32        Self {
33            col_ptr: vec![0; ncols + 1],
34            row_ind: Vec::new(),
35            values: Vec::new(),
36            nrows,
37            ncols,
38        }
39    }
40
41    pub fn nnz(&self) -> usize {
42        self.values.len()
43    }
44
45    pub fn col_ptr(&self) -> &[usize] {
46        &self.col_ptr
47    }
48
49    pub fn row_ind(&self) -> &[usize] {
50        &self.row_ind
51    }
52
53    pub fn values(&self) -> &[f64] {
54        &self.values
55    }
56
57    /// Returns a new matrix with all non-zero values multiplied by `factor`.
58    pub fn scale_values(&self, factor: f64) -> Self {
59        Self {
60            col_ptr: self.col_ptr.clone(),
61            row_ind: self.row_ind.clone(),
62            values: self.values.iter().map(|&v| v * factor).collect(),
63            nrows: self.nrows,
64            ncols: self.ncols,
65        }
66    }
67
68    pub fn nrows(&self) -> usize {
69        self.nrows
70    }
71
72    pub fn ncols(&self) -> usize {
73        self.ncols
74    }
75
76    /// 各行の∞ノルム(行ごとの最大絶対値)を一括計算する: O(nnz)
77    ///
78    /// CSC格式では行方向アクセスが非効率だが、全非ゼロ要素を1回走査して
79    /// 各行の最大絶対値を収集することで O(nnz) で完了する。
80    pub fn row_infinity_norms(&self) -> Vec<f64> {
81        let mut norms = vec![0.0_f64; self.nrows];
82        for (&val, &row) in self.values.iter().zip(self.row_ind.iter()) {
83            let abs_val = val.abs();
84            if abs_val > norms[row] {
85                norms[row] = abs_val;
86            }
87        }
88        norms
89    }
90
91    /// Builds a CSC matrix from COO triplets.
92    ///
93    /// Duplicate `(row, col)` entries are summed; results with `|v| ≤ DROP_TOL` are dropped.
94    pub fn from_triplets(
95        rows: &[usize],
96        cols: &[usize],
97        vals: &[f64],
98        nrows: usize,
99        ncols: usize,
100    ) -> Result<Self, SolverError> {
101        if rows.len() != cols.len() || rows.len() != vals.len() {
102            return Err(SolverError::DimensionMismatch {
103                field: "triplet_arrays",
104                expected: rows.len(),
105                got: vals.len(),
106            });
107        }
108        for (i, &v) in vals.iter().enumerate() {
109            if !v.is_finite() {
110                return Err(SolverError::NonFiniteCoefficient {
111                    field: "matrix",
112                    index: i,
113                });
114            }
115        }
116        // CSC: 主軸=列、副軸=行
117        let (col_ptr, row_ind, values) = build_compressed_format(ncols, nrows, cols, rows, vals)?;
118        Ok(Self {
119            col_ptr,
120            row_ind,
121            values,
122            nrows,
123            ncols,
124        })
125    }
126
127    /// 転置行列を生成する(新しい CSC 行列として返す)
128    ///
129    /// 元の行列の行と列を入れ替えた行列を返す。
130    /// counting sort を使用するため O(nnz) の計算量となる。
131    pub fn transpose(&self) -> Self {
132        let nnz = self.nnz();
133        // Transposed matrix: (ncols x nrows)
134        // Step 1: count nnz per row of original (= nnz per col of transposed)
135        let mut row_count = vec![0usize; self.nrows];
136        for &r in &self.row_ind {
137            row_count[r] += 1;
138        }
139
140        // Step 2: prefix sum to build col_ptr of transposed matrix
141        let mut col_ptr = vec![0usize; self.nrows + 1];
142        for r in 0..self.nrows {
143            col_ptr[r + 1] = col_ptr[r] + row_count[r];
144        }
145
146        // Step 3: scatter non-zeros into transposed positions
147        // Process columns 0..ncols in order; for each (row, col, val) in original,
148        // write col as row_ind of transposed at position pos[row].
149        // Since col increases monotonically, row_ind within each transposed column
150        // is written in ascending order — no extra sort needed.
151        let mut row_ind = vec![0usize; nnz];
152        let mut values = vec![0.0f64; nnz];
153        let mut pos = col_ptr[..self.nrows].to_vec();
154
155        for col in 0..self.ncols {
156            let start = self.col_ptr[col];
157            let end = self.col_ptr[col + 1];
158            for k in start..end {
159                let row = self.row_ind[k];
160                let p = pos[row];
161                row_ind[p] = col;
162                values[p] = self.values[k];
163                pos[row] += 1;
164            }
165        }
166
167        Self {
168            col_ptr,
169            row_ind,
170            values,
171            nrows: self.ncols,
172            ncols: self.nrows,
173        }
174    }
175
176    /// Matrix-vector product y = A * x. O(nnz).
177    pub fn mat_vec_mul(&self, x: &[f64]) -> Result<Vec<f64>, SolverError> {
178        if x.len() != self.ncols {
179            return Err(SolverError::DimensionMismatch {
180                field: "vector",
181                expected: self.ncols,
182                got: x.len(),
183            });
184        }
185
186        let mut y = vec![0.0; self.nrows];
187        for (col, &x_val) in x.iter().enumerate() {
188            let start = self.col_ptr[col];
189            let end = self.col_ptr[col + 1];
190            for idx in start..end {
191                let row = self.row_ind[idx];
192                let a_val = self.values[idx];
193                y[row] += a_val * x_val;
194            }
195        }
196        Ok(y)
197    }
198
199    /// Returns `(row_indices, values)` slices for column `j`; both are sorted by row index.
200    pub fn get_column(&self, j: usize) -> Result<(&[usize], &[f64]), SolverError> {
201        if j >= self.ncols {
202            return Err(SolverError::IndexOutOfBounds {
203                context: "column",
204                index: j,
205                bound: self.ncols,
206            });
207        }
208        let start = self.col_ptr[j];
209        let end = self.col_ptr[j + 1];
210        Ok((&self.row_ind[start..end], &self.values[start..end]))
211    }
212
213    pub fn identity(n: usize) -> Self {
214        let col_ptr: Vec<usize> = (0..=n).collect();
215        let row_ind: Vec<usize> = (0..n).collect();
216        let values = vec![1.0; n];
217        Self {
218            col_ptr,
219            row_ind,
220            values,
221            nrows: n,
222            ncols: n,
223        }
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_from_triplets_basic() {
233        // 3x3 matrix:
234        // [1.0  0.0  2.0]
235        // [0.0  3.0  0.0]
236        // [4.0  0.0  5.0]
237        let rows = vec![0, 2, 1, 0, 2];
238        let cols = vec![0, 0, 1, 2, 2];
239        let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
240
241        let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
242
243        assert_eq!(mat.nrows, 3);
244        assert_eq!(mat.ncols, 3);
245        assert_eq!(mat.nnz(), 5);
246
247        // Check column 0: [1.0 at row 0, 4.0 at row 2]
248        let (row_idx, values) = mat.get_column(0).unwrap();
249        assert_eq!(row_idx, &[0, 2]);
250        assert_eq!(values, &[1.0, 4.0]);
251
252        // Check column 1: [3.0 at row 1]
253        let (row_idx, values) = mat.get_column(1).unwrap();
254        assert_eq!(row_idx, &[1]);
255        assert_eq!(values, &[3.0]);
256
257        // Check column 2: [2.0 at row 0, 5.0 at row 2]
258        let (row_idx, values) = mat.get_column(2).unwrap();
259        assert_eq!(row_idx, &[0, 2]);
260        assert_eq!(values, &[2.0, 5.0]);
261    }
262
263    #[test]
264    fn test_from_triplets_duplicate_entries() {
265        // Same (row, col) appears twice -> values should be summed
266        let rows = vec![0, 0, 1];
267        let cols = vec![0, 0, 1];
268        let vals = vec![1.0, 2.0, 3.0];
269
270        let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 2).unwrap();
271
272        // Column 0: row 0 should have 1.0 + 2.0 = 3.0
273        let (row_idx, values) = mat.get_column(0).unwrap();
274        assert_eq!(row_idx, &[0]);
275        assert_eq!(values, &[3.0]);
276
277        // Column 1: row 1 should have 3.0
278        let (row_idx, values) = mat.get_column(1).unwrap();
279        assert_eq!(row_idx, &[1]);
280        assert_eq!(values, &[3.0]);
281    }
282
283    #[test]
284    fn test_transpose() {
285        // 2x3 matrix:
286        // [1.0  2.0  0.0]
287        // [0.0  0.0  3.0]
288        let rows = vec![0, 0, 1];
289        let cols = vec![0, 1, 2];
290        let vals = vec![1.0, 2.0, 3.0];
291
292        let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
293        let mat_t = mat.transpose();
294
295        // Transposed should be 3x2:
296        // [1.0  0.0]
297        // [2.0  0.0]
298        // [0.0  3.0]
299        assert_eq!(mat_t.nrows, 3);
300        assert_eq!(mat_t.ncols, 2);
301        assert_eq!(mat_t.nnz(), 3);
302
303        // Check column 0: [1.0 at row 0, 2.0 at row 1]
304        let (row_idx, values) = mat_t.get_column(0).unwrap();
305        assert_eq!(row_idx, &[0, 1]);
306        assert_eq!(values, &[1.0, 2.0]);
307
308        // Check column 1: [3.0 at row 2]
309        let (row_idx, values) = mat_t.get_column(1).unwrap();
310        assert_eq!(row_idx, &[2]);
311        assert_eq!(values, &[3.0]);
312
313        // Double transpose should return to original
314        let mat_tt = mat_t.transpose();
315        assert_eq!(mat_tt.nrows, mat.nrows);
316        assert_eq!(mat_tt.ncols, mat.ncols);
317        assert_eq!(mat_tt.row_ind, mat.row_ind);
318        assert_eq!(mat_tt.col_ptr, mat.col_ptr);
319        assert_eq!(mat_tt.values, mat.values);
320    }
321
322    #[test]
323    fn test_mat_vec_mul() {
324        // 3x3 matrix:
325        // [1.0  0.0  2.0]
326        // [0.0  3.0  0.0]
327        // [4.0  0.0  5.0]
328        let rows = vec![0, 2, 1, 0, 2];
329        let cols = vec![0, 0, 1, 2, 2];
330        let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
331        let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
332
333        let x = vec![1.0, 2.0, 3.0];
334        let y = mat.mat_vec_mul(&x).unwrap();
335
336        // Expected: [1*1 + 0*2 + 2*3, 0*1 + 3*2 + 0*3, 4*1 + 0*2 + 5*3]
337        //         = [7.0, 6.0, 19.0]
338        assert_eq!(y.len(), 3);
339        assert!((y[0] - 7.0).abs() < 1e-10);
340        assert!((y[1] - 6.0).abs() < 1e-10);
341        assert!((y[2] - 19.0).abs() < 1e-10);
342    }
343
344    #[test]
345    fn test_mat_vec_mul_dimension_mismatch() {
346        let mat = CscMatrix::identity(3);
347        let x = vec![1.0, 2.0]; // Wrong size
348        let result = mat.mat_vec_mul(&x);
349        assert!(result.is_err());
350    }
351
352    #[test]
353    fn test_identity() {
354        let id = CscMatrix::identity(4);
355        assert_eq!(id.nrows, 4);
356        assert_eq!(id.ncols, 4);
357        assert_eq!(id.nnz(), 4);
358
359        // Each column should have exactly one entry at its own row
360        for j in 0..4 {
361            let (row_idx, values) = id.get_column(j).unwrap();
362            assert_eq!(row_idx, &[j]);
363            assert_eq!(values, &[1.0]);
364        }
365
366        // Identity * vector = vector
367        let x = vec![1.0, 2.0, 3.0, 4.0];
368        let y = id.mat_vec_mul(&x).unwrap();
369        assert_eq!(y, x);
370    }
371
372    #[test]
373    fn test_empty_matrix() {
374        let mat = CscMatrix::from_triplets(&[], &[], &[], 2, 3).unwrap();
375        assert_eq!(mat.nrows, 2);
376        assert_eq!(mat.ncols, 3);
377        assert_eq!(mat.nnz(), 0);
378
379        // All columns should be empty
380        for j in 0..3 {
381            let (row_idx, values) = mat.get_column(j).unwrap();
382            assert_eq!(row_idx.len(), 0);
383            assert_eq!(values.len(), 0);
384        }
385
386        // mat_vec_mul should return zero vector
387        let y = mat.mat_vec_mul(&[1.0, 2.0, 3.0]).unwrap();
388        assert_eq!(y, vec![0.0, 0.0]);
389    }
390
391    #[test]
392    fn test_get_column_out_of_bounds() {
393        let mat = CscMatrix::identity(3);
394        let result = mat.get_column(3);
395        assert!(result.is_err());
396    }
397
398    #[test]
399    fn test_from_triplets_out_of_bounds() {
400        // Row index out of bounds
401        let result = CscMatrix::from_triplets(&[0, 3], &[0, 0], &[1.0, 2.0], 3, 2);
402        assert!(result.is_err());
403
404        // Column index out of bounds
405        let result = CscMatrix::from_triplets(&[0, 0], &[0, 2], &[1.0, 2.0], 3, 2);
406        assert!(result.is_err());
407    }
408
409    #[test]
410    fn test_from_triplets_mismatched_lengths() {
411        let result = CscMatrix::from_triplets(&[0, 1], &[0], &[1.0, 2.0], 2, 2);
412        assert!(result.is_err());
413    }
414
415    /// Sentinel: non-finite values in triplets must be rejected at construction.
416    /// Removing the finiteness check turns Err into Ok → assertion fails (no-op fail).
417    #[test]
418    fn test_sentinel_triplet_non_finite_rejected() {
419        let r = CscMatrix::from_triplets(&[0], &[0], &[f64::NAN], 1, 1);
420        assert!(r.is_err(), "NaN in triplet vals must be rejected");
421        let r = CscMatrix::from_triplets(&[0], &[0], &[f64::INFINITY], 1, 1);
422        assert!(r.is_err(), "+Inf in triplet vals must be rejected");
423        let r = CscMatrix::from_triplets(&[0], &[0], &[f64::NEG_INFINITY], 1, 1);
424        assert!(r.is_err(), "-Inf in triplet vals must be rejected");
425        // Finite values still accepted.
426        let r = CscMatrix::from_triplets(&[0], &[0], &[1.0], 1, 1);
427        assert!(r.is_ok(), "finite value must still be accepted");
428    }
429}