Skip to main content

otspot_core/sparse/
csc.rs

1use super::compress::build_compressed_format;
2use crate::error::SolverError;
3
4/// 列圧縮形式(CSC: Compressed Sparse Column)の疎行列
5///
6/// 非ゼロ要素を列単位で格納する疎行列フォーマット。
7/// 列ポインタ・行インデックス・値の3配列で表現される。
8///
9/// # フォーマット詳細
10///
11/// 列 `j` の非ゼロ要素は `values[col_ptr[j]..col_ptr[j+1]]` に格納され、
12/// 対応する行インデックスは `row_ind[col_ptr[j]..col_ptr[j+1]]` に入る。
13/// 各列の行インデックスは昇順にソートされている。
14#[derive(Debug, Clone)]
15pub struct CscMatrix {
16    pub(crate) col_ptr: Vec<usize>,
17    pub(crate) row_ind: Vec<usize>,
18    pub(crate) values: Vec<f64>,
19    pub(crate) nrows: usize,
20    pub(crate) ncols: usize,
21}
22
23impl CscMatrix {
24    /// 空の CSC 行列を生成する
25    ///
26    /// すべての要素がゼロの (nrows × ncols) 行列として初期化される。
27    ///
28    /// # 引数
29    /// - `nrows`: 行数
30    /// - `ncols`: 列数
31    pub fn new(nrows: usize, ncols: usize) -> Self {
32        Self {
33            col_ptr: vec![0; ncols + 1],
34            row_ind: Vec::new(),
35            values: Vec::new(),
36            nrows,
37            ncols,
38        }
39    }
40
41    /// 非ゼロ要素の総数を返す
42    pub fn nnz(&self) -> usize {
43        self.values.len()
44    }
45
46    /// Returns the column pointer array (length `ncols + 1`).
47    pub fn col_ptr(&self) -> &[usize] {
48        &self.col_ptr
49    }
50
51    /// Returns the row index array for non-zero entries.
52    pub fn row_ind(&self) -> &[usize] {
53        &self.row_ind
54    }
55
56    /// Returns the value array for non-zero entries.
57    pub fn values(&self) -> &[f64] {
58        &self.values
59    }
60
61    /// Returns a new matrix with all non-zero values multiplied by `factor`.
62    pub fn scale_values(&self, factor: f64) -> Self {
63        Self {
64            col_ptr: self.col_ptr.clone(),
65            row_ind: self.row_ind.clone(),
66            values: self.values.iter().map(|&v| v * factor).collect(),
67            nrows: self.nrows,
68            ncols: self.ncols,
69        }
70    }
71
72    /// Returns the number of rows.
73    pub fn nrows(&self) -> usize {
74        self.nrows
75    }
76
77    /// Returns the number of columns.
78    pub fn ncols(&self) -> usize {
79        self.ncols
80    }
81
82    /// 各行の∞ノルム(行ごとの最大絶対値)を一括計算する: O(nnz)
83    ///
84    /// CSC格式では行方向アクセスが非効率だが、全非ゼロ要素を1回走査して
85    /// 各行の最大絶対値を収集することで O(nnz) で完了する。
86    pub fn row_infinity_norms(&self) -> Vec<f64> {
87        let mut norms = vec![0.0_f64; self.nrows];
88        for (&val, &row) in self.values.iter().zip(self.row_ind.iter()) {
89            let abs_val = val.abs();
90            if abs_val > norms[row] {
91                norms[row] = abs_val;
92            }
93        }
94        norms
95    }
96
97    /// COO(座標形式)のトリプレットから CSC 行列を構築する
98    ///
99    /// 同一 (row, col) への重複エントリは自動的に加算される。
100    /// ゼロ近傍の結果値(絶対値 DROP_TOL 以下)は格納しない。
101    ///
102    /// # 引数
103    /// - `rows`: 各エントリの行インデックス
104    /// - `cols`: 各エントリの列インデックス
105    /// - `vals`: 各エントリの値
106    /// - `nrows`: 行列の行数
107    /// - `ncols`: 行列の列数
108    ///
109    /// # エラー
110    /// - `rows`、`cols`、`vals` の長さが異なる場合
111    /// - 行/列インデックスが範囲外の場合
112    pub fn from_triplets(
113        rows: &[usize],
114        cols: &[usize],
115        vals: &[f64],
116        nrows: usize,
117        ncols: usize,
118    ) -> Result<Self, SolverError> {
119        if rows.len() != cols.len() || rows.len() != vals.len() {
120            return Err(SolverError::DimensionMismatch { field: "triplet_arrays", expected: rows.len(), got: vals.len() });
121        }
122        // CSC: 主軸=列、副軸=行
123        let (col_ptr, row_ind, values) =
124            build_compressed_format(ncols, nrows, cols, rows, vals)?;
125        Ok(Self { col_ptr, row_ind, values, nrows, ncols })
126    }
127
128    /// 転置行列を生成する(新しい CSC 行列として返す)
129    ///
130    /// 元の行列の行と列を入れ替えた行列を返す。
131    /// counting sort を使用するため O(nnz) の計算量となる。
132    pub fn transpose(&self) -> Self {
133        let nnz = self.nnz();
134        // Transposed matrix: (ncols x nrows)
135        // Step 1: count nnz per row of original (= nnz per col of transposed)
136        let mut row_count = vec![0usize; self.nrows];
137        for &r in &self.row_ind {
138            row_count[r] += 1;
139        }
140
141        // Step 2: prefix sum to build col_ptr of transposed matrix
142        let mut col_ptr = vec![0usize; self.nrows + 1];
143        for r in 0..self.nrows {
144            col_ptr[r + 1] = col_ptr[r] + row_count[r];
145        }
146
147        // Step 3: scatter non-zeros into transposed positions
148        // Process columns 0..ncols in order; for each (row, col, val) in original,
149        // write col as row_ind of transposed at position pos[row].
150        // Since col increases monotonically, row_ind within each transposed column
151        // is written in ascending order — no extra sort needed.
152        let mut row_ind = vec![0usize; nnz];
153        let mut values = vec![0.0f64; nnz];
154        let mut pos = col_ptr[..self.nrows].to_vec();
155
156        for col in 0..self.ncols {
157            let start = self.col_ptr[col];
158            let end = self.col_ptr[col + 1];
159            for k in start..end {
160                let row = self.row_ind[k];
161                let p = pos[row];
162                row_ind[p] = col;
163                values[p] = self.values[k];
164                pos[row] += 1;
165            }
166        }
167
168        Self {
169            col_ptr,
170            row_ind,
171            values,
172            nrows: self.ncols,
173            ncols: self.nrows,
174        }
175    }
176
177    /// 行列ベクトル積を計算する: y = A * x
178    ///
179    /// CSC 形式の列走査を利用して O(nnz) で計算する。
180    ///
181    /// # 引数
182    /// - `x`: 入力ベクトル(長さ: ncols)
183    ///
184    /// # 戻り値
185    /// - `Ok(y)`: 結果ベクトル(長さ: nrows)
186    /// - `Err`: `x` の長さが `ncols` と一致しない場合
187    pub fn mat_vec_mul(&self, x: &[f64]) -> Result<Vec<f64>, SolverError> {
188        if x.len() != self.ncols {
189            return Err(SolverError::DimensionMismatch { field: "vector", expected: self.ncols, got: x.len() });
190        }
191
192        let mut y = vec![0.0; self.nrows];
193        for (col, &x_val) in x.iter().enumerate() {
194            let start = self.col_ptr[col];
195            let end = self.col_ptr[col + 1];
196            for idx in start..end {
197                let row = self.row_ind[idx];
198                let a_val = self.values[idx];
199                y[row] += a_val * x_val;
200            }
201        }
202        Ok(y)
203    }
204
205    /// 列 j の非ゼロ要素を取得する
206    ///
207    /// 行インデックス配列と値配列のスライスを返す。両スライスの長さは等しく、
208    /// 行インデックスは昇順にソートされている。
209    ///
210    /// # 引数
211    /// - `j`: 取得する列インデックス(0-based)
212    ///
213    /// # 戻り値
214    /// - `Ok((row_indices, values))`: 列 j の行インデックスと値のスライスペア
215    /// - `Err`: `j` が範囲外の場合
216    pub fn get_column(&self, j: usize) -> Result<(&[usize], &[f64]), SolverError> {
217        if j >= self.ncols {
218            return Err(SolverError::IndexOutOfBounds { context: "column", index: j, bound: self.ncols });
219        }
220        let start = self.col_ptr[j];
221        let end = self.col_ptr[j + 1];
222        Ok((&self.row_ind[start..end], &self.values[start..end]))
223    }
224
225    /// n×n 単位行列を CSC 形式で生成する
226    ///
227    /// 対角要素が 1.0 で、非対角要素がゼロの正方行列を返す。
228    ///
229    /// # 引数
230    /// - `n`: 行列のサイズ(n×n)
231    pub fn identity(n: usize) -> Self {
232        let col_ptr: Vec<usize> = (0..=n).collect();
233        let row_ind: Vec<usize> = (0..n).collect();
234        let values = vec![1.0; n];
235        Self {
236            col_ptr,
237            row_ind,
238            values,
239            nrows: n,
240            ncols: n,
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_from_triplets_basic() {
251        // 3x3 matrix:
252        // [1.0  0.0  2.0]
253        // [0.0  3.0  0.0]
254        // [4.0  0.0  5.0]
255        let rows = vec![0, 2, 1, 0, 2];
256        let cols = vec![0, 0, 1, 2, 2];
257        let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
258
259        let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
260
261        assert_eq!(mat.nrows, 3);
262        assert_eq!(mat.ncols, 3);
263        assert_eq!(mat.nnz(), 5);
264
265        // Check column 0: [1.0 at row 0, 4.0 at row 2]
266        let (row_idx, values) = mat.get_column(0).unwrap();
267        assert_eq!(row_idx, &[0, 2]);
268        assert_eq!(values, &[1.0, 4.0]);
269
270        // Check column 1: [3.0 at row 1]
271        let (row_idx, values) = mat.get_column(1).unwrap();
272        assert_eq!(row_idx, &[1]);
273        assert_eq!(values, &[3.0]);
274
275        // Check column 2: [2.0 at row 0, 5.0 at row 2]
276        let (row_idx, values) = mat.get_column(2).unwrap();
277        assert_eq!(row_idx, &[0, 2]);
278        assert_eq!(values, &[2.0, 5.0]);
279    }
280
281    #[test]
282    fn test_from_triplets_duplicate_entries() {
283        // Same (row, col) appears twice -> values should be summed
284        let rows = vec![0, 0, 1];
285        let cols = vec![0, 0, 1];
286        let vals = vec![1.0, 2.0, 3.0];
287
288        let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 2).unwrap();
289
290        // Column 0: row 0 should have 1.0 + 2.0 = 3.0
291        let (row_idx, values) = mat.get_column(0).unwrap();
292        assert_eq!(row_idx, &[0]);
293        assert_eq!(values, &[3.0]);
294
295        // Column 1: row 1 should have 3.0
296        let (row_idx, values) = mat.get_column(1).unwrap();
297        assert_eq!(row_idx, &[1]);
298        assert_eq!(values, &[3.0]);
299    }
300
301    #[test]
302    fn test_transpose() {
303        // 2x3 matrix:
304        // [1.0  2.0  0.0]
305        // [0.0  0.0  3.0]
306        let rows = vec![0, 0, 1];
307        let cols = vec![0, 1, 2];
308        let vals = vec![1.0, 2.0, 3.0];
309
310        let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
311        let mat_t = mat.transpose();
312
313        // Transposed should be 3x2:
314        // [1.0  0.0]
315        // [2.0  0.0]
316        // [0.0  3.0]
317        assert_eq!(mat_t.nrows, 3);
318        assert_eq!(mat_t.ncols, 2);
319        assert_eq!(mat_t.nnz(), 3);
320
321        // Check column 0: [1.0 at row 0, 2.0 at row 1]
322        let (row_idx, values) = mat_t.get_column(0).unwrap();
323        assert_eq!(row_idx, &[0, 1]);
324        assert_eq!(values, &[1.0, 2.0]);
325
326        // Check column 1: [3.0 at row 2]
327        let (row_idx, values) = mat_t.get_column(1).unwrap();
328        assert_eq!(row_idx, &[2]);
329        assert_eq!(values, &[3.0]);
330
331        // Double transpose should return to original
332        let mat_tt = mat_t.transpose();
333        assert_eq!(mat_tt.nrows, mat.nrows);
334        assert_eq!(mat_tt.ncols, mat.ncols);
335        assert_eq!(mat_tt.row_ind, mat.row_ind);
336        assert_eq!(mat_tt.col_ptr, mat.col_ptr);
337        assert_eq!(mat_tt.values, mat.values);
338    }
339
340    #[test]
341    fn test_mat_vec_mul() {
342        // 3x3 matrix:
343        // [1.0  0.0  2.0]
344        // [0.0  3.0  0.0]
345        // [4.0  0.0  5.0]
346        let rows = vec![0, 2, 1, 0, 2];
347        let cols = vec![0, 0, 1, 2, 2];
348        let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
349        let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
350
351        let x = vec![1.0, 2.0, 3.0];
352        let y = mat.mat_vec_mul(&x).unwrap();
353
354        // Expected: [1*1 + 0*2 + 2*3, 0*1 + 3*2 + 0*3, 4*1 + 0*2 + 5*3]
355        //         = [7.0, 6.0, 19.0]
356        assert_eq!(y.len(), 3);
357        assert!((y[0] - 7.0).abs() < 1e-10);
358        assert!((y[1] - 6.0).abs() < 1e-10);
359        assert!((y[2] - 19.0).abs() < 1e-10);
360    }
361
362    #[test]
363    fn test_mat_vec_mul_dimension_mismatch() {
364        let mat = CscMatrix::identity(3);
365        let x = vec![1.0, 2.0]; // Wrong size
366        let result = mat.mat_vec_mul(&x);
367        assert!(result.is_err());
368    }
369
370    #[test]
371    fn test_identity() {
372        let id = CscMatrix::identity(4);
373        assert_eq!(id.nrows, 4);
374        assert_eq!(id.ncols, 4);
375        assert_eq!(id.nnz(), 4);
376
377        // Each column should have exactly one entry at its own row
378        for j in 0..4 {
379            let (row_idx, values) = id.get_column(j).unwrap();
380            assert_eq!(row_idx, &[j]);
381            assert_eq!(values, &[1.0]);
382        }
383
384        // Identity * vector = vector
385        let x = vec![1.0, 2.0, 3.0, 4.0];
386        let y = id.mat_vec_mul(&x).unwrap();
387        assert_eq!(y, x);
388    }
389
390    #[test]
391    fn test_empty_matrix() {
392        let mat = CscMatrix::from_triplets(&[], &[], &[], 2, 3).unwrap();
393        assert_eq!(mat.nrows, 2);
394        assert_eq!(mat.ncols, 3);
395        assert_eq!(mat.nnz(), 0);
396
397        // All columns should be empty
398        for j in 0..3 {
399            let (row_idx, values) = mat.get_column(j).unwrap();
400            assert_eq!(row_idx.len(), 0);
401            assert_eq!(values.len(), 0);
402        }
403
404        // mat_vec_mul should return zero vector
405        let y = mat.mat_vec_mul(&[1.0, 2.0, 3.0]).unwrap();
406        assert_eq!(y, vec![0.0, 0.0]);
407    }
408
409    #[test]
410    fn test_get_column_out_of_bounds() {
411        let mat = CscMatrix::identity(3);
412        let result = mat.get_column(3);
413        assert!(result.is_err());
414    }
415
416    #[test]
417    fn test_from_triplets_out_of_bounds() {
418        // Row index out of bounds
419        let result = CscMatrix::from_triplets(&[0, 3], &[0, 0], &[1.0, 2.0], 3, 2);
420        assert!(result.is_err());
421
422        // Column index out of bounds
423        let result = CscMatrix::from_triplets(&[0, 0], &[0, 2], &[1.0, 2.0], 3, 2);
424        assert!(result.is_err());
425    }
426
427    #[test]
428    fn test_from_triplets_mismatched_lengths() {
429        let result = CscMatrix::from_triplets(&[0, 1], &[0], &[1.0, 2.0], 2, 2);
430        assert!(result.is_err());
431    }
432}