Skip to main content

evoc_rs/utils/
sparse.rs

1//! Sparse structures, conversion and matrix operations for `evoc-rs`.
2
3use crate::prelude::*;
4use faer::{Mat, MatRef};
5use rayon::prelude::*;
6
7/////////////
8// Helpers //
9/////////////
10
11/// Sparse accumulator for scatter-gather in CSR matrix multiplication.
12///
13/// Implements the classic sparse accumulator (SPA) pattern: a dense value and
14/// flag array indexed by column, plus a compact list of touched indices.
15/// Scatter writes into the dense arrays in O(1); gather collects and resets
16/// them in O(nnz log nnz). The flag array makes repeated scatter to the same
17/// index an accumulation rather than an overwrite.
18struct SpaAcc<T: EvocFloat> {
19    /// Dense value buffer; only entries whose flag is set carry meaningful
20    /// data.
21    values: Vec<T>,
22    /// Column indices touched since the last `gather_sorted` call. Used to
23    /// iterate and reset without scanning the full buffer.
24    indices: Vec<usize>,
25    /// Occupancy flags; `flags[i]` is `true` iff `values[i]` has been written
26    /// to in the current accumulation round.
27    flags: Vec<bool>,
28}
29
30impl<T: EvocFloat> SpaAcc<T> {
31    /// Create a new sparse accumulator capable of addressing columns `0..size`.
32    ///
33    /// The internal index buffer is pre-allocated for a tenth of `size` on the
34    /// assumption that typical rows are sparse; it will grow if needed.
35    ///
36    /// ### Params
37    ///
38    /// * `size` - Number of addressable columns (i.e. the column count of the
39    ///   matrix being multiplied into)
40    fn new(size: usize) -> Self {
41        Self {
42            values: vec![T::zero(); size],
43            indices: Vec::with_capacity(size / 10),
44            flags: vec![false; size],
45        }
46    }
47
48    /// Accumulate `val` at column `idx`.
49    ///
50    /// If `idx` has not been touched in the current round it is recorded and
51    /// its slot initialised to `val`; otherwise `val` is added to the existing
52    /// partial sum.
53    ///
54    /// ### Params
55    ///
56    /// * `idx` - Column index to accumulate into
57    /// * `val` - Value to add
58    ///
59    /// # Safety
60    ///
61    /// `idx` must be less than the `size` passed to `new`. Violating this
62    /// causes out-of-bounds writes to `values` and `flags`, which is
63    /// undefined behaviour.
64    #[inline]
65    unsafe fn scatter(&mut self, idx: usize, val: T) {
66        unsafe {
67            if !*self.flags.get_unchecked(idx) {
68                *self.flags.get_unchecked_mut(idx) = true;
69                self.indices.push(idx);
70                *self.values.get_unchecked_mut(idx) = val;
71            } else {
72                let cur = *self.values.get_unchecked(idx);
73                *self.values.get_unchecked_mut(idx) = cur + val;
74            }
75        }
76    }
77
78    /// Collect all accumulated entries in ascending column order, then reset
79    /// the accumulator for reuse.
80    ///
81    /// Sorting is done in-place on `indices` before reading `values`, so the
82    /// output is always ordered by column index. Every touched slot is zeroed
83    /// and its flag cleared before returning.
84    ///
85    /// ### Returns
86    ///
87    /// A `Vec` of `(column_index, accumulated_value)` pairs sorted by
88    /// `column_index`
89    #[inline]
90    fn gather_sorted(&mut self) -> Vec<(usize, T)> {
91        self.indices.sort_unstable();
92        let out: Vec<(usize, T)> = self
93            .indices
94            .iter()
95            // Safety: every index in `self.indices` was bounds-checked against
96            // `size` at scatter time, so all reads here are in bounds.
97            .map(|&i| unsafe { (i, *self.values.get_unchecked(i)) })
98            .collect();
99        for &i in &self.indices {
100            // Safety: same guarantee as above.
101            unsafe {
102                *self.flags.get_unchecked_mut(i) = false;
103                *self.values.get_unchecked_mut(i) = T::zero();
104            }
105        }
106        self.indices.clear();
107        out
108    }
109}
110
111/////////
112// COO //
113/////////
114
115/// Coordinate list
116///
117/// Represents the graph in COO (Coordinate) format - tensor-friendly
118#[derive(Clone)]
119pub struct CoordinateList<T> {
120    /// Row index
121    pub row_indices: Vec<usize>,
122    /// Column index
123    pub col_indices: Vec<usize>,
124    /// Edge weights
125    pub values: Vec<T>,
126    /// Number of vertices in the graph
127    pub n_samples: usize,
128}
129
130/////////
131// CSR //
132/////////
133
134/// Lightweight CSR for label propagation sparse algebra.
135#[derive(Clone, Debug)]
136pub struct Csr<T> {
137    /// Row index pointers
138    pub indptr: Vec<usize>,
139    /// Column indices
140    pub indices: Vec<usize>,
141    /// Data for the sparse matrix
142    pub data: Vec<T>,
143    /// Number of rows
144    pub nrows: usize,
145    /// Number of columns
146    pub ncols: usize,
147}
148
149impl<T: EvocFloat> Csr<T> {
150    /// Construct a CSR matrix from its raw components.
151    ///
152    /// No reordering or deduplication is performed; the caller is responsible
153    /// for providing a valid CSR representation.
154    ///
155    /// ### Params
156    ///
157    /// * `indptr`  - Row pointer array of length `nrows + 1`
158    /// * `indices` - Column indices of each stored entry
159    /// * `data`    - Values corresponding to each entry in `indices`
160    /// * `nrows`   - Number of rows
161    /// * `ncols`   - Number of columns
162    pub fn new(
163        indptr: Vec<usize>,
164        indices: Vec<usize>,
165        data: Vec<T>,
166        nrows: usize,
167        ncols: usize,
168    ) -> Self {
169        debug_assert_eq!(indptr.len(), nrows + 1);
170        debug_assert_eq!(indices.len(), data.len());
171        debug_assert_eq!(*indptr.last().unwrap(), data.len());
172        Self {
173            indptr,
174            indices,
175            data,
176            nrows,
177            ncols,
178        }
179    }
180
181    /// Build a square CSR matrix from a COO coordinate list, summing duplicate
182    /// entries.
183    ///
184    /// Triplets are sorted by (row, column) in parallel before assembly.
185    /// Consecutive entries sharing the same (row, column) pair are folded into
186    /// a single stored value.
187    ///
188    /// ### Params
189    ///
190    /// * `coo` - Coordinate list with `n_samples x n_samples` logical shape
191    ///
192    /// ### Returns
193    ///
194    /// A square `n_samples x n_samples` CSR matrix
195    pub fn from_coo(coo: &CoordinateList<T>) -> Self {
196        let n = coo.n_samples;
197        let nnz = coo.values.len();
198        if nnz == 0 {
199            return Self::new(vec![0; n + 1], Vec::new(), Vec::new(), n, n);
200        }
201
202        let mut triplets: Vec<(usize, usize, T)> = (0..nnz)
203            .map(|i| (coo.row_indices[i], coo.col_indices[i], coo.values[i]))
204            .collect();
205        triplets.par_sort_unstable_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
206
207        let mut data = Vec::with_capacity(nnz);
208        let mut indices = Vec::with_capacity(nnz);
209        let mut indptr = vec![0usize; n + 1];
210
211        let mut prev_r = usize::MAX;
212        let mut prev_c = usize::MAX;
213        for &(r, c, v) in &triplets {
214            if r == prev_r && c == prev_c {
215                let last = data.last().copied().unwrap();
216                *data.last_mut().unwrap() = last + v;
217            } else {
218                data.push(v);
219                indices.push(c);
220                indptr[r + 1] += 1;
221                prev_r = r;
222                prev_c = c;
223            }
224        }
225        for i in 0..n {
226            indptr[i + 1] += indptr[i];
227        }
228
229        Self {
230            indptr,
231            indices,
232            data,
233            nrows: n,
234            ncols: n,
235        }
236    }
237
238    /// Build a partition indicator matrix of shape `n_points x n_parts`.
239    ///
240    /// Row `i` contains a single `1.0` at column `partition[i]`, encoding a
241    /// hard cluster assignment as a sparse one-hot matrix.
242    ///
243    /// ### Params
244    ///
245    /// * `partition` - Slice of length `n_points` where `partition[i]` is the
246    ///   part index assigned to point `i`
247    /// * `n_parts`   - Total number of parts (column count)
248    pub fn from_partition(partition: &[usize], n_parts: usize) -> Self {
249        let n = partition.len();
250        Self {
251            indptr: (0..=n).collect(),
252            indices: partition.to_vec(),
253            data: vec![T::one(); n],
254            nrows: n,
255            ncols: n_parts,
256        }
257    }
258
259    /// Number of stored non-zero entries.
260    pub fn nnz(&self) -> usize {
261        self.data.len()
262    }
263
264    /// Transpose into a new CSR matrix.
265    ///
266    /// Constructs the transpose via a two-pass counting sort: the first pass
267    /// builds column counts to allocate `indptr`; the second scatters each
268    /// entry to its transposed position using per-column cursors.
269    ///
270    /// ### Returns
271    ///
272    /// A new `ncols x nrows` CSR matrix equal to `self^T`
273    pub fn transpose(&self) -> Self {
274        let nnz = self.nnz();
275        let mut col_count = vec![0usize; self.ncols];
276        for &c in &self.indices {
277            col_count[c] += 1;
278        }
279
280        let mut indptr = vec![0usize; self.ncols + 1];
281        for i in 0..self.ncols {
282            indptr[i + 1] = indptr[i] + col_count[i];
283        }
284
285        let mut data = vec![T::zero(); nnz];
286        let mut indices = vec![0usize; nnz];
287        let mut cursor = indptr[..self.ncols].to_vec();
288
289        for row in 0..self.nrows {
290            for idx in self.indptr[row]..self.indptr[row + 1] {
291                let col = self.indices[idx];
292                let pos = cursor[col];
293                data[pos] = self.data[idx];
294                indices[pos] = row;
295                cursor[col] += 1;
296            }
297        }
298
299        Self {
300            indptr,
301            indices,
302            data,
303            nrows: self.ncols,
304            ncols: self.nrows,
305        }
306    }
307
308    /// Sparse-sparse matrix multiplication: `self` (m x k) * `other` (k x n)
309    /// -> (m x n).
310    ///
311    /// ### Params
312    ///
313    /// * `other` - Right-hand operand; its row count must equal `self.ncols`
314    ///
315    /// ### Returns
316    ///
317    /// A new `m x n` CSR matrix
318    pub fn matmul(&self, other: &Csr<T>) -> Self {
319        assert_eq!(
320            self.ncols, other.nrows,
321            "Dimension mismatch: ({} x {}) * ({} x {})",
322            self.nrows, self.ncols, other.nrows, other.ncols
323        );
324
325        let m = self.nrows;
326        let n = other.ncols;
327
328        let row_results: Vec<Vec<(usize, T)>> = (0..m)
329            .into_par_iter()
330            .map(|i| {
331                let mut acc = SpaAcc::new(n);
332                for a_idx in self.indptr[i]..self.indptr[i + 1] {
333                    let k = self.indices[a_idx];
334                    let a_val = self.data[a_idx];
335                    for b_idx in other.indptr[k]..other.indptr[k + 1] {
336                        unsafe {
337                            acc.scatter(other.indices[b_idx], a_val * other.data[b_idx]);
338                        }
339                    }
340                }
341                acc.gather_sorted()
342            })
343            .collect();
344
345        let total_nnz: usize = row_results.iter().map(|r| r.len()).sum();
346        let mut data = Vec::with_capacity(total_nnz);
347        let mut indices = Vec::with_capacity(total_nnz);
348        let mut indptr = Vec::with_capacity(m + 1);
349        indptr.push(0);
350
351        for row in row_results {
352            for (col, val) in row {
353                indices.push(col);
354                data.push(val);
355            }
356            indptr.push(data.len());
357        }
358
359        Self {
360            indptr,
361            indices,
362            data,
363            nrows: m,
364            ncols: n,
365        }
366    }
367
368    /// Element-wise (Hadamard) product of two matrices with identical shape.
369    ///
370    /// ### Params
371    ///
372    /// * `other` - Right-hand operand; must have the same shape as `self` and
373    ///   sorted column indices per row
374    ///
375    /// ### Returns
376    ///
377    /// A new CSR matrix containing only the entries where both operands are
378    /// non-zero
379    pub fn elementwise_mul(&self, other: &Csr<T>) -> Self {
380        assert_eq!(
381            (self.nrows, self.ncols),
382            (other.nrows, other.ncols),
383            "Shape mismatch for element-wise multiply"
384        );
385
386        let mut indptr = vec![0usize; self.nrows + 1];
387        let mut indices = Vec::new();
388        let mut data = Vec::new();
389
390        for i in 0..self.nrows {
391            let (mut p, end_p) = (self.indptr[i], self.indptr[i + 1]);
392            let (mut q, end_q) = (other.indptr[i], other.indptr[i + 1]);
393            while p < end_p && q < end_q {
394                let ci = self.indices[p];
395                let cj = other.indices[q];
396                match ci.cmp(&cj) {
397                    std::cmp::Ordering::Equal => {
398                        indices.push(ci);
399                        data.push(self.data[p] * other.data[q]);
400                        p += 1;
401                        q += 1;
402                    }
403                    std::cmp::Ordering::Less => p += 1,
404                    std::cmp::Ordering::Greater => q += 1,
405                }
406            }
407            indptr[i + 1] = data.len();
408        }
409
410        Self {
411            indptr,
412            indices,
413            data,
414            nrows: self.nrows,
415            ncols: self.ncols,
416        }
417    }
418
419    /// Column-wise L2 normalisation.
420    ///
421    /// Each column is scaled by the reciprocal of its L2 norm. Columns with
422    /// zero norm are left unchanged (scale factor of 1).
423    ///
424    /// ### Returns
425    ///
426    /// A new CSR matrix with the same sparsity pattern and unit-norm columns
427    pub fn normalise_cols_l2(&self) -> Self {
428        let mut col_sq = vec![T::zero(); self.ncols];
429        for (idx, &v) in self.data.iter().enumerate() {
430            let c = self.indices[idx];
431            col_sq[c] += v * v;
432        }
433
434        let col_inv: Vec<T> = col_sq
435            .iter()
436            .map(|&sq| {
437                let norm = sq.sqrt();
438                if norm > T::zero() {
439                    T::one() / norm
440                } else {
441                    T::one()
442                }
443            })
444            .collect();
445
446        let new_data: Vec<T> = self
447            .data
448            .iter()
449            .enumerate()
450            .map(|(idx, &v)| v * col_inv[self.indices[idx]])
451            .collect();
452
453        Self {
454            indptr: self.indptr.clone(),
455            indices: self.indices.clone(),
456            data: new_data,
457            nrows: self.nrows,
458            ncols: self.ncols,
459        }
460    }
461
462    /// Row-wise L1 normalisation.
463    ///
464    /// Each row is scaled by the reciprocal of the sum of absolute values of
465    /// its entries. Rows with zero norm are left unchanged.
466    ///
467    /// ### Returns
468    ///
469    /// A new CSR matrix with the same sparsity pattern and unit-L1-norm rows
470    pub fn normalise_rows_l1(&self) -> Self {
471        let mut new_data = self.data.clone();
472        for i in 0..self.nrows {
473            let start = self.indptr[i];
474            let end = self.indptr[i + 1];
475            let mut norm = T::zero();
476            for idx in start..end {
477                norm += self.data[idx].abs();
478            }
479            if norm > T::zero() {
480                let inv = T::one() / norm;
481                for idx in start..end {
482                    new_data[idx] = new_data[idx] * inv;
483                }
484            }
485        }
486
487        Self {
488            indptr: self.indptr.clone(),
489            indices: self.indices.clone(),
490            data: new_data,
491            nrows: self.nrows,
492            ncols: self.ncols,
493        }
494    }
495
496    /// Clamp all stored values to the closed interval `[lo, hi]`.
497    ///
498    /// ### Params
499    ///
500    /// * `lo` - Lower bound
501    /// * `hi` - Upper bound
502    pub fn clip_values(&mut self, lo: T, hi: T) {
503        for d in &mut self.data {
504            if *d < lo {
505                *d = lo;
506            } else if *d > hi {
507                *d = hi;
508            }
509        }
510    }
511
512    /// Convert to an adjacency list representation.
513    ///
514    /// Each entry `graph[i]` is a `Vec` of `(column, value)` pairs for row
515    /// `i`, suitable for consumption by `evoc_embedding`.
516    ///
517    /// ### Returns
518    ///
519    /// A `Vec` of length `nrows`, where `graph[i]` contains the neighbours and
520    /// edge weights of node `i`
521    pub fn to_adjacency_list(&self) -> Vec<Vec<(usize, T)>> {
522        (0..self.nrows)
523            .map(|i| {
524                (self.indptr[i]..self.indptr[i + 1])
525                    .map(|idx| (self.indices[idx], self.data[idx]))
526                    .collect()
527            })
528            .collect()
529    }
530
531    /// Sparse-dense matrix multiplication: `self` (m x k) * `rhs` (k x d)
532    /// -> (m x d).
533    ///
534    /// ### Params
535    ///
536    /// * `rhs` - Dense right-hand operand; its row count must equal
537    ///   `self.ncols`
538    ///
539    /// ### Returns
540    ///
541    /// A dense `m x d` matrix
542    pub fn matmul_dense(&self, rhs: &MatRef<T>) -> Mat<T> {
543        assert_eq!(
544            self.ncols,
545            rhs.nrows(),
546            "Dimension mismatch: CSR cols {} vs Mat rows {}",
547            self.ncols,
548            rhs.nrows()
549        );
550
551        let d = rhs.ncols();
552        let rows: Vec<Vec<T>> = (0..self.nrows)
553            .into_par_iter()
554            .map(|i| {
555                let mut row = vec![T::zero(); d];
556                for idx in self.indptr[i]..self.indptr[i + 1] {
557                    let j = self.indices[idx];
558                    let v = self.data[idx];
559                    for k in 0..d {
560                        row[k] += v * rhs[(j, k)];
561                    }
562                }
563                row
564            })
565            .collect();
566
567        Mat::from_fn(self.nrows, d, |i, j| rows[i][j])
568    }
569
570    /// Convert to a dense `faer::Mat`, filling structural zeros explicitly.
571    ///
572    /// ### Returns
573    ///
574    /// A dense `nrows x ncols` matrix
575    pub fn to_dense(&self) -> Mat<T> {
576        let mut dense = Mat::zeros(self.nrows, self.ncols);
577        for i in 0..self.nrows {
578            for idx in self.indptr[i]..self.indptr[i + 1] {
579                dense[(i, self.indices[idx])] = self.data[idx];
580            }
581        }
582        dense
583    }
584}
585
586///////////////////////////
587// Conversion utilities  //
588///////////////////////////
589
590/// Pack a row-major `Vec<Vec<T>>` into a `faer::Mat`.
591///
592/// All inner `Vec`s must have the same length. An empty outer slice produces a
593/// `0 x 0` matrix.
594///
595/// ### Params
596///
597/// * `rows` - Slice of rows, each of length `d`
598///
599/// ### Returns
600///
601/// A `rows.len() x d` matrix
602pub fn vecs_to_mat<T: EvocFloat>(rows: &[Vec<T>]) -> Mat<T> {
603    let n = rows.len();
604    if n == 0 {
605        return Mat::zeros(0, 0);
606    }
607    let d = rows[0].len();
608    Mat::from_fn(n, d, |i, j| rows[i][j])
609}
610
611/// Unpack a `faer::Mat` into a row-major `Vec<Vec<T>>`.
612///
613/// ### Params
614///
615/// * `mat` - Matrix to unpack
616///
617/// ### Returns
618///
619/// A `Vec` of length `nrows`, each inner `Vec` of length `ncols`
620pub fn mat_to_vecs<T: EvocFloat>(mat: &Mat<T>) -> Vec<Vec<T>> {
621    (0..mat.nrows())
622        .map(|i| (0..mat.ncols()).map(|j| mat[(i, j)]).collect())
623        .collect()
624}
625
626///////////
627// Tests //
628///////////
629
630#[cfg(test)]
631mod tests {
632    use super::*;
633
634    /// Helper: build a small 3x3 CSR.
635    ///
636    /// ```text
637    /// [[1, 0, 2],
638    ///  [0, 3, 0],
639    ///  [4, 0, 5]]
640    /// ```
641    fn make_3x3() -> Csr<f64> {
642        Csr::new(
643            vec![0, 2, 3, 5],
644            vec![0, 2, 1, 0, 2],
645            vec![1.0, 2.0, 3.0, 4.0, 5.0],
646            3,
647            3,
648        )
649    }
650
651    fn approx_eq(a: f64, b: f64) -> bool {
652        (a - b).abs() < 1e-12
653    }
654
655    #[test]
656    fn from_coo_basic() {
657        let coo = CoordinateList {
658            row_indices: vec![0, 0, 1, 2, 2],
659            col_indices: vec![0, 2, 1, 0, 2],
660            values: vec![1.0, 2.0, 3.0, 4.0, 5.0],
661            n_samples: 3,
662        };
663        let csr = Csr::from_coo(&coo);
664        assert_eq!(csr.nrows, 3);
665        assert_eq!(csr.ncols, 3);
666        assert_eq!(csr.nnz(), 5);
667        assert_eq!(csr.indptr, vec![0, 2, 3, 5]);
668        assert_eq!(csr.indices, vec![0, 2, 1, 0, 2]);
669        assert_eq!(csr.data, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
670    }
671
672    #[test]
673    fn from_coo_duplicates_summed() {
674        let coo = CoordinateList {
675            row_indices: vec![0, 0, 0],
676            col_indices: vec![1, 1, 2],
677            values: vec![1.0, 3.0, 5.0],
678            n_samples: 2,
679        };
680        let csr = Csr::from_coo(&coo);
681        // (0,1) should be 1+3=4, (0,2) should be 5
682        assert_eq!(csr.indptr, vec![0, 2, 2]);
683        assert_eq!(csr.indices, vec![1, 2]);
684        assert!(approx_eq(csr.data[0], 4.0));
685        assert!(approx_eq(csr.data[1], 5.0));
686    }
687
688    #[test]
689    fn from_coo_empty() {
690        let coo: CoordinateList<f64> = CoordinateList {
691            row_indices: Vec::new(),
692            col_indices: Vec::new(),
693            values: Vec::new(),
694            n_samples: 5,
695        };
696        let csr = Csr::from_coo(&coo);
697        assert_eq!(csr.nrows, 5);
698        assert_eq!(csr.nnz(), 0);
699        assert_eq!(csr.indptr, vec![0, 0, 0, 0, 0, 0]);
700    }
701
702    #[test]
703    fn from_partition_basic() {
704        let part = vec![2, 0, 1, 2];
705        let csr = Csr::<f64>::from_partition(&part, 3);
706        assert_eq!(csr.nrows, 4);
707        assert_eq!(csr.ncols, 3);
708        assert_eq!(csr.nnz(), 4);
709        // Row 0 -> col 2, row 1 -> col 0, etc.
710        assert_eq!(csr.indices, vec![2, 0, 1, 2]);
711        assert!(csr.data.iter().all(|&v| approx_eq(v, 1.0)));
712    }
713
714    #[test]
715    fn transpose_roundtrip() {
716        let a = make_3x3();
717        let at = a.transpose();
718        assert_eq!(at.nrows, 3);
719        assert_eq!(at.ncols, 3);
720        assert_eq!(at.nnz(), 5);
721
722        // A^T[0] should be cols [0, 2] with vals [1, 4]
723        let row0: Vec<(usize, f64)> = (at.indptr[0]..at.indptr[1])
724            .map(|idx| (at.indices[idx], at.data[idx]))
725            .collect();
726        assert_eq!(row0, vec![(0, 1.0), (2, 4.0)]);
727
728        // Double transpose should recover original
729        let att = at.transpose();
730        assert_eq!(att.indptr, a.indptr);
731        assert_eq!(att.indices, a.indices);
732        assert_eq!(att.data, a.data);
733    }
734
735    #[test]
736    fn transpose_non_square() {
737        // 2x3: [[1, 2, 0], [0, 0, 3]]
738        let m = Csr::new(vec![0, 2, 3], vec![0, 1, 2], vec![1.0, 2.0, 3.0], 2, 3);
739        let mt = m.transpose();
740        assert_eq!(mt.nrows, 3);
741        assert_eq!(mt.ncols, 2);
742        // T[0] = [1, 0], T[1] = [2, 0], T[2] = [0, 3]
743        assert_eq!(mt.indptr, vec![0, 1, 2, 3]);
744        assert_eq!(mt.indices, vec![0, 0, 1]);
745        assert_eq!(mt.data, vec![1.0, 2.0, 3.0]);
746    }
747
748    #[test]
749    fn matmul_identity() {
750        let a = make_3x3();
751        // 3x3 identity
752        let eye = Csr::new(vec![0, 1, 2, 3], vec![0, 1, 2], vec![1.0, 1.0, 1.0], 3, 3);
753        let result = a.matmul(&eye);
754        assert_eq!(result.data, a.data);
755        assert_eq!(result.indices, a.indices);
756    }
757
758    #[test]
759    fn matmul_a_times_at() {
760        let a = make_3x3();
761        let at = a.transpose();
762        let aat = a.matmul(&at);
763        let dense = aat.to_dense();
764
765        // A * A^T = [[5, 0, 14], [0, 9, 0], [14, 0, 41]]
766        assert!(approx_eq(dense[(0, 0)], 5.0));
767        assert!(approx_eq(dense[(0, 1)], 0.0));
768        assert!(approx_eq(dense[(0, 2)], 14.0));
769        assert!(approx_eq(dense[(1, 1)], 9.0));
770        assert!(approx_eq(dense[(2, 0)], 14.0));
771        assert!(approx_eq(dense[(2, 2)], 41.0));
772    }
773
774    #[test]
775    fn matmul_non_square() {
776        // (2x3) * (3x2)
777        let a = Csr::new(vec![0, 2, 3], vec![0, 1, 2], vec![1.0, 2.0, 3.0], 2, 3);
778        let b = Csr::new(vec![0, 1, 2, 3], vec![0, 1, 0], vec![4.0, 5.0, 6.0], 3, 2);
779        let c = a.matmul(&b);
780        assert_eq!(c.nrows, 2);
781        assert_eq!(c.ncols, 2);
782        let dense = c.to_dense();
783        // Row 0: 1*[4,0] + 2*[0,5] = [4, 10]
784        // Row 1: 3*[6,0] = [18, 0]
785        assert!(approx_eq(dense[(0, 0)], 4.0));
786        assert!(approx_eq(dense[(0, 1)], 10.0));
787        assert!(approx_eq(dense[(1, 0)], 18.0));
788        assert!(approx_eq(dense[(1, 1)], 0.0));
789    }
790
791    #[test]
792    fn matmul_dense_basic() {
793        let a = make_3x3();
794        // Dense 3x2: [[1, 0], [0, 1], [1, 1]]
795        let rhs = Mat::from_fn(3, 2, |i, j| match (i, j) {
796            (0, 0) | (1, 1) | (2, 0) | (2, 1) => 1.0_f64,
797            _ => 0.0,
798        });
799        let result = a.matmul_dense(&rhs.as_ref());
800        // Row 0: 1*[1,0] + 2*[1,1] = [3, 2]
801        // Row 1: 3*[0,1] = [0, 3]
802        // Row 2: 4*[1,0] + 5*[1,1] = [9, 5]
803        assert!(approx_eq(result[(0, 0)], 3.0));
804        assert!(approx_eq(result[(0, 1)], 2.0));
805        assert!(approx_eq(result[(1, 0)], 0.0));
806        assert!(approx_eq(result[(1, 1)], 3.0));
807        assert!(approx_eq(result[(2, 0)], 9.0));
808        assert!(approx_eq(result[(2, 1)], 5.0));
809    }
810
811    #[test]
812    fn elementwise_mul_with_transpose() {
813        let a = make_3x3();
814        let at = a.transpose();
815        let h = a.elementwise_mul(&at);
816        let dense = h.to_dense();
817
818        // A .* A^T:
819        // (0,0): 1*1=1, (0,2): 2*4=8
820        // (1,1): 3*3=9
821        // (2,0): 4*2=8, (2,2): 5*5=25
822        assert!(approx_eq(dense[(0, 0)], 1.0));
823        assert!(approx_eq(dense[(0, 2)], 8.0));
824        assert!(approx_eq(dense[(1, 1)], 9.0));
825        assert!(approx_eq(dense[(2, 0)], 8.0));
826        assert!(approx_eq(dense[(2, 2)], 25.0));
827        assert_eq!(h.nnz(), 5);
828    }
829
830    #[test]
831    fn elementwise_mul_disjoint() {
832        // No overlapping entries -> empty result
833        let a = Csr::new(vec![0, 1, 1], vec![0], vec![5.0], 2, 2);
834        let b = Csr::new(vec![0, 0, 1], vec![1], vec![7.0], 2, 2);
835        let h = a.elementwise_mul(&b);
836        assert_eq!(h.nnz(), 0);
837    }
838
839    #[test]
840    fn normalise_cols_l2_unit_norms() {
841        // [[1, 0], [0, 2], [3, 4]]
842        let m = Csr::new(
843            vec![0, 1, 2, 4],
844            vec![0, 1, 0, 1],
845            vec![1.0, 2.0, 3.0, 4.0],
846            3,
847            2,
848        );
849        let normed = m.normalise_cols_l2();
850
851        // Check column norms are 1
852        let mut col_sq = [0.0f64; 2];
853        for (idx, &v) in normed.data.iter().enumerate() {
854            col_sq[normed.indices[idx]] += v * v;
855        }
856        assert!(approx_eq(col_sq[0].sqrt(), 1.0));
857        assert!(approx_eq(col_sq[1].sqrt(), 1.0));
858
859        // Check specific values
860        let c0_norm = (1.0f64 + 9.0).sqrt(); // sqrt(10)
861        let c1_norm = (4.0f64 + 16.0).sqrt(); // sqrt(20)
862        assert!(approx_eq(normed.data[0], 1.0 / c0_norm));
863        assert!(approx_eq(normed.data[1], 2.0 / c1_norm));
864        assert!(approx_eq(normed.data[2], 3.0 / c0_norm));
865        assert!(approx_eq(normed.data[3], 4.0 / c1_norm));
866    }
867
868    #[test]
869    fn normalise_cols_l2_empty_column() {
870        // Column 1 has no entries -> should not panic
871        let m = Csr::new(vec![0, 1, 1], vec![0], vec![3.0], 2, 2);
872        let normed = m.normalise_cols_l2();
873        assert!(approx_eq(normed.data[0], 1.0)); // 3/3 = 1
874    }
875
876    #[test]
877    fn normalise_rows_l1_unit_sums() {
878        let m = Csr::new(
879            vec![0, 1, 2, 4],
880            vec![0, 1, 0, 1],
881            vec![1.0, 2.0, 3.0, 4.0],
882            3,
883            2,
884        );
885        let normed = m.normalise_rows_l1();
886
887        // Row sums should be 1
888        for i in 0..normed.nrows {
889            let sum: f64 = normed.data[normed.indptr[i]..normed.indptr[i + 1]]
890                .iter()
891                .map(|v: &f64| v.abs())
892                .sum();
893            assert!(approx_eq(sum, 1.0));
894        }
895
896        // Row 0: [1] -> [1]
897        assert!(approx_eq(normed.data[0], 1.0));
898        // Row 1: [2] -> [1]
899        assert!(approx_eq(normed.data[1], 1.0));
900        // Row 2: [3, 4] -> [3/7, 4/7]
901        assert!(approx_eq(normed.data[2], 3.0 / 7.0));
902        assert!(approx_eq(normed.data[3], 4.0 / 7.0));
903    }
904
905    #[test]
906    fn normalise_rows_l1_empty_row() {
907        let m = Csr::new(vec![0, 0, 1], vec![0], vec![5.0], 2, 2);
908        let normed = m.normalise_rows_l1();
909        // Row 0 is empty -> should not panic
910        assert!(approx_eq(normed.data[0], 1.0));
911    }
912
913    #[test]
914    fn clip_values_basic() {
915        let mut m = Csr::new(vec![0, 3], vec![0, 1, 2], vec![-1.0, 0.5, 2.0], 1, 3);
916        m.clip_values(0.0, 1.0);
917        assert!(approx_eq(m.data[0], 0.0));
918        assert!(approx_eq(m.data[1], 0.5));
919        assert!(approx_eq(m.data[2], 1.0));
920    }
921
922    #[test]
923    fn to_adjacency_list_roundtrip() {
924        let a = make_3x3();
925        let adj = a.to_adjacency_list();
926        assert_eq!(adj.len(), 3);
927        assert_eq!(adj[0], vec![(0, 1.0), (2, 2.0)]);
928        assert_eq!(adj[1], vec![(1, 3.0)]);
929        assert_eq!(adj[2], vec![(0, 4.0), (2, 5.0)]);
930    }
931
932    #[test]
933    fn to_dense_roundtrip() {
934        let a = make_3x3();
935        let d = a.to_dense();
936        assert!(approx_eq(d[(0, 0)], 1.0));
937        assert!(approx_eq(d[(0, 1)], 0.0));
938        assert!(approx_eq(d[(0, 2)], 2.0));
939        assert!(approx_eq(d[(1, 1)], 3.0));
940        assert!(approx_eq(d[(2, 0)], 4.0));
941        assert!(approx_eq(d[(2, 2)], 5.0));
942    }
943
944    #[test]
945    fn vecs_mat_roundtrip() {
946        let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
947        let mat = vecs_to_mat(&rows);
948        let back = mat_to_vecs(&mat);
949        assert_eq!(rows, back);
950    }
951}