1use crate::csr::CsrMatrix;
13use crate::error::{SparseError, SparseResult};
14use scirs2_core::numeric::{Float, NumAssign, SparseElement};
15use std::collections::VecDeque;
16use std::fmt::Debug;
17use std::iter::Sum;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SparseNorm {
26 One,
28 Inf,
30 Frobenius,
32}
33
34pub fn sparse_matrix_norm<F>(a: &CsrMatrix<F>, norm_type: SparseNorm) -> SparseResult<F>
40where
41 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
42{
43 let (m, n_cols) = a.shape();
44 match norm_type {
45 SparseNorm::Inf => {
46 let mut max_row_sum = F::sparse_zero();
47 for i in 0..m {
48 let range = a.row_range(i);
49 let vals = &a.data[range];
50 let row_sum: F = vals.iter().map(|v| v.abs()).sum();
51 if row_sum > max_row_sum {
52 max_row_sum = row_sum;
53 }
54 }
55 Ok(max_row_sum)
56 }
57 SparseNorm::One => {
58 let mut col_sums = vec![F::sparse_zero(); n_cols];
59 for i in 0..m {
60 let range = a.row_range(i);
61 let indices = &a.indices[range.clone()];
62 let vals = &a.data[range];
63 for (idx, &col) in indices.iter().enumerate() {
64 col_sums[col] += vals[idx].abs();
65 }
66 }
67 let max_col =
68 col_sums
69 .iter()
70 .copied()
71 .fold(F::sparse_zero(), |acc, x| if x > acc { x } else { acc });
72 Ok(max_col)
73 }
74 SparseNorm::Frobenius => {
75 let mut sum_sq = F::sparse_zero();
76 for val in &a.data {
77 sum_sq += *val * *val;
78 }
79 Ok(sum_sq.sqrt())
80 }
81 }
82}
83
84pub fn spgemm<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
93where
94 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
95{
96 let (m, ka) = a.shape();
97 let (kb, n) = b.shape();
98 if ka != kb {
99 return Err(SparseError::ShapeMismatch {
100 expected: (m, ka),
101 found: (kb, n),
102 });
103 }
104
105 let mut values = vec![F::sparse_zero(); n];
108 let mut active = vec![false; n];
109 let mut rows_out = Vec::new();
110 let mut cols_out = Vec::new();
111 let mut data_out = Vec::new();
112
113 for i in 0..m {
114 let a_range = a.row_range(i);
115 let a_cols = &a.indices[a_range.clone()];
116 let a_vals = &a.data[a_range];
117
118 let mut col_list: Vec<usize> = Vec::new();
120 for (a_idx, &k_col) in a_cols.iter().enumerate() {
121 let a_ik = a_vals[a_idx];
122 let b_range = b.row_range(k_col);
123 let b_cols = &b.indices[b_range.clone()];
124 let b_vals = &b.data[b_range];
125
126 for (b_idx, &j) in b_cols.iter().enumerate() {
127 values[j] += a_ik * b_vals[b_idx];
128 if !active[j] {
129 active[j] = true;
130 col_list.push(j);
131 }
132 }
133 }
134
135 col_list.sort_unstable();
137 for &j in &col_list {
138 let val = values[j];
139 if val.abs() > F::epsilon() * F::from(0.01).unwrap_or(F::sparse_zero()) {
140 rows_out.push(i);
141 cols_out.push(j);
142 data_out.push(val);
143 }
144 values[j] = F::sparse_zero();
146 active[j] = false;
147 }
148 }
149
150 CsrMatrix::new(data_out, rows_out, cols_out, (m, n))
151}
152
153pub fn sparse_add<F>(
161 a: &CsrMatrix<F>,
162 b: &CsrMatrix<F>,
163 alpha: F,
164 beta: F,
165) -> SparseResult<CsrMatrix<F>>
166where
167 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
168{
169 let (ma, na) = a.shape();
170 let (mb, nb) = b.shape();
171 if ma != mb || na != nb {
172 return Err(SparseError::ShapeMismatch {
173 expected: (ma, na),
174 found: (mb, nb),
175 });
176 }
177
178 let mut rows_out = Vec::new();
179 let mut cols_out = Vec::new();
180 let mut data_out = Vec::new();
181
182 let mut b_vals = vec![F::sparse_zero(); na]; let mut b_flags = vec![false; na];
184
185 for i in 0..ma {
186 let b_range = b.row_range(i);
188 let b_cols = &b.indices[b_range.clone()];
189 let b_data = &b.data[b_range];
190 for (idx, &col) in b_cols.iter().enumerate() {
191 b_vals[col] = b_data[idx];
192 b_flags[col] = true;
193 }
194
195 let a_range = a.row_range(i);
197 let a_cols = &a.indices[a_range.clone()];
198 let a_data = &a.data[a_range];
199 let mut used_cols: Vec<usize> = Vec::new();
200
201 for (idx, &col) in a_cols.iter().enumerate() {
202 let val = alpha * a_data[idx]
203 + if b_flags[col] {
204 beta * b_vals[col]
205 } else {
206 F::sparse_zero()
207 };
208 if val.abs() > F::epsilon() {
209 rows_out.push(i);
210 cols_out.push(col);
211 data_out.push(val);
212 }
213 if b_flags[col] {
214 b_flags[col] = false;
215 b_vals[col] = F::sparse_zero();
216 }
217 used_cols.push(col);
218 }
219
220 for (idx, &col) in b_cols.iter().enumerate() {
222 if b_flags[col] {
223 let val = beta * b_data[idx];
224 if val.abs() > F::epsilon() {
225 rows_out.push(i);
226 cols_out.push(col);
227 data_out.push(val);
228 }
229 b_flags[col] = false;
230 b_vals[col] = F::sparse_zero();
231 }
232 }
233 }
234
235 CsrMatrix::new(data_out, rows_out, cols_out, (ma, na))
236}
237
238pub fn sparse_sub<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
240where
241 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
242{
243 sparse_add(a, b, F::sparse_one(), -F::sparse_one())
244}
245
246pub fn sparse_scale<F>(a: &CsrMatrix<F>, alpha: F) -> SparseResult<CsrMatrix<F>>
248where
249 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
250{
251 let (m, n) = a.shape();
252 let (rows_in, cols_in, data_in) = a.get_triplets();
253 let data_out: Vec<F> = data_in.iter().map(|&v| alpha * v).collect();
254 CsrMatrix::new(data_out, rows_in, cols_in, (m, n))
255}
256
257pub fn sparse_kronecker<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
265where
266 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
267{
268 let (ma, na) = a.shape();
269 let (mb, nb) = b.shape();
270 let out_rows = ma * mb;
271 let out_cols = na * nb;
272
273 let mut rows_out = Vec::new();
274 let mut cols_out = Vec::new();
275 let mut data_out = Vec::new();
276
277 for ia in 0..ma {
278 let a_range = a.row_range(ia);
279 let a_cols = &a.indices[a_range.clone()];
280 let a_vals = &a.data[a_range];
281
282 for ib in 0..mb {
283 let b_range = b.row_range(ib);
284 let b_cols = &b.indices[b_range.clone()];
285 let b_vals = &b.data[b_range];
286
287 let out_row = ia * mb + ib;
288
289 for (a_idx, &ja) in a_cols.iter().enumerate() {
290 for (b_idx, &jb) in b_cols.iter().enumerate() {
291 let out_col = ja * nb + jb;
292 let val = a_vals[a_idx] * b_vals[b_idx];
293 if val.abs() > F::epsilon() {
294 rows_out.push(out_row);
295 cols_out.push(out_col);
296 data_out.push(val);
297 }
298 }
299 }
300 }
301 }
302
303 CsrMatrix::new(data_out, rows_out, cols_out, (out_rows, out_cols))
304}
305
306#[derive(Debug, Clone)]
312pub struct RcmResult {
313 pub permutation: Vec<usize>,
315 pub inverse_permutation: Vec<usize>,
317 pub original_bandwidth: usize,
319 pub new_bandwidth: usize,
321}
322
323fn compute_bandwidth<F>(a: &CsrMatrix<F>) -> usize
325where
326 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
327{
328 let m = a.rows();
329 let mut bw = 0usize;
330 for i in 0..m {
331 let range = a.row_range(i);
332 for &col in &a.indices[range] {
333 let diff = i.abs_diff(col);
334 if diff > bw {
335 bw = diff;
336 }
337 }
338 }
339 bw
340}
341
342fn node_degree<F>(a: &CsrMatrix<F>, i: usize) -> usize
344where
345 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
346{
347 let range = a.row_range(i);
348 a.indices[range].iter().filter(|&&col| col != i).count()
349}
350
351fn find_pseudo_peripheral<F>(a: &CsrMatrix<F>) -> usize
353where
354 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
355{
356 let n = a.rows();
357 if n == 0 {
358 return 0;
359 }
360
361 let mut start = 0;
363 let mut min_deg = usize::MAX;
364 for i in 0..n {
365 let deg = node_degree(a, i);
366 if deg < min_deg {
367 min_deg = deg;
368 start = i;
369 }
370 }
371
372 for _ in 0..5 {
374 let levels = bfs_levels(a, start);
375 let max_level = levels.iter().copied().max().unwrap_or(0);
376 if max_level == 0 {
377 break;
378 }
379 let mut best = start;
381 let mut best_deg = usize::MAX;
382 for i in 0..n {
383 if levels[i] == max_level {
384 let deg = node_degree(a, i);
385 if deg < best_deg {
386 best_deg = deg;
387 best = i;
388 }
389 }
390 }
391 if best == start {
392 break;
393 }
394 start = best;
395 }
396
397 start
398}
399
400fn bfs_levels<F>(a: &CsrMatrix<F>, start: usize) -> Vec<usize>
402where
403 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
404{
405 let n = a.rows();
406 let mut levels = vec![usize::MAX; n];
407 let mut queue = VecDeque::new();
408 levels[start] = 0;
409 queue.push_back(start);
410
411 while let Some(node) = queue.pop_front() {
412 let range = a.row_range(node);
413 for &neighbor in &a.indices[range] {
414 if levels[neighbor] == usize::MAX {
415 levels[neighbor] = levels[node] + 1;
416 queue.push_back(neighbor);
417 }
418 }
419 }
420
421 levels
422}
423
424pub fn reverse_cuthill_mckee<F>(a: &CsrMatrix<F>) -> SparseResult<RcmResult>
438where
439 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
440{
441 let (m, n) = a.shape();
442 if m != n {
443 return Err(SparseError::ValueError(
444 "RCM requires a square matrix".to_string(),
445 ));
446 }
447
448 let original_bandwidth = compute_bandwidth(a);
449
450 if m == 0 {
451 return Ok(RcmResult {
452 permutation: Vec::new(),
453 inverse_permutation: Vec::new(),
454 original_bandwidth: 0,
455 new_bandwidth: 0,
456 });
457 }
458
459 let mut visited = vec![false; m];
461 let mut cm_order = Vec::with_capacity(m);
462
463 while cm_order.len() < m {
465 let start = if cm_order.is_empty() {
467 find_pseudo_peripheral(a)
468 } else {
469 let mut s = 0;
471 for i in 0..m {
472 if !visited[i] {
473 s = i;
474 break;
475 }
476 }
477 s
478 };
479
480 if visited[start] {
481 break;
482 }
483
484 visited[start] = true;
485 cm_order.push(start);
486 let mut queue_start = cm_order.len() - 1;
487
488 while queue_start < cm_order.len() {
489 let node = cm_order[queue_start];
490 queue_start += 1;
491
492 let range = a.row_range(node);
494 let mut neighbors: Vec<usize> = a.indices[range]
495 .iter()
496 .copied()
497 .filter(|&nb| !visited[nb])
498 .collect();
499 neighbors.sort_by_key(|&nb| node_degree(a, nb));
500
501 for nb in neighbors {
502 if !visited[nb] {
503 visited[nb] = true;
504 cm_order.push(nb);
505 }
506 }
507 }
508 }
509
510 cm_order.reverse();
512
513 let mut inv_perm = vec![0usize; m];
515 for (new_idx, &old_idx) in cm_order.iter().enumerate() {
516 inv_perm[old_idx] = new_idx;
517 }
518
519 let mut new_bw = 0usize;
521 for i in 0..m {
522 let range = a.row_range(i);
523 let new_i = inv_perm[i];
524 for &col in &a.indices[range] {
525 let new_j = inv_perm[col];
526 let diff = new_i.abs_diff(new_j);
527 if diff > new_bw {
528 new_bw = diff;
529 }
530 }
531 }
532
533 Ok(RcmResult {
534 permutation: cm_order,
535 inverse_permutation: inv_perm,
536 original_bandwidth,
537 new_bandwidth: new_bw,
538 })
539}
540
541pub fn permute_matrix<F>(
546 a: &CsrMatrix<F>,
547 perm: &[usize],
548 inv_perm: &[usize],
549) -> SparseResult<CsrMatrix<F>>
550where
551 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
552{
553 let (m, n) = a.shape();
554 if perm.len() != m || inv_perm.len() != n {
555 return Err(SparseError::ValueError(
556 "Permutation size mismatch".to_string(),
557 ));
558 }
559
560 let mut rows_out = Vec::new();
561 let mut cols_out = Vec::new();
562 let mut data_out = Vec::new();
563
564 for new_i in 0..m {
565 let old_i = perm[new_i];
566 let range = a.row_range(old_i);
567 let old_cols = &a.indices[range.clone()];
568 let vals = &a.data[range];
569
570 for (idx, &old_j) in old_cols.iter().enumerate() {
571 let new_j = inv_perm[old_j];
572 rows_out.push(new_i);
573 cols_out.push(new_j);
574 data_out.push(vals[idx]);
575 }
576 }
577
578 CsrMatrix::new(data_out, rows_out, cols_out, (m, n))
579}
580
581pub fn condest_1norm<F>(a: &CsrMatrix<F>) -> SparseResult<Option<F>>
593where
594 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
595{
596 let (m, n) = a.shape();
597 if m != n || m == 0 {
598 return Err(SparseError::ValueError(
599 "condest requires a non-empty square matrix".to_string(),
600 ));
601 }
602
603 let a_norm = sparse_matrix_norm(a, SparseNorm::One)?;
604 if a_norm < F::epsilon() {
605 return Ok(None); }
607
608 let inv_n = F::sparse_one()
611 / F::from(n as f64)
612 .ok_or_else(|| SparseError::ValueError("Failed to convert n".to_string()))?;
613
614 let mut x = vec![inv_n; n];
615 let max_iter = 5;
616 let mut gamma = F::sparse_zero();
617
618 for _ in 0..max_iter {
619 let y = approximate_solve(a, &x)?;
621
622 let new_gamma: F = y.iter().map(|v| v.abs()).sum();
624 if new_gamma <= gamma {
625 break;
626 }
627 gamma = new_gamma;
628
629 let sign_y: Vec<F> = y
631 .iter()
632 .map(|&v| {
633 if v >= F::sparse_zero() {
634 F::sparse_one()
635 } else {
636 -F::sparse_one()
637 }
638 })
639 .collect();
640
641 let at = a.transpose();
642 let z = approximate_solve(&at, &sign_y)?;
643
644 let mut max_abs = F::sparse_zero();
646 let mut max_idx = 0;
647 for (j, &zj) in z.iter().enumerate() {
648 if zj.abs() > max_abs {
649 max_abs = zj.abs();
650 max_idx = j;
651 }
652 }
653
654 let ztx: F = z.iter().zip(x.iter()).map(|(&zi, &xi)| zi * xi).sum();
656 if max_abs <= ztx {
657 break;
658 }
659
660 for xi in x.iter_mut() {
662 *xi = F::sparse_zero();
663 }
664 x[max_idx] = F::sparse_one();
665 }
666
667 if gamma < F::epsilon() {
668 return Ok(None);
669 }
670
671 Ok(Some(a_norm * gamma))
672}
673
674fn approximate_solve<F>(a: &CsrMatrix<F>, b: &[F]) -> SparseResult<Vec<F>>
678where
679 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
680{
681 let n = b.len();
682 let (m, _) = a.shape();
683 if m != n {
684 return Err(SparseError::DimensionMismatch {
685 expected: m,
686 found: n,
687 });
688 }
689
690 let mut diag = vec![F::sparse_one(); n];
692 for i in 0..n {
693 let d = a.get(i, i);
694 if d.abs() > F::epsilon() {
695 diag[i] = d;
696 }
697 }
698
699 let mut x = vec![F::sparse_zero(); n];
701 for _ in 0..10 {
702 let mut x_new = vec![F::sparse_zero(); n];
703 for i in 0..n {
704 let range = a.row_range(i);
705 let cols = &a.indices[range.clone()];
706 let vals = &a.data[range];
707 let mut sum = b[i];
708 for (idx, &col) in cols.iter().enumerate() {
709 if col != i {
710 sum -= vals[idx] * x[col];
711 }
712 }
713 x_new[i] = sum / diag[i];
714 }
715 x = x_new;
716 }
717
718 Ok(x)
719}
720
721pub fn sparse_transpose<F>(a: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
727where
728 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
729{
730 Ok(a.transpose())
731}
732
733pub fn sparse_extract_diagonal<F>(a: &CsrMatrix<F>) -> Vec<F>
739where
740 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
741{
742 let n = a.rows().min(a.cols());
743 let mut diag = vec![F::sparse_zero(); n];
744 for i in 0..n {
745 diag[i] = a.get(i, i);
746 }
747 diag
748}
749
750pub fn sparse_matrix_trace<F>(a: &CsrMatrix<F>) -> F
752where
753 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
754{
755 let diag = sparse_extract_diagonal(a);
756 diag.iter().copied().sum()
757}
758
759pub fn sparse_identity<F>(n: usize) -> SparseResult<CsrMatrix<F>>
765where
766 F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
767{
768 let rows: Vec<usize> = (0..n).collect();
769 let cols: Vec<usize> = (0..n).collect();
770 let data: Vec<F> = vec![F::sparse_one(); n];
771 CsrMatrix::new(data, rows, cols, (n, n))
772}
773
774#[cfg(test)]
779mod tests {
780 use super::*;
781
782 fn build_test_matrix() -> CsrMatrix<f64> {
783 let rows = vec![0, 0, 1, 1, 1, 2, 2];
788 let cols = vec![0, 1, 0, 1, 2, 1, 2];
789 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
790 CsrMatrix::new(data, rows, cols, (3, 3)).expect("valid matrix")
791 }
792
793 fn build_identity(n: usize) -> CsrMatrix<f64> {
794 let rows: Vec<usize> = (0..n).collect();
795 let cols: Vec<usize> = (0..n).collect();
796 let data = vec![1.0; n];
797 CsrMatrix::new(data, rows, cols, (n, n)).expect("valid identity")
798 }
799
800 fn build_tridiag(n: usize) -> CsrMatrix<f64> {
801 let mut rows = Vec::new();
802 let mut cols = Vec::new();
803 let mut data = Vec::new();
804 for i in 0..n {
805 if i > 0 {
806 rows.push(i);
807 cols.push(i - 1);
808 data.push(-1.0);
809 }
810 rows.push(i);
811 cols.push(i);
812 data.push(2.0);
813 if i + 1 < n {
814 rows.push(i);
815 cols.push(i + 1);
816 data.push(-1.0);
817 }
818 }
819 CsrMatrix::new(data, rows, cols, (n, n)).expect("valid matrix")
820 }
821
822 #[test]
823 fn test_frobenius_norm() {
824 let a = build_test_matrix();
825 let nrm = sparse_matrix_norm(&a, SparseNorm::Frobenius).expect("frobenius norm");
826 let expected = (140.0_f64).sqrt();
828 assert!(
829 (nrm - expected).abs() < 1e-10,
830 "Expected {expected}, got {nrm}"
831 );
832 }
833
834 #[test]
835 fn test_one_norm() {
836 let a = build_test_matrix();
837 let nrm = sparse_matrix_norm(&a, SparseNorm::One).expect("1-norm");
838 assert!((nrm - 12.0).abs() < 1e-10, "Expected 12.0, got {nrm}");
840 }
841
842 #[test]
843 fn test_inf_norm() {
844 let a = build_test_matrix();
845 let nrm = sparse_matrix_norm(&a, SparseNorm::Inf).expect("inf-norm");
846 assert!((nrm - 13.0).abs() < 1e-10, "Expected 13.0, got {nrm}");
848 }
849
850 #[test]
851 fn test_spgemm_identity() {
852 let a = build_test_matrix();
853 let eye = build_identity(3);
854 let c = spgemm(&a, &eye).expect("spgemm A*I");
855 for i in 0..3 {
856 for j in 0..3 {
857 assert!(
858 (c.get(i, j) - a.get(i, j)).abs() < 1e-10,
859 "Mismatch at ({i},{j})"
860 );
861 }
862 }
863 }
864
865 #[test]
866 fn test_spgemm_square() {
867 let rows = vec![0, 0, 1, 1];
870 let cols = vec![0, 1, 0, 1];
871 let data = vec![1.0, 2.0, 3.0, 4.0];
872 let a = CsrMatrix::new(data, rows, cols, (2, 2)).expect("valid matrix");
873 let c = spgemm(&a, &a).expect("spgemm A*A");
874 assert!((c.get(0, 0) - 7.0).abs() < 1e-10);
875 assert!((c.get(0, 1) - 10.0).abs() < 1e-10);
876 assert!((c.get(1, 0) - 15.0).abs() < 1e-10);
877 assert!((c.get(1, 1) - 22.0).abs() < 1e-10);
878 }
879
880 #[test]
881 fn test_spgemm_dimension_mismatch() {
882 let a = CsrMatrix::new(vec![1.0], vec![0], vec![0], (1, 2)).expect("valid");
883 let b = CsrMatrix::new(vec![1.0], vec![0], vec![0], (3, 1)).expect("valid");
884 assert!(spgemm(&a, &b).is_err());
885 }
886
887 #[test]
888 fn test_sparse_add() {
889 let a = build_identity(3);
890 let b = build_identity(3);
891 let c = sparse_add(&a, &b, 2.0, 3.0).expect("sparse add");
892 for i in 0..3 {
894 assert!((c.get(i, i) - 5.0).abs() < 1e-10);
895 }
896 }
897
898 #[test]
899 fn test_sparse_sub() {
900 let a = build_test_matrix();
901 let c = sparse_sub(&a, &a).expect("sparse sub");
902 for i in 0..3 {
904 for j in 0..3 {
905 assert!(c.get(i, j).abs() < 1e-10);
906 }
907 }
908 }
909
910 #[test]
911 fn test_sparse_scale() {
912 let a = build_identity(3);
913 let c = sparse_scale(&a, 5.0).expect("sparse scale");
914 for i in 0..3 {
915 assert!((c.get(i, i) - 5.0).abs() < 1e-10);
916 }
917 }
918
919 #[test]
920 fn test_sparse_kronecker_identity() {
921 let i2 = build_identity(2);
922 let i3 = build_identity(3);
923 let c = sparse_kronecker(&i2, &i3).expect("kronecker I2 x I3");
924 let (m, n) = c.shape();
926 assert_eq!(m, 6);
927 assert_eq!(n, 6);
928 for i in 0..6 {
929 for j in 0..6 {
930 let expected = if i == j { 1.0 } else { 0.0 };
931 assert!(
932 (c.get(i, j) - expected).abs() < 1e-10,
933 "Kronecker I2xI3 mismatch at ({i},{j})"
934 );
935 }
936 }
937 }
938
939 #[test]
940 fn test_sparse_kronecker_small() {
941 let a = CsrMatrix::new(
944 vec![1.0, 2.0, 3.0, 4.0],
945 vec![0, 0, 1, 1],
946 vec![0, 1, 0, 1],
947 (2, 2),
948 )
949 .expect("valid");
950 let b = CsrMatrix::new(
951 vec![5.0, 6.0, 7.0, 8.0],
952 vec![0, 0, 1, 1],
953 vec![0, 1, 0, 1],
954 (2, 2),
955 )
956 .expect("valid");
957 let c = sparse_kronecker(&a, &b).expect("kronecker");
958 assert!((c.get(0, 0) - 5.0).abs() < 1e-10);
959 assert!((c.get(0, 2) - 10.0).abs() < 1e-10);
960 assert!((c.get(3, 3) - 32.0).abs() < 1e-10);
961 }
962
963 #[test]
964 fn test_rcm_tridiagonal() {
965 let n = 10;
966 let a = build_tridiag(n);
967 let result = reverse_cuthill_mckee(&a).expect("rcm");
968 assert_eq!(result.permutation.len(), n);
969 assert_eq!(result.inverse_permutation.len(), n);
970 assert!(result.new_bandwidth <= result.original_bandwidth + 1);
972 }
973
974 #[test]
975 fn test_rcm_sparse_matrix() {
976 let n = 8;
978 let mut rows = Vec::new();
979 let mut cols = Vec::new();
980 let mut data = Vec::new();
981 for i in 0..n {
982 rows.push(i);
983 cols.push(i);
984 data.push(4.0);
985 if i + 1 < n {
986 rows.push(i);
987 cols.push(i + 1);
988 data.push(-1.0);
989 rows.push(i + 1);
990 cols.push(i);
991 data.push(-1.0);
992 }
993 if i + 3 < n {
994 rows.push(i);
995 cols.push(i + 3);
996 data.push(-0.5);
997 rows.push(i + 3);
998 cols.push(i);
999 data.push(-0.5);
1000 }
1001 }
1002 let a = CsrMatrix::new(data, rows, cols, (n, n)).expect("valid matrix");
1003 let result = reverse_cuthill_mckee(&a).expect("rcm");
1004 let mut sorted_perm = result.permutation.clone();
1006 sorted_perm.sort();
1007 let expected: Vec<usize> = (0..n).collect();
1008 assert_eq!(sorted_perm, expected);
1009 }
1010
1011 #[test]
1012 fn test_rcm_identity() {
1013 let eye = build_identity(5);
1014 let result = reverse_cuthill_mckee(&eye).expect("rcm identity");
1015 assert_eq!(result.original_bandwidth, 0);
1016 assert_eq!(result.new_bandwidth, 0);
1017 }
1018
1019 #[test]
1020 fn test_rcm_error_non_square() {
1021 let a = CsrMatrix::new(vec![1.0, 2.0], vec![0, 1], vec![0, 1], (2, 3)).expect("valid");
1022 assert!(reverse_cuthill_mckee(&a).is_err());
1023 }
1024
1025 #[test]
1026 fn test_permute_matrix() {
1027 let a = CsrMatrix::new(
1030 vec![1.0, 2.0, 3.0, 4.0],
1031 vec![0, 0, 1, 1],
1032 vec![0, 1, 0, 1],
1033 (2, 2),
1034 )
1035 .expect("valid");
1036 let perm = vec![1, 0];
1037 let inv_perm = vec![1, 0];
1038 let b = permute_matrix(&a, &perm, &inv_perm).expect("permute");
1039 assert!((b.get(0, 0) - 4.0).abs() < 1e-10);
1040 assert!((b.get(0, 1) - 3.0).abs() < 1e-10);
1041 assert!((b.get(1, 0) - 2.0).abs() < 1e-10);
1042 assert!((b.get(1, 1) - 1.0).abs() < 1e-10);
1043 }
1044
1045 #[test]
1046 fn test_condest_identity() {
1047 let eye = build_identity(5);
1048 let cond = condest_1norm(&eye).expect("condest");
1049 if let Some(c) = cond {
1051 assert!((c - 1.0).abs() < 1.0, "Expected cond(I) ~ 1, got {c}");
1052 }
1053 }
1054
1055 #[test]
1056 fn test_condest_diagonal() {
1057 let a = CsrMatrix::new(vec![1.0, 100.0], vec![0, 1], vec![0, 1], (2, 2)).expect("valid");
1059 let cond = condest_1norm(&a).expect("condest");
1060 if let Some(c) = cond {
1061 assert!(c > 10.0 && c < 1000.0, "Expected cond ~ 100, got {c}");
1063 }
1064 }
1065
1066 #[test]
1067 fn test_condest_error_non_square() {
1068 let a = CsrMatrix::new(vec![1.0], vec![0], vec![0], (1, 2)).expect("valid");
1069 assert!(condest_1norm(&a).is_err());
1070 }
1071
1072 #[test]
1073 fn test_sparse_extract_diagonal() {
1074 let a = build_test_matrix();
1075 let diag = sparse_extract_diagonal(&a);
1076 assert_eq!(diag.len(), 3);
1077 assert!((diag[0] - 1.0).abs() < 1e-10);
1078 assert!((diag[1] - 4.0).abs() < 1e-10);
1079 assert!((diag[2] - 7.0).abs() < 1e-10);
1080 }
1081
1082 #[test]
1083 fn test_sparse_matrix_trace() {
1084 let a = build_test_matrix();
1085 let tr = sparse_matrix_trace(&a);
1086 assert!((tr - 12.0).abs() < 1e-10); }
1088
1089 #[test]
1090 fn test_sparse_identity() {
1091 let eye: CsrMatrix<f64> = sparse_identity(4).expect("sparse identity");
1092 for i in 0..4 {
1093 for j in 0..4 {
1094 let expected = if i == j { 1.0 } else { 0.0 };
1095 assert!(
1096 (eye.get(i, j) - expected).abs() < 1e-10,
1097 "Identity mismatch at ({i},{j})"
1098 );
1099 }
1100 }
1101 }
1102
1103 #[test]
1104 fn test_sparse_transpose() {
1105 let a = build_test_matrix();
1106 let at = sparse_transpose(&a).expect("transpose");
1107 for i in 0..3 {
1108 for j in 0..3 {
1109 assert!(
1110 (at.get(i, j) - a.get(j, i)).abs() < 1e-10,
1111 "Transpose mismatch at ({i},{j})"
1112 );
1113 }
1114 }
1115 }
1116
1117 #[test]
1118 fn test_compute_bandwidth() {
1119 let a = build_tridiag(5);
1120 assert_eq!(compute_bandwidth(&a), 1);
1121 }
1122}