1#![doc = include_str!("../README.md")]
2
3pub mod surface;
4use math_linear::{F32Matrix, F32MatrixView, MatrixShape};
5use vector_analysis_core::DenseVector;
6use video_analysis_core::{DetectError, Result};
7
8fn invalid_argument(message: impl Into<String>) -> DetectError {
9 DetectError::InvalidArgument(message.into())
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum SparseSimilarityMetric {
15 Dot,
17 Cosine,
19}
20
21#[derive(Debug, Clone, PartialEq)]
22pub struct SparseVector {
24 dimensions: usize,
25 indices: Vec<usize>,
26 values: Vec<f32>,
27}
28
29#[derive(Debug, Clone, PartialEq)]
30pub struct SparseMatrixSummary {
32 pub rows: usize,
34 pub cols: usize,
36 pub nnz: usize,
38 pub density: f32,
40 pub row_nnz_min: usize,
42 pub row_nnz_max: usize,
44 pub row_nnz_mean: f32,
46 pub column_nnz_min: usize,
48 pub column_nnz_max: usize,
50 pub column_nnz_mean: f32,
52}
53
54impl SparseVector {
55 pub fn new(dimensions: usize, indices: Vec<usize>, values: Vec<f32>) -> Result<Self> {
57 let vector = Self {
58 dimensions,
59 indices,
60 values,
61 };
62 vector.validate()?;
63 Ok(vector)
64 }
65
66 pub fn dimensions(&self) -> usize {
68 self.dimensions
69 }
70
71 pub fn indices(&self) -> &[usize] {
73 &self.indices
74 }
75
76 pub fn values(&self) -> &[f32] {
78 &self.values
79 }
80
81 pub fn nnz(&self) -> usize {
83 self.indices.len()
84 }
85
86 pub fn validate(&self) -> Result<()> {
88 if self.dimensions == 0 {
89 return Err(invalid_argument(
90 "sparse vector dimensions must be greater than zero",
91 ));
92 }
93 if self.indices.len() != self.values.len() {
94 return Err(invalid_argument(
95 "sparse vector indices and values must have the same length",
96 ));
97 }
98 if self.values.iter().any(|value| !value.is_finite()) {
99 return Err(invalid_argument("sparse vector values must be finite"));
100 }
101 if self.indices.iter().any(|index| *index >= self.dimensions) {
102 return Err(invalid_argument("sparse vector index is out of bounds"));
103 }
104 Ok(())
105 }
106
107 pub fn canonicalized(&self) -> Result<Self> {
109 self.validate()?;
110 let mut pairs = self
111 .indices
112 .iter()
113 .copied()
114 .zip(self.values.iter().copied())
115 .collect::<Vec<_>>();
116 pairs.sort_by_key(|(index, _)| *index);
117 let mut indices = Vec::new();
118 let mut values = Vec::new();
119 for (index, value) in pairs {
120 if let Some(last) = indices.last().copied() {
121 if last == index {
122 if let Some(last_value) = values.last_mut() {
123 *last_value += value;
124 }
125 continue;
126 }
127 }
128 if value != 0.0 {
129 indices.push(index);
130 values.push(value);
131 }
132 }
133 Self::new(self.dimensions, indices, values)
134 }
135
136 pub fn dot(&self, other: &Self) -> Result<f32> {
138 let left = self.canonicalized()?;
139 let right = other.canonicalized()?;
140 if left.dimensions != right.dimensions {
141 return Err(invalid_argument("sparse vector dimensions must match"));
142 }
143 let mut i = 0;
144 let mut j = 0;
145 let mut acc = 0.0;
146 while i < left.indices.len() && j < right.indices.len() {
147 match left.indices[i].cmp(&right.indices[j]) {
148 std::cmp::Ordering::Less => i += 1,
149 std::cmp::Ordering::Greater => j += 1,
150 std::cmp::Ordering::Equal => {
151 acc += left.values[i] * right.values[j];
152 i += 1;
153 j += 1;
154 }
155 }
156 }
157 Ok(acc)
158 }
159
160 pub fn cosine_similarity(&self, other: &Self) -> Result<f32> {
162 let left_norm = self
163 .values
164 .iter()
165 .map(|value| value * value)
166 .sum::<f32>()
167 .sqrt();
168 let right_norm = other
169 .values
170 .iter()
171 .map(|value| value * value)
172 .sum::<f32>()
173 .sqrt();
174 if left_norm <= f32::EPSILON || right_norm <= f32::EPSILON {
175 return Err(invalid_argument(
176 "cosine similarity requires non-zero sparse vectors",
177 ));
178 }
179 Ok(self.dot(other)? / (left_norm * right_norm))
180 }
181
182 pub fn l1_norm(&self) -> Result<f32> {
184 self.validate()?;
185 Ok(self.values.iter().map(|value| value.abs()).sum())
186 }
187
188 pub fn l2_norm(&self) -> Result<f32> {
190 self.validate()?;
191 Ok(self
192 .values
193 .iter()
194 .map(|value| value * value)
195 .sum::<f32>()
196 .sqrt())
197 }
198
199 pub fn scale(&self, factor: f32) -> Result<Self> {
201 self.validate()?;
202 if !factor.is_finite() {
203 return Err(invalid_argument(
204 "sparse vector scale factor must be finite",
205 ));
206 }
207 Self::new(
208 self.dimensions,
209 self.indices.clone(),
210 self.values.iter().map(|value| value * factor).collect(),
211 )?
212 .canonicalized()
213 }
214
215 pub fn add(&self, other: &Self) -> Result<Self> {
217 let left = self.canonicalized()?;
218 let right = other.canonicalized()?;
219 if left.dimensions != right.dimensions {
220 return Err(invalid_argument("sparse vector dimensions must match"));
221 }
222 let mut indices = Vec::new();
223 let mut values = Vec::new();
224 let mut left_index = 0;
225 let mut right_index = 0;
226 while left_index < left.indices.len() || right_index < right.indices.len() {
227 match (
228 left.indices.get(left_index).copied(),
229 right.indices.get(right_index).copied(),
230 ) {
231 (Some(left_col), Some(right_col)) if left_col == right_col => {
232 let value = left.values[left_index] + right.values[right_index];
233 if value != 0.0 {
234 indices.push(left_col);
235 values.push(value);
236 }
237 left_index += 1;
238 right_index += 1;
239 }
240 (Some(left_col), Some(right_col)) if left_col < right_col => {
241 indices.push(left_col);
242 values.push(left.values[left_index]);
243 left_index += 1;
244 }
245 (Some(_), Some(right_col)) => {
246 indices.push(right_col);
247 values.push(right.values[right_index]);
248 right_index += 1;
249 }
250 (Some(left_col), None) => {
251 indices.push(left_col);
252 values.push(left.values[left_index]);
253 left_index += 1;
254 }
255 (None, Some(right_col)) => {
256 indices.push(right_col);
257 values.push(right.values[right_index]);
258 right_index += 1;
259 }
260 (None, None) => break,
261 }
262 }
263 Self::new(left.dimensions, indices, values)
264 }
265
266 pub fn hadamard(&self, other: &Self) -> Result<Self> {
268 let left = self.canonicalized()?;
269 let right = other.canonicalized()?;
270 if left.dimensions != right.dimensions {
271 return Err(invalid_argument("sparse vector dimensions must match"));
272 }
273 let mut indices = Vec::new();
274 let mut values = Vec::new();
275 let mut i = 0;
276 let mut j = 0;
277 while i < left.indices.len() && j < right.indices.len() {
278 match left.indices[i].cmp(&right.indices[j]) {
279 std::cmp::Ordering::Less => i += 1,
280 std::cmp::Ordering::Greater => j += 1,
281 std::cmp::Ordering::Equal => {
282 let value = left.values[i] * right.values[j];
283 if value != 0.0 {
284 indices.push(left.indices[i]);
285 values.push(value);
286 }
287 i += 1;
288 j += 1;
289 }
290 }
291 }
292 Self::new(left.dimensions, indices, values)
293 }
294
295 pub fn prune_abs_below(&self, threshold: f32) -> Result<Self> {
297 let canonical = self.canonicalized()?;
298 if !threshold.is_finite() || threshold < 0.0 {
299 return Err(invalid_argument(
300 "sparse prune threshold must be finite and non-negative",
301 ));
302 }
303 let mut indices = Vec::new();
304 let mut values = Vec::new();
305 for (index, value) in canonical.indices.iter().copied().zip(canonical.values) {
306 if value.abs() >= threshold {
307 indices.push(index);
308 values.push(value);
309 }
310 }
311 Self::new(canonical.dimensions, indices, values)
312 }
313
314 pub fn top_k_by_abs(&self, k: usize) -> Result<Vec<(usize, f32)>> {
316 let canonical = self.canonicalized()?;
317 let mut pairs = canonical
318 .indices
319 .into_iter()
320 .zip(canonical.values)
321 .collect::<Vec<_>>();
322 pairs.sort_by(|left, right| {
323 right
324 .1
325 .abs()
326 .partial_cmp(&left.1.abs())
327 .unwrap_or(std::cmp::Ordering::Equal)
328 .then_with(|| left.0.cmp(&right.0))
329 });
330 pairs.truncate(k);
331 Ok(pairs)
332 }
333
334 pub fn normalize_l2(&self) -> Result<Self> {
336 let norm = self
337 .values
338 .iter()
339 .map(|value| value * value)
340 .sum::<f32>()
341 .sqrt();
342 if norm <= f32::EPSILON {
343 return Err(invalid_argument(
344 "sparse vector norm must be greater than zero",
345 ));
346 }
347 Self::new(
348 self.dimensions,
349 self.indices.clone(),
350 self.values.iter().map(|value| value / norm).collect(),
351 )
352 }
353
354 pub fn to_dense(&self) -> Vec<f32> {
356 let mut dense = vec![0.0; self.dimensions];
357 for (&index, &value) in self.indices.iter().zip(&self.values) {
358 dense[index] = value;
359 }
360 dense
361 }
362
363 pub fn from_dense(values: &[f32]) -> Result<Self> {
365 if values.is_empty() {
366 return Err(invalid_argument("dense vector must not be empty"));
367 }
368 if values.iter().any(|value| !value.is_finite()) {
369 return Err(invalid_argument("dense vector values must be finite"));
370 }
371 let mut indices = Vec::new();
372 let mut sparse_values = Vec::new();
373 for (index, value) in values.iter().copied().enumerate() {
374 if value != 0.0 {
375 indices.push(index);
376 sparse_values.push(value);
377 }
378 }
379 Self::new(values.len(), indices, sparse_values)
380 }
381}
382
383impl TryFrom<&DenseVector> for SparseVector {
384 type Error = DetectError;
385
386 fn try_from(value: &DenseVector) -> Result<Self> {
387 Self::from_dense(value.as_slice())
388 }
389}
390
391#[derive(Debug, Clone, PartialEq)]
392pub struct CooMatrix {
394 rows: usize,
395 cols: usize,
396 entries: Vec<(usize, usize, f32)>,
397}
398
399impl CooMatrix {
400 pub fn new(rows: usize, cols: usize, entries: Vec<(usize, usize, f32)>) -> Result<Self> {
402 let matrix = Self {
403 rows,
404 cols,
405 entries,
406 };
407 matrix.validate()?;
408 Ok(matrix)
409 }
410
411 pub fn rows(&self) -> usize {
413 self.rows
414 }
415
416 pub fn cols(&self) -> usize {
418 self.cols
419 }
420
421 pub fn entries(&self) -> &[(usize, usize, f32)] {
423 &self.entries
424 }
425
426 pub fn nnz(&self) -> usize {
428 self.entries.len()
429 }
430
431 pub fn validate(&self) -> Result<()> {
433 if self.rows == 0 || self.cols == 0 {
434 return Err(invalid_argument(
435 "COO matrix rows and cols must be greater than zero",
436 ));
437 }
438 for &(row, col, value) in &self.entries {
439 if row >= self.rows || col >= self.cols {
440 return Err(invalid_argument("COO entry index is out of bounds"));
441 }
442 if !value.is_finite() {
443 return Err(invalid_argument("COO entry values must be finite"));
444 }
445 }
446 Ok(())
447 }
448
449 pub fn canonicalized(&self) -> Result<Self> {
451 self.validate()?;
452 let mut entries = self.entries.clone();
453 entries.sort_by_key(|(row, col, _)| (*row, *col));
454 let mut output = Vec::new();
455 for (row, col, value) in entries {
456 if let Some((last_row, last_col, last_value)) = output.last_mut() {
457 if *last_row == row && *last_col == col {
458 *last_value += value;
459 continue;
460 }
461 }
462 if value != 0.0 {
463 output.push((row, col, value));
464 }
465 }
466 Self::new(self.rows, self.cols, output)
467 }
468
469 pub fn to_csr(&self) -> Result<CsrMatrix> {
471 CsrMatrix::from_coo(self)
472 }
473
474 pub fn transpose(&self) -> Result<Self> {
476 self.validate()?;
477 Self::new(
478 self.cols,
479 self.rows,
480 self.entries
481 .iter()
482 .map(|(row, col, value)| (*col, *row, *value))
483 .collect(),
484 )
485 .and_then(|matrix| matrix.canonicalized())
486 }
487}
488
489#[derive(Debug, Clone, PartialEq)]
490pub struct CsrMatrix {
492 rows: usize,
493 cols: usize,
494 row_offsets: Vec<usize>,
495 column_indices: Vec<usize>,
496 values: Vec<f32>,
497}
498
499impl CsrMatrix {
500 pub fn new(
502 rows: usize,
503 cols: usize,
504 row_offsets: Vec<usize>,
505 column_indices: Vec<usize>,
506 values: Vec<f32>,
507 ) -> Result<Self> {
508 let matrix = Self {
509 rows,
510 cols,
511 row_offsets,
512 column_indices,
513 values,
514 };
515 matrix.validate()?;
516 Ok(matrix)
517 }
518
519 pub fn from_coo(coo: &CooMatrix) -> Result<Self> {
521 let canonical = coo.canonicalized()?;
522 let mut row_offsets = vec![0usize; canonical.rows + 1];
523 let mut column_indices = Vec::with_capacity(canonical.entries.len());
524 let mut values = Vec::with_capacity(canonical.entries.len());
525 let mut current_row = 0usize;
526 for (row, col, value) in canonical.entries {
527 while current_row < row {
528 row_offsets[current_row + 1] = column_indices.len();
529 current_row += 1;
530 }
531 column_indices.push(col);
532 values.push(value);
533 }
534 while current_row < canonical.rows {
535 row_offsets[current_row + 1] = column_indices.len();
536 current_row += 1;
537 }
538 Self::new(
539 canonical.rows,
540 canonical.cols,
541 row_offsets,
542 column_indices,
543 values,
544 )
545 }
546
547 pub fn rows(&self) -> usize {
549 self.rows
550 }
551
552 pub fn cols(&self) -> usize {
554 self.cols
555 }
556
557 pub fn row(&self, index: usize) -> Result<SparseRow<'_>> {
559 if index >= self.rows {
560 return Err(invalid_argument("CSR row index is out of bounds"));
561 }
562 let start = self.row_offsets[index];
563 let end = self.row_offsets[index + 1];
564 Ok(SparseRow {
565 cols: self.cols,
566 indices: &self.column_indices[start..end],
567 values: &self.values[start..end],
568 })
569 }
570
571 pub fn rows_iter(&self) -> impl Iterator<Item = SparseRow<'_>> {
573 (0..self.rows).map(|index| self.row(index).expect("indices are validated"))
574 }
575
576 pub fn row_nnz(&self) -> Vec<usize> {
578 self.row_offsets
579 .windows(2)
580 .map(|window| window[1] - window[0])
581 .collect()
582 }
583
584 pub fn density(&self) -> Result<f32> {
586 self.validate()?;
587 let elements = self
588 .rows
589 .checked_mul(self.cols)
590 .ok_or_else(|| invalid_argument("CSR matrix element count overflowed usize"))?;
591 Ok(self.values.len() as f32 / elements as f32)
592 }
593
594 pub fn column_nnz(&self) -> Vec<usize> {
596 let mut counts = vec![0usize; self.cols];
597 for col in &self.column_indices {
598 if let Some(count) = counts.get_mut(*col) {
599 *count += 1;
600 }
601 }
602 counts
603 }
604
605 pub fn row_sums(&self) -> Result<Vec<f32>> {
607 self.validate()?;
608 Ok(self
609 .rows_iter()
610 .map(|row| row.values().iter().sum::<f32>())
611 .collect())
612 }
613
614 pub fn column_sums(&self) -> Result<Vec<f32>> {
616 self.validate()?;
617 let mut sums = vec![0.0; self.cols];
618 for (col, value) in self.column_indices.iter().zip(&self.values) {
619 sums[*col] += value;
620 }
621 Ok(sums)
622 }
623
624 pub fn summary(&self) -> Result<SparseMatrixSummary> {
626 self.validate()?;
627 let row_nnz = self.row_nnz();
628 let column_nnz = self.column_nnz();
629 let row_nnz_min = row_nnz.iter().copied().min().unwrap_or(0);
630 let row_nnz_max = row_nnz.iter().copied().max().unwrap_or(0);
631 let column_nnz_min = column_nnz.iter().copied().min().unwrap_or(0);
632 let column_nnz_max = column_nnz.iter().copied().max().unwrap_or(0);
633 Ok(SparseMatrixSummary {
634 rows: self.rows,
635 cols: self.cols,
636 nnz: self.values.len(),
637 density: self.density()?,
638 row_nnz_min,
639 row_nnz_max,
640 row_nnz_mean: row_nnz.iter().sum::<usize>() as f32 / self.rows as f32,
641 column_nnz_min,
642 column_nnz_max,
643 column_nnz_mean: column_nnz.iter().sum::<usize>() as f32 / self.cols as f32,
644 })
645 }
646
647 pub fn l2_normalize_rows(&self) -> Result<Self> {
649 self.validate()?;
650 let mut values = self.values.clone();
651 for row in 0..self.rows {
652 let start = self.row_offsets[row];
653 let end = self.row_offsets[row + 1];
654 let norm = values[start..end]
655 .iter()
656 .map(|value| value * value)
657 .sum::<f32>()
658 .sqrt();
659 if norm > f32::EPSILON {
660 for value in &mut values[start..end] {
661 *value /= norm;
662 }
663 }
664 }
665 Self::new(
666 self.rows,
667 self.cols,
668 self.row_offsets.clone(),
669 self.column_indices.clone(),
670 values,
671 )
672 }
673
674 pub fn mul_dense_matrix(&self, right: &F32MatrixView<'_>) -> Result<F32Matrix> {
676 self.validate()?;
677 right.validate()?;
678 if self.cols != right.shape().rows {
679 return Err(invalid_argument(
680 "sparse matrix/dense matrix dimensions are incompatible",
681 ));
682 }
683 let shape = MatrixShape::new(self.rows, right.shape().cols)?;
684 let mut values = vec![0.0; shape.element_count()?];
685 for row in 0..self.rows {
686 for entry in self.row_offsets[row]..self.row_offsets[row + 1] {
687 let sparse_col = self.column_indices[entry];
688 let sparse_value = self.values[entry];
689 for col in 0..right.shape().cols {
690 values[row * shape.cols + col] += sparse_value * right.get(sparse_col, col)?;
691 }
692 }
693 }
694 F32Matrix::new(shape, values)
695 }
696
697 pub fn to_dense_matrix(&self) -> Result<F32Matrix> {
699 self.validate()?;
700 let shape = MatrixShape::new(self.rows, self.cols)?;
701 let mut values = vec![0.0; shape.element_count()?];
702 for row in 0..self.rows {
703 for index in self.row_offsets[row]..self.row_offsets[row + 1] {
704 values[row * self.cols + self.column_indices[index]] = self.values[index];
705 }
706 }
707 F32Matrix::new(shape, values)
708 }
709
710 pub fn mul_dense_vector(&self, vector: &[f32]) -> Result<Vec<f32>> {
712 self.validate()?;
713 if vector.len() != self.cols {
714 return Err(invalid_argument(
715 "sparse matrix/vector dimensions are incompatible",
716 ));
717 }
718 if vector.iter().any(|value| !value.is_finite()) {
719 return Err(invalid_argument("dense vector values must be finite"));
720 }
721 let mut output = vec![0.0; self.rows];
722 for (row_index, row) in self.rows_iter().enumerate() {
723 output[row_index] = row
724 .indices()
725 .iter()
726 .zip(row.values())
727 .map(|(col, value)| vector[*col] * value)
728 .sum();
729 }
730 Ok(output)
731 }
732
733 pub fn to_coo(&self) -> Result<CooMatrix> {
735 self.validate()?;
736 let mut entries = Vec::with_capacity(self.values.len());
737 for row in 0..self.rows {
738 for index in self.row_offsets[row]..self.row_offsets[row + 1] {
739 entries.push((row, self.column_indices[index], self.values[index]));
740 }
741 }
742 CooMatrix::new(self.rows, self.cols, entries)
743 }
744
745 pub fn transpose(&self) -> Result<Self> {
747 self.to_coo()?.transpose()?.to_csr()
748 }
749
750 pub fn validate(&self) -> Result<()> {
752 if self.rows == 0 || self.cols == 0 {
753 return Err(invalid_argument(
754 "CSR matrix rows and cols must be greater than zero",
755 ));
756 }
757 if self.row_offsets.len() != self.rows + 1 {
758 return Err(invalid_argument(
759 "CSR row_offsets length must equal rows + 1",
760 ));
761 }
762 if self.column_indices.len() != self.values.len() {
763 return Err(invalid_argument(
764 "CSR column_indices and values must have the same length",
765 ));
766 }
767 if self.row_offsets.first().copied().unwrap_or_default() != 0 {
768 return Err(invalid_argument("CSR row_offsets must start at zero"));
769 }
770 if *self.row_offsets.last().unwrap_or(&0) != self.values.len() {
771 return Err(invalid_argument("CSR row_offsets must end at nnz"));
772 }
773 for window in self.row_offsets.windows(2) {
774 if window[0] > window[1] {
775 return Err(invalid_argument("CSR row_offsets must be non-decreasing"));
776 }
777 }
778 if self.column_indices.iter().any(|index| *index >= self.cols) {
779 return Err(invalid_argument("CSR column index is out of bounds"));
780 }
781 if self.values.iter().any(|value| !value.is_finite()) {
782 return Err(invalid_argument("CSR values must be finite"));
783 }
784 Ok(())
785 }
786}
787
788#[derive(Debug, Clone, Copy, PartialEq)]
789pub struct SparseRow<'a> {
791 cols: usize,
792 indices: &'a [usize],
793 values: &'a [f32],
794}
795
796impl<'a> SparseRow<'a> {
797 pub fn cols(&self) -> usize {
799 self.cols
800 }
801
802 pub fn indices(&self) -> &'a [usize] {
804 self.indices
805 }
806
807 pub fn values(&self) -> &'a [f32] {
809 self.values
810 }
811
812 pub fn to_sparse_vector(&self) -> Result<SparseVector> {
814 SparseVector::new(self.cols, self.indices.to_vec(), self.values.to_vec())
815 }
816}
817
818#[cfg(test)]
819mod tests {
820 use super::*;
821
822 #[test]
823 fn sparse_vector_canonicalization_and_similarity_work() {
824 let vector = SparseVector::new(4, vec![3, 1, 3], vec![2.0, 1.0, 1.0])
825 .unwrap()
826 .canonicalized()
827 .unwrap();
828 assert_eq!(vector.indices(), &[1, 3]);
829 assert_eq!(vector.values(), &[1.0, 3.0]);
830 assert_eq!(vector.dot(&vector).unwrap(), 10.0);
831 assert!((vector.cosine_similarity(&vector).unwrap() - 1.0).abs() < 1.0e-6);
832 }
833
834 #[test]
835 fn sparse_dot_matches_dense_dot() {
836 let left = SparseVector::new(5, vec![0, 3, 4], vec![1.5, -2.0, 3.0]).unwrap();
837 let right = SparseVector::new(5, vec![1, 3, 4], vec![8.0, 4.0, -1.0]).unwrap();
838 let dense_dot = left
839 .to_dense()
840 .iter()
841 .zip(right.to_dense())
842 .map(|(left, right)| *left * right)
843 .sum::<f32>();
844
845 assert_eq!(left.dot(&right).unwrap(), dense_dot);
846 }
847
848 #[test]
849 fn csr_and_coo_invariants_hold() {
850 let coo = CooMatrix::new(2, 3, vec![(1, 2, 2.0), (0, 0, 1.0), (1, 2, 1.0)]).unwrap();
851 let csr = coo.to_csr().unwrap();
852 assert_eq!(csr.row(0).unwrap().indices(), &[0]);
853 assert_eq!(csr.row(1).unwrap().values(), &[3.0]);
854 }
855
856 #[test]
857 fn coo_csr_round_trip_preserves_canonical_entries() {
858 let coo = CooMatrix::new(
859 3,
860 3,
861 vec![(2, 1, 1.0), (0, 2, 5.0), (2, 1, 2.0), (1, 0, 0.0)],
862 )
863 .unwrap();
864 let canonical = coo.canonicalized().unwrap();
865 let round_trip = canonical.to_csr().unwrap().to_coo().unwrap();
866
867 assert_eq!(round_trip.entries(), canonical.entries());
868 }
869
870 #[test]
871 fn dense_sparse_round_trip_preserves_values() {
872 let dense = [0.0, 1.0, 0.0, 2.0];
873 let sparse = SparseVector::from_dense(&dense).unwrap();
874 assert_eq!(sparse.to_dense(), dense);
875 }
876
877 #[test]
878 fn vector_ops_and_sparse_matrix_transpose_work() {
879 let left = SparseVector::new(4, vec![0, 2], vec![1.0, -3.0]).unwrap();
880 let right = SparseVector::new(4, vec![2, 3], vec![1.0, 2.0]).unwrap();
881 let added = left.add(&right).unwrap();
882 assert_eq!(added.indices(), &[0, 2, 3]);
883 assert_eq!(added.values(), &[1.0, -2.0, 2.0]);
884 assert_eq!(left.top_k_by_abs(1).unwrap(), vec![(2, -3.0)]);
885
886 let matrix = CooMatrix::new(2, 3, vec![(0, 1, 2.0), (1, 2, 3.0)])
887 .unwrap()
888 .to_csr()
889 .unwrap();
890 assert_eq!(matrix.row_nnz(), vec![1, 1]);
891 assert_eq!(
892 matrix.mul_dense_vector(&[1.0, 2.0, 3.0]).unwrap(),
893 vec![4.0, 9.0]
894 );
895 let transposed = matrix.transpose().unwrap();
896 assert_eq!(transposed.rows(), 3);
897 assert_eq!(transposed.cols(), 2);
898 assert_eq!(
899 transposed.transpose().unwrap().to_coo().unwrap().entries(),
900 matrix.to_coo().unwrap().entries()
901 );
902 }
903
904 #[test]
905 fn matrix_summary_reports_density_and_nnz_stats() {
906 let matrix = CooMatrix::new(3, 4, vec![(0, 1, 2.0), (1, 3, 4.0), (2, 1, -1.0)])
907 .unwrap()
908 .to_csr()
909 .unwrap();
910 let summary = matrix.summary().unwrap();
911
912 assert_eq!(summary.rows, 3);
913 assert_eq!(summary.cols, 4);
914 assert_eq!(summary.nnz, 3);
915 assert!((summary.density - 0.25).abs() < 1.0e-6);
916 assert_eq!(summary.row_nnz_min, 1);
917 assert_eq!(summary.row_nnz_max, 1);
918 assert_eq!(summary.column_nnz_min, 0);
919 assert_eq!(summary.column_nnz_max, 2);
920 assert_eq!(matrix.column_nnz(), vec![0, 2, 0, 1]);
921 assert_eq!(matrix.row_sums().unwrap(), vec![2.0, 4.0, -1.0]);
922 assert_eq!(matrix.column_sums().unwrap(), vec![0.0, 1.0, 0.0, 4.0]);
923 }
924
925 #[test]
926 fn row_normalization_unit_norms_non_zero_rows() {
927 let matrix = CooMatrix::new(3, 3, vec![(0, 0, 3.0), (0, 1, 4.0), (2, 2, 5.0)])
928 .unwrap()
929 .to_csr()
930 .unwrap();
931 let normalized = matrix.l2_normalize_rows().unwrap();
932
933 assert!((normalized.row(0).unwrap().values()[0] - 0.6).abs() < 1.0e-6);
934 assert!((normalized.row(0).unwrap().values()[1] - 0.8).abs() < 1.0e-6);
935 assert!(normalized.row(1).unwrap().values().is_empty());
936 assert!((normalized.row(2).unwrap().values()[0] - 1.0).abs() < 1.0e-6);
937 }
938
939 #[test]
940 fn sparse_dense_matrix_multiply_matches_dense_result() {
941 let sparse = CooMatrix::new(2, 3, vec![(0, 1, 2.0), (1, 0, 1.0), (1, 2, 3.0)])
942 .unwrap()
943 .to_csr()
944 .unwrap();
945 let right = F32Matrix::from_rows([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]).unwrap();
946 let product = sparse.mul_dense_matrix(&right.as_view()).unwrap();
947
948 assert_eq!(product.values(), &[6.0, 8.0, 16.0, 20.0]);
949 }
950
951 #[test]
952 fn dense_matrix_conversion_round_trips_through_coo_csr() {
953 let coo = CooMatrix::new(2, 3, vec![(0, 1, 2.0), (1, 2, 3.0)]).unwrap();
954 let csr = coo.to_csr().unwrap();
955 let dense = csr.to_dense_matrix().unwrap();
956
957 assert_eq!(dense.values(), &[0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
958 assert_eq!(csr.to_coo().unwrap().entries(), coo.entries());
959 }
960
961 #[test]
962 fn hadamard_keeps_only_overlapping_indices() {
963 let left = SparseVector::new(5, vec![0, 2, 4], vec![1.0, 2.0, 3.0]).unwrap();
964 let right = SparseVector::new(5, vec![1, 2, 4], vec![5.0, 7.0, 11.0]).unwrap();
965 let product = left.hadamard(&right).unwrap();
966
967 assert_eq!(product.indices(), &[2, 4]);
968 assert_eq!(product.values(), &[14.0, 33.0]);
969 }
970
971 #[test]
972 fn pruning_removes_small_values_and_rejects_invalid_thresholds() {
973 let vector = SparseVector::new(4, vec![0, 1, 2], vec![0.01, -0.5, 2.0]).unwrap();
974 let pruned = vector.prune_abs_below(0.1).unwrap();
975
976 assert_eq!(pruned.indices(), &[1, 2]);
977 assert_eq!(pruned.values(), &[-0.5, 2.0]);
978 assert!(vector.prune_abs_below(-0.1).is_err());
979 assert!(vector.prune_abs_below(f32::NAN).is_err());
980 }
981}