Skip to main content

otspot_core/sparse/
csr.rs

1use super::compress::build_compressed_format;
2use super::csc::CscMatrix;
3use crate::error::SolverError;
4
5/// 行圧縮形式(CSR: Compressed Sparse Row)の疎行列
6///
7/// 非ゼロ要素を行単位で格納する疎行列フォーマット。
8/// 行ポインタ・列インデックス・値の3配列で表現される。
9///
10/// # フォーマット詳細
11///
12/// 行 `i` の非ゼロ要素は `values[row_ptr[i]..row_ptr[i+1]]` に格納され、
13/// 対応する列インデックスは `col_ind[row_ptr[i]..row_ptr[i+1]]` に入る。
14/// 各行の列インデックスは昇順にソートされている。
15#[derive(Debug, Clone)]
16pub struct CsrMatrix {
17    /// 行ポインタ配列(長さ: nrows + 1)
18    /// `row_ptr[i]` は行 i の最初の非ゼロ要素の位置を示す
19    pub row_ptr: Vec<usize>,
20    /// 各非ゼロ要素の列インデックス
21    pub col_ind: Vec<usize>,
22    /// 各非ゼロ要素の値
23    pub values: Vec<f64>,
24    /// 行数
25    pub nrows: usize,
26    /// 列数
27    pub ncols: usize,
28}
29
30impl CsrMatrix {
31    /// 非ゼロ要素の総数を返す
32    pub fn nnz(&self) -> usize {
33        self.values.len()
34    }
35
36    /// COO(座標形式)のトリプレットから CSR 行列を構築する
37    ///
38    /// 同一 (row, col) への重複エントリは自動的に加算される。
39    /// ゼロ近傍の結果値(絶対値 DROP_TOL 以下)は格納しない。
40    ///
41    /// # 引数
42    /// - `rows`: 各エントリの行インデックス
43    /// - `cols`: 各エントリの列インデックス
44    /// - `vals`: 各エントリの値
45    /// - `nrows`: 行列の行数
46    /// - `ncols`: 行列の列数
47    ///
48    /// # エラー
49    /// - `rows`、`cols`、`vals` の長さが異なる場合
50    /// - 行/列インデックスが範囲外の場合
51    pub fn from_triplets(
52        rows: &[usize],
53        cols: &[usize],
54        vals: &[f64],
55        nrows: usize,
56        ncols: usize,
57    ) -> Result<Self, SolverError> {
58        if rows.len() != cols.len() || rows.len() != vals.len() {
59            return Err(SolverError::DimensionMismatch { field: "triplet_arrays", expected: rows.len(), got: vals.len() });
60        }
61        // CSR: 主軸=行、副軸=列
62        let (row_ptr, col_ind, values) =
63            build_compressed_format(nrows, ncols, rows, cols, vals)?;
64        Ok(Self { row_ptr, col_ind, values, nrows, ncols })
65    }
66
67    /// 行 i の非ゼロ要素を取得する
68    ///
69    /// 列インデックス配列と値配列のスライスを返す。両スライスの長さは等しく、
70    /// 列インデックスは昇順にソートされている。
71    ///
72    /// # 引数
73    /// - `i`: 取得する行インデックス(0-based)
74    ///
75    /// # 戻り値
76    /// - `Ok((col_indices, values))`: 行 i の列インデックスと値のスライスペア
77    /// - `Err`: `i` が範囲外の場合
78    pub fn get_row(&self, i: usize) -> Result<(&[usize], &[f64]), SolverError> {
79        if i >= self.nrows {
80            return Err(SolverError::IndexOutOfBounds { context: "row", index: i, bound: self.nrows });
81        }
82        let start = self.row_ptr[i];
83        let end = self.row_ptr[i + 1];
84        Ok((&self.col_ind[start..end], &self.values[start..end]))
85    }
86
87    /// CSC 行列を CSR 行列に変換する
88    ///
89    /// 直接変換アルゴリズムを使用する。
90    /// Pass 1: 各行の非ゼロ要素数を数え、prefix sum で row_ptr を構築する。
91    /// Pass 2: 列を昇順に走査して col_ind/values を埋める。
92    /// 列を昇順で処理するため、各行の col_ind は自動的にソート済みとなる。
93    /// 計算量は O(nnz)(トリプレット経由の O(nnz log nnz) より高速)。
94    ///
95    /// # 引数
96    /// - `csc`: 変換元の CSC 行列
97    pub fn from_csc(csc: &CscMatrix) -> Self {
98        let nnz = csc.nnz();
99        let nrows = csc.nrows;
100        let ncols = csc.ncols;
101
102        // Pass 1: 各行の要素数をカウントし、prefix sum で row_ptr を構築する
103        let mut row_ptr = vec![0usize; nrows + 1];
104        for &r in &csc.row_ind {
105            row_ptr[r + 1] += 1;
106        }
107        for i in 0..nrows {
108            row_ptr[i + 1] += row_ptr[i];
109        }
110
111        // Pass 2: 列を昇順に走査して col_ind/values を配置する
112        // cur[i] = 行 i の次の書き込み位置
113        let mut col_ind = vec![0usize; nnz];
114        let mut values = vec![0.0f64; nnz];
115        let mut cur = row_ptr[..nrows].to_vec();
116
117        for j in 0..ncols {
118            let start = csc.col_ptr[j];
119            let end = csc.col_ptr[j + 1];
120            for k in start..end {
121                let r = csc.row_ind[k];
122                let pos = cur[r];
123                col_ind[pos] = j;
124                values[pos] = csc.values[k];
125                cur[r] += 1;
126            }
127        }
128
129        Self {
130            row_ptr,
131            col_ind,
132            values,
133            nrows,
134            ncols,
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn test_csr_from_triplets() {
145        let rows = vec![0, 0, 1, 2, 2];
146        let cols = vec![0, 2, 1, 0, 2];
147        let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
148        let mat = CsrMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
149        assert_eq!(mat.nrows, 3);
150        assert_eq!(mat.ncols, 3);
151        assert_eq!(mat.nnz(), 5);
152
153        let (ci, v) = mat.get_row(0).unwrap();
154        assert_eq!(ci, &[0, 2]);
155        assert_eq!(v, &[1.0, 2.0]);
156
157        let (ci, v) = mat.get_row(1).unwrap();
158        assert_eq!(ci, &[1]);
159        assert_eq!(v, &[3.0]);
160
161        let (ci, v) = mat.get_row(2).unwrap();
162        assert_eq!(ci, &[0, 2]);
163        assert_eq!(v, &[4.0, 5.0]);
164    }
165
166    #[test]
167    fn test_csr_from_csc() {
168        let rows = vec![0, 2, 1, 0, 2];
169        let cols = vec![0, 0, 1, 2, 2];
170        let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
171        let csc = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
172        let csr = CsrMatrix::from_csc(&csc);
173
174        assert_eq!(csr.nrows, 3);
175        assert_eq!(csr.ncols, 3);
176        assert_eq!(csr.nnz(), 5);
177
178        let (ci, v) = csr.get_row(0).unwrap();
179        assert_eq!(ci, &[0, 2]);
180        assert_eq!(v, &[1.0, 2.0]);
181
182        let (ci, v) = csr.get_row(1).unwrap();
183        assert_eq!(ci, &[1]);
184        assert_eq!(v, &[3.0]);
185
186        let (ci, v) = csr.get_row(2).unwrap();
187        assert_eq!(ci, &[0, 2]);
188        assert_eq!(v, &[4.0, 5.0]);
189    }
190}