oxirs_vec/
sparse.rs

1use crate::{Vector, VectorError};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5/// Sparse vector representation using a hash map for efficient storage
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
7pub struct SparseVector {
8    /// Non-zero values indexed by their position
9    pub values: HashMap<usize, f32>,
10    /// Total dimensions of the vector
11    pub dimensions: usize,
12    /// Optional metadata
13    pub metadata: Option<HashMap<String, String>>,
14}
15
16impl SparseVector {
17    /// Create a new sparse vector from indices and values
18    pub fn new(
19        indices: Vec<usize>,
20        values: Vec<f32>,
21        dimensions: usize,
22    ) -> Result<Self, VectorError> {
23        if indices.len() != values.len() {
24            return Err(VectorError::InvalidDimensions(
25                "Indices and values must have same length".to_string(),
26            ));
27        }
28
29        if let Some(&max_idx) = indices.iter().max() {
30            if max_idx >= dimensions {
31                return Err(VectorError::InvalidDimensions(format!(
32                    "Index {max_idx} exceeds dimensions {dimensions}"
33                )));
34            }
35        }
36
37        let mut sparse_values = HashMap::new();
38        for (idx, val) in indices.into_iter().zip(values.into_iter()) {
39            if val != 0.0 {
40                // Only store non-zero values
41                sparse_values.insert(idx, val);
42            }
43        }
44
45        Ok(Self {
46            values: sparse_values,
47            dimensions,
48            metadata: None,
49        })
50    }
51
52    /// Create sparse vector from dense vector
53    pub fn from_dense(dense: &Vector) -> Self {
54        let values = dense.as_f32();
55        let mut sparse_values = HashMap::new();
56
57        for (idx, &val) in values.iter().enumerate() {
58            if val.abs() > f32::EPSILON {
59                // Only store non-zero values
60                sparse_values.insert(idx, val);
61            }
62        }
63
64        Self {
65            values: sparse_values,
66            dimensions: dense.dimensions,
67            metadata: dense.metadata.clone(),
68        }
69    }
70
71    /// Convert to dense vector
72    pub fn to_dense(&self) -> Vector {
73        let mut values = vec![0.0; self.dimensions];
74
75        for (&idx, &val) in &self.values {
76            if idx < self.dimensions {
77                values[idx] = val;
78            }
79        }
80
81        let mut vec = Vector::new(values);
82        vec.metadata = self.metadata.clone();
83        vec
84    }
85
86    /// Get value at index
87    pub fn get(&self, index: usize) -> f32 {
88        self.values.get(&index).copied().unwrap_or(0.0)
89    }
90
91    /// Set value at index
92    pub fn set(&mut self, index: usize, value: f32) -> Result<(), VectorError> {
93        if index >= self.dimensions {
94            return Err(VectorError::InvalidDimensions(format!(
95                "Index {} exceeds dimensions {}",
96                index, self.dimensions
97            )));
98        }
99
100        if value.abs() > f32::EPSILON {
101            self.values.insert(index, value);
102        } else {
103            self.values.remove(&index);
104        }
105
106        Ok(())
107    }
108
109    /// Get number of non-zero elements
110    pub fn nnz(&self) -> usize {
111        self.values.len()
112    }
113
114    /// Sparsity ratio (percentage of zero elements)
115    pub fn sparsity(&self) -> f32 {
116        let non_zero = self.nnz() as f32;
117        let total = self.dimensions as f32;
118        (total - non_zero) / total
119    }
120
121    /// Dot product with another sparse vector
122    pub fn dot(&self, other: &SparseVector) -> Result<f32, VectorError> {
123        if self.dimensions != other.dimensions {
124            return Err(VectorError::DimensionMismatch {
125                expected: self.dimensions,
126                actual: other.dimensions,
127            });
128        }
129
130        let mut sum = 0.0;
131
132        // Only iterate over the smaller set of indices
133        if self.values.len() <= other.values.len() {
134            for (&idx, &val) in &self.values {
135                if let Some(&other_val) = other.values.get(&idx) {
136                    sum += val * other_val;
137                }
138            }
139        } else {
140            for (&idx, &val) in &other.values {
141                if let Some(&self_val) = self.values.get(&idx) {
142                    sum += val * self_val;
143                }
144            }
145        }
146
147        Ok(sum)
148    }
149
150    /// Compute cosine similarity with another sparse vector
151    pub fn cosine_similarity(&self, other: &SparseVector) -> Result<f32, VectorError> {
152        let dot = self.dot(other)?;
153        let self_norm = self.l2_norm();
154        let other_norm = other.l2_norm();
155
156        if self_norm == 0.0 || other_norm == 0.0 {
157            Ok(0.0)
158        } else {
159            Ok(dot / (self_norm * other_norm))
160        }
161    }
162
163    /// Compute L2 norm
164    pub fn l2_norm(&self) -> f32 {
165        self.values.values().map(|v| v * v).sum::<f32>().sqrt()
166    }
167
168    /// Compute L1 norm
169    pub fn l1_norm(&self) -> f32 {
170        self.values.values().map(|v| v.abs()).sum()
171    }
172
173    /// Add another sparse vector
174    pub fn add(&self, other: &SparseVector) -> Result<SparseVector, VectorError> {
175        if self.dimensions != other.dimensions {
176            return Err(VectorError::DimensionMismatch {
177                expected: self.dimensions,
178                actual: other.dimensions,
179            });
180        }
181
182        let mut result = self.clone();
183
184        for (&idx, &val) in &other.values {
185            let new_val = result.get(idx) + val;
186            result.set(idx, new_val)?;
187        }
188
189        Ok(result)
190    }
191
192    /// Subtract another sparse vector
193    pub fn subtract(&self, other: &SparseVector) -> Result<SparseVector, VectorError> {
194        if self.dimensions != other.dimensions {
195            return Err(VectorError::DimensionMismatch {
196                expected: self.dimensions,
197                actual: other.dimensions,
198            });
199        }
200
201        let mut result = self.clone();
202
203        for (&idx, &val) in &other.values {
204            let new_val = result.get(idx) - val;
205            result.set(idx, new_val)?;
206        }
207
208        Ok(result)
209    }
210
211    /// Scale by scalar
212    pub fn scale(&self, scalar: f32) -> SparseVector {
213        let mut result = self.clone();
214
215        for val in result.values.values_mut() {
216            *val *= scalar;
217        }
218
219        result
220    }
221
222    /// Normalize to unit length
223    pub fn normalize(&self) -> SparseVector {
224        let norm = self.l2_norm();
225        if norm > 0.0 {
226            self.scale(1.0 / norm)
227        } else {
228            self.clone()
229        }
230    }
231}
232
233/// Compressed Sparse Row (CSR) format for efficient batch operations
234#[derive(Debug, Clone, PartialEq)]
235pub struct CSRMatrix {
236    /// Non-zero values
237    pub values: Vec<f32>,
238    /// Column indices for each value
239    pub col_indices: Vec<usize>,
240    /// Row pointers (start index of each row in values/col_indices)
241    pub row_ptrs: Vec<usize>,
242    /// Shape of the matrix (rows, cols)
243    pub shape: (usize, usize),
244}
245
246impl CSRMatrix {
247    /// Create CSR matrix from sparse vectors
248    pub fn from_sparse_vectors(vectors: &[SparseVector]) -> Result<Self, VectorError> {
249        if vectors.is_empty() {
250            return Ok(Self {
251                values: Vec::new(),
252                col_indices: Vec::new(),
253                row_ptrs: vec![0],
254                shape: (0, 0),
255            });
256        }
257
258        let num_rows = vectors.len();
259        let num_cols = vectors[0].dimensions;
260
261        // Verify all vectors have same dimensions
262        for (i, vec) in vectors.iter().enumerate() {
263            if vec.dimensions != num_cols {
264                return Err(VectorError::InvalidDimensions(format!(
265                    "Vector {} has {} dimensions, expected {}",
266                    i, vec.dimensions, num_cols
267                )));
268            }
269        }
270
271        let mut values = Vec::new();
272        let mut col_indices = Vec::new();
273        let mut row_ptrs = vec![0];
274
275        for vec in vectors {
276            // Sort by column index for CSR format
277            let mut sorted_entries: Vec<_> = vec.values.iter().collect();
278            sorted_entries.sort_by_key(|&(&idx, _)| idx);
279
280            for (&idx, &val) in sorted_entries {
281                values.push(val);
282                col_indices.push(idx);
283            }
284
285            row_ptrs.push(values.len());
286        }
287
288        Ok(Self {
289            values,
290            col_indices,
291            row_ptrs,
292            shape: (num_rows, num_cols),
293        })
294    }
295
296    /// Get a specific row as sparse vector
297    pub fn get_row(&self, row: usize) -> Option<SparseVector> {
298        if row >= self.shape.0 {
299            return None;
300        }
301
302        let start = self.row_ptrs[row];
303        let end = self.row_ptrs[row + 1];
304
305        let mut values = HashMap::new();
306        for i in start..end {
307            values.insert(self.col_indices[i], self.values[i]);
308        }
309
310        Some(SparseVector {
311            values,
312            dimensions: self.shape.1,
313            metadata: None,
314        })
315    }
316
317    /// Matrix-vector multiplication
318    pub fn multiply_vector(&self, vector: &SparseVector) -> Result<Vec<f32>, VectorError> {
319        if self.shape.1 != vector.dimensions {
320            return Err(VectorError::DimensionMismatch {
321                expected: self.shape.1,
322                actual: vector.dimensions,
323            });
324        }
325
326        let mut result = vec![0.0; self.shape.0];
327
328        for (row, result_val) in result.iter_mut().enumerate().take(self.shape.0) {
329            let start = self.row_ptrs[row];
330            let end = self.row_ptrs[row + 1];
331
332            let mut sum = 0.0;
333            for i in start..end {
334                let col = self.col_indices[i];
335                if let Some(&vec_val) = vector.values.get(&col) {
336                    sum += self.values[i] * vec_val;
337                }
338            }
339            *result_val = sum;
340        }
341
342        Ok(result)
343    }
344
345    /// Get memory usage in bytes
346    pub fn memory_usage(&self) -> usize {
347        self.values.len() * std::mem::size_of::<f32>()
348            + self.col_indices.len() * std::mem::size_of::<usize>()
349            + self.row_ptrs.len() * std::mem::size_of::<usize>()
350    }
351
352    /// Get sparsity of the matrix
353    pub fn sparsity(&self) -> f32 {
354        let total_elements = self.shape.0 * self.shape.1;
355        let non_zero = self.values.len();
356        (total_elements - non_zero) as f32 / total_elements as f32
357    }
358}
359
360/// Coordinate (COO) format for easy construction
361#[derive(Debug, Clone, PartialEq)]
362pub struct COOMatrix {
363    pub row_indices: Vec<usize>,
364    pub col_indices: Vec<usize>,
365    pub values: Vec<f32>,
366    pub shape: (usize, usize),
367}
368
369impl COOMatrix {
370    /// Create empty COO matrix
371    pub fn new(rows: usize, cols: usize) -> Self {
372        Self {
373            row_indices: Vec::new(),
374            col_indices: Vec::new(),
375            values: Vec::new(),
376            shape: (rows, cols),
377        }
378    }
379
380    /// Add a value to the matrix
381    pub fn add_value(&mut self, row: usize, col: usize, value: f32) -> Result<(), VectorError> {
382        if row >= self.shape.0 || col >= self.shape.1 {
383            return Err(VectorError::InvalidDimensions(format!(
384                "Index ({}, {}) out of bounds for shape {:?}",
385                row, col, self.shape
386            )));
387        }
388
389        if value.abs() > f32::EPSILON {
390            self.row_indices.push(row);
391            self.col_indices.push(col);
392            self.values.push(value);
393        }
394
395        Ok(())
396    }
397
398    /// Convert to CSR format
399    pub fn to_csr(&self) -> CSRMatrix {
400        // Sort by row, then column
401        let mut entries: Vec<_> = (0..self.values.len())
402            .map(|i| (self.row_indices[i], self.col_indices[i], self.values[i]))
403            .collect();
404        entries.sort_by_key(|&(r, c, _)| (r, c));
405
406        let mut values = Vec::new();
407        let mut col_indices = Vec::new();
408        let mut row_ptrs = vec![0];
409
410        let mut current_row = 0;
411        for (row, col, val) in entries {
412            while current_row < row {
413                row_ptrs.push(values.len());
414                current_row += 1;
415            }
416            values.push(val);
417            col_indices.push(col);
418        }
419
420        while current_row < self.shape.0 {
421            row_ptrs.push(values.len());
422            current_row += 1;
423        }
424
425        CSRMatrix {
426            values,
427            col_indices,
428            row_ptrs,
429            shape: self.shape,
430        }
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_sparse_vector_creation() {
440        let indices = vec![0, 3, 7];
441        let values = vec![1.0, 2.0, 3.0];
442        let sparse = SparseVector::new(indices, values, 10).unwrap();
443
444        assert_eq!(sparse.get(0), 1.0);
445        assert_eq!(sparse.get(3), 2.0);
446        assert_eq!(sparse.get(7), 3.0);
447        assert_eq!(sparse.get(5), 0.0);
448        assert_eq!(sparse.nnz(), 3);
449        assert_eq!(sparse.dimensions, 10);
450    }
451
452    #[test]
453    fn test_sparse_dense_conversion() {
454        let dense = Vector::new(vec![0.0, 1.0, 0.0, 2.0, 0.0]);
455        let sparse = SparseVector::from_dense(&dense);
456
457        assert_eq!(sparse.nnz(), 2);
458        assert_eq!(sparse.get(1), 1.0);
459        assert_eq!(sparse.get(3), 2.0);
460
461        let dense_back = sparse.to_dense();
462        assert_eq!(dense_back.as_f32(), vec![0.0, 1.0, 0.0, 2.0, 0.0]);
463    }
464
465    #[test]
466    fn test_sparse_operations() {
467        let sparse1 = SparseVector::new(vec![0, 2, 4], vec![1.0, 2.0, 3.0], 5).unwrap();
468        let sparse2 = SparseVector::new(vec![1, 2, 3], vec![4.0, 5.0, 6.0], 5).unwrap();
469
470        // Dot product
471        let dot = sparse1.dot(&sparse2).unwrap();
472        assert_eq!(dot, 10.0); // Only index 2 overlaps: 2.0 * 5.0 = 10.0
473
474        // Addition
475        let sum = sparse1.add(&sparse2).unwrap();
476        assert_eq!(sum.get(0), 1.0);
477        assert_eq!(sum.get(1), 4.0);
478        assert_eq!(sum.get(2), 7.0);
479        assert_eq!(sum.get(3), 6.0);
480        assert_eq!(sum.get(4), 3.0);
481
482        // Scaling
483        let scaled = sparse1.scale(2.0);
484        assert_eq!(scaled.get(0), 2.0);
485        assert_eq!(scaled.get(2), 4.0);
486        assert_eq!(scaled.get(4), 6.0);
487    }
488
489    #[test]
490    fn test_csr_matrix() {
491        let vectors = vec![
492            SparseVector::new(vec![0, 2], vec![1.0, 2.0], 4).unwrap(),
493            SparseVector::new(vec![1, 3], vec![3.0, 4.0], 4).unwrap(),
494            SparseVector::new(vec![0, 1, 2], vec![5.0, 6.0, 7.0], 4).unwrap(),
495        ];
496
497        let csr = CSRMatrix::from_sparse_vectors(&vectors).unwrap();
498
499        assert_eq!(csr.shape, (3, 4));
500        assert_eq!(csr.values.len(), 7);
501        assert_eq!(csr.row_ptrs, vec![0, 2, 4, 7]);
502
503        // Test row extraction
504        let row1 = csr.get_row(1).unwrap();
505        assert_eq!(row1.get(1), 3.0);
506        assert_eq!(row1.get(3), 4.0);
507    }
508
509    #[test]
510    fn test_coo_to_csr() {
511        let mut coo = COOMatrix::new(3, 3);
512        coo.add_value(0, 0, 1.0).unwrap();
513        coo.add_value(0, 2, 2.0).unwrap();
514        coo.add_value(1, 1, 3.0).unwrap();
515        coo.add_value(2, 0, 4.0).unwrap();
516        coo.add_value(2, 2, 5.0).unwrap();
517
518        let csr = coo.to_csr();
519        assert_eq!(csr.values, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
520        assert_eq!(csr.col_indices, vec![0, 2, 1, 0, 2]);
521        assert_eq!(csr.row_ptrs, vec![0, 2, 3, 5]);
522    }
523}