Skip to main content

math_sparse_data/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod surface;
4use math_linear::{F32Matrix, F32MatrixView, MatrixShape};
5use vector_analysis_core::DenseVector;
6use video_analysis_core::{DetectError, Result};
7
8fn invalid_argument(message: impl Into<String>) -> DetectError {
9    DetectError::InvalidArgument(message.into())
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13/// Variants describing sparse similarity metric.
14pub enum SparseSimilarityMetric {
15    /// The dot variant.
16    Dot,
17    /// The cosine variant.
18    Cosine,
19}
20
21#[derive(Debug, Clone, PartialEq)]
22/// Data type for sparse vector.
23pub struct SparseVector {
24    dimensions: usize,
25    indices: Vec<usize>,
26    values: Vec<f32>,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30/// Summary statistics for a validated CSR sparse matrix.
31pub struct SparseMatrixSummary {
32    /// Number of rows.
33    pub rows: usize,
34    /// Number of columns.
35    pub cols: usize,
36    /// Number of canonical stored entries.
37    pub nnz: usize,
38    /// Stored-entry density.
39    pub density: f32,
40    /// Minimum row non-zero count.
41    pub row_nnz_min: usize,
42    /// Maximum row non-zero count.
43    pub row_nnz_max: usize,
44    /// Mean row non-zero count.
45    pub row_nnz_mean: f32,
46    /// Minimum column non-zero count.
47    pub column_nnz_min: usize,
48    /// Maximum column non-zero count.
49    pub column_nnz_max: usize,
50    /// Mean column non-zero count.
51    pub column_nnz_mean: f32,
52}
53
54impl SparseVector {
55    /// Creates a new value.
56    pub fn new(dimensions: usize, indices: Vec<usize>, values: Vec<f32>) -> Result<Self> {
57        let vector = Self {
58            dimensions,
59            indices,
60            values,
61        };
62        vector.validate()?;
63        Ok(vector)
64    }
65
66    /// Returns dimensions.
67    pub fn dimensions(&self) -> usize {
68        self.dimensions
69    }
70
71    /// Returns indices.
72    pub fn indices(&self) -> &[usize] {
73        &self.indices
74    }
75
76    /// Returns values.
77    pub fn values(&self) -> &[f32] {
78        &self.values
79    }
80
81    /// Returns nnz.
82    pub fn nnz(&self) -> usize {
83        self.indices.len()
84    }
85
86    /// Validates this value.
87    pub fn validate(&self) -> Result<()> {
88        if self.dimensions == 0 {
89            return Err(invalid_argument(
90                "sparse vector dimensions must be greater than zero",
91            ));
92        }
93        if self.indices.len() != self.values.len() {
94            return Err(invalid_argument(
95                "sparse vector indices and values must have the same length",
96            ));
97        }
98        if self.values.iter().any(|value| !value.is_finite()) {
99            return Err(invalid_argument("sparse vector values must be finite"));
100        }
101        if self.indices.iter().any(|index| *index >= self.dimensions) {
102            return Err(invalid_argument("sparse vector index is out of bounds"));
103        }
104        Ok(())
105    }
106
107    /// Returns canonicalized.
108    pub fn canonicalized(&self) -> Result<Self> {
109        self.validate()?;
110        let mut pairs = self
111            .indices
112            .iter()
113            .copied()
114            .zip(self.values.iter().copied())
115            .collect::<Vec<_>>();
116        pairs.sort_by_key(|(index, _)| *index);
117        let mut indices = Vec::new();
118        let mut values = Vec::new();
119        for (index, value) in pairs {
120            if let Some(last) = indices.last().copied() {
121                if last == index {
122                    if let Some(last_value) = values.last_mut() {
123                        *last_value += value;
124                    }
125                    continue;
126                }
127            }
128            if value != 0.0 {
129                indices.push(index);
130                values.push(value);
131            }
132        }
133        Self::new(self.dimensions, indices, values)
134    }
135
136    /// Returns dot.
137    pub fn dot(&self, other: &Self) -> Result<f32> {
138        let left = self.canonicalized()?;
139        let right = other.canonicalized()?;
140        if left.dimensions != right.dimensions {
141            return Err(invalid_argument("sparse vector dimensions must match"));
142        }
143        let mut i = 0;
144        let mut j = 0;
145        let mut acc = 0.0;
146        while i < left.indices.len() && j < right.indices.len() {
147            match left.indices[i].cmp(&right.indices[j]) {
148                std::cmp::Ordering::Less => i += 1,
149                std::cmp::Ordering::Greater => j += 1,
150                std::cmp::Ordering::Equal => {
151                    acc += left.values[i] * right.values[j];
152                    i += 1;
153                    j += 1;
154                }
155            }
156        }
157        Ok(acc)
158    }
159
160    /// Returns cosine similarity.
161    pub fn cosine_similarity(&self, other: &Self) -> Result<f32> {
162        let left_norm = self
163            .values
164            .iter()
165            .map(|value| value * value)
166            .sum::<f32>()
167            .sqrt();
168        let right_norm = other
169            .values
170            .iter()
171            .map(|value| value * value)
172            .sum::<f32>()
173            .sqrt();
174        if left_norm <= f32::EPSILON || right_norm <= f32::EPSILON {
175            return Err(invalid_argument(
176                "cosine similarity requires non-zero sparse vectors",
177            ));
178        }
179        Ok(self.dot(other)? / (left_norm * right_norm))
180    }
181
182    /// Returns the L1 norm.
183    pub fn l1_norm(&self) -> Result<f32> {
184        self.validate()?;
185        Ok(self.values.iter().map(|value| value.abs()).sum())
186    }
187
188    /// Returns the L2 norm.
189    pub fn l2_norm(&self) -> Result<f32> {
190        self.validate()?;
191        Ok(self
192            .values
193            .iter()
194            .map(|value| value * value)
195            .sum::<f32>()
196            .sqrt())
197    }
198
199    /// Scales all sparse values by a finite factor.
200    pub fn scale(&self, factor: f32) -> Result<Self> {
201        self.validate()?;
202        if !factor.is_finite() {
203            return Err(invalid_argument(
204                "sparse vector scale factor must be finite",
205            ));
206        }
207        Self::new(
208            self.dimensions,
209            self.indices.clone(),
210            self.values.iter().map(|value| value * factor).collect(),
211        )?
212        .canonicalized()
213    }
214
215    /// Adds two sparse vectors with matching dimensions.
216    pub fn add(&self, other: &Self) -> Result<Self> {
217        let left = self.canonicalized()?;
218        let right = other.canonicalized()?;
219        if left.dimensions != right.dimensions {
220            return Err(invalid_argument("sparse vector dimensions must match"));
221        }
222        let mut indices = Vec::new();
223        let mut values = Vec::new();
224        let mut left_index = 0;
225        let mut right_index = 0;
226        while left_index < left.indices.len() || right_index < right.indices.len() {
227            match (
228                left.indices.get(left_index).copied(),
229                right.indices.get(right_index).copied(),
230            ) {
231                (Some(left_col), Some(right_col)) if left_col == right_col => {
232                    let value = left.values[left_index] + right.values[right_index];
233                    if value != 0.0 {
234                        indices.push(left_col);
235                        values.push(value);
236                    }
237                    left_index += 1;
238                    right_index += 1;
239                }
240                (Some(left_col), Some(right_col)) if left_col < right_col => {
241                    indices.push(left_col);
242                    values.push(left.values[left_index]);
243                    left_index += 1;
244                }
245                (Some(_), Some(right_col)) => {
246                    indices.push(right_col);
247                    values.push(right.values[right_index]);
248                    right_index += 1;
249                }
250                (Some(left_col), None) => {
251                    indices.push(left_col);
252                    values.push(left.values[left_index]);
253                    left_index += 1;
254                }
255                (None, Some(right_col)) => {
256                    indices.push(right_col);
257                    values.push(right.values[right_index]);
258                    right_index += 1;
259                }
260                (None, None) => break,
261            }
262        }
263        Self::new(left.dimensions, indices, values)
264    }
265
266    /// Returns the sparse Hadamard product, keeping only overlapping indices.
267    pub fn hadamard(&self, other: &Self) -> Result<Self> {
268        let left = self.canonicalized()?;
269        let right = other.canonicalized()?;
270        if left.dimensions != right.dimensions {
271            return Err(invalid_argument("sparse vector dimensions must match"));
272        }
273        let mut indices = Vec::new();
274        let mut values = Vec::new();
275        let mut i = 0;
276        let mut j = 0;
277        while i < left.indices.len() && j < right.indices.len() {
278            match left.indices[i].cmp(&right.indices[j]) {
279                std::cmp::Ordering::Less => i += 1,
280                std::cmp::Ordering::Greater => j += 1,
281                std::cmp::Ordering::Equal => {
282                    let value = left.values[i] * right.values[j];
283                    if value != 0.0 {
284                        indices.push(left.indices[i]);
285                        values.push(value);
286                    }
287                    i += 1;
288                    j += 1;
289                }
290            }
291        }
292        Self::new(left.dimensions, indices, values)
293    }
294
295    /// Removes entries whose absolute value is below a finite non-negative threshold.
296    pub fn prune_abs_below(&self, threshold: f32) -> Result<Self> {
297        let canonical = self.canonicalized()?;
298        if !threshold.is_finite() || threshold < 0.0 {
299            return Err(invalid_argument(
300                "sparse prune threshold must be finite and non-negative",
301            ));
302        }
303        let mut indices = Vec::new();
304        let mut values = Vec::new();
305        for (index, value) in canonical.indices.iter().copied().zip(canonical.values) {
306            if value.abs() >= threshold {
307                indices.push(index);
308                values.push(value);
309            }
310        }
311        Self::new(canonical.dimensions, indices, values)
312    }
313
314    /// Returns the top `k` entries sorted by descending absolute value, then index.
315    pub fn top_k_by_abs(&self, k: usize) -> Result<Vec<(usize, f32)>> {
316        let canonical = self.canonicalized()?;
317        let mut pairs = canonical
318            .indices
319            .into_iter()
320            .zip(canonical.values)
321            .collect::<Vec<_>>();
322        pairs.sort_by(|left, right| {
323            right
324                .1
325                .abs()
326                .partial_cmp(&left.1.abs())
327                .unwrap_or(std::cmp::Ordering::Equal)
328                .then_with(|| left.0.cmp(&right.0))
329        });
330        pairs.truncate(k);
331        Ok(pairs)
332    }
333
334    /// Returns normalize l2.
335    pub fn normalize_l2(&self) -> Result<Self> {
336        let norm = self
337            .values
338            .iter()
339            .map(|value| value * value)
340            .sum::<f32>()
341            .sqrt();
342        if norm <= f32::EPSILON {
343            return Err(invalid_argument(
344                "sparse vector norm must be greater than zero",
345            ));
346        }
347        Self::new(
348            self.dimensions,
349            self.indices.clone(),
350            self.values.iter().map(|value| value / norm).collect(),
351        )
352    }
353
354    /// Converts this value to dense.
355    pub fn to_dense(&self) -> Vec<f32> {
356        let mut dense = vec![0.0; self.dimensions];
357        for (&index, &value) in self.indices.iter().zip(&self.values) {
358            dense[index] = value;
359        }
360        dense
361    }
362
363    /// Builds this value from dense.
364    pub fn from_dense(values: &[f32]) -> Result<Self> {
365        if values.is_empty() {
366            return Err(invalid_argument("dense vector must not be empty"));
367        }
368        if values.iter().any(|value| !value.is_finite()) {
369            return Err(invalid_argument("dense vector values must be finite"));
370        }
371        let mut indices = Vec::new();
372        let mut sparse_values = Vec::new();
373        for (index, value) in values.iter().copied().enumerate() {
374            if value != 0.0 {
375                indices.push(index);
376                sparse_values.push(value);
377            }
378        }
379        Self::new(values.len(), indices, sparse_values)
380    }
381}
382
383impl TryFrom<&DenseVector> for SparseVector {
384    type Error = DetectError;
385
386    fn try_from(value: &DenseVector) -> Result<Self> {
387        Self::from_dense(value.as_slice())
388    }
389}
390
391#[derive(Debug, Clone, PartialEq)]
392/// Data type for coo matrix.
393pub struct CooMatrix {
394    rows: usize,
395    cols: usize,
396    entries: Vec<(usize, usize, f32)>,
397}
398
399impl CooMatrix {
400    /// Creates a new value.
401    pub fn new(rows: usize, cols: usize, entries: Vec<(usize, usize, f32)>) -> Result<Self> {
402        let matrix = Self {
403            rows,
404            cols,
405            entries,
406        };
407        matrix.validate()?;
408        Ok(matrix)
409    }
410
411    /// Returns rows.
412    pub fn rows(&self) -> usize {
413        self.rows
414    }
415
416    /// Returns cols.
417    pub fn cols(&self) -> usize {
418        self.cols
419    }
420
421    /// Returns entries.
422    pub fn entries(&self) -> &[(usize, usize, f32)] {
423        &self.entries
424    }
425
426    /// Returns nnz.
427    pub fn nnz(&self) -> usize {
428        self.entries.len()
429    }
430
431    /// Validates this value.
432    pub fn validate(&self) -> Result<()> {
433        if self.rows == 0 || self.cols == 0 {
434            return Err(invalid_argument(
435                "COO matrix rows and cols must be greater than zero",
436            ));
437        }
438        for &(row, col, value) in &self.entries {
439            if row >= self.rows || col >= self.cols {
440                return Err(invalid_argument("COO entry index is out of bounds"));
441            }
442            if !value.is_finite() {
443                return Err(invalid_argument("COO entry values must be finite"));
444            }
445        }
446        Ok(())
447    }
448
449    /// Returns canonicalized.
450    pub fn canonicalized(&self) -> Result<Self> {
451        self.validate()?;
452        let mut entries = self.entries.clone();
453        entries.sort_by_key(|(row, col, _)| (*row, *col));
454        let mut output = Vec::new();
455        for (row, col, value) in entries {
456            if let Some((last_row, last_col, last_value)) = output.last_mut() {
457                if *last_row == row && *last_col == col {
458                    *last_value += value;
459                    continue;
460                }
461            }
462            if value != 0.0 {
463                output.push((row, col, value));
464            }
465        }
466        Self::new(self.rows, self.cols, output)
467    }
468
469    /// Converts this value to csr.
470    pub fn to_csr(&self) -> Result<CsrMatrix> {
471        CsrMatrix::from_coo(self)
472    }
473
474    /// Returns a transposed COO matrix.
475    pub fn transpose(&self) -> Result<Self> {
476        self.validate()?;
477        Self::new(
478            self.cols,
479            self.rows,
480            self.entries
481                .iter()
482                .map(|(row, col, value)| (*col, *row, *value))
483                .collect(),
484        )
485        .and_then(|matrix| matrix.canonicalized())
486    }
487}
488
489#[derive(Debug, Clone, PartialEq)]
490/// Data type for csr matrix.
491pub struct CsrMatrix {
492    rows: usize,
493    cols: usize,
494    row_offsets: Vec<usize>,
495    column_indices: Vec<usize>,
496    values: Vec<f32>,
497}
498
499impl CsrMatrix {
500    /// Creates a new value.
501    pub fn new(
502        rows: usize,
503        cols: usize,
504        row_offsets: Vec<usize>,
505        column_indices: Vec<usize>,
506        values: Vec<f32>,
507    ) -> Result<Self> {
508        let matrix = Self {
509            rows,
510            cols,
511            row_offsets,
512            column_indices,
513            values,
514        };
515        matrix.validate()?;
516        Ok(matrix)
517    }
518
519    /// Builds this value from coo.
520    pub fn from_coo(coo: &CooMatrix) -> Result<Self> {
521        let canonical = coo.canonicalized()?;
522        let mut row_offsets = vec![0usize; canonical.rows + 1];
523        let mut column_indices = Vec::with_capacity(canonical.entries.len());
524        let mut values = Vec::with_capacity(canonical.entries.len());
525        let mut current_row = 0usize;
526        for (row, col, value) in canonical.entries {
527            while current_row < row {
528                row_offsets[current_row + 1] = column_indices.len();
529                current_row += 1;
530            }
531            column_indices.push(col);
532            values.push(value);
533        }
534        while current_row < canonical.rows {
535            row_offsets[current_row + 1] = column_indices.len();
536            current_row += 1;
537        }
538        Self::new(
539            canonical.rows,
540            canonical.cols,
541            row_offsets,
542            column_indices,
543            values,
544        )
545    }
546
547    /// Returns rows.
548    pub fn rows(&self) -> usize {
549        self.rows
550    }
551
552    /// Returns cols.
553    pub fn cols(&self) -> usize {
554        self.cols
555    }
556
557    /// Returns row.
558    pub fn row(&self, index: usize) -> Result<SparseRow<'_>> {
559        if index >= self.rows {
560            return Err(invalid_argument("CSR row index is out of bounds"));
561        }
562        let start = self.row_offsets[index];
563        let end = self.row_offsets[index + 1];
564        Ok(SparseRow {
565            cols: self.cols,
566            indices: &self.column_indices[start..end],
567            values: &self.values[start..end],
568        })
569    }
570
571    /// Returns rows iter.
572    pub fn rows_iter(&self) -> impl Iterator<Item = SparseRow<'_>> {
573        (0..self.rows).map(|index| self.row(index).expect("indices are validated"))
574    }
575
576    /// Returns the non-zero count for each row.
577    pub fn row_nnz(&self) -> Vec<usize> {
578        self.row_offsets
579            .windows(2)
580            .map(|window| window[1] - window[0])
581            .collect()
582    }
583
584    /// Returns stored-entry matrix density.
585    pub fn density(&self) -> Result<f32> {
586        self.validate()?;
587        let elements = self
588            .rows
589            .checked_mul(self.cols)
590            .ok_or_else(|| invalid_argument("CSR matrix element count overflowed usize"))?;
591        Ok(self.values.len() as f32 / elements as f32)
592    }
593
594    /// Returns the non-zero count for each column.
595    pub fn column_nnz(&self) -> Vec<usize> {
596        let mut counts = vec![0usize; self.cols];
597        for col in &self.column_indices {
598            if let Some(count) = counts.get_mut(*col) {
599                *count += 1;
600            }
601        }
602        counts
603    }
604
605    /// Returns sums of stored values by row.
606    pub fn row_sums(&self) -> Result<Vec<f32>> {
607        self.validate()?;
608        Ok(self
609            .rows_iter()
610            .map(|row| row.values().iter().sum::<f32>())
611            .collect())
612    }
613
614    /// Returns sums of stored values by column.
615    pub fn column_sums(&self) -> Result<Vec<f32>> {
616        self.validate()?;
617        let mut sums = vec![0.0; self.cols];
618        for (col, value) in self.column_indices.iter().zip(&self.values) {
619            sums[*col] += value;
620        }
621        Ok(sums)
622    }
623
624    /// Returns compact shape, density, row nnz, and column nnz statistics.
625    pub fn summary(&self) -> Result<SparseMatrixSummary> {
626        self.validate()?;
627        let row_nnz = self.row_nnz();
628        let column_nnz = self.column_nnz();
629        let row_nnz_min = row_nnz.iter().copied().min().unwrap_or(0);
630        let row_nnz_max = row_nnz.iter().copied().max().unwrap_or(0);
631        let column_nnz_min = column_nnz.iter().copied().min().unwrap_or(0);
632        let column_nnz_max = column_nnz.iter().copied().max().unwrap_or(0);
633        Ok(SparseMatrixSummary {
634            rows: self.rows,
635            cols: self.cols,
636            nnz: self.values.len(),
637            density: self.density()?,
638            row_nnz_min,
639            row_nnz_max,
640            row_nnz_mean: row_nnz.iter().sum::<usize>() as f32 / self.rows as f32,
641            column_nnz_min,
642            column_nnz_max,
643            column_nnz_mean: column_nnz.iter().sum::<usize>() as f32 / self.cols as f32,
644        })
645    }
646
647    /// Returns a CSR matrix whose non-zero rows have unit L2 norm.
648    pub fn l2_normalize_rows(&self) -> Result<Self> {
649        self.validate()?;
650        let mut values = self.values.clone();
651        for row in 0..self.rows {
652            let start = self.row_offsets[row];
653            let end = self.row_offsets[row + 1];
654            let norm = values[start..end]
655                .iter()
656                .map(|value| value * value)
657                .sum::<f32>()
658                .sqrt();
659            if norm > f32::EPSILON {
660                for value in &mut values[start..end] {
661                    *value /= norm;
662                }
663            }
664        }
665        Self::new(
666            self.rows,
667            self.cols,
668            self.row_offsets.clone(),
669            self.column_indices.clone(),
670            values,
671        )
672    }
673
674    /// Multiplies this CSR matrix by a dense finite matrix.
675    pub fn mul_dense_matrix(&self, right: &F32MatrixView<'_>) -> Result<F32Matrix> {
676        self.validate()?;
677        right.validate()?;
678        if self.cols != right.shape().rows {
679            return Err(invalid_argument(
680                "sparse matrix/dense matrix dimensions are incompatible",
681            ));
682        }
683        let shape = MatrixShape::new(self.rows, right.shape().cols)?;
684        let mut values = vec![0.0; shape.element_count()?];
685        for row in 0..self.rows {
686            for entry in self.row_offsets[row]..self.row_offsets[row + 1] {
687                let sparse_col = self.column_indices[entry];
688                let sparse_value = self.values[entry];
689                for col in 0..right.shape().cols {
690                    values[row * shape.cols + col] += sparse_value * right.get(sparse_col, col)?;
691                }
692            }
693        }
694        F32Matrix::new(shape, values)
695    }
696
697    /// Converts this CSR matrix into a row-major dense matrix.
698    pub fn to_dense_matrix(&self) -> Result<F32Matrix> {
699        self.validate()?;
700        let shape = MatrixShape::new(self.rows, self.cols)?;
701        let mut values = vec![0.0; shape.element_count()?];
702        for row in 0..self.rows {
703            for index in self.row_offsets[row]..self.row_offsets[row + 1] {
704                values[row * self.cols + self.column_indices[index]] = self.values[index];
705            }
706        }
707        F32Matrix::new(shape, values)
708    }
709
710    /// Multiplies this CSR matrix by a dense finite vector.
711    pub fn mul_dense_vector(&self, vector: &[f32]) -> Result<Vec<f32>> {
712        self.validate()?;
713        if vector.len() != self.cols {
714            return Err(invalid_argument(
715                "sparse matrix/vector dimensions are incompatible",
716            ));
717        }
718        if vector.iter().any(|value| !value.is_finite()) {
719            return Err(invalid_argument("dense vector values must be finite"));
720        }
721        let mut output = vec![0.0; self.rows];
722        for (row_index, row) in self.rows_iter().enumerate() {
723            output[row_index] = row
724                .indices()
725                .iter()
726                .zip(row.values())
727                .map(|(col, value)| vector[*col] * value)
728                .sum();
729        }
730        Ok(output)
731    }
732
733    /// Converts this CSR matrix to COO entries.
734    pub fn to_coo(&self) -> Result<CooMatrix> {
735        self.validate()?;
736        let mut entries = Vec::with_capacity(self.values.len());
737        for row in 0..self.rows {
738            for index in self.row_offsets[row]..self.row_offsets[row + 1] {
739                entries.push((row, self.column_indices[index], self.values[index]));
740            }
741        }
742        CooMatrix::new(self.rows, self.cols, entries)
743    }
744
745    /// Returns the matrix transpose.
746    pub fn transpose(&self) -> Result<Self> {
747        self.to_coo()?.transpose()?.to_csr()
748    }
749
750    /// Validates this value.
751    pub fn validate(&self) -> Result<()> {
752        if self.rows == 0 || self.cols == 0 {
753            return Err(invalid_argument(
754                "CSR matrix rows and cols must be greater than zero",
755            ));
756        }
757        if self.row_offsets.len() != self.rows + 1 {
758            return Err(invalid_argument(
759                "CSR row_offsets length must equal rows + 1",
760            ));
761        }
762        if self.column_indices.len() != self.values.len() {
763            return Err(invalid_argument(
764                "CSR column_indices and values must have the same length",
765            ));
766        }
767        if self.row_offsets.first().copied().unwrap_or_default() != 0 {
768            return Err(invalid_argument("CSR row_offsets must start at zero"));
769        }
770        if *self.row_offsets.last().unwrap_or(&0) != self.values.len() {
771            return Err(invalid_argument("CSR row_offsets must end at nnz"));
772        }
773        for window in self.row_offsets.windows(2) {
774            if window[0] > window[1] {
775                return Err(invalid_argument("CSR row_offsets must be non-decreasing"));
776            }
777        }
778        if self.column_indices.iter().any(|index| *index >= self.cols) {
779            return Err(invalid_argument("CSR column index is out of bounds"));
780        }
781        if self.values.iter().any(|value| !value.is_finite()) {
782            return Err(invalid_argument("CSR values must be finite"));
783        }
784        Ok(())
785    }
786}
787
788#[derive(Debug, Clone, Copy, PartialEq)]
789/// Data type for sparse row.
790pub struct SparseRow<'a> {
791    cols: usize,
792    indices: &'a [usize],
793    values: &'a [f32],
794}
795
796impl<'a> SparseRow<'a> {
797    /// Returns cols.
798    pub fn cols(&self) -> usize {
799        self.cols
800    }
801
802    /// Returns indices.
803    pub fn indices(&self) -> &'a [usize] {
804        self.indices
805    }
806
807    /// Returns values.
808    pub fn values(&self) -> &'a [f32] {
809        self.values
810    }
811
812    /// Converts this value to sparse vector.
813    pub fn to_sparse_vector(&self) -> Result<SparseVector> {
814        SparseVector::new(self.cols, self.indices.to_vec(), self.values.to_vec())
815    }
816}
817
818#[cfg(test)]
819mod tests {
820    use super::*;
821
822    #[test]
823    fn sparse_vector_canonicalization_and_similarity_work() {
824        let vector = SparseVector::new(4, vec![3, 1, 3], vec![2.0, 1.0, 1.0])
825            .unwrap()
826            .canonicalized()
827            .unwrap();
828        assert_eq!(vector.indices(), &[1, 3]);
829        assert_eq!(vector.values(), &[1.0, 3.0]);
830        assert_eq!(vector.dot(&vector).unwrap(), 10.0);
831        assert!((vector.cosine_similarity(&vector).unwrap() - 1.0).abs() < 1.0e-6);
832    }
833
834    #[test]
835    fn sparse_dot_matches_dense_dot() {
836        let left = SparseVector::new(5, vec![0, 3, 4], vec![1.5, -2.0, 3.0]).unwrap();
837        let right = SparseVector::new(5, vec![1, 3, 4], vec![8.0, 4.0, -1.0]).unwrap();
838        let dense_dot = left
839            .to_dense()
840            .iter()
841            .zip(right.to_dense())
842            .map(|(left, right)| *left * right)
843            .sum::<f32>();
844
845        assert_eq!(left.dot(&right).unwrap(), dense_dot);
846    }
847
848    #[test]
849    fn csr_and_coo_invariants_hold() {
850        let coo = CooMatrix::new(2, 3, vec![(1, 2, 2.0), (0, 0, 1.0), (1, 2, 1.0)]).unwrap();
851        let csr = coo.to_csr().unwrap();
852        assert_eq!(csr.row(0).unwrap().indices(), &[0]);
853        assert_eq!(csr.row(1).unwrap().values(), &[3.0]);
854    }
855
856    #[test]
857    fn coo_csr_round_trip_preserves_canonical_entries() {
858        let coo = CooMatrix::new(
859            3,
860            3,
861            vec![(2, 1, 1.0), (0, 2, 5.0), (2, 1, 2.0), (1, 0, 0.0)],
862        )
863        .unwrap();
864        let canonical = coo.canonicalized().unwrap();
865        let round_trip = canonical.to_csr().unwrap().to_coo().unwrap();
866
867        assert_eq!(round_trip.entries(), canonical.entries());
868    }
869
870    #[test]
871    fn dense_sparse_round_trip_preserves_values() {
872        let dense = [0.0, 1.0, 0.0, 2.0];
873        let sparse = SparseVector::from_dense(&dense).unwrap();
874        assert_eq!(sparse.to_dense(), dense);
875    }
876
877    #[test]
878    fn vector_ops_and_sparse_matrix_transpose_work() {
879        let left = SparseVector::new(4, vec![0, 2], vec![1.0, -3.0]).unwrap();
880        let right = SparseVector::new(4, vec![2, 3], vec![1.0, 2.0]).unwrap();
881        let added = left.add(&right).unwrap();
882        assert_eq!(added.indices(), &[0, 2, 3]);
883        assert_eq!(added.values(), &[1.0, -2.0, 2.0]);
884        assert_eq!(left.top_k_by_abs(1).unwrap(), vec![(2, -3.0)]);
885
886        let matrix = CooMatrix::new(2, 3, vec![(0, 1, 2.0), (1, 2, 3.0)])
887            .unwrap()
888            .to_csr()
889            .unwrap();
890        assert_eq!(matrix.row_nnz(), vec![1, 1]);
891        assert_eq!(
892            matrix.mul_dense_vector(&[1.0, 2.0, 3.0]).unwrap(),
893            vec![4.0, 9.0]
894        );
895        let transposed = matrix.transpose().unwrap();
896        assert_eq!(transposed.rows(), 3);
897        assert_eq!(transposed.cols(), 2);
898        assert_eq!(
899            transposed.transpose().unwrap().to_coo().unwrap().entries(),
900            matrix.to_coo().unwrap().entries()
901        );
902    }
903
904    #[test]
905    fn matrix_summary_reports_density_and_nnz_stats() {
906        let matrix = CooMatrix::new(3, 4, vec![(0, 1, 2.0), (1, 3, 4.0), (2, 1, -1.0)])
907            .unwrap()
908            .to_csr()
909            .unwrap();
910        let summary = matrix.summary().unwrap();
911
912        assert_eq!(summary.rows, 3);
913        assert_eq!(summary.cols, 4);
914        assert_eq!(summary.nnz, 3);
915        assert!((summary.density - 0.25).abs() < 1.0e-6);
916        assert_eq!(summary.row_nnz_min, 1);
917        assert_eq!(summary.row_nnz_max, 1);
918        assert_eq!(summary.column_nnz_min, 0);
919        assert_eq!(summary.column_nnz_max, 2);
920        assert_eq!(matrix.column_nnz(), vec![0, 2, 0, 1]);
921        assert_eq!(matrix.row_sums().unwrap(), vec![2.0, 4.0, -1.0]);
922        assert_eq!(matrix.column_sums().unwrap(), vec![0.0, 1.0, 0.0, 4.0]);
923    }
924
925    #[test]
926    fn row_normalization_unit_norms_non_zero_rows() {
927        let matrix = CooMatrix::new(3, 3, vec![(0, 0, 3.0), (0, 1, 4.0), (2, 2, 5.0)])
928            .unwrap()
929            .to_csr()
930            .unwrap();
931        let normalized = matrix.l2_normalize_rows().unwrap();
932
933        assert!((normalized.row(0).unwrap().values()[0] - 0.6).abs() < 1.0e-6);
934        assert!((normalized.row(0).unwrap().values()[1] - 0.8).abs() < 1.0e-6);
935        assert!(normalized.row(1).unwrap().values().is_empty());
936        assert!((normalized.row(2).unwrap().values()[0] - 1.0).abs() < 1.0e-6);
937    }
938
939    #[test]
940    fn sparse_dense_matrix_multiply_matches_dense_result() {
941        let sparse = CooMatrix::new(2, 3, vec![(0, 1, 2.0), (1, 0, 1.0), (1, 2, 3.0)])
942            .unwrap()
943            .to_csr()
944            .unwrap();
945        let right = F32Matrix::from_rows([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]).unwrap();
946        let product = sparse.mul_dense_matrix(&right.as_view()).unwrap();
947
948        assert_eq!(product.values(), &[6.0, 8.0, 16.0, 20.0]);
949    }
950
951    #[test]
952    fn dense_matrix_conversion_round_trips_through_coo_csr() {
953        let coo = CooMatrix::new(2, 3, vec![(0, 1, 2.0), (1, 2, 3.0)]).unwrap();
954        let csr = coo.to_csr().unwrap();
955        let dense = csr.to_dense_matrix().unwrap();
956
957        assert_eq!(dense.values(), &[0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
958        assert_eq!(csr.to_coo().unwrap().entries(), coo.entries());
959    }
960
961    #[test]
962    fn hadamard_keeps_only_overlapping_indices() {
963        let left = SparseVector::new(5, vec![0, 2, 4], vec![1.0, 2.0, 3.0]).unwrap();
964        let right = SparseVector::new(5, vec![1, 2, 4], vec![5.0, 7.0, 11.0]).unwrap();
965        let product = left.hadamard(&right).unwrap();
966
967        assert_eq!(product.indices(), &[2, 4]);
968        assert_eq!(product.values(), &[14.0, 33.0]);
969    }
970
971    #[test]
972    fn pruning_removes_small_values_and_rejects_invalid_thresholds() {
973        let vector = SparseVector::new(4, vec![0, 1, 2], vec![0.01, -0.5, 2.0]).unwrap();
974        let pruned = vector.prune_abs_below(0.1).unwrap();
975
976        assert_eq!(pruned.indices(), &[1, 2]);
977        assert_eq!(pruned.values(), &[-0.5, 2.0]);
978        assert!(vector.prune_abs_below(-0.1).is_err());
979        assert!(vector.prune_abs_below(f32::NAN).is_err());
980    }
981}