1#![allow(clippy::needless_range_loop, clippy::ptr_arg)]
2use rand::RngExt;
33
34#[allow(dead_code)]
42pub fn soft_threshold(x: f64, lambda: f64) -> f64 {
43 if x > lambda {
44 x - lambda
45 } else if x < -lambda {
46 x + lambda
47 } else {
48 0.0
49 }
50}
51
52#[allow(dead_code)]
56pub fn nyquist_rate(bandwidth: f64) -> f64 {
57 2.0 * bandwidth
58}
59
60#[allow(dead_code)]
64pub fn compression_ratio(n: usize, m: usize) -> f64 {
65 if n == 0 {
66 return 0.0;
67 }
68 m as f64 / n as f64
69}
70
71#[allow(dead_code)]
75pub fn l2_norm(x: &[f64]) -> f64 {
76 x.iter().map(|v| v * v).sum::<f64>().sqrt()
77}
78
79#[allow(dead_code)]
83pub fn normalise(x: &mut Vec<f64>) {
84 let n = l2_norm(x);
85 if n > 1e-14 {
86 for v in x.iter_mut() {
87 *v /= n;
88 }
89 }
90}
91
92#[allow(dead_code)]
96pub fn mat_vec(a: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
97 a.iter()
98 .map(|row| row.iter().zip(x.iter()).map(|(ai, xi)| ai * xi).sum())
99 .collect()
100}
101
102#[allow(dead_code)]
106pub fn mat_transpose_vec(a: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
107 if a.is_empty() {
108 return Vec::new();
109 }
110 let n = a[0].len();
111 let m = a.len();
112 let mut y = vec![0.0_f64; n];
113 for i in 0..m {
114 for j in 0..n {
115 y[j] += a[i][j] * x[i];
116 }
117 }
118 y
119}
120
121#[allow(dead_code)]
125pub fn spectral_norm(a: &[Vec<f64>], max_iter: usize) -> f64 {
126 if a.is_empty() {
127 return 0.0;
128 }
129 let n = a[0].len();
130 let mut v = vec![1.0_f64; n];
131 normalise(&mut v);
132 for _ in 0..max_iter {
133 let av = mat_vec(a, &v);
134 let mut atav = mat_transpose_vec(a, &av);
135 normalise(&mut atav);
136 v = atav;
137 }
138 let av = mat_vec(a, &v);
139 l2_norm(&av)
140}
141
142#[allow(dead_code)]
151pub struct DctBasis {
152 pub n: usize,
154}
155
156#[allow(dead_code)]
157impl DctBasis {
158 pub fn new(n: usize) -> Self {
160 Self { n }
161 }
162
163 pub fn transform(&self, x: &[f64]) -> Vec<f64> {
167 let n = self.n.min(x.len());
168 let mut out = vec![0.0; n];
169 let pi_over_n = std::f64::consts::PI / n as f64;
170 for k in 0..n {
171 let mut sum = 0.0;
172 for j in 0..n {
173 sum += x[j] * ((j as f64 + 0.5) * k as f64 * pi_over_n).cos();
174 }
175 let norm = if k == 0 {
177 (1.0 / n as f64).sqrt()
178 } else {
179 (2.0 / n as f64).sqrt()
180 };
181 out[k] = sum * norm;
182 }
183 out
184 }
185
186 pub fn inverse(&self, coeffs: &[f64]) -> Vec<f64> {
190 let n = self.n.min(coeffs.len());
191 let mut out = vec![0.0; n];
192 let pi_over_n = std::f64::consts::PI / n as f64;
193 for j in 0..n {
194 let mut sum = (1.0 / n as f64).sqrt() * coeffs[0];
195 for k in 1..n {
196 let norm = (2.0 / n as f64).sqrt();
197 sum += norm * coeffs[k] * ((j as f64 + 0.5) * k as f64 * pi_over_n).cos();
198 }
199 out[j] = sum;
200 }
201 out
202 }
203
204 pub fn truncate(&self, coeffs: &[f64], k: usize) -> Vec<f64> {
208 let mut indexed: Vec<(usize, f64)> = coeffs.iter().copied().enumerate().collect();
209 indexed.sort_by(|a, b| {
210 b.1.abs()
211 .partial_cmp(&a.1.abs())
212 .unwrap_or(std::cmp::Ordering::Equal)
213 });
214 let mut out = vec![0.0_f64; coeffs.len()];
215 for (i, v) in indexed.into_iter().take(k) {
216 out[i] = v;
217 }
218 out
219 }
220}
221
222#[allow(dead_code)]
231pub struct RandomMeasurementMatrix {
232 pub m: usize,
234 pub n: usize,
236 pub matrix: Vec<Vec<f64>>,
238}
239
240#[allow(dead_code)]
241impl RandomMeasurementMatrix {
242 pub fn generate_gaussian(m: usize, n: usize) -> Self {
247 use rand::RngExt as _;
248 let mut rng = rand::rng();
249 let scale = 1.0 / (m as f64).sqrt();
250 let matrix: Vec<Vec<f64>> = (0..m)
251 .map(|_| {
252 (0..n)
253 .map(|_| {
254 let u1: f64 = rng.random_range(1e-12_f64..1.0_f64);
256 let u2: f64 = rng.random_range(0.0_f64..1.0_f64);
257 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
258 z * scale
259 })
260 .collect()
261 })
262 .collect();
263 Self { m, n, matrix }
264 }
265
266 pub fn generate_bernoulli(m: usize, n: usize) -> Self {
270 let mut rng = rand::rng();
271 let scale = 1.0 / (m as f64).sqrt();
272 let matrix: Vec<Vec<f64>> = (0..m)
273 .map(|_| {
274 (0..n)
275 .map(|_| {
276 if rng.random_range(0.0_f64..1.0_f64) < 0.5 {
277 scale
278 } else {
279 -scale
280 }
281 })
282 .collect()
283 })
284 .collect();
285 Self { m, n, matrix }
286 }
287
288 pub fn measure(&self, x: &[f64]) -> Vec<f64> {
292 mat_vec(&self.matrix, x)
293 }
294
295 pub fn coherence(&self) -> f64 {
299 SparsityMetrics::coherence(&self.matrix)
300 }
301}
302
303#[allow(dead_code)]
314pub struct BasisPursuit;
315
316#[allow(dead_code)]
317impl BasisPursuit {
318 fn lipschitz(a: &[Vec<f64>]) -> f64 {
320 spectral_norm(a, 20).powi(2).max(1e-10)
321 }
322
323 pub fn solve_lasso(a: &[Vec<f64>], b: &[f64], lambda: f64, max_iter: usize) -> Vec<f64> {
332 if a.is_empty() || b.is_empty() {
333 return Vec::new();
334 }
335 let n = a[0].len();
336 let l = Self::lipschitz(a);
337 let step = 1.0 / l;
338 let mut x = vec![0.0_f64; n];
339
340 for _ in 0..max_iter {
341 let residual: Vec<f64> = mat_vec(a, &x)
342 .iter()
343 .zip(b.iter())
344 .map(|(r, bi)| r - bi)
345 .collect();
346 let grad = mat_transpose_vec(a, &residual);
347 x = x
348 .iter()
349 .zip(grad.iter())
350 .map(|(xi, gi)| soft_threshold(xi - step * gi, step * lambda))
351 .collect();
352 }
353 x
354 }
355
356 pub fn solve_fista(a: &[Vec<f64>], b: &[f64], lambda: f64, max_iter: usize) -> Vec<f64> {
365 if a.is_empty() || b.is_empty() {
366 return Vec::new();
367 }
368 let n = a[0].len();
369 let l = Self::lipschitz(a);
370 let step = 1.0 / l;
371
372 let mut x = vec![0.0_f64; n];
373 let mut y = x.clone();
374 let mut t = 1.0_f64;
375
376 for _ in 0..max_iter {
377 let x_prev = x.clone();
378
379 let residual: Vec<f64> = mat_vec(a, &y)
381 .iter()
382 .zip(b.iter())
383 .map(|(r, bi)| r - bi)
384 .collect();
385 let grad = mat_transpose_vec(a, &residual);
386 x = y
387 .iter()
388 .zip(grad.iter())
389 .map(|(yi, gi)| soft_threshold(yi - step * gi, step * lambda))
390 .collect();
391
392 let t_new = (1.0 + (1.0 + 4.0 * t * t).sqrt()) / 2.0;
394 let momentum = (t - 1.0) / t_new;
395 y = x
396 .iter()
397 .zip(x_prev.iter())
398 .map(|(xi, xi_prev)| xi + momentum * (xi - xi_prev))
399 .collect();
400 t = t_new;
401 }
402 x
403 }
404
405 pub fn objective(a: &[Vec<f64>], b: &[f64], x: &[f64], lambda: f64) -> f64 {
409 if a.is_empty() || b.is_empty() {
410 return 0.0;
411 }
412 let ax = mat_vec(a, x);
413 let residual_sq: f64 = ax
414 .iter()
415 .zip(b.iter())
416 .map(|(r, bi)| (r - bi).powi(2))
417 .sum();
418 let l1: f64 = x.iter().map(|xi| xi.abs()).sum();
419 0.5 * residual_sq + lambda * l1
420 }
421}
422
423#[allow(dead_code)]
432pub struct OrthogonalMatchingPursuit {
433 pub max_k: usize,
435}
436
437#[allow(dead_code)]
438impl OrthogonalMatchingPursuit {
439 pub fn new(max_k: usize) -> Self {
441 Self { max_k }
442 }
443
444 pub fn solve(&self, a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
451 if a.is_empty() || b.is_empty() {
452 return Vec::new();
453 }
454 let m = a.len();
455 let n = a[0].len();
456 let k = self.max_k.min(n).min(m);
457
458 let mut residual = b.to_vec();
459 let mut support: Vec<usize> = Vec::with_capacity(k);
460 let mut x = vec![0.0_f64; n];
461
462 for _ in 0..k {
463 let mut best_idx = 0;
465 let mut best_corr = 0.0_f64;
466 for j in 0..n {
467 if support.contains(&j) {
468 continue;
469 }
470 let corr: f64 = (0..m).map(|i| a[i][j] * residual[i]).sum::<f64>().abs();
471 if corr > best_corr {
472 best_corr = corr;
473 best_idx = j;
474 }
475 }
476 support.push(best_idx);
477
478 let s = support.len();
480 let mut ata = vec![vec![0.0_f64; s]; s];
481 let mut atb = vec![0.0_f64; s];
482 for (si, &ci) in support.iter().enumerate() {
483 for (sj, &cj) in support.iter().enumerate() {
484 ata[si][sj] = (0..m).map(|i| a[i][ci] * a[i][cj]).sum();
485 }
486 atb[si] = (0..m).map(|i| a[i][ci] * b[i]).sum();
487 }
488
489 let coeffs = gauss_solve(&ata, &atb);
491
492 for j in 0..n {
494 x[j] = 0.0;
495 }
496 for (si, &ci) in support.iter().enumerate() {
497 x[ci] = coeffs[si];
498 }
499
500 residual = (0..m)
502 .map(|i| {
503 let ax_i: f64 = (0..n).map(|j| a[i][j] * x[j]).sum();
504 b[i] - ax_i
505 })
506 .collect();
507
508 let res_norm: f64 = residual.iter().map(|r| r * r).sum::<f64>().sqrt();
509 if res_norm < 1e-12 {
510 break;
511 }
512 }
513 x
514 }
515
516 pub fn support(&self, a: &[Vec<f64>], b: &[f64]) -> Vec<usize> {
520 if a.is_empty() || b.is_empty() {
521 return Vec::new();
522 }
523 let m = a.len();
524 let n = a[0].len();
525 let k = self.max_k.min(n).min(m);
526 let mut residual = b.to_vec();
527 let mut support: Vec<usize> = Vec::with_capacity(k);
528
529 for _ in 0..k {
530 let mut best_idx = 0;
531 let mut best_corr = 0.0_f64;
532 for j in 0..n {
533 if support.contains(&j) {
534 continue;
535 }
536 let corr: f64 = (0..m).map(|i| a[i][j] * residual[i]).sum::<f64>().abs();
537 if corr > best_corr {
538 best_corr = corr;
539 best_idx = j;
540 }
541 }
542 support.push(best_idx);
543 let col_norm_sq: f64 = (0..m).map(|i| a[i][best_idx].powi(2)).sum();
545 if col_norm_sq < 1e-14 {
546 break;
547 }
548 let proj: f64 = (0..m).map(|i| a[i][best_idx] * residual[i]).sum::<f64>() / col_norm_sq;
549 for i in 0..m {
550 residual[i] -= proj * a[i][best_idx];
551 }
552 }
553 support
554 }
555}
556
557fn gauss_solve(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
561 let n = b.len();
562 if n == 0 {
563 return Vec::new();
564 }
565 let mut mat: Vec<Vec<f64>> = a.to_vec();
566 let mut rhs: Vec<f64> = b.to_vec();
567
568 for col in 0..n {
569 let pivot = (col..n).max_by(|&i, &j| {
571 mat[i][col]
572 .abs()
573 .partial_cmp(&mat[j][col].abs())
574 .unwrap_or(std::cmp::Ordering::Equal)
575 });
576 if let Some(p) = pivot {
577 mat.swap(col, p);
578 rhs.swap(col, p);
579 }
580 let diag = mat[col][col];
581 if diag.abs() < 1e-14 {
582 continue;
583 }
584 for row in (col + 1)..n {
585 let factor = mat[row][col] / diag;
586 for k in col..n {
587 let v = mat[col][k];
588 mat[row][k] -= factor * v;
589 }
590 rhs[row] -= factor * rhs[col];
591 }
592 }
593
594 let mut x = vec![0.0_f64; n];
596 for i in (0..n).rev() {
597 let mut s = rhs[i];
598 for j in (i + 1)..n {
599 s -= mat[i][j] * x[j];
600 }
601 let d = mat[i][i];
602 x[i] = if d.abs() < 1e-14 { 0.0 } else { s / d };
603 }
604 x
605}
606
607#[allow(dead_code)]
613pub struct SparsityMetrics;
614
615#[allow(dead_code)]
616impl SparsityMetrics {
617 pub fn l0_norm(x: &[f64], threshold: f64) -> usize {
619 x.iter().filter(|&&v| v.abs() > threshold).count()
620 }
621
622 pub fn l1_norm(x: &[f64]) -> f64 {
624 x.iter().map(|v| v.abs()).sum()
625 }
626
627 pub fn l2_norm(x: &[f64]) -> f64 {
629 x.iter().map(|v| v * v).sum::<f64>().sqrt()
630 }
631
632 pub fn gini(x: &[f64]) -> f64 {
636 let n = x.len();
637 if n == 0 {
638 return 0.0;
639 }
640 let mut sorted: Vec<f64> = x.iter().map(|v| v.abs()).collect();
641 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
642 let sum: f64 = sorted.iter().sum();
643 if sum < 1e-14 {
644 return 1.0; }
646 let weighted: f64 = sorted
648 .iter()
649 .enumerate()
650 .map(|(i, v)| (i + 1) as f64 * v)
651 .sum();
652 ((2.0 * weighted) / (n as f64 * sum) - (n as f64 + 1.0) / n as f64).clamp(0.0, 1.0)
653 }
654
655 pub fn coherence(a: &[Vec<f64>]) -> f64 {
660 if a.is_empty() {
661 return 0.0;
662 }
663 let m = a.len();
664 let n = a[0].len();
665 let cols: Vec<Vec<f64>> = (0..n).map(|j| (0..m).map(|i| a[i][j]).collect()).collect();
667 let norms: Vec<f64> = cols
668 .iter()
669 .map(|c| c.iter().map(|x| x * x).sum::<f64>().sqrt())
670 .collect();
671
672 let mut max_coherence = 0.0_f64;
673 for i in 0..n {
674 for j in (i + 1)..n {
675 let ni = norms[i];
676 let nj = norms[j];
677 if ni < 1e-14 || nj < 1e-14 {
678 continue;
679 }
680 let dot: f64 = cols[i].iter().zip(cols[j].iter()).map(|(a, b)| a * b).sum();
681 let c = (dot / (ni * nj)).abs();
682 if c > max_coherence {
683 max_coherence = c;
684 }
685 }
686 }
687 max_coherence
688 }
689
690 pub fn babel_function(a: &[Vec<f64>], k: usize) -> f64 {
695 if a.is_empty() {
696 return 0.0;
697 }
698 let m = a.len();
699 let n = a[0].len();
700 let cols: Vec<Vec<f64>> = (0..n).map(|j| (0..m).map(|i| a[i][j]).collect()).collect();
701 let norms: Vec<f64> = cols
702 .iter()
703 .map(|c| c.iter().map(|x| x * x).sum::<f64>().sqrt())
704 .collect();
705
706 let mut max_babel = 0.0_f64;
707 for i in 0..n {
708 if norms[i] < 1e-14 {
709 continue;
710 }
711 let mut corrs: Vec<f64> = (0..n)
713 .filter(|&j| j != i)
714 .filter(|&j| norms[j] > 1e-14)
715 .map(|j| {
716 let dot: f64 = cols[i].iter().zip(cols[j].iter()).map(|(a, b)| a * b).sum();
717 (dot / (norms[i] * norms[j])).abs()
718 })
719 .collect();
720 corrs.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
721 let babel: f64 = corrs.iter().take(k).sum();
722 if babel > max_babel {
723 max_babel = babel;
724 }
725 }
726 max_babel
727 }
728}
729
730#[allow(dead_code)]
736pub struct RecoveryGuarantee;
737
738#[allow(dead_code)]
739impl RecoveryGuarantee {
740 pub fn rip_constant(a: &[Vec<f64>], k: usize) -> f64 {
745 if a.is_empty() {
746 return 0.0;
747 }
748 let m = a.len();
749 let n = a[0].len();
750 let k = k.min(n);
751
752 let mut max_dev = 0.0_f64;
753 let trials = if n <= 10 { n } else { 50 };
755 for start in 0..trials {
756 let support: Vec<usize> = (0..k).map(|i| (start + i) % n).collect();
757 let v: Vec<f64> = {
759 let mut vec = vec![0.0_f64; n];
760 let norm = (k as f64).sqrt();
761 for &j in &support {
762 vec[j] = 1.0 / norm;
763 }
764 vec
765 };
766 let av: Vec<f64> = (0..m)
768 .map(|i| {
769 a[i].iter()
770 .zip(v.iter())
771 .map(|(ai, vi)| ai * vi)
772 .sum::<f64>()
773 })
774 .collect();
775 let energy: f64 = av.iter().map(|x| x * x).sum();
776 let dev = (energy - 1.0).abs();
778 if dev > max_dev {
779 max_dev = dev;
780 }
781 }
782 max_dev
783 }
784
785 pub fn exact_recovery_condition(k: usize, m: usize, n: usize) -> bool {
789 if k == 0 || n == 0 || k > n {
790 return true;
791 }
792 let required = 2.0 * k as f64 * ((n as f64 / k as f64).ln()).max(1.0);
793 m as f64 >= required
794 }
795
796 pub fn rip_measurement_lower_bound(k: usize, n: usize) -> usize {
800 if k == 0 || n == 0 || k > n {
801 return 0;
802 }
803 let c = 4.0_f64;
804 (c * k as f64 * (n as f64 / k as f64).ln()).ceil() as usize
805 }
806
807 pub fn lasso_error_bound(lambda: f64, k: usize, rip_delta: f64) -> f64 {
812 let c = (1.0 + rip_delta) / (1.0 - 2.0_f64.sqrt() * rip_delta).max(1e-14);
813 c * lambda * (k as f64).sqrt()
814 }
815}
816
817#[allow(dead_code)]
828pub struct KSvd {
829 pub n_atoms: usize,
831 pub sparsity: usize,
833 pub n_iter: usize,
835}
836
837#[allow(dead_code)]
838impl KSvd {
839 pub fn new(n_atoms: usize, sparsity: usize, n_iter: usize) -> Self {
845 Self {
846 n_atoms,
847 sparsity,
848 n_iter,
849 }
850 }
851
852 pub fn fit(&self, signals: &[Vec<f64>]) -> Vec<Vec<f64>> {
859 if signals.is_empty() {
860 return Vec::new();
861 }
862 let d = signals[0].len();
863 let n_signals = signals.len();
864 let n_atoms = self.n_atoms.min(d);
865
866 let mut rng = rand::rng();
868 let mut dict: Vec<Vec<f64>> = (0..n_atoms)
869 .map(|k| {
870 let idx = k % n_signals;
871 let _ = idx; let pick = rng.random_range(0..n_signals);
873 let mut atom = signals[pick].clone();
874 let norm = l2_norm(&atom);
875 if norm > 1e-14 {
876 for v in atom.iter_mut() {
877 *v /= norm;
878 }
879 }
880 atom
881 })
882 .collect();
883
884 let omp = OrthogonalMatchingPursuit::new(self.sparsity);
885
886 for _iter in 0..self.n_iter {
887 let dict_t: Vec<Vec<f64>> = (0..d)
894 .map(|i| (0..n_atoms).map(|k| dict[k][i]).collect())
895 .collect();
896
897 let codes: Vec<Vec<f64>> = signals.iter().map(|y| omp.solve(&dict_t, y)).collect();
898
899 for k in 0..n_atoms {
901 let using: Vec<usize> = (0..n_signals)
903 .filter(|&s| codes[s][k].abs() > 1e-14)
904 .collect();
905 if using.is_empty() {
906 let pick = rng.random_range(0..n_signals);
908 let mut atom = signals[pick].clone();
909 let norm = l2_norm(&atom);
910 if norm > 1e-14 {
911 for v in atom.iter_mut() {
912 *v /= norm;
913 }
914 }
915 dict[k] = atom;
916 continue;
917 }
918
919 let e_rows: Vec<Vec<f64>> = using
923 .iter()
924 .map(|&s| {
925 let mut e = signals[s].clone();
926 for j in 0..n_atoms {
927 if j == k {
928 continue;
929 }
930 let coef = codes[s][j];
931 for i in 0..d {
932 e[i] -= coef * dict[j][i];
933 }
934 }
935 e
936 })
937 .collect();
938
939 let mut atom = dict[k].clone();
941 for _pi in 0..10 {
942 let e_atom: Vec<f64> = e_rows
945 .iter()
946 .map(|row| row.iter().zip(atom.iter()).map(|(a, b)| a * b).sum::<f64>())
947 .collect();
948 let mut new_atom = vec![0.0_f64; d];
949 for (e_row, &ea) in e_rows.iter().zip(e_atom.iter()) {
950 for (i, &ei) in e_row.iter().enumerate() {
951 new_atom[i] += ei * ea;
952 }
953 }
954 normalise(&mut new_atom);
955 atom = new_atom;
956 }
957 dict[k] = atom;
958
959 for &s in &using {
961 e_rows.iter().position(|_| true).map(|_| ()).unwrap_or(());
962 if let Some(pos) = using.iter().position(|&u| u == s) {
964 let dot: f64 = e_rows[pos]
965 .iter()
966 .zip(dict[k].iter())
967 .map(|(a, b)| a * b)
968 .sum();
969 let _ = (s, dot, pos);
970 }
971 }
972 let _ = e_rows.len();
974 }
975 }
976
977 dict
978 }
979
980 pub fn encode(&self, dict: &[Vec<f64>], signal: &[f64]) -> Vec<f64> {
984 if dict.is_empty() || signal.is_empty() {
985 return Vec::new();
986 }
987 let d = signal.len();
988 let n_atoms = dict.len();
989 let dict_t: Vec<Vec<f64>> = (0..d)
991 .map(|i| (0..n_atoms).map(|k| dict[k][i]).collect())
992 .collect();
993 let omp = OrthogonalMatchingPursuit::new(self.sparsity);
994 omp.solve(&dict_t, signal)
995 }
996
997 pub fn reconstruct(dict: &[Vec<f64>], code: &[f64]) -> Vec<f64> {
1001 if dict.is_empty() {
1002 return Vec::new();
1003 }
1004 let d = dict[0].len();
1005 let mut out = vec![0.0_f64; d];
1006 for (k, atom) in dict.iter().enumerate() {
1007 if k >= code.len() {
1008 break;
1009 }
1010 for (i, &ai) in atom.iter().enumerate() {
1011 out[i] += code[k] * ai;
1012 }
1013 }
1014 out
1015 }
1016}
1017
1018#[allow(dead_code)]
1028pub struct MriCompressedSensing {
1029 pub n: usize,
1031 pub m: usize,
1033}
1034
1035#[allow(dead_code)]
1036impl MriCompressedSensing {
1037 pub fn new(n: usize, m: usize) -> Self {
1040 Self { n, m }
1041 }
1042
1043 pub fn sample_kspace_indices(&self) -> Vec<usize> {
1047 let mut rng = rand::rng();
1048 let mut indices: Vec<usize> = (0..self.n).collect();
1049 for i in 0..self.m.min(self.n) {
1051 let j = rng.random_range(i..self.n);
1052 indices.swap(i, j);
1053 }
1054 indices[..self.m.min(self.n)].to_vec()
1055 }
1056
1057 pub fn build_measurement_matrix(&self, kspace_indices: &[usize]) -> Vec<Vec<f64>> {
1062 let scale = 1.0 / (self.m as f64).sqrt();
1063 kspace_indices
1064 .iter()
1065 .map(|&ki| {
1066 (0..self.n)
1067 .map(|j| {
1068 (2.0 * std::f64::consts::PI * ki as f64 * j as f64 / self.n as f64).cos()
1069 * scale
1070 })
1071 .collect()
1072 })
1073 .collect()
1074 }
1075
1076 #[allow(clippy::too_many_arguments)]
1085 pub fn reconstruct_fista(
1086 &self,
1087 measurements: &[f64],
1088 kspace_indices: &[usize],
1089 lambda: f64,
1090 max_iter: usize,
1091 ) -> Vec<f64> {
1092 let a = self.build_measurement_matrix(kspace_indices);
1093 BasisPursuit::solve_fista(&a, measurements, lambda, max_iter)
1094 }
1095
1096 pub fn psnr(original: &[f64], reconstructed: &[f64], max_val: f64) -> f64 {
1100 let n = original.len().min(reconstructed.len());
1101 if n == 0 {
1102 return 0.0;
1103 }
1104 let mse: f64 = original[..n]
1105 .iter()
1106 .zip(reconstructed[..n].iter())
1107 .map(|(a, b)| (a - b).powi(2))
1108 .sum::<f64>()
1109 / n as f64;
1110 if mse < 1e-14 {
1111 return f64::INFINITY;
1112 }
1113 20.0 * (max_val / mse.sqrt()).log10()
1114 }
1115}
1116
1117#[allow(dead_code)]
1123pub struct SparseSignal;
1124
1125#[allow(dead_code)]
1126impl SparseSignal {
1127 pub fn generate(n: usize, k: usize, amplitude: f64) -> Vec<f64> {
1131 let mut rng = rand::rng();
1132 let mut signal = vec![0.0_f64; n];
1133 let k = k.min(n);
1134
1135 let mut indices: Vec<usize> = (0..n).collect();
1137 for i in 0..k {
1138 let j = rng.random_range(i..n);
1139 indices.swap(i, j);
1140 }
1141 for &idx in &indices[..k] {
1142 signal[idx] = rng.random_range(-amplitude..amplitude);
1143 }
1144 signal
1145 }
1146
1147 pub fn add_noise(signal: &[f64], sigma: f64) -> Vec<f64> {
1149 let mut rng = rand::rng();
1150 signal
1151 .iter()
1152 .map(|&x| {
1153 let u1: f64 = rng.random_range(1e-12_f64..1.0_f64);
1154 let u2: f64 = rng.random_range(0.0_f64..1.0_f64);
1155 let noise = (-2.0_f64 * u1.ln()).sqrt()
1156 * (2.0_f64 * std::f64::consts::PI * u2).cos()
1157 * sigma;
1158 x + noise
1159 })
1160 .collect()
1161 }
1162
1163 pub fn support_error(truth: &[f64], recovered: &[f64], threshold: f64) -> f64 {
1167 let n = truth.len().min(recovered.len());
1168 if n == 0 {
1169 return 0.0;
1170 }
1171 let mismatches: usize = truth[..n]
1172 .iter()
1173 .zip(recovered[..n].iter())
1174 .filter(|&(t, r): &(&f64, &f64)| {
1175 let t_nonzero = t.abs() > threshold;
1176 let r_nonzero = r.abs() > threshold;
1177 t_nonzero != r_nonzero
1178 })
1179 .count();
1180 mismatches as f64 / n as f64
1181 }
1182
1183 pub fn relative_error(truth: &[f64], recovered: &[f64]) -> f64 {
1185 let n = truth.len().min(recovered.len());
1186 let err: f64 = truth[..n]
1187 .iter()
1188 .zip(recovered[..n].iter())
1189 .map(|(a, b)| (a - b).powi(2))
1190 .sum::<f64>()
1191 .sqrt();
1192 let norm: f64 = truth[..n].iter().map(|x| x * x).sum::<f64>().sqrt();
1193 if norm < 1e-14 { err } else { err / norm }
1194 }
1195}
1196
1197#[cfg(test)]
1202mod tests {
1203 use super::*;
1204
1205 #[test]
1208 fn test_soft_threshold_positive_above() {
1209 assert!((soft_threshold(3.0, 1.0) - 2.0).abs() < 1e-12);
1210 }
1211
1212 #[test]
1213 fn test_soft_threshold_positive_below() {
1214 assert_eq!(soft_threshold(0.5, 1.0), 0.0);
1215 }
1216
1217 #[test]
1218 fn test_soft_threshold_negative_above() {
1219 assert!((soft_threshold(-3.0, 1.0) + 2.0).abs() < 1e-12);
1220 }
1221
1222 #[test]
1223 fn test_soft_threshold_zero() {
1224 assert_eq!(soft_threshold(0.0, 1.0), 0.0);
1225 }
1226
1227 #[test]
1228 fn test_soft_threshold_exact_boundary() {
1229 assert_eq!(soft_threshold(1.0, 1.0), 0.0);
1230 assert_eq!(soft_threshold(-1.0, 1.0), 0.0);
1231 }
1232
1233 #[test]
1234 fn test_soft_threshold_zero_lambda() {
1235 assert!((soft_threshold(5.0, 0.0) - 5.0).abs() < 1e-12);
1236 }
1237
1238 #[test]
1241 fn test_nyquist_rate_basic() {
1242 assert!((nyquist_rate(1000.0) - 2000.0).abs() < 1e-9);
1243 }
1244
1245 #[test]
1246 fn test_nyquist_rate_zero() {
1247 assert_eq!(nyquist_rate(0.0), 0.0);
1248 }
1249
1250 #[test]
1253 fn test_compression_ratio_half() {
1254 assert!((compression_ratio(100, 50) - 0.5).abs() < 1e-12);
1255 }
1256
1257 #[test]
1258 fn test_compression_ratio_zero_n() {
1259 assert_eq!(compression_ratio(0, 10), 0.0);
1260 }
1261
1262 #[test]
1263 fn test_compression_ratio_full() {
1264 assert!((compression_ratio(10, 10) - 1.0).abs() < 1e-12);
1265 }
1266
1267 #[test]
1270 fn test_l2_norm_known() {
1271 let x = vec![3.0, 4.0];
1272 assert!((l2_norm(&x) - 5.0).abs() < 1e-12);
1273 }
1274
1275 #[test]
1276 fn test_normalise_unit_vector() {
1277 let mut v = vec![3.0, 0.0, 4.0];
1278 normalise(&mut v);
1279 assert!((l2_norm(&v) - 1.0).abs() < 1e-12);
1280 }
1281
1282 #[test]
1283 fn test_normalise_zero_vector_unchanged() {
1284 let mut v = vec![0.0, 0.0, 0.0];
1285 normalise(&mut v);
1286 assert_eq!(v, vec![0.0, 0.0, 0.0]);
1287 }
1288
1289 #[test]
1292 fn test_mat_vec_identity() {
1293 let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1294 let x = vec![3.0, 7.0];
1295 let y = mat_vec(&a, &x);
1296 assert!((y[0] - 3.0).abs() < 1e-12);
1297 assert!((y[1] - 7.0).abs() < 1e-12);
1298 }
1299
1300 #[test]
1301 fn test_mat_transpose_vec_basic() {
1302 let a = vec![vec![1.0, 2.0, 3.0]]; let x = vec![2.0]; let y = mat_transpose_vec(&a, &x);
1305 assert_eq!(y.len(), 3);
1306 assert!((y[0] - 2.0).abs() < 1e-12);
1307 assert!((y[1] - 4.0).abs() < 1e-12);
1308 assert!((y[2] - 6.0).abs() < 1e-12);
1309 }
1310
1311 #[test]
1314 fn test_dct_roundtrip() {
1315 let n = 8;
1316 let basis = DctBasis::new(n);
1317 let signal: Vec<f64> = (0..n).map(|i| (i as f64).sin()).collect();
1318 let coeffs = basis.transform(&signal);
1319 let recovered = basis.inverse(&coeffs);
1320 for (a, b) in signal.iter().zip(recovered.iter()) {
1321 assert!((a - b).abs() < 1e-10, "DCT roundtrip mismatch: {a} vs {b}");
1322 }
1323 }
1324
1325 #[test]
1326 fn test_dct_dc_component() {
1327 let n = 8;
1329 let basis = DctBasis::new(n);
1330 let signal = vec![1.0_f64; n];
1331 let coeffs = basis.transform(&signal);
1332 for k in 1..n {
1333 assert!(
1334 coeffs[k].abs() < 1e-10,
1335 "non-DC coefficient k={k} should be ~0, got {}",
1336 coeffs[k]
1337 );
1338 }
1339 assert!(coeffs[0].abs() > 0.5, "DC component should be non-zero");
1340 }
1341
1342 #[test]
1343 fn test_dct_length_preserved() {
1344 let n = 16;
1345 let basis = DctBasis::new(n);
1346 let signal: Vec<f64> = vec![1.0; n];
1347 let coeffs = basis.transform(&signal);
1348 assert_eq!(coeffs.len(), n);
1349 }
1350
1351 #[test]
1352 fn test_dct_energy_preservation() {
1353 let n = 8;
1355 let basis = DctBasis::new(n);
1356 let signal: Vec<f64> = (0..n).map(|i| i as f64).collect();
1357 let coeffs = basis.transform(&signal);
1358 let e_signal: f64 = signal.iter().map(|x| x * x).sum();
1359 let e_coeffs: f64 = coeffs.iter().map(|x| x * x).sum();
1360 assert!((e_signal - e_coeffs).abs() / (e_signal + 1.0) < 1e-10);
1361 }
1362
1363 #[test]
1364 fn test_dct_new_n() {
1365 let basis = DctBasis::new(4);
1366 assert_eq!(basis.n, 4);
1367 }
1368
1369 #[test]
1370 fn test_dct_truncate_keeps_k_largest() {
1371 let basis = DctBasis::new(8);
1372 let coeffs = vec![1.0, 5.0, 0.1, 3.0, 0.0, 2.0, 0.0, 0.0];
1373 let truncated = basis.truncate(&coeffs, 2);
1374 let nonzero = truncated.iter().filter(|&&v| v.abs() > 1e-14).count();
1375 assert_eq!(nonzero, 2, "truncate(k=2) should leave 2 non-zeros");
1376 assert!((truncated[1] - 5.0).abs() < 1e-12, "5.0 should be kept");
1377 assert!((truncated[3] - 3.0).abs() < 1e-12, "3.0 should be kept");
1378 }
1379
1380 #[test]
1383 fn test_random_measurement_matrix_dimensions() {
1384 let mat = RandomMeasurementMatrix::generate_gaussian(10, 20);
1385 assert_eq!(mat.m, 10);
1386 assert_eq!(mat.n, 20);
1387 assert_eq!(mat.matrix.len(), 10);
1388 assert_eq!(mat.matrix[0].len(), 20);
1389 }
1390
1391 #[test]
1392 fn test_measurement_output_length() {
1393 let mat = RandomMeasurementMatrix::generate_gaussian(5, 10);
1394 let x = vec![1.0_f64; 10];
1395 let y = mat.measure(&x);
1396 assert_eq!(y.len(), 5);
1397 }
1398
1399 #[test]
1400 fn test_measurement_linearity() {
1401 let mat = RandomMeasurementMatrix::generate_gaussian(5, 8);
1402 let x1: Vec<f64> = (0..8).map(|i| i as f64).collect();
1403 let x2: Vec<f64> = (0..8).map(|i| (8 - i) as f64).collect();
1404 let y1 = mat.measure(&x1);
1405 let y2 = mat.measure(&x2);
1406 let y_sum: Vec<f64> = x1.iter().zip(x2.iter()).map(|(a, b)| a + b).collect();
1407 let y_direct = mat.measure(&y_sum);
1408 for (a, b) in y_direct
1409 .iter()
1410 .zip(y1.iter().zip(y2.iter()).map(|(a, b)| a + b))
1411 {
1412 assert!((a - b).abs() < 1e-10, "linearity: {a} vs {b}");
1413 }
1414 }
1415
1416 #[test]
1417 fn test_bernoulli_matrix_dimensions() {
1418 let mat = RandomMeasurementMatrix::generate_bernoulli(8, 16);
1419 assert_eq!(mat.m, 8);
1420 assert_eq!(mat.n, 16);
1421 }
1422
1423 #[test]
1424 fn test_bernoulli_entries_are_plus_minus_scale() {
1425 let m = 5;
1426 let n = 10;
1427 let mat = RandomMeasurementMatrix::generate_bernoulli(m, n);
1428 let scale = 1.0 / (m as f64).sqrt();
1429 for row in &mat.matrix {
1430 for &v in row {
1431 let diff = (v.abs() - scale).abs();
1432 assert!(diff < 1e-12, "Bernoulli entry |{v}| ≠ {scale}");
1433 }
1434 }
1435 }
1436
1437 #[test]
1440 fn test_ista_trivial_identity() {
1441 let a: Vec<Vec<f64>> = (0..4)
1443 .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1444 .collect();
1445 let b = vec![1.0, 0.0, 0.0, 0.0];
1446 let x = BasisPursuit::solve_lasso(&a, &b, 1e-4, 200);
1447 assert_eq!(x.len(), 4);
1448 assert!((x[0] - 1.0).abs() < 0.05, "x[0] should be ~1, got {}", x[0]);
1449 assert!(x[1].abs() < 0.05);
1450 }
1451
1452 #[test]
1453 fn test_ista_empty_input() {
1454 let x = BasisPursuit::solve_lasso(&[], &[], 1.0, 100);
1455 assert!(x.is_empty());
1456 }
1457
1458 #[test]
1459 fn test_ista_sparse_recovery() {
1460 let a: Vec<Vec<f64>> = (0..4)
1462 .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1463 .collect();
1464 let b = vec![3.0, 0.0, 0.0, 0.0];
1465 let x = BasisPursuit::solve_lasso(&a, &b, 0.01, 300);
1466 assert!((x[0] - 3.0).abs() < 0.1, "x[0] ≈ 3, got {}", x[0]);
1467 }
1468
1469 #[test]
1472 fn test_fista_identity_recovery() {
1473 let a: Vec<Vec<f64>> = (0..4)
1474 .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1475 .collect();
1476 let b = vec![0.0, 2.5, 0.0, 0.0];
1477 let x = BasisPursuit::solve_fista(&a, &b, 1e-3, 300);
1478 assert!((x[1] - 2.5).abs() < 0.05, "x[1] ≈ 2.5, got {}", x[1]);
1479 }
1480
1481 #[test]
1482 fn test_fista_empty_input() {
1483 let x = BasisPursuit::solve_fista(&[], &[], 1.0, 100);
1484 assert!(x.is_empty());
1485 }
1486
1487 #[test]
1488 fn test_fista_objective_decreases() {
1489 let a: Vec<Vec<f64>> = (0..3)
1490 .map(|i| (0..3).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1491 .collect();
1492 let b = vec![1.0, -1.0, 0.5];
1493 let lambda = 0.1;
1494 let x0 = vec![0.0_f64; 3];
1495 let obj0 = BasisPursuit::objective(&a, &b, &x0, lambda);
1496 let x_hat = BasisPursuit::solve_fista(&a, &b, lambda, 100);
1497 let obj1 = BasisPursuit::objective(&a, &b, &x_hat, lambda);
1498 assert!(
1499 obj1 <= obj0 + 1e-10,
1500 "FISTA should decrease objective: {obj1} > {obj0}"
1501 );
1502 }
1503
1504 #[test]
1507 fn test_omp_exact_1_sparse() {
1508 let a: Vec<Vec<f64>> = (0..4)
1510 .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1511 .collect();
1512 let b = vec![0.0, 5.0, 0.0, 0.0];
1513 let omp = OrthogonalMatchingPursuit::new(1);
1514 let x = omp.solve(&a, &b);
1515 assert_eq!(x.len(), 4);
1516 assert!((x[1] - 5.0).abs() < 1e-10, "x[1] should be 5, got {}", x[1]);
1517 }
1518
1519 #[test]
1520 fn test_omp_exact_2_sparse() {
1521 let a: Vec<Vec<f64>> = (0..4)
1522 .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1523 .collect();
1524 let b = vec![2.0, 0.0, 7.0, 0.0];
1525 let omp = OrthogonalMatchingPursuit::new(2);
1526 let x = omp.solve(&a, &b);
1527 assert!((x[0] - 2.0).abs() < 1e-8);
1528 assert!((x[2] - 7.0).abs() < 1e-8);
1529 }
1530
1531 #[test]
1532 fn test_omp_empty_input() {
1533 let omp = OrthogonalMatchingPursuit::new(3);
1534 let x = omp.solve(&[], &[]);
1535 assert!(x.is_empty());
1536 }
1537
1538 #[test]
1539 fn test_omp_new() {
1540 let omp = OrthogonalMatchingPursuit::new(5);
1541 assert_eq!(omp.max_k, 5);
1542 }
1543
1544 #[test]
1545 fn test_omp_residual_decreases() {
1546 let a: Vec<Vec<f64>> = (0..6)
1548 .map(|i| (0..6).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1549 .collect();
1550 let b = vec![1.0, -1.0, 2.0, 0.0, -2.0, 0.0];
1551 let omp = OrthogonalMatchingPursuit::new(4);
1552 let x = omp.solve(&a, &b);
1553 let residual: f64 = b
1555 .iter()
1556 .enumerate()
1557 .map(|(i, &bi)| {
1558 let ax: f64 = a[i].iter().zip(x.iter()).map(|(aij, xj)| aij * xj).sum();
1559 (bi - ax).powi(2)
1560 })
1561 .sum::<f64>()
1562 .sqrt();
1563 assert!(residual < 1e-8, "residual should be tiny, got {residual}");
1564 }
1565
1566 #[test]
1567 fn test_omp_support_length() {
1568 let a: Vec<Vec<f64>> = (0..5)
1569 .map(|i| (0..5).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1570 .collect();
1571 let b = vec![1.0, 0.0, 3.0, 0.0, 2.0];
1572 let omp = OrthogonalMatchingPursuit::new(3);
1573 let supp = omp.support(&a, &b);
1574 assert_eq!(
1575 supp.len(),
1576 3,
1577 "support should have 3 elements, got {}",
1578 supp.len()
1579 );
1580 }
1581
1582 #[test]
1585 fn test_l0_norm_basic() {
1586 let x = vec![0.0, 1.0, 0.0, -2.0, 0.001];
1587 assert_eq!(SparsityMetrics::l0_norm(&x, 0.5), 2); }
1589
1590 #[test]
1591 fn test_l0_norm_all_zero() {
1592 let x = vec![0.0, 0.0, 0.0];
1593 assert_eq!(SparsityMetrics::l0_norm(&x, 1e-6), 0);
1594 }
1595
1596 #[test]
1597 fn test_l1_norm_basic() {
1598 let x = vec![1.0, -2.0, 3.0];
1599 assert!((SparsityMetrics::l1_norm(&x) - 6.0).abs() < 1e-12);
1600 }
1601
1602 #[test]
1603 fn test_l1_norm_empty() {
1604 assert_eq!(SparsityMetrics::l1_norm(&[]), 0.0);
1605 }
1606
1607 #[test]
1608 fn test_l2_norm_sparsity() {
1609 let x = vec![3.0, 4.0];
1610 assert!((SparsityMetrics::l2_norm(&x) - 5.0).abs() < 1e-12);
1611 }
1612
1613 #[test]
1614 fn test_gini_sparse_is_near_one() {
1615 let x = vec![0.0, 0.0, 0.0, 10.0]; let g = SparsityMetrics::gini(&x);
1617 assert!(g > 0.6, "sparse signal should have high Gini, got {g}");
1618 }
1619
1620 #[test]
1621 fn test_gini_uniform_is_near_zero() {
1622 let x = vec![1.0, 1.0, 1.0, 1.0]; let g = SparsityMetrics::gini(&x);
1624 assert!(g < 0.1, "uniform signal should have low Gini, got {g}");
1625 }
1626
1627 #[test]
1628 fn test_coherence_identity_is_zero() {
1629 let a: Vec<Vec<f64>> = (0..4)
1631 .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1632 .collect();
1633 let mu = SparsityMetrics::coherence(&a);
1634 assert!(mu < 1e-12, "identity coherence should be 0, got {mu}");
1635 }
1636
1637 #[test]
1638 fn test_coherence_empty() {
1639 assert_eq!(SparsityMetrics::coherence(&[]), 0.0);
1640 }
1641
1642 #[test]
1643 fn test_coherence_collinear_columns() {
1644 let a = vec![vec![1.0, 1.0], vec![0.0, 0.0]];
1646 let mu = SparsityMetrics::coherence(&a);
1647 assert!(
1648 (mu - 1.0).abs() < 1e-10,
1649 "collinear columns → coherence=1, got {mu}"
1650 );
1651 }
1652
1653 #[test]
1654 fn test_babel_function_identity() {
1655 let a: Vec<Vec<f64>> = (0..4)
1656 .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1657 .collect();
1658 let babel = SparsityMetrics::babel_function(&a, 2);
1659 assert!(babel < 1e-12, "identity babel should be 0, got {babel}");
1660 }
1661
1662 #[test]
1665 fn test_exact_recovery_condition_sufficient_measurements() {
1666 assert!(RecoveryGuarantee::exact_recovery_condition(1, 20, 100));
1668 }
1669
1670 #[test]
1671 fn test_exact_recovery_condition_insufficient() {
1672 assert!(!RecoveryGuarantee::exact_recovery_condition(50, 5, 100));
1674 }
1675
1676 #[test]
1677 fn test_exact_recovery_condition_k_zero() {
1678 assert!(RecoveryGuarantee::exact_recovery_condition(0, 0, 100));
1679 }
1680
1681 #[test]
1682 fn test_rip_constant_identity() {
1683 let n = 4;
1685 let a: Vec<Vec<f64>> = (0..n)
1686 .map(|i| (0..n).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1687 .collect();
1688 let delta = RecoveryGuarantee::rip_constant(&a, 1);
1689 assert!(
1690 delta < 1e-10,
1691 "identity RIP constant should be ~0, got {delta}"
1692 );
1693 }
1694
1695 #[test]
1696 fn test_rip_constant_empty() {
1697 assert_eq!(RecoveryGuarantee::rip_constant(&[], 1), 0.0);
1698 }
1699
1700 #[test]
1701 fn test_rip_measurement_lower_bound_nonzero() {
1702 let lb = RecoveryGuarantee::rip_measurement_lower_bound(5, 100);
1703 assert!(lb > 0, "lower bound should be positive");
1704 }
1705
1706 #[test]
1707 fn test_lasso_error_bound_positive() {
1708 let bound = RecoveryGuarantee::lasso_error_bound(0.1, 4, 0.1);
1709 assert!(bound > 0.0, "LASSO error bound should be positive");
1710 }
1711
1712 #[test]
1715 fn test_gauss_solve_2x2() {
1716 let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
1718 let b = vec![5.0, 10.0];
1719 let x = gauss_solve(&a, &b);
1720 assert!((x[0] - 1.0).abs() < 1e-10, "x[0]={}", x[0]);
1721 assert!((x[1] - 3.0).abs() < 1e-10, "x[1]={}", x[1]);
1722 }
1723
1724 #[test]
1725 fn test_gauss_solve_1x1() {
1726 let a = vec![vec![4.0]];
1727 let b = vec![8.0];
1728 let x = gauss_solve(&a, &b);
1729 assert!((x[0] - 2.0).abs() < 1e-12);
1730 }
1731
1732 #[test]
1733 fn test_gauss_solve_empty() {
1734 let x = gauss_solve(&[], &[]);
1735 assert!(x.is_empty());
1736 }
1737
1738 #[test]
1741 fn test_mri_cs_new() {
1742 let mri = MriCompressedSensing::new(64, 20);
1743 assert_eq!(mri.n, 64);
1744 assert_eq!(mri.m, 20);
1745 }
1746
1747 #[test]
1748 fn test_mri_kspace_indices_length() {
1749 let mri = MriCompressedSensing::new(32, 10);
1750 let idx = mri.sample_kspace_indices();
1751 assert_eq!(idx.len(), 10);
1752 }
1753
1754 #[test]
1755 fn test_mri_measurement_matrix_shape() {
1756 let mri = MriCompressedSensing::new(16, 8);
1757 let idx: Vec<usize> = (0..8).collect();
1758 let a = mri.build_measurement_matrix(&idx);
1759 assert_eq!(a.len(), 8);
1760 assert_eq!(a[0].len(), 16);
1761 }
1762
1763 #[test]
1764 fn test_psnr_identical_signals() {
1765 let s = vec![1.0, 2.0, 3.0];
1766 let psnr = MriCompressedSensing::psnr(&s, &s, 3.0);
1767 assert!(psnr.is_infinite(), "identical signals → PSNR = ∞");
1768 }
1769
1770 #[test]
1771 fn test_psnr_known_value() {
1772 let original = vec![1.0, 0.0];
1773 let reconstructed = vec![0.0, 0.0];
1774 let psnr = MriCompressedSensing::psnr(&original, &reconstructed, 1.0);
1775 assert!(psnr.is_finite(), "PSNR should be finite for non-identical");
1776 }
1777
1778 #[test]
1781 fn test_sparse_signal_generate_sparsity() {
1782 let sig = SparseSignal::generate(20, 3, 1.0);
1783 assert_eq!(sig.len(), 20);
1784 let nnz = sig.iter().filter(|&&v| v.abs() > 1e-14).count();
1785 assert_eq!(
1786 nnz, 3,
1787 "generated signal should have exactly 3 non-zeros, got {nnz}"
1788 );
1789 }
1790
1791 #[test]
1792 fn test_sparse_signal_generate_length() {
1793 let sig = SparseSignal::generate(100, 5, 2.0);
1794 assert_eq!(sig.len(), 100);
1795 }
1796
1797 #[test]
1798 fn test_sparse_signal_relative_error_zero() {
1799 let s = vec![1.0, 2.0, 3.0];
1800 let err = SparseSignal::relative_error(&s, &s);
1801 assert!(
1802 err < 1e-12,
1803 "identical signals should have 0 relative error"
1804 );
1805 }
1806
1807 #[test]
1808 fn test_sparse_signal_support_error_identical() {
1809 let s = vec![0.0, 1.0, 0.0, 2.0];
1810 let err = SparseSignal::support_error(&s, &s, 0.5);
1811 assert_eq!(err, 0.0, "identical support → error = 0");
1812 }
1813
1814 #[test]
1817 fn test_ksvd_new() {
1818 let k = KSvd::new(8, 2, 5);
1819 assert_eq!(k.n_atoms, 8);
1820 assert_eq!(k.sparsity, 2);
1821 assert_eq!(k.n_iter, 5);
1822 }
1823
1824 #[test]
1825 fn test_ksvd_reconstruct_zero_code() {
1826 let dict: Vec<Vec<f64>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1827 let code = vec![0.0, 0.0];
1828 let rec = KSvd::reconstruct(&dict, &code);
1829 assert_eq!(rec, vec![0.0, 0.0]);
1830 }
1831
1832 #[test]
1833 fn test_ksvd_reconstruct_unit_code() {
1834 let dict: Vec<Vec<f64>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1835 let code = vec![3.0, 5.0];
1836 let rec = KSvd::reconstruct(&dict, &code);
1837 assert!((rec[0] - 3.0).abs() < 1e-12);
1838 assert!((rec[1] - 5.0).abs() < 1e-12);
1839 }
1840
1841 #[test]
1842 fn test_ksvd_fit_returns_correct_shape() {
1843 let signals: Vec<Vec<f64>> = (0..4)
1845 .map(|i| (0..6).map(|j| if j == i { 1.0 } else { 0.0 }).collect())
1846 .collect();
1847 let ksvd = KSvd::new(4, 1, 2);
1848 let dict = ksvd.fit(&signals);
1849 assert_eq!(dict.len(), 4, "dict should have 4 atoms");
1850 assert_eq!(dict[0].len(), 6, "each atom should have length 6");
1851 }
1852
1853 #[test]
1854 fn test_ksvd_encode_length() {
1855 let dict: Vec<Vec<f64>> = (0..4)
1856 .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1857 .collect();
1858 let ksvd = KSvd::new(4, 1, 1);
1859 let signal = vec![0.0, 1.0, 0.0, 0.0];
1860 let code = ksvd.encode(&dict, &signal);
1861 assert_eq!(code.len(), 4, "code length should match n_atoms");
1862 }
1863
1864 #[test]
1865 fn test_spectral_norm_identity() {
1866 let a: Vec<Vec<f64>> = (0..4)
1867 .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1868 .collect();
1869 let sn = spectral_norm(&a, 30);
1870 assert!(
1871 (sn - 1.0).abs() < 0.01,
1872 "spectral norm of I should be ~1, got {sn}"
1873 );
1874 }
1875
1876 #[test]
1877 fn test_spectral_norm_empty() {
1878 let sn = spectral_norm(&[], 10);
1879 assert_eq!(sn, 0.0);
1880 }
1881}