1use crate::{Vector, VectorError};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
7pub struct SparseVector {
8 pub values: HashMap<usize, f32>,
10 pub dimensions: usize,
12 pub metadata: Option<HashMap<String, String>>,
14}
15
16impl SparseVector {
17 pub fn new(
19 indices: Vec<usize>,
20 values: Vec<f32>,
21 dimensions: usize,
22 ) -> Result<Self, VectorError> {
23 if indices.len() != values.len() {
24 return Err(VectorError::InvalidDimensions(
25 "Indices and values must have same length".to_string(),
26 ));
27 }
28
29 if let Some(&max_idx) = indices.iter().max() {
30 if max_idx >= dimensions {
31 return Err(VectorError::InvalidDimensions(format!(
32 "Index {max_idx} exceeds dimensions {dimensions}"
33 )));
34 }
35 }
36
37 let mut sparse_values = HashMap::new();
38 for (idx, val) in indices.into_iter().zip(values.into_iter()) {
39 if val != 0.0 {
40 sparse_values.insert(idx, val);
42 }
43 }
44
45 Ok(Self {
46 values: sparse_values,
47 dimensions,
48 metadata: None,
49 })
50 }
51
52 pub fn from_dense(dense: &Vector) -> Self {
54 let values = dense.as_f32();
55 let mut sparse_values = HashMap::new();
56
57 for (idx, &val) in values.iter().enumerate() {
58 if val.abs() > f32::EPSILON {
59 sparse_values.insert(idx, val);
61 }
62 }
63
64 Self {
65 values: sparse_values,
66 dimensions: dense.dimensions,
67 metadata: dense.metadata.clone(),
68 }
69 }
70
71 pub fn to_dense(&self) -> Vector {
73 let mut values = vec![0.0; self.dimensions];
74
75 for (&idx, &val) in &self.values {
76 if idx < self.dimensions {
77 values[idx] = val;
78 }
79 }
80
81 let mut vec = Vector::new(values);
82 vec.metadata = self.metadata.clone();
83 vec
84 }
85
86 pub fn get(&self, index: usize) -> f32 {
88 self.values.get(&index).copied().unwrap_or(0.0)
89 }
90
91 pub fn set(&mut self, index: usize, value: f32) -> Result<(), VectorError> {
93 if index >= self.dimensions {
94 return Err(VectorError::InvalidDimensions(format!(
95 "Index {} exceeds dimensions {}",
96 index, self.dimensions
97 )));
98 }
99
100 if value.abs() > f32::EPSILON {
101 self.values.insert(index, value);
102 } else {
103 self.values.remove(&index);
104 }
105
106 Ok(())
107 }
108
109 pub fn nnz(&self) -> usize {
111 self.values.len()
112 }
113
114 pub fn sparsity(&self) -> f32 {
116 let non_zero = self.nnz() as f32;
117 let total = self.dimensions as f32;
118 (total - non_zero) / total
119 }
120
121 pub fn dot(&self, other: &SparseVector) -> Result<f32, VectorError> {
123 if self.dimensions != other.dimensions {
124 return Err(VectorError::DimensionMismatch {
125 expected: self.dimensions,
126 actual: other.dimensions,
127 });
128 }
129
130 let mut sum = 0.0;
131
132 if self.values.len() <= other.values.len() {
134 for (&idx, &val) in &self.values {
135 if let Some(&other_val) = other.values.get(&idx) {
136 sum += val * other_val;
137 }
138 }
139 } else {
140 for (&idx, &val) in &other.values {
141 if let Some(&self_val) = self.values.get(&idx) {
142 sum += val * self_val;
143 }
144 }
145 }
146
147 Ok(sum)
148 }
149
150 pub fn cosine_similarity(&self, other: &SparseVector) -> Result<f32, VectorError> {
152 let dot = self.dot(other)?;
153 let self_norm = self.l2_norm();
154 let other_norm = other.l2_norm();
155
156 if self_norm == 0.0 || other_norm == 0.0 {
157 Ok(0.0)
158 } else {
159 Ok(dot / (self_norm * other_norm))
160 }
161 }
162
163 pub fn l2_norm(&self) -> f32 {
165 self.values.values().map(|v| v * v).sum::<f32>().sqrt()
166 }
167
168 pub fn l1_norm(&self) -> f32 {
170 self.values.values().map(|v| v.abs()).sum()
171 }
172
173 pub fn add(&self, other: &SparseVector) -> Result<SparseVector, VectorError> {
175 if self.dimensions != other.dimensions {
176 return Err(VectorError::DimensionMismatch {
177 expected: self.dimensions,
178 actual: other.dimensions,
179 });
180 }
181
182 let mut result = self.clone();
183
184 for (&idx, &val) in &other.values {
185 let new_val = result.get(idx) + val;
186 result.set(idx, new_val)?;
187 }
188
189 Ok(result)
190 }
191
192 pub fn subtract(&self, other: &SparseVector) -> Result<SparseVector, VectorError> {
194 if self.dimensions != other.dimensions {
195 return Err(VectorError::DimensionMismatch {
196 expected: self.dimensions,
197 actual: other.dimensions,
198 });
199 }
200
201 let mut result = self.clone();
202
203 for (&idx, &val) in &other.values {
204 let new_val = result.get(idx) - val;
205 result.set(idx, new_val)?;
206 }
207
208 Ok(result)
209 }
210
211 pub fn scale(&self, scalar: f32) -> SparseVector {
213 let mut result = self.clone();
214
215 for val in result.values.values_mut() {
216 *val *= scalar;
217 }
218
219 result
220 }
221
222 pub fn normalize(&self) -> SparseVector {
224 let norm = self.l2_norm();
225 if norm > 0.0 {
226 self.scale(1.0 / norm)
227 } else {
228 self.clone()
229 }
230 }
231}
232
233#[derive(Debug, Clone, PartialEq)]
235pub struct CSRMatrix {
236 pub values: Vec<f32>,
238 pub col_indices: Vec<usize>,
240 pub row_ptrs: Vec<usize>,
242 pub shape: (usize, usize),
244}
245
246impl CSRMatrix {
247 pub fn from_sparse_vectors(vectors: &[SparseVector]) -> Result<Self, VectorError> {
249 if vectors.is_empty() {
250 return Ok(Self {
251 values: Vec::new(),
252 col_indices: Vec::new(),
253 row_ptrs: vec![0],
254 shape: (0, 0),
255 });
256 }
257
258 let num_rows = vectors.len();
259 let num_cols = vectors[0].dimensions;
260
261 for (i, vec) in vectors.iter().enumerate() {
263 if vec.dimensions != num_cols {
264 return Err(VectorError::InvalidDimensions(format!(
265 "Vector {} has {} dimensions, expected {}",
266 i, vec.dimensions, num_cols
267 )));
268 }
269 }
270
271 let mut values = Vec::new();
272 let mut col_indices = Vec::new();
273 let mut row_ptrs = vec![0];
274
275 for vec in vectors {
276 let mut sorted_entries: Vec<_> = vec.values.iter().collect();
278 sorted_entries.sort_by_key(|&(&idx, _)| idx);
279
280 for (&idx, &val) in sorted_entries {
281 values.push(val);
282 col_indices.push(idx);
283 }
284
285 row_ptrs.push(values.len());
286 }
287
288 Ok(Self {
289 values,
290 col_indices,
291 row_ptrs,
292 shape: (num_rows, num_cols),
293 })
294 }
295
296 pub fn get_row(&self, row: usize) -> Option<SparseVector> {
298 if row >= self.shape.0 {
299 return None;
300 }
301
302 let start = self.row_ptrs[row];
303 let end = self.row_ptrs[row + 1];
304
305 let mut values = HashMap::new();
306 for i in start..end {
307 values.insert(self.col_indices[i], self.values[i]);
308 }
309
310 Some(SparseVector {
311 values,
312 dimensions: self.shape.1,
313 metadata: None,
314 })
315 }
316
317 pub fn multiply_vector(&self, vector: &SparseVector) -> Result<Vec<f32>, VectorError> {
319 if self.shape.1 != vector.dimensions {
320 return Err(VectorError::DimensionMismatch {
321 expected: self.shape.1,
322 actual: vector.dimensions,
323 });
324 }
325
326 let mut result = vec![0.0; self.shape.0];
327
328 for (row, result_val) in result.iter_mut().enumerate().take(self.shape.0) {
329 let start = self.row_ptrs[row];
330 let end = self.row_ptrs[row + 1];
331
332 let mut sum = 0.0;
333 for i in start..end {
334 let col = self.col_indices[i];
335 if let Some(&vec_val) = vector.values.get(&col) {
336 sum += self.values[i] * vec_val;
337 }
338 }
339 *result_val = sum;
340 }
341
342 Ok(result)
343 }
344
345 pub fn memory_usage(&self) -> usize {
347 self.values.len() * std::mem::size_of::<f32>()
348 + self.col_indices.len() * std::mem::size_of::<usize>()
349 + self.row_ptrs.len() * std::mem::size_of::<usize>()
350 }
351
352 pub fn sparsity(&self) -> f32 {
354 let total_elements = self.shape.0 * self.shape.1;
355 let non_zero = self.values.len();
356 (total_elements - non_zero) as f32 / total_elements as f32
357 }
358}
359
360#[derive(Debug, Clone, PartialEq)]
362pub struct COOMatrix {
363 pub row_indices: Vec<usize>,
364 pub col_indices: Vec<usize>,
365 pub values: Vec<f32>,
366 pub shape: (usize, usize),
367}
368
369impl COOMatrix {
370 pub fn new(rows: usize, cols: usize) -> Self {
372 Self {
373 row_indices: Vec::new(),
374 col_indices: Vec::new(),
375 values: Vec::new(),
376 shape: (rows, cols),
377 }
378 }
379
380 pub fn add_value(&mut self, row: usize, col: usize, value: f32) -> Result<(), VectorError> {
382 if row >= self.shape.0 || col >= self.shape.1 {
383 return Err(VectorError::InvalidDimensions(format!(
384 "Index ({}, {}) out of bounds for shape {:?}",
385 row, col, self.shape
386 )));
387 }
388
389 if value.abs() > f32::EPSILON {
390 self.row_indices.push(row);
391 self.col_indices.push(col);
392 self.values.push(value);
393 }
394
395 Ok(())
396 }
397
398 pub fn to_csr(&self) -> CSRMatrix {
400 let mut entries: Vec<_> = (0..self.values.len())
402 .map(|i| (self.row_indices[i], self.col_indices[i], self.values[i]))
403 .collect();
404 entries.sort_by_key(|&(r, c, _)| (r, c));
405
406 let mut values = Vec::new();
407 let mut col_indices = Vec::new();
408 let mut row_ptrs = vec![0];
409
410 let mut current_row = 0;
411 for (row, col, val) in entries {
412 while current_row < row {
413 row_ptrs.push(values.len());
414 current_row += 1;
415 }
416 values.push(val);
417 col_indices.push(col);
418 }
419
420 while current_row < self.shape.0 {
421 row_ptrs.push(values.len());
422 current_row += 1;
423 }
424
425 CSRMatrix {
426 values,
427 col_indices,
428 row_ptrs,
429 shape: self.shape,
430 }
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_sparse_vector_creation() {
440 let indices = vec![0, 3, 7];
441 let values = vec![1.0, 2.0, 3.0];
442 let sparse = SparseVector::new(indices, values, 10).unwrap();
443
444 assert_eq!(sparse.get(0), 1.0);
445 assert_eq!(sparse.get(3), 2.0);
446 assert_eq!(sparse.get(7), 3.0);
447 assert_eq!(sparse.get(5), 0.0);
448 assert_eq!(sparse.nnz(), 3);
449 assert_eq!(sparse.dimensions, 10);
450 }
451
452 #[test]
453 fn test_sparse_dense_conversion() {
454 let dense = Vector::new(vec![0.0, 1.0, 0.0, 2.0, 0.0]);
455 let sparse = SparseVector::from_dense(&dense);
456
457 assert_eq!(sparse.nnz(), 2);
458 assert_eq!(sparse.get(1), 1.0);
459 assert_eq!(sparse.get(3), 2.0);
460
461 let dense_back = sparse.to_dense();
462 assert_eq!(dense_back.as_f32(), vec![0.0, 1.0, 0.0, 2.0, 0.0]);
463 }
464
465 #[test]
466 fn test_sparse_operations() {
467 let sparse1 = SparseVector::new(vec![0, 2, 4], vec![1.0, 2.0, 3.0], 5).unwrap();
468 let sparse2 = SparseVector::new(vec![1, 2, 3], vec![4.0, 5.0, 6.0], 5).unwrap();
469
470 let dot = sparse1.dot(&sparse2).unwrap();
472 assert_eq!(dot, 10.0); let sum = sparse1.add(&sparse2).unwrap();
476 assert_eq!(sum.get(0), 1.0);
477 assert_eq!(sum.get(1), 4.0);
478 assert_eq!(sum.get(2), 7.0);
479 assert_eq!(sum.get(3), 6.0);
480 assert_eq!(sum.get(4), 3.0);
481
482 let scaled = sparse1.scale(2.0);
484 assert_eq!(scaled.get(0), 2.0);
485 assert_eq!(scaled.get(2), 4.0);
486 assert_eq!(scaled.get(4), 6.0);
487 }
488
489 #[test]
490 fn test_csr_matrix() {
491 let vectors = vec![
492 SparseVector::new(vec![0, 2], vec![1.0, 2.0], 4).unwrap(),
493 SparseVector::new(vec![1, 3], vec![3.0, 4.0], 4).unwrap(),
494 SparseVector::new(vec![0, 1, 2], vec![5.0, 6.0, 7.0], 4).unwrap(),
495 ];
496
497 let csr = CSRMatrix::from_sparse_vectors(&vectors).unwrap();
498
499 assert_eq!(csr.shape, (3, 4));
500 assert_eq!(csr.values.len(), 7);
501 assert_eq!(csr.row_ptrs, vec![0, 2, 4, 7]);
502
503 let row1 = csr.get_row(1).unwrap();
505 assert_eq!(row1.get(1), 3.0);
506 assert_eq!(row1.get(3), 4.0);
507 }
508
509 #[test]
510 fn test_coo_to_csr() {
511 let mut coo = COOMatrix::new(3, 3);
512 coo.add_value(0, 0, 1.0).unwrap();
513 coo.add_value(0, 2, 2.0).unwrap();
514 coo.add_value(1, 1, 3.0).unwrap();
515 coo.add_value(2, 0, 4.0).unwrap();
516 coo.add_value(2, 2, 5.0).unwrap();
517
518 let csr = coo.to_csr();
519 assert_eq!(csr.values, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
520 assert_eq!(csr.col_indices, vec![0, 2, 1, 0, 2]);
521 assert_eq!(csr.row_ptrs, vec![0, 2, 3, 5]);
522 }
523}