Skip to main content

otspot_core/sparse/
vec.rs

1use crate::tolerances::ZERO_TOL;
2
3/// 疎ベクトル(インデックス・値のペアリスト、インデックスで昇順ソート済み)
4///
5/// ゼロでない要素のみをインデックスと値のペアで保持する。
6/// `indices` は常に昇順にソートされており、二分探索による O(log n) アクセスが可能。
7/// ゼロ近傍の値(絶対値が `ZERO_TOL` 以下)は自動的に除去される。
8#[derive(Debug, Clone)]
9pub struct SparseVec {
10    /// 非ゼロ要素のインデックス(昇順ソート済み)
11    pub indices: Vec<usize>,
12    /// 非ゼロ要素の値(`indices` と同じ順序)
13    pub values: Vec<f64>,
14    /// 論理的な長さ(ゼロ要素を含む全体の次元数)
15    pub len: usize, // logical length
16}
17
18impl SparseVec {
19    /// 指定した論理長の空疎ベクトルを生成する
20    ///
21    /// 非ゼロ要素は含まない(すべてゼロ)状態で初期化される。
22    ///
23    /// # 引数
24    /// - `len`: ベクトルの論理的な長さ(次元数)
25    pub fn new(len: usize) -> Self {
26        Self {
27            indices: Vec::new(),
28            values: Vec::new(),
29            len,
30        }
31    }
32
33    /// 密ベクトルから疎ベクトルを生成する
34    ///
35    /// 絶対値が ZERO_TOL(1e-12)を超える要素のみを保持し、残りは捨てる。
36    /// インデックスは元の配列の位置順(昇順)で格納される。
37    ///
38    /// # 引数
39    /// - `dense`: 変換元の密ベクトル(スライス)
40    pub fn from_dense(dense: &[f64]) -> Self {
41        let mut indices = Vec::new();
42        let mut values = Vec::new();
43        for (i, &v) in dense.iter().enumerate() {
44            if v.abs() > ZERO_TOL {
45                indices.push(i);
46                values.push(v);
47            }
48        }
49        Self {
50            indices,
51            values,
52            len: dense.len(),
53        }
54    }
55
56    /// 疎ベクトルを密ベクトルに変換する
57    ///
58    /// 非ゼロ要素を対応するインデックスに配置し、残りはゼロで埋める。
59    /// 返却ベクトルの長さは `self.len` と等しい。
60    pub fn to_dense(&self) -> Vec<f64> {
61        let mut dense = vec![0.0; self.len];
62        for (k, &idx) in self.indices.iter().enumerate() {
63            dense[idx] = self.values[k];
64        }
65        dense
66    }
67
68    /// 事前確保済みバッファに密ベクトルを書き込む
69    ///
70    /// `buf` を一旦ゼロクリアしてから非ゼロ要素を書き込む。
71    /// ヒープ割り当てを行わないため、反復ループ内での再利用に適する。
72    ///
73    /// # 引数
74    /// - `buf`: 書き込み先バッファ(長さ >= `self.len` であること)
75    pub fn to_dense_into(&self, buf: &mut [f64]) {
76        for v in buf.iter_mut() {
77            *v = 0.0;
78        }
79        for (k, &idx) in self.indices.iter().enumerate() {
80            buf[idx] = self.values[k];
81        }
82    }
83
84    /// 指定インデックスの値を取得する
85    ///
86    /// インデックスが非ゼロ要素として存在しない場合は `0.0` を返す。
87    /// 内部では二分探索を使用するため O(log n) で動作する。
88    ///
89    /// # 引数
90    /// - `idx`: 取得するインデックス
91    pub fn get(&self, idx: usize) -> f64 {
92        match self.indices.binary_search(&idx) {
93            Ok(pos) => self.values[pos],
94            Err(_) => 0.0,
95        }
96    }
97
98    /// 指定インデックスに値をセットする
99    ///
100    /// `val` の絶対値が ZERO_TOL 以下の場合、そのインデックスを非ゼロリストから削除する
101    /// (ゼロとみなす)。既存のエントリがない場合は挿入し、ある場合は上書きする。
102    /// ソート順を維持するため、挿入位置は二分探索で決定する。
103    ///
104    /// # 引数
105    /// - `idx`: セットするインデックス
106    /// - `val`: セットする値(ZERO_TOL 以下なら削除)
107    pub fn set(&mut self, idx: usize, val: f64) {
108        match self.indices.binary_search(&idx) {
109            Ok(pos) => {
110                if val.abs() <= ZERO_TOL {
111                    self.indices.remove(pos);
112                    self.values.remove(pos);
113                } else {
114                    self.values[pos] = val;
115                }
116            }
117            Err(pos) => {
118                if val.abs() > ZERO_TOL {
119                    self.indices.insert(pos, idx);
120                    self.values.insert(pos, val);
121                }
122            }
123        }
124    }
125
126    /// AXPY 演算: `self += alpha * other`
127    ///
128    /// 両ベクトルのインデックスリストをtwo-pointer mergeで走査し、
129    /// O(nnz_a + nnz_b) で演算する。ZERO_TOL 以下の結果はドロップする。
130    ///
131    /// # 引数
132    /// - `alpha`: スカラー倍率
133    /// - `other`: 加算する疎ベクトル
134    pub fn axpy(&mut self, alpha: f64, other: &SparseVec) {
135        let mut new_indices = Vec::new();
136        let mut new_values = Vec::new();
137        let (mut i, mut j) = (0, 0);
138
139        while i < self.indices.len() && j < other.indices.len() {
140            if self.indices[i] == other.indices[j] {
141                let val = self.values[i] + alpha * other.values[j];
142                if val.abs() > ZERO_TOL {
143                    new_indices.push(self.indices[i]);
144                    new_values.push(val);
145                }
146                i += 1;
147                j += 1;
148            } else if self.indices[i] < other.indices[j] {
149                new_indices.push(self.indices[i]);
150                new_values.push(self.values[i]);
151                i += 1;
152            } else {
153                let val = alpha * other.values[j];
154                if val.abs() > ZERO_TOL {
155                    new_indices.push(other.indices[j]);
156                    new_values.push(val);
157                }
158                j += 1;
159            }
160        }
161        // Drain remaining self
162        while i < self.indices.len() {
163            new_indices.push(self.indices[i]);
164            new_values.push(self.values[i]);
165            i += 1;
166        }
167        // Drain remaining other
168        while j < other.indices.len() {
169            let val = alpha * other.values[j];
170            if val.abs() > ZERO_TOL {
171                new_indices.push(other.indices[j]);
172                new_values.push(val);
173            }
174            j += 1;
175        }
176
177        self.indices = new_indices;
178        self.values = new_values;
179    }
180
181    /// 別の疎ベクトルとの内積を計算する
182    ///
183    /// 両ベクトルのインデックスリストをマージソート的に走査し、
184    /// 一致するインデックスの積を加算する。計算量は O(nnz_a + nnz_b)。
185    ///
186    /// # 引数
187    /// - `other`: 内積を取る相手の疎ベクトル
188    pub fn dot(&self, other: &SparseVec) -> f64 {
189        let mut result = 0.0;
190        let (mut i, mut j) = (0, 0);
191        while i < self.indices.len() && j < other.indices.len() {
192            if self.indices[i] == other.indices[j] {
193                result += self.values[i] * other.values[j];
194                i += 1;
195                j += 1;
196            } else if self.indices[i] < other.indices[j] {
197                i += 1;
198            } else {
199                j += 1;
200            }
201        }
202        result
203    }
204
205    /// 密ベクトルとの内積を計算する
206    ///
207    /// 疎ベクトルの非ゼロ要素のインデックスのみを参照するため、
208    /// 密ベクトルとの積でも O(nnz) で動作する。
209    ///
210    /// # 引数
211    /// - `dense`: 内積を取る相手の密ベクトル(スライス)
212    pub fn dot_dense(&self, dense: &[f64]) -> f64 {
213        let mut result = 0.0;
214        for (k, &idx) in self.indices.iter().enumerate() {
215            if idx < dense.len() {
216                result += self.values[k] * dense[idx];
217            }
218        }
219        result
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_sparse_vec_from_dense_to_dense() {
229        let dense = vec![1.0, 0.0, 0.0, 3.5, 0.0, -2.0];
230        let sv = SparseVec::from_dense(&dense);
231        assert_eq!(sv.len, 6);
232        assert_eq!(sv.indices, vec![0, 3, 5]);
233        assert_eq!(sv.values, vec![1.0, 3.5, -2.0]);
234
235        let back = sv.to_dense();
236        assert_eq!(back, dense);
237    }
238
239    #[test]
240    fn test_sparse_vec_get_set() {
241        let mut sv = SparseVec::new(5);
242        assert_eq!(sv.get(0), 0.0);
243        assert_eq!(sv.get(4), 0.0);
244
245        sv.set(2, 7.0);
246        sv.set(4, -1.0);
247        assert_eq!(sv.get(2), 7.0);
248        assert_eq!(sv.get(4), -1.0);
249        assert_eq!(sv.get(3), 0.0);
250
251        // Overwrite
252        sv.set(2, 3.0);
253        assert_eq!(sv.get(2), 3.0);
254
255        // Remove by setting to zero
256        sv.set(2, 0.0);
257        assert_eq!(sv.get(2), 0.0);
258        assert_eq!(sv.indices, vec![4]);
259    }
260
261    #[test]
262    fn test_sparse_vec_dot() {
263        let a = SparseVec::from_dense(&[1.0, 0.0, 3.0, 0.0]);
264        let b = SparseVec::from_dense(&[2.0, 5.0, 4.0, 0.0]);
265        // 1*2 + 0*5 + 3*4 + 0*0 = 14
266        assert!((a.dot(&b) - 14.0).abs() < 1e-10);
267
268        // Dot with dense
269        let dense = vec![2.0, 5.0, 4.0, 0.0];
270        assert!((a.dot_dense(&dense) - 14.0).abs() < 1e-10);
271    }
272
273    #[test]
274    fn test_sparse_vec_axpy() {
275        let mut a = SparseVec::from_dense(&[1.0, 0.0, 3.0]);
276        let b = SparseVec::from_dense(&[0.0, 2.0, 1.0]);
277        a.axpy(2.0, &b);
278        // a = [1, 0, 3] + 2*[0, 2, 1] = [1, 4, 5]
279        let dense = a.to_dense();
280        assert!((dense[0] - 1.0).abs() < 1e-10);
281        assert!((dense[1] - 4.0).abs() < 1e-10);
282        assert!((dense[2] - 5.0).abs() < 1e-10);
283    }
284
285    #[test]
286    fn test_dot_different_len() {
287        // SparseVec{len:3} と SparseVec{len:5} の dot → 共通インデックス範囲のみで正しく計算
288        let a = SparseVec { indices: vec![0, 2], values: vec![1.0, 2.0], len: 3 };
289        let b = SparseVec {
290            indices: vec![0, 1, 2, 3, 4],
291            values: vec![3.0, 4.0, 5.0, 6.0, 7.0],
292            len: 5,
293        };
294        // 共通インデックス: 0→1.0*3.0=3.0, 2→2.0*5.0=10.0 → 合計13.0
295        assert!((a.dot(&b) - 13.0).abs() < 1e-10);
296
297        // 空ベクトル同士の dot → 0.0
298        let empty_a = SparseVec::new(3);
299        let empty_b = SparseVec::new(5);
300        assert_eq!(empty_a.dot(&empty_b), 0.0);
301    }
302
303    #[test]
304    fn test_axpy_different_len() {
305        // SparseVec{len:3} に SparseVec{len:5} を axpy → len:3の範囲内のインデックスで正しく加算
306        let mut a = SparseVec { indices: vec![0], values: vec![1.0], len: 3 };
307        let b = SparseVec { indices: vec![0, 2], values: vec![2.0, 3.0], len: 5 };
308        a.axpy(1.0, &b);
309        assert!((a.get(0) - 3.0).abs() < 1e-10, "index 0: expected 3.0, got {}", a.get(0));
310        assert!((a.get(2) - 3.0).abs() < 1e-10, "index 2: expected 3.0, got {}", a.get(2));
311        assert_eq!(a.get(1), 0.0, "index 1 should remain 0");
312
313        // 空ベクトルへの axpy → other の内容がコピーされること
314        let mut empty = SparseVec::new(3);
315        let src = SparseVec { indices: vec![1, 2], values: vec![4.0, 5.0], len: 3 };
316        empty.axpy(1.0, &src);
317        assert!((empty.get(1) - 4.0).abs() < 1e-10, "index 1: expected 4.0, got {}", empty.get(1));
318        assert!((empty.get(2) - 5.0).abs() < 1e-10, "index 2: expected 5.0, got {}", empty.get(2));
319        assert_eq!(empty.get(0), 0.0, "index 0 should be 0");
320    }
321}