1use crate::tolerances::ZERO_TOL;
2
3#[derive(Debug, Clone)]
9pub struct SparseVec {
10 pub indices: Vec<usize>,
12 pub values: Vec<f64>,
14 pub len: usize, }
17
18impl SparseVec {
19 pub fn new(len: usize) -> Self {
26 Self {
27 indices: Vec::new(),
28 values: Vec::new(),
29 len,
30 }
31 }
32
33 pub fn from_dense(dense: &[f64]) -> Self {
41 let mut indices = Vec::new();
42 let mut values = Vec::new();
43 for (i, &v) in dense.iter().enumerate() {
44 if v.abs() > ZERO_TOL {
45 indices.push(i);
46 values.push(v);
47 }
48 }
49 Self {
50 indices,
51 values,
52 len: dense.len(),
53 }
54 }
55
56 pub fn to_dense(&self) -> Vec<f64> {
61 let mut dense = vec![0.0; self.len];
62 for (k, &idx) in self.indices.iter().enumerate() {
63 dense[idx] = self.values[k];
64 }
65 dense
66 }
67
68 pub fn to_dense_into(&self, buf: &mut [f64]) {
76 for v in buf.iter_mut() {
77 *v = 0.0;
78 }
79 for (k, &idx) in self.indices.iter().enumerate() {
80 buf[idx] = self.values[k];
81 }
82 }
83
84 pub fn get(&self, idx: usize) -> f64 {
92 match self.indices.binary_search(&idx) {
93 Ok(pos) => self.values[pos],
94 Err(_) => 0.0,
95 }
96 }
97
98 pub fn set(&mut self, idx: usize, val: f64) {
108 match self.indices.binary_search(&idx) {
109 Ok(pos) => {
110 if val.abs() <= ZERO_TOL {
111 self.indices.remove(pos);
112 self.values.remove(pos);
113 } else {
114 self.values[pos] = val;
115 }
116 }
117 Err(pos) => {
118 if val.abs() > ZERO_TOL {
119 self.indices.insert(pos, idx);
120 self.values.insert(pos, val);
121 }
122 }
123 }
124 }
125
126 pub fn axpy(&mut self, alpha: f64, other: &SparseVec) {
135 let mut new_indices = Vec::new();
136 let mut new_values = Vec::new();
137 let (mut i, mut j) = (0, 0);
138
139 while i < self.indices.len() && j < other.indices.len() {
140 if self.indices[i] == other.indices[j] {
141 let val = self.values[i] + alpha * other.values[j];
142 if val.abs() > ZERO_TOL {
143 new_indices.push(self.indices[i]);
144 new_values.push(val);
145 }
146 i += 1;
147 j += 1;
148 } else if self.indices[i] < other.indices[j] {
149 new_indices.push(self.indices[i]);
150 new_values.push(self.values[i]);
151 i += 1;
152 } else {
153 let val = alpha * other.values[j];
154 if val.abs() > ZERO_TOL {
155 new_indices.push(other.indices[j]);
156 new_values.push(val);
157 }
158 j += 1;
159 }
160 }
161 while i < self.indices.len() {
163 new_indices.push(self.indices[i]);
164 new_values.push(self.values[i]);
165 i += 1;
166 }
167 while j < other.indices.len() {
169 let val = alpha * other.values[j];
170 if val.abs() > ZERO_TOL {
171 new_indices.push(other.indices[j]);
172 new_values.push(val);
173 }
174 j += 1;
175 }
176
177 self.indices = new_indices;
178 self.values = new_values;
179 }
180
181 pub fn dot(&self, other: &SparseVec) -> f64 {
189 let mut result = 0.0;
190 let (mut i, mut j) = (0, 0);
191 while i < self.indices.len() && j < other.indices.len() {
192 if self.indices[i] == other.indices[j] {
193 result += self.values[i] * other.values[j];
194 i += 1;
195 j += 1;
196 } else if self.indices[i] < other.indices[j] {
197 i += 1;
198 } else {
199 j += 1;
200 }
201 }
202 result
203 }
204
205 pub fn dot_dense(&self, dense: &[f64]) -> f64 {
213 let mut result = 0.0;
214 for (k, &idx) in self.indices.iter().enumerate() {
215 if idx < dense.len() {
216 result += self.values[k] * dense[idx];
217 }
218 }
219 result
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn test_sparse_vec_from_dense_to_dense() {
229 let dense = vec![1.0, 0.0, 0.0, 3.5, 0.0, -2.0];
230 let sv = SparseVec::from_dense(&dense);
231 assert_eq!(sv.len, 6);
232 assert_eq!(sv.indices, vec![0, 3, 5]);
233 assert_eq!(sv.values, vec![1.0, 3.5, -2.0]);
234
235 let back = sv.to_dense();
236 assert_eq!(back, dense);
237 }
238
239 #[test]
240 fn test_sparse_vec_get_set() {
241 let mut sv = SparseVec::new(5);
242 assert_eq!(sv.get(0), 0.0);
243 assert_eq!(sv.get(4), 0.0);
244
245 sv.set(2, 7.0);
246 sv.set(4, -1.0);
247 assert_eq!(sv.get(2), 7.0);
248 assert_eq!(sv.get(4), -1.0);
249 assert_eq!(sv.get(3), 0.0);
250
251 sv.set(2, 3.0);
253 assert_eq!(sv.get(2), 3.0);
254
255 sv.set(2, 0.0);
257 assert_eq!(sv.get(2), 0.0);
258 assert_eq!(sv.indices, vec![4]);
259 }
260
261 #[test]
262 fn test_sparse_vec_dot() {
263 let a = SparseVec::from_dense(&[1.0, 0.0, 3.0, 0.0]);
264 let b = SparseVec::from_dense(&[2.0, 5.0, 4.0, 0.0]);
265 assert!((a.dot(&b) - 14.0).abs() < 1e-10);
267
268 let dense = vec![2.0, 5.0, 4.0, 0.0];
270 assert!((a.dot_dense(&dense) - 14.0).abs() < 1e-10);
271 }
272
273 #[test]
274 fn test_sparse_vec_axpy() {
275 let mut a = SparseVec::from_dense(&[1.0, 0.0, 3.0]);
276 let b = SparseVec::from_dense(&[0.0, 2.0, 1.0]);
277 a.axpy(2.0, &b);
278 let dense = a.to_dense();
280 assert!((dense[0] - 1.0).abs() < 1e-10);
281 assert!((dense[1] - 4.0).abs() < 1e-10);
282 assert!((dense[2] - 5.0).abs() < 1e-10);
283 }
284
285 #[test]
286 fn test_dot_different_len() {
287 let a = SparseVec { indices: vec![0, 2], values: vec![1.0, 2.0], len: 3 };
289 let b = SparseVec {
290 indices: vec![0, 1, 2, 3, 4],
291 values: vec![3.0, 4.0, 5.0, 6.0, 7.0],
292 len: 5,
293 };
294 assert!((a.dot(&b) - 13.0).abs() < 1e-10);
296
297 let empty_a = SparseVec::new(3);
299 let empty_b = SparseVec::new(5);
300 assert_eq!(empty_a.dot(&empty_b), 0.0);
301 }
302
303 #[test]
304 fn test_axpy_different_len() {
305 let mut a = SparseVec { indices: vec![0], values: vec![1.0], len: 3 };
307 let b = SparseVec { indices: vec![0, 2], values: vec![2.0, 3.0], len: 5 };
308 a.axpy(1.0, &b);
309 assert!((a.get(0) - 3.0).abs() < 1e-10, "index 0: expected 3.0, got {}", a.get(0));
310 assert!((a.get(2) - 3.0).abs() < 1e-10, "index 2: expected 3.0, got {}", a.get(2));
311 assert_eq!(a.get(1), 0.0, "index 1 should remain 0");
312
313 let mut empty = SparseVec::new(3);
315 let src = SparseVec { indices: vec![1, 2], values: vec![4.0, 5.0], len: 3 };
316 empty.axpy(1.0, &src);
317 assert!((empty.get(1) - 4.0).abs() < 1e-10, "index 1: expected 4.0, got {}", empty.get(1));
318 assert!((empty.get(2) - 5.0).abs() < 1e-10, "index 2: expected 5.0, got {}", empty.get(2));
319 assert_eq!(empty.get(0), 0.0, "index 0 should be 0");
320 }
321}