Skip to main content

otspot_core/sparse/
csc.rs

1use super::compress::build_compressed_format;
2use crate::error::SolverError;
3
4/// 列圧縮形式(CSC: Compressed Sparse Column)の疎行列
5///
6/// 非ゼロ要素を列単位で格納する疎行列フォーマット。
7/// 列ポインタ・行インデックス・値の3配列で表現される。
8///
9/// # フォーマット詳細
10///
11/// 列 `j` の非ゼロ要素は `values[col_ptr[j]..col_ptr[j+1]]` に格納され、
12/// 対応する行インデックスは `row_ind[col_ptr[j]..col_ptr[j+1]]` に入る。
13/// 各列の行インデックスは昇順にソートされている。
14#[derive(Debug, Clone)]
15pub struct CscMatrix {
16    pub(crate) col_ptr: Vec<usize>,
17    pub(crate) row_ind: Vec<usize>,
18    pub(crate) values: Vec<f64>,
19    pub(crate) nrows: usize,
20    pub(crate) ncols: usize,
21}
22
23impl CscMatrix {
24    /// 空の CSC 行列を生成する
25    ///
26    /// すべての要素がゼロの (nrows × ncols) 行列として初期化される。
27    ///
28    /// # 引数
29    /// - `nrows`: 行数
30    /// - `ncols`: 列数
31    pub fn new(nrows: usize, ncols: usize) -> Self {
32        Self {
33            col_ptr: vec![0; ncols + 1],
34            row_ind: Vec::new(),
35            values: Vec::new(),
36            nrows,
37            ncols,
38        }
39    }
40
41    pub fn nnz(&self) -> usize {
42        self.values.len()
43    }
44
45    pub fn col_ptr(&self) -> &[usize] {
46        &self.col_ptr
47    }
48
49    pub fn row_ind(&self) -> &[usize] {
50        &self.row_ind
51    }
52
53    pub fn values(&self) -> &[f64] {
54        &self.values
55    }
56
57    /// Returns a new matrix with all non-zero values multiplied by `factor`.
58    pub fn scale_values(&self, factor: f64) -> Self {
59        Self {
60            col_ptr: self.col_ptr.clone(),
61            row_ind: self.row_ind.clone(),
62            values: self.values.iter().map(|&v| v * factor).collect(),
63            nrows: self.nrows,
64            ncols: self.ncols,
65        }
66    }
67
68    pub fn nrows(&self) -> usize {
69        self.nrows
70    }
71
72    pub fn ncols(&self) -> usize {
73        self.ncols
74    }
75
76    /// 各行の∞ノルム(行ごとの最大絶対値)を一括計算する: O(nnz)
77    ///
78    /// CSC格式では行方向アクセスが非効率だが、全非ゼロ要素を1回走査して
79    /// 各行の最大絶対値を収集することで O(nnz) で完了する。
80    pub fn row_infinity_norms(&self) -> Vec<f64> {
81        let mut norms = vec![0.0_f64; self.nrows];
82        for (&val, &row) in self.values.iter().zip(self.row_ind.iter()) {
83            let abs_val = val.abs();
84            if abs_val > norms[row] {
85                norms[row] = abs_val;
86            }
87        }
88        norms
89    }
90
91    /// Builds a CSC matrix from COO triplets.
92    ///
93    /// Duplicate `(row, col)` entries are summed; results with `|v| ≤ DROP_TOL` are dropped.
94    pub fn from_triplets(
95        rows: &[usize],
96        cols: &[usize],
97        vals: &[f64],
98        nrows: usize,
99        ncols: usize,
100    ) -> Result<Self, SolverError> {
101        if rows.len() != cols.len() || rows.len() != vals.len() {
102            return Err(SolverError::DimensionMismatch {
103                field: "triplet_arrays",
104                expected: rows.len(),
105                got: vals.len(),
106            });
107        }
108        // CSC: 主軸=列、副軸=行
109        let (col_ptr, row_ind, values) = build_compressed_format(ncols, nrows, cols, rows, vals)?;
110        Ok(Self {
111            col_ptr,
112            row_ind,
113            values,
114            nrows,
115            ncols,
116        })
117    }
118
119    /// 転置行列を生成する(新しい CSC 行列として返す)
120    ///
121    /// 元の行列の行と列を入れ替えた行列を返す。
122    /// counting sort を使用するため O(nnz) の計算量となる。
123    pub fn transpose(&self) -> Self {
124        let nnz = self.nnz();
125        // Transposed matrix: (ncols x nrows)
126        // Step 1: count nnz per row of original (= nnz per col of transposed)
127        let mut row_count = vec![0usize; self.nrows];
128        for &r in &self.row_ind {
129            row_count[r] += 1;
130        }
131
132        // Step 2: prefix sum to build col_ptr of transposed matrix
133        let mut col_ptr = vec![0usize; self.nrows + 1];
134        for r in 0..self.nrows {
135            col_ptr[r + 1] = col_ptr[r] + row_count[r];
136        }
137
138        // Step 3: scatter non-zeros into transposed positions
139        // Process columns 0..ncols in order; for each (row, col, val) in original,
140        // write col as row_ind of transposed at position pos[row].
141        // Since col increases monotonically, row_ind within each transposed column
142        // is written in ascending order — no extra sort needed.
143        let mut row_ind = vec![0usize; nnz];
144        let mut values = vec![0.0f64; nnz];
145        let mut pos = col_ptr[..self.nrows].to_vec();
146
147        for col in 0..self.ncols {
148            let start = self.col_ptr[col];
149            let end = self.col_ptr[col + 1];
150            for k in start..end {
151                let row = self.row_ind[k];
152                let p = pos[row];
153                row_ind[p] = col;
154                values[p] = self.values[k];
155                pos[row] += 1;
156            }
157        }
158
159        Self {
160            col_ptr,
161            row_ind,
162            values,
163            nrows: self.ncols,
164            ncols: self.nrows,
165        }
166    }
167
168    /// Matrix-vector product y = A * x. O(nnz).
169    pub fn mat_vec_mul(&self, x: &[f64]) -> Result<Vec<f64>, SolverError> {
170        if x.len() != self.ncols {
171            return Err(SolverError::DimensionMismatch {
172                field: "vector",
173                expected: self.ncols,
174                got: x.len(),
175            });
176        }
177
178        let mut y = vec![0.0; self.nrows];
179        for (col, &x_val) in x.iter().enumerate() {
180            let start = self.col_ptr[col];
181            let end = self.col_ptr[col + 1];
182            for idx in start..end {
183                let row = self.row_ind[idx];
184                let a_val = self.values[idx];
185                y[row] += a_val * x_val;
186            }
187        }
188        Ok(y)
189    }
190
191    /// Returns `(row_indices, values)` slices for column `j`; both are sorted by row index.
192    pub fn get_column(&self, j: usize) -> Result<(&[usize], &[f64]), SolverError> {
193        if j >= self.ncols {
194            return Err(SolverError::IndexOutOfBounds {
195                context: "column",
196                index: j,
197                bound: self.ncols,
198            });
199        }
200        let start = self.col_ptr[j];
201        let end = self.col_ptr[j + 1];
202        Ok((&self.row_ind[start..end], &self.values[start..end]))
203    }
204
205    pub fn identity(n: usize) -> Self {
206        let col_ptr: Vec<usize> = (0..=n).collect();
207        let row_ind: Vec<usize> = (0..n).collect();
208        let values = vec![1.0; n];
209        Self {
210            col_ptr,
211            row_ind,
212            values,
213            nrows: n,
214            ncols: n,
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_from_triplets_basic() {
225        // 3x3 matrix:
226        // [1.0  0.0  2.0]
227        // [0.0  3.0  0.0]
228        // [4.0  0.0  5.0]
229        let rows = vec![0, 2, 1, 0, 2];
230        let cols = vec![0, 0, 1, 2, 2];
231        let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
232
233        let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
234
235        assert_eq!(mat.nrows, 3);
236        assert_eq!(mat.ncols, 3);
237        assert_eq!(mat.nnz(), 5);
238
239        // Check column 0: [1.0 at row 0, 4.0 at row 2]
240        let (row_idx, values) = mat.get_column(0).unwrap();
241        assert_eq!(row_idx, &[0, 2]);
242        assert_eq!(values, &[1.0, 4.0]);
243
244        // Check column 1: [3.0 at row 1]
245        let (row_idx, values) = mat.get_column(1).unwrap();
246        assert_eq!(row_idx, &[1]);
247        assert_eq!(values, &[3.0]);
248
249        // Check column 2: [2.0 at row 0, 5.0 at row 2]
250        let (row_idx, values) = mat.get_column(2).unwrap();
251        assert_eq!(row_idx, &[0, 2]);
252        assert_eq!(values, &[2.0, 5.0]);
253    }
254
255    #[test]
256    fn test_from_triplets_duplicate_entries() {
257        // Same (row, col) appears twice -> values should be summed
258        let rows = vec![0, 0, 1];
259        let cols = vec![0, 0, 1];
260        let vals = vec![1.0, 2.0, 3.0];
261
262        let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 2).unwrap();
263
264        // Column 0: row 0 should have 1.0 + 2.0 = 3.0
265        let (row_idx, values) = mat.get_column(0).unwrap();
266        assert_eq!(row_idx, &[0]);
267        assert_eq!(values, &[3.0]);
268
269        // Column 1: row 1 should have 3.0
270        let (row_idx, values) = mat.get_column(1).unwrap();
271        assert_eq!(row_idx, &[1]);
272        assert_eq!(values, &[3.0]);
273    }
274
275    #[test]
276    fn test_transpose() {
277        // 2x3 matrix:
278        // [1.0  2.0  0.0]
279        // [0.0  0.0  3.0]
280        let rows = vec![0, 0, 1];
281        let cols = vec![0, 1, 2];
282        let vals = vec![1.0, 2.0, 3.0];
283
284        let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
285        let mat_t = mat.transpose();
286
287        // Transposed should be 3x2:
288        // [1.0  0.0]
289        // [2.0  0.0]
290        // [0.0  3.0]
291        assert_eq!(mat_t.nrows, 3);
292        assert_eq!(mat_t.ncols, 2);
293        assert_eq!(mat_t.nnz(), 3);
294
295        // Check column 0: [1.0 at row 0, 2.0 at row 1]
296        let (row_idx, values) = mat_t.get_column(0).unwrap();
297        assert_eq!(row_idx, &[0, 1]);
298        assert_eq!(values, &[1.0, 2.0]);
299
300        // Check column 1: [3.0 at row 2]
301        let (row_idx, values) = mat_t.get_column(1).unwrap();
302        assert_eq!(row_idx, &[2]);
303        assert_eq!(values, &[3.0]);
304
305        // Double transpose should return to original
306        let mat_tt = mat_t.transpose();
307        assert_eq!(mat_tt.nrows, mat.nrows);
308        assert_eq!(mat_tt.ncols, mat.ncols);
309        assert_eq!(mat_tt.row_ind, mat.row_ind);
310        assert_eq!(mat_tt.col_ptr, mat.col_ptr);
311        assert_eq!(mat_tt.values, mat.values);
312    }
313
314    #[test]
315    fn test_mat_vec_mul() {
316        // 3x3 matrix:
317        // [1.0  0.0  2.0]
318        // [0.0  3.0  0.0]
319        // [4.0  0.0  5.0]
320        let rows = vec![0, 2, 1, 0, 2];
321        let cols = vec![0, 0, 1, 2, 2];
322        let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
323        let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
324
325        let x = vec![1.0, 2.0, 3.0];
326        let y = mat.mat_vec_mul(&x).unwrap();
327
328        // Expected: [1*1 + 0*2 + 2*3, 0*1 + 3*2 + 0*3, 4*1 + 0*2 + 5*3]
329        //         = [7.0, 6.0, 19.0]
330        assert_eq!(y.len(), 3);
331        assert!((y[0] - 7.0).abs() < 1e-10);
332        assert!((y[1] - 6.0).abs() < 1e-10);
333        assert!((y[2] - 19.0).abs() < 1e-10);
334    }
335
336    #[test]
337    fn test_mat_vec_mul_dimension_mismatch() {
338        let mat = CscMatrix::identity(3);
339        let x = vec![1.0, 2.0]; // Wrong size
340        let result = mat.mat_vec_mul(&x);
341        assert!(result.is_err());
342    }
343
344    #[test]
345    fn test_identity() {
346        let id = CscMatrix::identity(4);
347        assert_eq!(id.nrows, 4);
348        assert_eq!(id.ncols, 4);
349        assert_eq!(id.nnz(), 4);
350
351        // Each column should have exactly one entry at its own row
352        for j in 0..4 {
353            let (row_idx, values) = id.get_column(j).unwrap();
354            assert_eq!(row_idx, &[j]);
355            assert_eq!(values, &[1.0]);
356        }
357
358        // Identity * vector = vector
359        let x = vec![1.0, 2.0, 3.0, 4.0];
360        let y = id.mat_vec_mul(&x).unwrap();
361        assert_eq!(y, x);
362    }
363
364    #[test]
365    fn test_empty_matrix() {
366        let mat = CscMatrix::from_triplets(&[], &[], &[], 2, 3).unwrap();
367        assert_eq!(mat.nrows, 2);
368        assert_eq!(mat.ncols, 3);
369        assert_eq!(mat.nnz(), 0);
370
371        // All columns should be empty
372        for j in 0..3 {
373            let (row_idx, values) = mat.get_column(j).unwrap();
374            assert_eq!(row_idx.len(), 0);
375            assert_eq!(values.len(), 0);
376        }
377
378        // mat_vec_mul should return zero vector
379        let y = mat.mat_vec_mul(&[1.0, 2.0, 3.0]).unwrap();
380        assert_eq!(y, vec![0.0, 0.0]);
381    }
382
383    #[test]
384    fn test_get_column_out_of_bounds() {
385        let mat = CscMatrix::identity(3);
386        let result = mat.get_column(3);
387        assert!(result.is_err());
388    }
389
390    #[test]
391    fn test_from_triplets_out_of_bounds() {
392        // Row index out of bounds
393        let result = CscMatrix::from_triplets(&[0, 3], &[0, 0], &[1.0, 2.0], 3, 2);
394        assert!(result.is_err());
395
396        // Column index out of bounds
397        let result = CscMatrix::from_triplets(&[0, 0], &[0, 2], &[1.0, 2.0], 3, 2);
398        assert!(result.is_err());
399    }
400
401    #[test]
402    fn test_from_triplets_mismatched_lengths() {
403        let result = CscMatrix::from_triplets(&[0, 1], &[0], &[1.0, 2.0], 2, 2);
404        assert!(result.is_err());
405    }
406}