Skip to main content

oxiphysics_gpu/
sparse_gpu.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! GPU-ready sparse matrix formats and iterative solvers with data-oriented layouts.
6//!
7//! Provides CSR, ELLPACK, Hybrid (ELL+COO), and Block-CSR formats along with
8//! iterative solvers (CG, BiCGSTAB, preconditioned CG) suitable for GPU offload.
9#![allow(missing_docs)]
10#![allow(dead_code)]
11
12// ---------------------------------------------------------------------------
13// Sparse vector operations
14// ---------------------------------------------------------------------------
15
16/// Dot product of two vectors.
17pub fn dot(x: &[f64], y: &[f64]) -> f64 {
18    x.iter().zip(y.iter()).map(|(a, b)| a * b).sum()
19}
20
21/// `y + alpha * x` (AXPY).
22pub fn axpy(alpha: f64, x: &[f64], y: &[f64]) -> Vec<f64> {
23    y.iter()
24        .zip(x.iter())
25        .map(|(yi, xi)| yi + alpha * xi)
26        .collect()
27}
28
29/// Euclidean norm of a vector.
30pub fn norm2(x: &[f64]) -> f64 {
31    dot(x, x).sqrt()
32}
33
34/// Scale every element of `x` by `s`.
35pub fn scale_vec(x: &[f64], s: f64) -> Vec<f64> {
36    x.iter().map(|v| v * s).collect()
37}
38
39// ---------------------------------------------------------------------------
40// SparseTriplet – coordinate (COO) format for assembly
41// ---------------------------------------------------------------------------
42
43/// Coordinate (COO) format sparse matrix for incremental assembly.
44pub struct SparseTriplet {
45    pub rows: Vec<usize>,
46    pub cols: Vec<usize>,
47    pub vals: Vec<f64>,
48}
49
50impl SparseTriplet {
51    /// Create an empty triplet store.
52    pub fn new() -> Self {
53        Self {
54            rows: Vec::new(),
55            cols: Vec::new(),
56            vals: Vec::new(),
57        }
58    }
59
60    /// Push a single entry `(row, col, val)`.
61    pub fn add(&mut self, row: usize, col: usize, val: f64) {
62        self.rows.push(row);
63        self.cols.push(col);
64        self.vals.push(val);
65    }
66
67    /// Convert to [`CsrMatrix`], sorting by (row, col) and summing duplicates.
68    pub fn to_csr(&self, n_rows: usize, n_cols: usize) -> CsrMatrix {
69        // Sort indices by (row, col)
70        let mut order: Vec<usize> = (0..self.rows.len()).collect();
71        order.sort_by_key(|&i| (self.rows[i], self.cols[i]));
72
73        // Accumulate into (row, col, val) triples, summing duplicates
74        let mut entries: Vec<(usize, usize, f64)> = Vec::new();
75        for &i in &order {
76            let r = self.rows[i];
77            let c = self.cols[i];
78            let v = self.vals[i];
79            if let Some(last) = entries.last_mut()
80                && last.0 == r
81                && last.1 == c
82            {
83                last.2 += v;
84                continue;
85            }
86            entries.push((r, c, v));
87        }
88
89        // Build CSR
90        let nnz = entries.len();
91        let mut row_ptr = vec![0usize; n_rows + 1];
92        let mut col_idx = Vec::with_capacity(nnz);
93        let mut values = Vec::with_capacity(nnz);
94
95        for &(r, c, v) in &entries {
96            row_ptr[r + 1] += 1;
97            col_idx.push(c);
98            values.push(v);
99        }
100        for i in 0..n_rows {
101            row_ptr[i + 1] += row_ptr[i];
102        }
103
104        CsrMatrix {
105            n_rows,
106            n_cols,
107            row_ptr,
108            col_idx,
109            values,
110        }
111    }
112}
113
114impl Default for SparseTriplet {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120// ---------------------------------------------------------------------------
121// CsrMatrix – Compressed Sparse Row
122// ---------------------------------------------------------------------------
123
124/// Compressed Sparse Row (CSR) matrix stored as plain `f64` arrays.
125pub struct CsrMatrix {
126    pub n_rows: usize,
127    pub n_cols: usize,
128    /// Row start indices, length `n_rows + 1`.
129    pub row_ptr: Vec<usize>,
130    /// Column indices of non-zeros.
131    pub col_idx: Vec<usize>,
132    /// Non-zero values.
133    pub values: Vec<f64>,
134}
135
136impl CsrMatrix {
137    /// Create an empty (all-zero) CSR matrix.
138    pub fn new(n_rows: usize, n_cols: usize) -> Self {
139        Self {
140            n_rows,
141            n_cols,
142            row_ptr: vec![0; n_rows + 1],
143            col_idx: Vec::new(),
144            values: Vec::new(),
145        }
146    }
147
148    /// Build from a dense row-major matrix (rows of length `n_cols`).
149    pub fn from_dense(m: &[Vec<f64>]) -> Self {
150        let n_rows = m.len();
151        let n_cols = if n_rows > 0 { m[0].len() } else { 0 };
152        let mut row_ptr = vec![0usize; n_rows + 1];
153        let mut col_idx = Vec::new();
154        let mut values = Vec::new();
155        for (r, row) in m.iter().enumerate() {
156            for (c, &v) in row.iter().enumerate() {
157                if v != 0.0 {
158                    col_idx.push(c);
159                    values.push(v);
160                }
161            }
162            row_ptr[r + 1] = col_idx.len();
163        }
164        Self {
165            n_rows,
166            n_cols,
167            row_ptr,
168            col_idx,
169            values,
170        }
171    }
172
173    /// Number of stored non-zeros.
174    pub fn nnz(&self) -> usize {
175        self.values.len()
176    }
177
178    /// Sparse matrix–vector product `y = A * x`.
179    pub fn spmv(&self, x: &[f64]) -> Vec<f64> {
180        let mut y = vec![0.0f64; self.n_rows];
181        for r in 0..self.n_rows {
182            let start = self.row_ptr[r];
183            let end = self.row_ptr[r + 1];
184            let mut sum = 0.0;
185            for k in start..end {
186                sum += self.values[k] * x[self.col_idx[k]];
187            }
188            y[r] = sum;
189        }
190        y
191    }
192
193    /// Get the value at `(row, col)`, returning `0.0` if not stored.
194    pub fn get(&self, row: usize, col: usize) -> f64 {
195        let start = self.row_ptr[row];
196        let end = self.row_ptr[row + 1];
197        for k in start..end {
198            if self.col_idx[k] == col {
199                return self.values[k];
200            }
201        }
202        0.0
203    }
204
205    /// Return `A^T` as a new `CsrMatrix`.
206    pub fn transpose(&self) -> CsrMatrix {
207        // Count non-zeros per column (which become rows of A^T)
208        let mut row_ptr = vec![0usize; self.n_cols + 1];
209        for &c in &self.col_idx {
210            row_ptr[c + 1] += 1;
211        }
212        for i in 0..self.n_cols {
213            row_ptr[i + 1] += row_ptr[i];
214        }
215
216        let nnz = self.values.len();
217        let mut col_idx = vec![0usize; nnz];
218        let mut values = vec![0.0f64; nnz];
219        let mut pos = row_ptr[..self.n_cols].to_vec();
220
221        for r in 0..self.n_rows {
222            let start = self.row_ptr[r];
223            let end = self.row_ptr[r + 1];
224            for k in start..end {
225                let c = self.col_idx[k];
226                let dest = pos[c];
227                col_idx[dest] = r;
228                values[dest] = self.values[k];
229                pos[c] += 1;
230            }
231        }
232
233        CsrMatrix {
234            n_rows: self.n_cols,
235            n_cols: self.n_rows,
236            row_ptr,
237            col_idx,
238            values,
239        }
240    }
241
242    /// Add two CSR matrices of identical shape.
243    pub fn add(&self, other: &CsrMatrix) -> CsrMatrix {
244        assert_eq!(self.n_rows, other.n_rows);
245        assert_eq!(self.n_cols, other.n_cols);
246        // Assemble via triplet
247        let mut trip = SparseTriplet::new();
248        for r in 0..self.n_rows {
249            for k in self.row_ptr[r]..self.row_ptr[r + 1] {
250                trip.add(r, self.col_idx[k], self.values[k]);
251            }
252            for k in other.row_ptr[r]..other.row_ptr[r + 1] {
253                trip.add(r, other.col_idx[k], other.values[k]);
254            }
255        }
256        trip.to_csr(self.n_rows, self.n_cols)
257    }
258
259    /// Return a new matrix with every value multiplied by `s`.
260    pub fn scale(&self, s: f64) -> CsrMatrix {
261        CsrMatrix {
262            n_rows: self.n_rows,
263            n_cols: self.n_cols,
264            row_ptr: self.row_ptr.clone(),
265            col_idx: self.col_idx.clone(),
266            values: self.values.iter().map(|v| v * s).collect(),
267        }
268    }
269}
270
271// ---------------------------------------------------------------------------
272// EllMatrix – ELLPACK / ITPACK format (GPU-friendly padded layout)
273// ---------------------------------------------------------------------------
274
275/// ELLPACK-format sparse matrix: rows padded to `max_nnz_per_row`.
276pub struct EllMatrix {
277    pub n_rows: usize,
278    pub n_cols: usize,
279    pub max_nnz_per_row: usize,
280    /// Column indices, row-major `[n_rows × max_nnz_per_row]`.
281    pub col_idx: Vec<usize>,
282    /// Values, row-major `[n_rows × max_nnz_per_row]`.
283    pub values: Vec<f64>,
284}
285
286impl EllMatrix {
287    /// Convert a [`CsrMatrix`] to ELLPACK.
288    pub fn from_csr(csr: &CsrMatrix) -> Self {
289        let n_rows = csr.n_rows;
290        let n_cols = csr.n_cols;
291        let max_nnz_per_row = (0..n_rows)
292            .map(|r| csr.row_ptr[r + 1] - csr.row_ptr[r])
293            .max()
294            .unwrap_or(0);
295
296        let size = n_rows * max_nnz_per_row;
297        let mut col_idx = vec![0usize; size];
298        let mut values = vec![0.0f64; size];
299
300        for r in 0..n_rows {
301            let start = csr.row_ptr[r];
302            let end = csr.row_ptr[r + 1];
303            for (j, k) in (start..end).enumerate() {
304                col_idx[r * max_nnz_per_row + j] = csr.col_idx[k];
305                values[r * max_nnz_per_row + j] = csr.values[k];
306            }
307        }
308
309        Self {
310            n_rows,
311            n_cols,
312            max_nnz_per_row,
313            col_idx,
314            values,
315        }
316    }
317
318    /// Sparse matrix–vector product `y = A * x`.
319    pub fn spmv(&self, x: &[f64]) -> Vec<f64> {
320        let mut y = vec![0.0f64; self.n_rows];
321        for r in 0..self.n_rows {
322            let mut sum = 0.0;
323            for j in 0..self.max_nnz_per_row {
324                let v = self.values[r * self.max_nnz_per_row + j];
325                if v != 0.0 {
326                    let c = self.col_idx[r * self.max_nnz_per_row + j];
327                    sum += v * x[c];
328                }
329            }
330            y[r] = sum;
331        }
332        y
333    }
334}
335
336// ---------------------------------------------------------------------------
337// HybridMatrix – ELL + COO for irregular sparsity
338// ---------------------------------------------------------------------------
339
340/// Hybrid ELL+COO matrix: regular rows stored in ELL, overflow in COO.
341pub struct HybridMatrix {
342    pub ell: EllMatrix,
343    pub coo_row: Vec<usize>,
344    pub coo_col: Vec<usize>,
345    pub coo_val: Vec<f64>,
346}
347
348impl HybridMatrix {
349    /// Sparse matrix–vector product `y = (ELL + COO) * x`.
350    pub fn spmv(&self, x: &[f64]) -> Vec<f64> {
351        let mut y = self.ell.spmv(x);
352        for k in 0..self.coo_val.len() {
353            y[self.coo_row[k]] += self.coo_val[k] * x[self.coo_col[k]];
354        }
355        y
356    }
357}
358
359// ---------------------------------------------------------------------------
360// BlockCsrMatrix – block sparse row for FEM
361// ---------------------------------------------------------------------------
362
363/// Block-CSR matrix where every stored entry is a `block_size × block_size` dense tile.
364pub struct BlockCsrMatrix {
365    pub block_size: usize,
366    pub n_block_rows: usize,
367    pub n_block_cols: usize,
368    /// Block row start indices, length `n_block_rows + 1`.
369    pub row_ptr: Vec<usize>,
370    /// Column (block) indices.
371    pub col_idx: Vec<usize>,
372    /// Dense tiles, each of length `block_size * block_size`.
373    pub blocks: Vec<Vec<f64>>,
374}
375
376impl BlockCsrMatrix {
377    /// Sparse matrix–vector product treating the matrix as `(n_block_rows * block_size)` × `(n_block_cols * block_size)`.
378    pub fn spmv_block(&self, x: &[f64]) -> Vec<f64> {
379        let bs = self.block_size;
380        let n = self.n_block_rows * bs;
381        let mut y = vec![0.0f64; n];
382        for br in 0..self.n_block_rows {
383            let row_start = self.row_ptr[br];
384            let row_end = self.row_ptr[br + 1];
385            for k in row_start..row_end {
386                let bc = self.col_idx[k];
387                let blk = &self.blocks[k];
388                // Multiply dense block into y
389                for i in 0..bs {
390                    let mut s = 0.0;
391                    for j in 0..bs {
392                        s += blk[i * bs + j] * x[bc * bs + j];
393                    }
394                    y[br * bs + i] += s;
395                }
396            }
397        }
398        y
399    }
400
401    /// Convert to a flat [`CsrMatrix`].
402    pub fn to_csr(&self) -> CsrMatrix {
403        let bs = self.block_size;
404        let n_rows = self.n_block_rows * bs;
405        let n_cols = self.n_block_cols * bs;
406        let mut trip = SparseTriplet::new();
407        for br in 0..self.n_block_rows {
408            for k in self.row_ptr[br]..self.row_ptr[br + 1] {
409                let bc = self.col_idx[k];
410                let blk = &self.blocks[k];
411                for i in 0..bs {
412                    for j in 0..bs {
413                        let v = blk[i * bs + j];
414                        if v != 0.0 {
415                            trip.add(br * bs + i, bc * bs + j, v);
416                        }
417                    }
418                }
419            }
420        }
421        trip.to_csr(n_rows, n_cols)
422    }
423}
424
425// ---------------------------------------------------------------------------
426// Iterative solvers
427// ---------------------------------------------------------------------------
428
429/// Conjugate Gradient solver for symmetric positive-definite systems `A x = b`.
430///
431/// Returns `(solution, iterations_used)`.
432pub fn cg_solve(
433    a: &CsrMatrix,
434    b: &[f64],
435    x0: &[f64],
436    max_iter: usize,
437    tol: f64,
438) -> (Vec<f64>, usize) {
439    let n = b.len();
440    let mut x = x0.to_vec();
441    // r = b - A*x
442    let ax = a.spmv(&x);
443    let mut r: Vec<f64> = (0..n).map(|i| b[i] - ax[i]).collect();
444    let mut p = r.clone();
445    let mut rs_old = dot(&r, &r);
446
447    for iter in 0..max_iter {
448        if rs_old.sqrt() < tol {
449            return (x, iter);
450        }
451        let ap = a.spmv(&p);
452        let alpha = rs_old / dot(&p, &ap);
453        x = axpy(alpha, &p, &x);
454        r = axpy(-alpha, &ap, &r);
455        let rs_new = dot(&r, &r);
456        let beta = rs_new / rs_old;
457        p = axpy(beta, &p, &r);
458        rs_old = rs_new;
459    }
460    (x, max_iter)
461}
462
463/// BiCGSTAB solver for general (possibly non-symmetric) systems `A x = b`.
464///
465/// Returns `(solution, iterations_used)`.
466pub fn bicgstab_solve(
467    a: &CsrMatrix,
468    b: &[f64],
469    x0: &[f64],
470    max_iter: usize,
471    tol: f64,
472) -> (Vec<f64>, usize) {
473    let n = b.len();
474    let mut x = x0.to_vec();
475    let ax = a.spmv(&x);
476    let mut r: Vec<f64> = (0..n).map(|i| b[i] - ax[i]).collect();
477    let r_hat = r.clone();
478    let mut rho = 1.0_f64;
479    let mut alpha_s = 1.0_f64;
480    let mut omega = 1.0_f64;
481    let mut v = vec![0.0f64; n];
482    let mut p = vec![0.0f64; n];
483
484    for iter in 0..max_iter {
485        if norm2(&r) < tol {
486            return (x, iter);
487        }
488        let rho_new = dot(&r_hat, &r);
489        let beta = (rho_new / rho) * (alpha_s / omega);
490        p = axpy(beta, &p, &axpy(-beta * omega, &v, &r));
491        v = a.spmv(&p);
492        let denom = dot(&r_hat, &v);
493        if denom.abs() < 1e-300 {
494            return (x, iter);
495        }
496        alpha_s = rho_new / denom;
497        let s: Vec<f64> = axpy(-alpha_s, &v, &r);
498        if norm2(&s) < tol {
499            x = axpy(alpha_s, &p, &x);
500            return (x, iter + 1);
501        }
502        let t = a.spmv(&s);
503        let tt = dot(&t, &t);
504        omega = if tt.abs() < 1e-300 {
505            0.0
506        } else {
507            dot(&t, &s) / tt
508        };
509        x = axpy(omega, &s, &axpy(alpha_s, &p, &x));
510        r = axpy(-omega, &t, &s);
511        rho = rho_new;
512    }
513    (x, max_iter)
514}
515
516/// Conjugate Gradient with Jacobi (diagonal) preconditioner for `A x = b` (SPD).
517///
518/// Returns `(solution, iterations_used)`.
519pub fn jacobi_preconditioned_cg(
520    a: &CsrMatrix,
521    b: &[f64],
522    max_iter: usize,
523    tol: f64,
524) -> (Vec<f64>, usize) {
525    let n = b.len();
526    // Build inverse diagonal preconditioner M^{-1}
527    let mut m_inv = vec![1.0f64; n];
528    for r in 0..n {
529        let d = a.get(r, r);
530        if d.abs() > 1e-300 {
531            m_inv[r] = 1.0 / d;
532        }
533    }
534
535    let mut x = vec![0.0f64; n];
536    let ax = a.spmv(&x);
537    let mut r: Vec<f64> = (0..n).map(|i| b[i] - ax[i]).collect();
538    // z = M^{-1} r
539    let z: Vec<f64> = (0..n).map(|i| m_inv[i] * r[i]).collect();
540    let mut p = z.clone();
541    let mut rz_old = dot(&r, &z);
542
543    for iter in 0..max_iter {
544        if norm2(&r) < tol {
545            return (x, iter);
546        }
547        let ap = a.spmv(&p);
548        let alpha = rz_old / dot(&p, &ap);
549        x = axpy(alpha, &p, &x);
550        r = axpy(-alpha, &ap, &r);
551        let z_new: Vec<f64> = (0..n).map(|i| m_inv[i] * r[i]).collect();
552        let rz_new = dot(&r, &z_new);
553        let beta = rz_new / rz_old;
554        p = axpy(beta, &p, &z_new);
555        rz_old = rz_new;
556    }
557    (x, max_iter)
558}
559
560// ---------------------------------------------------------------------------
561// GPU simulation utilities
562// ---------------------------------------------------------------------------
563
564/// Estimate SpMV throughput in GFLOPS given matrix dimensions and nnz count.
565///
566/// Uses a simple roofline model: bandwidth-bound at 100 GB/s with 12 bytes/flop.
567pub fn simulate_spmv_throughput(n: usize, nnz: usize) -> f64 {
568    // Memory traffic estimate: col_idx (8 bytes) + values (8 bytes) + x access (~random, 8 bytes)
569    // = 24 bytes per nnz, plus row_ptr read (n * 8 bytes)
570    let _ = n; // size used for context but dominated by nnz traffic
571    let bytes_transferred = (nnz * 24) as f64;
572    let bandwidth_gb_s = 100.0_f64; // typical GPU HBM bandwidth
573    let time_s = bytes_transferred / (bandwidth_gb_s * 1e9);
574    let flops = 2.0 * nnz as f64; // one multiply + one add per nnz
575    flops / time_s / 1e9 // GFLOPS
576}
577
578/// Choose the ELLPACK row width (max non-zeros per row) to minimize padding waste.
579///
580/// Uses the 75th-percentile of the per-row nnz distribution.
581pub fn optimal_ell_row_width(nnz_distribution: &[usize]) -> usize {
582    if nnz_distribution.is_empty() {
583        return 0;
584    }
585    let mut sorted = nnz_distribution.to_vec();
586    sorted.sort_unstable();
587    let idx = (sorted.len() * 3) / 4; // 75th percentile index
588    sorted[idx]
589}
590
591// ---------------------------------------------------------------------------
592// SpMV – segmented (row-parallel) variant
593// ---------------------------------------------------------------------------
594
595/// Segmented SpMV: processes each row independently to prepare for
596/// GPU-style parallel execution. Functionally identical to `CsrMatrix::spmv`
597/// but structured for row-parallel dispatch.
598pub fn spmv_segmented(a: &CsrMatrix, x: &[f64]) -> Vec<f64> {
599    let mut y = vec![0.0_f64; a.n_rows];
600    for r in 0..a.n_rows {
601        let start = a.row_ptr[r];
602        let end = a.row_ptr[r + 1];
603        let mut acc = 0.0_f64;
604        for k in start..end {
605            acc += a.values[k] * x[a.col_idx[k]];
606        }
607        y[r] = acc;
608    }
609    y
610}
611
612// ---------------------------------------------------------------------------
613// Sparse matrix assembly helpers
614// ---------------------------------------------------------------------------
615
616/// Assemble a 1D Laplacian matrix of size `n × n` (tridiagonal: 2 on diag, -1 off-diag).
617pub fn assemble_1d_laplacian(n: usize) -> CsrMatrix {
618    let mut trip = SparseTriplet::new();
619    for i in 0..n {
620        trip.add(i, i, 2.0);
621        if i > 0 {
622            trip.add(i, i - 1, -1.0);
623        }
624        if i + 1 < n {
625            trip.add(i, i + 1, -1.0);
626        }
627    }
628    trip.to_csr(n, n)
629}
630
631// ---------------------------------------------------------------------------
632// CSR-to-ELL conversion (alternative entry point)
633// ---------------------------------------------------------------------------
634
635/// Convert a CSR matrix to ELLPACK format (convenience wrapper).
636pub fn csr_to_ell(csr: &CsrMatrix) -> EllMatrix {
637    EllMatrix::from_csr(csr)
638}
639
640// ---------------------------------------------------------------------------
641// Extract diagonal
642// ---------------------------------------------------------------------------
643
644/// Extract the main diagonal of a CSR matrix.
645pub fn extract_diagonal(a: &CsrMatrix) -> Vec<f64> {
646    let n = a.n_rows.min(a.n_cols);
647    let mut diag = vec![0.0_f64; n];
648    for r in 0..n {
649        diag[r] = a.get(r, r);
650    }
651    diag
652}
653
654// ---------------------------------------------------------------------------
655// Per-row nnz distribution
656// ---------------------------------------------------------------------------
657
658/// Compute the number of non-zeros per row for a CSR matrix.
659pub fn compute_nnz_per_row(a: &CsrMatrix) -> Vec<usize> {
660    (0..a.n_rows)
661        .map(|r| a.row_ptr[r + 1] - a.row_ptr[r])
662        .collect()
663}
664
665// ---------------------------------------------------------------------------
666// Frobenius norm
667// ---------------------------------------------------------------------------
668
669/// Compute the Frobenius norm of a sparse matrix: `sqrt(sum(a_ij^2))`.
670pub fn frobenius_norm(a: &CsrMatrix) -> f64 {
671    let sum_sq: f64 = a.values.iter().map(|v| v * v).sum();
672    sum_sq.sqrt()
673}
674
675// ---------------------------------------------------------------------------
676// Sparse triangular solves
677// ---------------------------------------------------------------------------
678
679/// Forward-substitution solve `L x = b` where `L` is lower-triangular (CSR).
680///
681/// Assumes `L` has non-zero diagonal entries. The diagonal entry for row `i`
682/// is taken as `L.get(i, i)`.
683pub fn sparse_lower_triangular_solve(l: &CsrMatrix, b: &[f64]) -> Vec<f64> {
684    let n = b.len();
685    let mut x = vec![0.0_f64; n];
686    for i in 0..n {
687        let mut sum = b[i];
688        let start = l.row_ptr[i];
689        let end = l.row_ptr[i + 1];
690        let mut diag = 1.0_f64;
691        for k in start..end {
692            let c = l.col_idx[k];
693            if c < i {
694                sum -= l.values[k] * x[c];
695            } else if c == i {
696                diag = l.values[k];
697            }
698        }
699        x[i] = sum / diag;
700    }
701    x
702}
703
704/// Back-substitution solve `U x = b` where `U` is upper-triangular (CSR).
705///
706/// Assumes `U` has non-zero diagonal entries.
707pub fn sparse_upper_triangular_solve(u: &CsrMatrix, b: &[f64]) -> Vec<f64> {
708    let n = b.len();
709    let mut x = vec![0.0_f64; n];
710    for i in (0..n).rev() {
711        let mut sum = b[i];
712        let start = u.row_ptr[i];
713        let end = u.row_ptr[i + 1];
714        let mut diag = 1.0_f64;
715        for k in start..end {
716            let c = u.col_idx[k];
717            if c > i {
718                sum -= u.values[k] * x[c];
719            } else if c == i {
720                diag = u.values[k];
721            }
722        }
723        x[i] = sum / diag;
724    }
725    x
726}
727
728// ---------------------------------------------------------------------------
729// Tests
730// ---------------------------------------------------------------------------
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735
736    #[test]
737    fn test_csr_from_dense_nnz() {
738        let m = vec![
739            vec![1.0, 0.0, 2.0],
740            vec![0.0, 3.0, 0.0],
741            vec![4.0, 5.0, 6.0],
742        ];
743        let csr = CsrMatrix::from_dense(&m);
744        assert_eq!(csr.n_rows, 3);
745        assert_eq!(csr.n_cols, 3);
746        assert_eq!(csr.nnz(), 6);
747    }
748
749    #[test]
750    fn test_csr_spmv_identity() {
751        // 3x3 identity
752        let m = vec![
753            vec![1.0, 0.0, 0.0],
754            vec![0.0, 1.0, 0.0],
755            vec![0.0, 0.0, 1.0],
756        ];
757        let csr = CsrMatrix::from_dense(&m);
758        let x = vec![3.0, 7.0, -2.0];
759        let y = csr.spmv(&x);
760        assert_eq!(y, x);
761    }
762
763    #[test]
764    fn test_csr_spmv_known_3x3() {
765        // A = [[2,1,0],[1,3,1],[0,1,2]]
766        let m = vec![
767            vec![2.0, 1.0, 0.0],
768            vec![1.0, 3.0, 1.0],
769            vec![0.0, 1.0, 2.0],
770        ];
771        let csr = CsrMatrix::from_dense(&m);
772        let x = vec![1.0, 2.0, 3.0];
773        let y = csr.spmv(&x);
774        // y[0] = 2+2 = 4, y[1] = 1+6+3 = 10, y[2] = 2+6 = 8
775        assert!((y[0] - 4.0).abs() < 1e-12);
776        assert!((y[1] - 10.0).abs() < 1e-12);
777        assert!((y[2] - 8.0).abs() < 1e-12);
778    }
779
780    #[test]
781    fn test_cg_solve_diagonal_spd() {
782        // A = diag(1,2,3,4), b = [1,2,3,4], exact solution x = [1,1,1,1]
783        let m = vec![
784            vec![1.0, 0.0, 0.0, 0.0],
785            vec![0.0, 2.0, 0.0, 0.0],
786            vec![0.0, 0.0, 3.0, 0.0],
787            vec![0.0, 0.0, 0.0, 4.0],
788        ];
789        let a = CsrMatrix::from_dense(&m);
790        let b = vec![1.0, 2.0, 3.0, 4.0];
791        let x0 = vec![0.0; 4];
792        let (x, _iters) = cg_solve(&a, &b, &x0, 100, 1e-12);
793        for v in &x {
794            assert!((v - 1.0).abs() < 1e-10, "x value {v} not close to 1.0");
795        }
796    }
797
798    #[test]
799    fn test_sparse_triplet_to_csr_duplicate_sum() {
800        let mut trip = SparseTriplet::new();
801        trip.add(0, 0, 1.0);
802        trip.add(0, 0, 2.0); // duplicate → should sum to 3.0
803        trip.add(1, 1, 5.0);
804        let csr = trip.to_csr(2, 2);
805        assert!((csr.get(0, 0) - 3.0).abs() < 1e-12);
806        assert!((csr.get(1, 1) - 5.0).abs() < 1e-12);
807        assert_eq!(csr.nnz(), 2);
808    }
809
810    #[test]
811    fn test_ell_spmv_matches_csr() {
812        let m = vec![
813            vec![2.0, 1.0, 0.0],
814            vec![1.0, 3.0, 1.0],
815            vec![0.0, 1.0, 2.0],
816        ];
817        let csr = CsrMatrix::from_dense(&m);
818        let ell = EllMatrix::from_csr(&csr);
819        let x = vec![1.0, -1.0, 2.0];
820        let y_csr = csr.spmv(&x);
821        let y_ell = ell.spmv(&x);
822        for (a, b) in y_csr.iter().zip(y_ell.iter()) {
823            assert!((a - b).abs() < 1e-12, "ELL mismatch: {a} vs {b}");
824        }
825    }
826
827    #[test]
828    fn test_csr_transpose() {
829        // A = [[1,2,3],[4,5,6]]  →  A^T = [[1,4],[2,5],[3,6]]
830        let m = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
831        let csr = CsrMatrix::from_dense(&m);
832        let at = csr.transpose();
833        assert_eq!(at.n_rows, 3);
834        assert_eq!(at.n_cols, 2);
835        assert!((at.get(0, 0) - 1.0).abs() < 1e-12);
836        assert!((at.get(0, 1) - 4.0).abs() < 1e-12);
837        assert!((at.get(1, 0) - 2.0).abs() < 1e-12);
838        assert!((at.get(1, 1) - 5.0).abs() < 1e-12);
839        assert!((at.get(2, 0) - 3.0).abs() < 1e-12);
840        assert!((at.get(2, 1) - 6.0).abs() < 1e-12);
841    }
842
843    // ── spmv_segmented ─────────────────────────────────────────────────────
844
845    #[test]
846    fn test_spmv_segmented_identity() {
847        let m = vec![
848            vec![1.0, 0.0, 0.0],
849            vec![0.0, 1.0, 0.0],
850            vec![0.0, 0.0, 1.0],
851        ];
852        let csr = CsrMatrix::from_dense(&m);
853        let x = vec![3.0, 7.0, -2.0];
854        let y = spmv_segmented(&csr, &x);
855        for (a, b) in y.iter().zip(x.iter()) {
856            assert!((a - b).abs() < 1e-12);
857        }
858    }
859
860    #[test]
861    fn test_spmv_segmented_matches_csr() {
862        let m = vec![
863            vec![2.0, 1.0, 0.0],
864            vec![1.0, 3.0, 1.0],
865            vec![0.0, 1.0, 2.0],
866        ];
867        let csr = CsrMatrix::from_dense(&m);
868        let x = vec![1.0, -1.0, 2.0];
869        let y_std = csr.spmv(&x);
870        let y_seg = spmv_segmented(&csr, &x);
871        for (a, b) in y_std.iter().zip(y_seg.iter()) {
872            assert!((a - b).abs() < 1e-12);
873        }
874    }
875
876    // ── assemble_1d_laplacian ──────────────────────────────────────────────
877
878    #[test]
879    fn test_assemble_1d_laplacian_3x3() {
880        let l = assemble_1d_laplacian(3);
881        // Expected: [[2,-1,0],[-1,2,-1],[0,-1,2]]
882        assert!((l.get(0, 0) - 2.0).abs() < 1e-12);
883        assert!((l.get(0, 1) - (-1.0)).abs() < 1e-12);
884        assert!((l.get(0, 2)).abs() < 1e-12);
885        assert!((l.get(1, 0) - (-1.0)).abs() < 1e-12);
886        assert!((l.get(1, 1) - 2.0).abs() < 1e-12);
887        assert!((l.get(1, 2) - (-1.0)).abs() < 1e-12);
888        assert!((l.get(2, 2) - 2.0).abs() < 1e-12);
889    }
890
891    #[test]
892    fn test_assemble_1d_laplacian_spd() {
893        // 1D Laplacian is symmetric positive definite
894        let n = 5;
895        let l = assemble_1d_laplacian(n);
896        // Check symmetry
897        for i in 0..n {
898            for j in 0..n {
899                assert!((l.get(i, j) - l.get(j, i)).abs() < 1e-12);
900            }
901        }
902        // CG should converge for SPD systems
903        let b = vec![1.0; n];
904        let x0 = vec![0.0; n];
905        let (x, iters) = cg_solve(&l, &b, &x0, 200, 1e-10);
906        assert!(iters < 200);
907        // Verify Ax ≈ b
908        let ax = l.spmv(&x);
909        for i in 0..n {
910            assert!((ax[i] - b[i]).abs() < 1e-8);
911        }
912    }
913
914    // ── csr_to_ell ─────────────────────────────────────────────────────────
915
916    #[test]
917    fn test_csr_to_ell_spmv() {
918        let m = vec![
919            vec![5.0, 0.0, 1.0, 0.0],
920            vec![0.0, 3.0, 0.0, 2.0],
921            vec![1.0, 0.0, 4.0, 0.0],
922            vec![0.0, 0.0, 0.0, 6.0],
923        ];
924        let csr = CsrMatrix::from_dense(&m);
925        let ell = csr_to_ell(&csr);
926        let x = vec![1.0, 2.0, 3.0, 4.0];
927        let y_csr = csr.spmv(&x);
928        let y_ell = ell.spmv(&x);
929        for (a, b) in y_csr.iter().zip(y_ell.iter()) {
930            assert!((a - b).abs() < 1e-12, "mismatch: {a} vs {b}");
931        }
932    }
933
934    #[test]
935    fn test_csr_to_ell_max_nnz() {
936        let m = vec![
937            vec![1.0, 0.0, 0.0],
938            vec![1.0, 2.0, 3.0], // 3 nnz → max
939            vec![0.0, 0.0, 1.0],
940        ];
941        let csr = CsrMatrix::from_dense(&m);
942        let ell = csr_to_ell(&csr);
943        assert_eq!(ell.max_nnz_per_row, 3);
944    }
945
946    // ── BlockCsrMatrix additional tests ────────────────────────────────────
947
948    #[test]
949    fn test_block_csr_spmv_2x2() {
950        // Single 2x2 block [[1,2],[3,4]] at block position (0,0)
951        let bcsr = BlockCsrMatrix {
952            block_size: 2,
953            n_block_rows: 1,
954            n_block_cols: 1,
955            row_ptr: vec![0, 1],
956            col_idx: vec![0],
957            blocks: vec![vec![1.0, 2.0, 3.0, 4.0]],
958        };
959        let x = vec![1.0, 1.0];
960        let y = bcsr.spmv_block(&x);
961        assert!((y[0] - 3.0).abs() < 1e-12); // 1+2
962        assert!((y[1] - 7.0).abs() < 1e-12); // 3+4
963    }
964
965    #[test]
966    fn test_block_csr_to_csr_roundtrip() {
967        let bcsr = BlockCsrMatrix {
968            block_size: 2,
969            n_block_rows: 2,
970            n_block_cols: 2,
971            row_ptr: vec![0, 1, 2],
972            col_idx: vec![0, 1],
973            blocks: vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]],
974        };
975        let csr = bcsr.to_csr();
976        let x = vec![1.0, 1.0, 1.0, 1.0];
977        let y_block = bcsr.spmv_block(&x);
978        let y_csr = csr.spmv(&x);
979        for (a, b) in y_block.iter().zip(y_csr.iter()) {
980            assert!((a - b).abs() < 1e-12);
981        }
982    }
983
984    // ── sparse_lower_triangular_solve ──────────────────────────────────────
985
986    #[test]
987    fn test_lower_tri_solve_identity() {
988        let m = vec![
989            vec![1.0, 0.0, 0.0],
990            vec![0.0, 1.0, 0.0],
991            vec![0.0, 0.0, 1.0],
992        ];
993        let l = CsrMatrix::from_dense(&m);
994        let b = vec![3.0, 7.0, -2.0];
995        let x = sparse_lower_triangular_solve(&l, &b);
996        for (a, bv) in x.iter().zip(b.iter()) {
997            assert!((a - bv).abs() < 1e-12);
998        }
999    }
1000
1001    #[test]
1002    fn test_lower_tri_solve_3x3() {
1003        // L = [[2,0,0],[1,3,0],[4,2,5]]
1004        let m = vec![
1005            vec![2.0, 0.0, 0.0],
1006            vec![1.0, 3.0, 0.0],
1007            vec![4.0, 2.0, 5.0],
1008        ];
1009        let l = CsrMatrix::from_dense(&m);
1010        let b = vec![4.0, 7.0, 26.0];
1011        let x = sparse_lower_triangular_solve(&l, &b);
1012        // x[0] = 4/2 = 2
1013        // x[1] = (7 - 1*2)/3 = 5/3
1014        // x[2] = (26 - 4*2 - 2*(5/3))/5
1015        assert!((x[0] - 2.0).abs() < 1e-10);
1016        assert!((x[1] - 5.0 / 3.0).abs() < 1e-10);
1017        let expected_x2 = (26.0 - 8.0 - 10.0 / 3.0) / 5.0;
1018        assert!((x[2] - expected_x2).abs() < 1e-10);
1019    }
1020
1021    #[test]
1022    fn test_lower_tri_solve_verify_lx_eq_b() {
1023        let m = vec![
1024            vec![3.0, 0.0, 0.0, 0.0],
1025            vec![1.0, 2.0, 0.0, 0.0],
1026            vec![0.0, 4.0, 5.0, 0.0],
1027            vec![2.0, 0.0, 1.0, 6.0],
1028        ];
1029        let l = CsrMatrix::from_dense(&m);
1030        let b = vec![9.0, 8.0, 22.0, 29.0];
1031        let x = sparse_lower_triangular_solve(&l, &b);
1032        // Verify L*x = b
1033        let lx = l.spmv(&x);
1034        for i in 0..4 {
1035            assert!(
1036                (lx[i] - b[i]).abs() < 1e-10,
1037                "row {i}: {} vs {}",
1038                lx[i],
1039                b[i]
1040            );
1041        }
1042    }
1043
1044    // ── sparse_upper_triangular_solve ──────────────────────────────────────
1045
1046    #[test]
1047    fn test_upper_tri_solve_identity() {
1048        let m = vec![
1049            vec![1.0, 0.0, 0.0],
1050            vec![0.0, 1.0, 0.0],
1051            vec![0.0, 0.0, 1.0],
1052        ];
1053        let u = CsrMatrix::from_dense(&m);
1054        let b = vec![3.0, 7.0, -2.0];
1055        let x = sparse_upper_triangular_solve(&u, &b);
1056        for (a, bv) in x.iter().zip(b.iter()) {
1057            assert!((a - bv).abs() < 1e-12);
1058        }
1059    }
1060
1061    #[test]
1062    fn test_upper_tri_solve_verify_ux_eq_b() {
1063        // U = [[2,1,3],[0,4,2],[0,0,5]]
1064        let m = vec![
1065            vec![2.0, 1.0, 3.0],
1066            vec![0.0, 4.0, 2.0],
1067            vec![0.0, 0.0, 5.0],
1068        ];
1069        let u = CsrMatrix::from_dense(&m);
1070        let b = vec![13.0, 14.0, 10.0];
1071        let x = sparse_upper_triangular_solve(&u, &b);
1072        let ux = u.spmv(&x);
1073        for i in 0..3 {
1074            assert!(
1075                (ux[i] - b[i]).abs() < 1e-10,
1076                "row {i}: {} vs {}",
1077                ux[i],
1078                b[i]
1079            );
1080        }
1081    }
1082
1083    // ── CsrMatrix::add & scale ─────────────────────────────────────────────
1084
1085    #[test]
1086    fn test_csr_add() {
1087        let m1 = vec![vec![1.0, 0.0], vec![0.0, 2.0]];
1088        let m2 = vec![vec![0.0, 3.0], vec![4.0, 0.0]];
1089        let a = CsrMatrix::from_dense(&m1);
1090        let b = CsrMatrix::from_dense(&m2);
1091        let c = a.add(&b);
1092        assert!((c.get(0, 0) - 1.0).abs() < 1e-12);
1093        assert!((c.get(0, 1) - 3.0).abs() < 1e-12);
1094        assert!((c.get(1, 0) - 4.0).abs() < 1e-12);
1095        assert!((c.get(1, 1) - 2.0).abs() < 1e-12);
1096    }
1097
1098    #[test]
1099    fn test_csr_scale() {
1100        let m = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1101        let a = CsrMatrix::from_dense(&m);
1102        let b = a.scale(2.0);
1103        assert!((b.get(0, 0) - 2.0).abs() < 1e-12);
1104        assert!((b.get(0, 1) - 4.0).abs() < 1e-12);
1105        assert!((b.get(1, 0) - 6.0).abs() < 1e-12);
1106        assert!((b.get(1, 1) - 8.0).abs() < 1e-12);
1107    }
1108
1109    // ── BiCGSTAB ───────────────────────────────────────────────────────────
1110
1111    #[test]
1112    fn test_bicgstab_diagonal() {
1113        let m = vec![
1114            vec![2.0, 0.0, 0.0],
1115            vec![0.0, 3.0, 0.0],
1116            vec![0.0, 0.0, 4.0],
1117        ];
1118        let a = CsrMatrix::from_dense(&m);
1119        let b = vec![4.0, 9.0, 16.0];
1120        let x0 = vec![0.0; 3];
1121        let (x, _iters) = bicgstab_solve(&a, &b, &x0, 100, 1e-10);
1122        assert!((x[0] - 2.0).abs() < 1e-8);
1123        assert!((x[1] - 3.0).abs() < 1e-8);
1124        assert!((x[2] - 4.0).abs() < 1e-8);
1125    }
1126
1127    #[test]
1128    fn test_bicgstab_nonsymmetric() {
1129        // Non-symmetric but diagonally dominant
1130        let m = vec![
1131            vec![4.0, 1.0, 0.0],
1132            vec![0.0, 3.0, 1.0],
1133            vec![0.0, 0.0, 5.0],
1134        ];
1135        let a = CsrMatrix::from_dense(&m);
1136        let b = vec![5.0, 4.0, 5.0];
1137        let x0 = vec![0.0; 3];
1138        let (x, _iters) = bicgstab_solve(&a, &b, &x0, 200, 1e-10);
1139        // Verify Ax ≈ b
1140        let ax = a.spmv(&x);
1141        for i in 0..3 {
1142            assert!(
1143                (ax[i] - b[i]).abs() < 1e-6,
1144                "row {i}: {} vs {}",
1145                ax[i],
1146                b[i]
1147            );
1148        }
1149    }
1150
1151    // ── Jacobi preconditioned CG ───────────────────────────────────────────
1152
1153    #[test]
1154    fn test_jacobi_pcg_laplacian() {
1155        let n = 8;
1156        let a = assemble_1d_laplacian(n);
1157        let b: Vec<f64> = (0..n).map(|i| (i as f64 + 1.0).sin()).collect();
1158        let (x, iters) = jacobi_preconditioned_cg(&a, &b, 500, 1e-10);
1159        assert!(iters < 500, "PCG should converge, used {iters} iterations");
1160        let ax = a.spmv(&x);
1161        for i in 0..n {
1162            assert!(
1163                (ax[i] - b[i]).abs() < 1e-6,
1164                "row {i}: {} vs {}",
1165                ax[i],
1166                b[i]
1167            );
1168        }
1169    }
1170
1171    // ── GPU simulation utilities ───────────────────────────────────────────
1172
1173    #[test]
1174    fn test_simulate_spmv_throughput() {
1175        let gflops = simulate_spmv_throughput(100, 1000);
1176        assert!(gflops > 0.0);
1177        // With 100 GB/s bandwidth and 24 bytes/nnz, bigger nnz should give higher GFLOPS
1178        let gflops2 = simulate_spmv_throughput(1000, 10000);
1179        assert!((gflops2 - gflops).abs() < 1.0); // roofline is flat w.r.t. nnz ratio
1180    }
1181
1182    #[test]
1183    fn test_optimal_ell_row_width_empty() {
1184        assert_eq!(optimal_ell_row_width(&[]), 0);
1185    }
1186
1187    #[test]
1188    fn test_optimal_ell_row_width_uniform() {
1189        // All rows have 5 nnz
1190        let dist = vec![5; 10];
1191        assert_eq!(optimal_ell_row_width(&dist), 5);
1192    }
1193
1194    #[test]
1195    fn test_optimal_ell_row_width_varied() {
1196        let dist = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
1197        let w = optimal_ell_row_width(&dist);
1198        // 75th percentile index = (10 * 3) / 4 = 7 → sorted[7] = 8
1199        assert_eq!(w, 8);
1200    }
1201
1202    // ── SparseTriplet default ──────────────────────────────────────────────
1203
1204    #[test]
1205    fn test_sparse_triplet_default() {
1206        let t = SparseTriplet::default();
1207        assert!(t.rows.is_empty());
1208        assert!(t.cols.is_empty());
1209        assert!(t.vals.is_empty());
1210    }
1211
1212    // ── vector operations ──────────────────────────────────────────────────
1213
1214    #[test]
1215    fn test_dot_product() {
1216        let x = vec![1.0, 2.0, 3.0];
1217        let y = vec![4.0, 5.0, 6.0];
1218        assert!((dot(&x, &y) - 32.0).abs() < 1e-12);
1219    }
1220
1221    #[test]
1222    fn test_axpy() {
1223        let x = vec![1.0, 2.0, 3.0];
1224        let y = vec![10.0, 20.0, 30.0];
1225        let z = axpy(2.0, &x, &y);
1226        assert_eq!(z, vec![12.0, 24.0, 36.0]);
1227    }
1228
1229    #[test]
1230    fn test_norm2() {
1231        let x = vec![3.0, 4.0];
1232        assert!((norm2(&x) - 5.0).abs() < 1e-12);
1233    }
1234
1235    #[test]
1236    fn test_scale_vec() {
1237        let x = vec![1.0, 2.0, 3.0];
1238        let s = scale_vec(&x, 3.0);
1239        assert_eq!(s, vec![3.0, 6.0, 9.0]);
1240    }
1241
1242    // ── HybridMatrix ──────────────────────────────────────────────────────
1243
1244    #[test]
1245    fn test_hybrid_spmv() {
1246        let m = vec![
1247            vec![1.0, 2.0, 0.0],
1248            vec![0.0, 3.0, 0.0],
1249            vec![0.0, 0.0, 4.0],
1250        ];
1251        let csr = CsrMatrix::from_dense(&m);
1252        let ell = EllMatrix::from_csr(&csr);
1253        let hybrid = HybridMatrix {
1254            ell,
1255            coo_row: vec![],
1256            coo_col: vec![],
1257            coo_val: vec![],
1258        };
1259        let x = vec![1.0, 1.0, 1.0];
1260        let y = hybrid.spmv(&x);
1261        assert!((y[0] - 3.0).abs() < 1e-12);
1262        assert!((y[1] - 3.0).abs() < 1e-12);
1263        assert!((y[2] - 4.0).abs() < 1e-12);
1264    }
1265
1266    #[test]
1267    fn test_hybrid_spmv_with_coo() {
1268        // ELL part: identity, COO part: adds off-diagonal
1269        let m = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1270        let csr = CsrMatrix::from_dense(&m);
1271        let ell = EllMatrix::from_csr(&csr);
1272        let hybrid = HybridMatrix {
1273            ell,
1274            coo_row: vec![0],
1275            coo_col: vec![1],
1276            coo_val: vec![5.0],
1277        };
1278        let x = vec![1.0, 2.0];
1279        let y = hybrid.spmv(&x);
1280        assert!((y[0] - 11.0).abs() < 1e-12); // 1*1 + 5*2
1281        assert!((y[1] - 2.0).abs() < 1e-12);
1282    }
1283
1284    // ── CsrMatrix from empty ──────────────────────────────────────────────
1285
1286    #[test]
1287    fn test_csr_empty() {
1288        let csr = CsrMatrix::new(3, 3);
1289        assert_eq!(csr.nnz(), 0);
1290        let y = csr.spmv(&[1.0, 2.0, 3.0]);
1291        assert_eq!(y, vec![0.0, 0.0, 0.0]);
1292    }
1293
1294    // ── extract_diagonal ──────────────────────────────────────────────────
1295
1296    #[test]
1297    fn test_extract_diagonal() {
1298        let m = vec![
1299            vec![5.0, 1.0, 0.0],
1300            vec![0.0, 3.0, 2.0],
1301            vec![0.0, 0.0, 7.0],
1302        ];
1303        let csr = CsrMatrix::from_dense(&m);
1304        let diag = extract_diagonal(&csr);
1305        assert!((diag[0] - 5.0).abs() < 1e-12);
1306        assert!((diag[1] - 3.0).abs() < 1e-12);
1307        assert!((diag[2] - 7.0).abs() < 1e-12);
1308    }
1309
1310    #[test]
1311    fn test_extract_diagonal_missing() {
1312        // Matrix with zero diagonal
1313        let m = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
1314        let csr = CsrMatrix::from_dense(&m);
1315        let diag = extract_diagonal(&csr);
1316        assert!((diag[0]).abs() < 1e-12);
1317        assert!((diag[1]).abs() < 1e-12);
1318    }
1319
1320    // ── compute_nnz_per_row ────────────────────────────────────────────────
1321
1322    #[test]
1323    fn test_compute_nnz_per_row() {
1324        let m = vec![
1325            vec![1.0, 2.0, 0.0],
1326            vec![0.0, 3.0, 0.0],
1327            vec![4.0, 5.0, 6.0],
1328        ];
1329        let csr = CsrMatrix::from_dense(&m);
1330        let nnz = compute_nnz_per_row(&csr);
1331        assert_eq!(nnz, vec![2, 1, 3]);
1332    }
1333
1334    // ── frobenius_norm ─────────────────────────────────────────────────────
1335
1336    #[test]
1337    fn test_frobenius_norm() {
1338        // Identity 3x3 → frobenius = sqrt(3)
1339        let m = vec![
1340            vec![1.0, 0.0, 0.0],
1341            vec![0.0, 1.0, 0.0],
1342            vec![0.0, 0.0, 1.0],
1343        ];
1344        let csr = CsrMatrix::from_dense(&m);
1345        let f = frobenius_norm(&csr);
1346        assert!((f - 3.0_f64.sqrt()).abs() < 1e-12);
1347    }
1348
1349    #[test]
1350    fn test_frobenius_norm_known() {
1351        let m = vec![vec![3.0, 4.0], vec![0.0, 0.0]];
1352        let csr = CsrMatrix::from_dense(&m);
1353        let f = frobenius_norm(&csr);
1354        // sqrt(9 + 16) = 5
1355        assert!((f - 5.0).abs() < 1e-12);
1356    }
1357}