Skip to main content

oxiphysics_core/
compressed_sensing.rs

1#![allow(clippy::needless_range_loop, clippy::ptr_arg)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Compressed sensing and sparse signal recovery algorithms.
6//!
7//! Provides Discrete Cosine Transform (DCT) bases, random measurement matrices,
8//! Basis Pursuit (ISTA/LASSO/FISTA), Orthogonal Matching Pursuit (OMP),
9//! sparsity metrics, dictionary learning (K-SVD), sparse coding, FISTA,
10//! restricted isometry property (RIP) analysis, Bernoulli measurement matrices,
11//! MRI-like signal reconstruction, and theoretical recovery guarantees for
12//! compressed sensing problems.
13//!
14//! # Overview
15//!
16//! Compressed sensing (CS) exploits the sparsity of natural signals to allow
17//! faithful reconstruction from far fewer measurements than the Nyquist rate.
18//! The key ingredients are:
19//!
20//! * A **sparsifying basis** (DCT, Wavelet, etc.) in which the signal has few
21//!   non-zero coefficients.
22//! * A **measurement matrix** (random Gaussian or Bernoulli) that is incoherent
23//!   with the sparsifying basis.
24//! * A **recovery algorithm** (Basis Pursuit / ISTA / FISTA / OMP) that finds
25//!   the sparsest signal consistent with the measurements.
26//!
27//! The `BasisPursuit` struct implements ISTA (slow) and FISTA (fast, with
28//! Nesterov momentum) for L1-regularised least-squares (LASSO) recovery.
29//! `OrthogonalMatchingPursuit` provides a greedy alternative.
30//! `KSvd` implements the K-SVD dictionary learning algorithm.
31
32use rand::RngExt;
33
34// ─────────────────────────────────────────────────────────────────────────────
35// Free functions
36// ─────────────────────────────────────────────────────────────────────────────
37
38/// Apply element-wise soft-thresholding: `sign(x) * max(|x| - lambda, 0)`.
39///
40/// Used as a proximal operator in iterative shrinkage-thresholding algorithms.
41#[allow(dead_code)]
42pub fn soft_threshold(x: f64, lambda: f64) -> f64 {
43    if x > lambda {
44        x - lambda
45    } else if x < -lambda {
46        x + lambda
47    } else {
48        0.0
49    }
50}
51
52/// Compute the Nyquist sampling rate for a band-limited signal.
53///
54/// Returns `2 * bandwidth` (samples per second).
55#[allow(dead_code)]
56pub fn nyquist_rate(bandwidth: f64) -> f64 {
57    2.0 * bandwidth
58}
59
60/// Compute the compression ratio `m / n`.
61///
62/// A ratio less than 1 indicates sub-Nyquist sampling.
63#[allow(dead_code)]
64pub fn compression_ratio(n: usize, m: usize) -> f64 {
65    if n == 0 {
66        return 0.0;
67    }
68    m as f64 / n as f64
69}
70
71/// Compute the ℓ₂ norm of a slice.
72///
73/// Returns `sqrt(sum of squares)`.
74#[allow(dead_code)]
75pub fn l2_norm(x: &[f64]) -> f64 {
76    x.iter().map(|v| v * v).sum::<f64>().sqrt()
77}
78
79/// Normalise a vector to unit ℓ₂ norm in place.
80///
81/// If the norm is smaller than `1e-14` the vector is left unchanged.
82#[allow(dead_code)]
83pub fn normalise(x: &mut Vec<f64>) {
84    let n = l2_norm(x);
85    if n > 1e-14 {
86        for v in x.iter_mut() {
87            *v /= n;
88        }
89    }
90}
91
92/// Compute the matrix-vector product `y = A x`.
93///
94/// `a` is row-major with shape `m × n`; returns a vector of length `m`.
95#[allow(dead_code)]
96pub fn mat_vec(a: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
97    a.iter()
98        .map(|row| row.iter().zip(x.iter()).map(|(ai, xi)| ai * xi).sum())
99        .collect()
100}
101
102/// Compute the transposed matrix-vector product `y = A^T x`.
103///
104/// `a` is row-major with shape `m × n`; returns a vector of length `n`.
105#[allow(dead_code)]
106pub fn mat_transpose_vec(a: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
107    if a.is_empty() {
108        return Vec::new();
109    }
110    let n = a[0].len();
111    let m = a.len();
112    let mut y = vec![0.0_f64; n];
113    for i in 0..m {
114        for j in 0..n {
115            y[j] += a[i][j] * x[i];
116        }
117    }
118    y
119}
120
121/// Estimate the spectral norm (largest singular value) of `a` via power iteration.
122///
123/// Runs `max_iter` iterations; returns an approximation to `||A||_2`.
124#[allow(dead_code)]
125pub fn spectral_norm(a: &[Vec<f64>], max_iter: usize) -> f64 {
126    if a.is_empty() {
127        return 0.0;
128    }
129    let n = a[0].len();
130    let mut v = vec![1.0_f64; n];
131    normalise(&mut v);
132    for _ in 0..max_iter {
133        let av = mat_vec(a, &v);
134        let mut atav = mat_transpose_vec(a, &av);
135        normalise(&mut atav);
136        v = atav;
137    }
138    let av = mat_vec(a, &v);
139    l2_norm(&av)
140}
141
142// ─────────────────────────────────────────────────────────────────────────────
143// DctBasis
144// ─────────────────────────────────────────────────────────────────────────────
145
146/// Discrete Cosine Transform (DCT-II) basis for sparse representation.
147///
148/// Signals with few significant DCT coefficients can be recovered from
149/// far fewer measurements than the Nyquist rate.
150#[allow(dead_code)]
151pub struct DctBasis {
152    /// Length of the signal (number of samples).
153    pub n: usize,
154}
155
156#[allow(dead_code)]
157impl DctBasis {
158    /// Create a DCT basis for signals of length `n`.
159    pub fn new(n: usize) -> Self {
160        Self { n }
161    }
162
163    /// Compute the forward DCT-II transform of `x`.
164    ///
165    /// Returns a coefficient vector of the same length.
166    pub fn transform(&self, x: &[f64]) -> Vec<f64> {
167        let n = self.n.min(x.len());
168        let mut out = vec![0.0; n];
169        let pi_over_n = std::f64::consts::PI / n as f64;
170        for k in 0..n {
171            let mut sum = 0.0;
172            for j in 0..n {
173                sum += x[j] * ((j as f64 + 0.5) * k as f64 * pi_over_n).cos();
174            }
175            // DCT-II normalisation
176            let norm = if k == 0 {
177                (1.0 / n as f64).sqrt()
178            } else {
179                (2.0 / n as f64).sqrt()
180            };
181            out[k] = sum * norm;
182        }
183        out
184    }
185
186    /// Compute the inverse DCT-II (i.e. DCT-III) transform of `coeffs`.
187    ///
188    /// Reconstructs the original signal from its DCT coefficients.
189    pub fn inverse(&self, coeffs: &[f64]) -> Vec<f64> {
190        let n = self.n.min(coeffs.len());
191        let mut out = vec![0.0; n];
192        let pi_over_n = std::f64::consts::PI / n as f64;
193        for j in 0..n {
194            let mut sum = (1.0 / n as f64).sqrt() * coeffs[0];
195            for k in 1..n {
196                let norm = (2.0 / n as f64).sqrt();
197                sum += norm * coeffs[k] * ((j as f64 + 0.5) * k as f64 * pi_over_n).cos();
198            }
199            out[j] = sum;
200        }
201        out
202    }
203
204    /// Threshold DCT coefficients to keep only the `k` largest-magnitude ones.
205    ///
206    /// Returns a truncated coefficient vector (all others set to zero).
207    pub fn truncate(&self, coeffs: &[f64], k: usize) -> Vec<f64> {
208        let mut indexed: Vec<(usize, f64)> = coeffs.iter().copied().enumerate().collect();
209        indexed.sort_by(|a, b| {
210            b.1.abs()
211                .partial_cmp(&a.1.abs())
212                .unwrap_or(std::cmp::Ordering::Equal)
213        });
214        let mut out = vec![0.0_f64; coeffs.len()];
215        for (i, v) in indexed.into_iter().take(k) {
216            out[i] = v;
217        }
218        out
219    }
220}
221
222// ─────────────────────────────────────────────────────────────────────────────
223// RandomMeasurementMatrix
224// ─────────────────────────────────────────────────────────────────────────────
225
226/// A random Gaussian measurement matrix for compressed sensing.
227///
228/// Each row is an independent Gaussian random vector; `m << n` enables
229/// sub-Nyquist recovery of sparse signals.
230#[allow(dead_code)]
231pub struct RandomMeasurementMatrix {
232    /// Number of measurements (rows).
233    pub m: usize,
234    /// Signal length (columns).
235    pub n: usize,
236    /// Underlying matrix entries stored row-major.
237    pub matrix: Vec<Vec<f64>>,
238}
239
240#[allow(dead_code)]
241impl RandomMeasurementMatrix {
242    /// Generate an `m × n` Gaussian measurement matrix (entries ~ N(0, 1/m)).
243    ///
244    /// The columns are scaled by `1/sqrt(m)` so that each measurement
245    /// approximately preserves the signal energy.
246    pub fn generate_gaussian(m: usize, n: usize) -> Self {
247        use rand::RngExt as _;
248        let mut rng = rand::rng();
249        let scale = 1.0 / (m as f64).sqrt();
250        let matrix: Vec<Vec<f64>> = (0..m)
251            .map(|_| {
252                (0..n)
253                    .map(|_| {
254                        // Box-Muller for N(0,1)
255                        let u1: f64 = rng.random_range(1e-12_f64..1.0_f64);
256                        let u2: f64 = rng.random_range(0.0_f64..1.0_f64);
257                        let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
258                        z * scale
259                    })
260                    .collect()
261            })
262            .collect();
263        Self { m, n, matrix }
264    }
265
266    /// Generate an `m × n` Bernoulli ±1/sqrt(m) measurement matrix.
267    ///
268    /// Each entry is independently ±1/sqrt(m) with equal probability 1/2.
269    pub fn generate_bernoulli(m: usize, n: usize) -> Self {
270        let mut rng = rand::rng();
271        let scale = 1.0 / (m as f64).sqrt();
272        let matrix: Vec<Vec<f64>> = (0..m)
273            .map(|_| {
274                (0..n)
275                    .map(|_| {
276                        if rng.random_range(0.0_f64..1.0_f64) < 0.5 {
277                            scale
278                        } else {
279                            -scale
280                        }
281                    })
282                    .collect()
283            })
284            .collect();
285        Self { m, n, matrix }
286    }
287
288    /// Apply the measurement matrix to signal `x`.
289    ///
290    /// Returns a vector of length `m` (the compressed measurements).
291    pub fn measure(&self, x: &[f64]) -> Vec<f64> {
292        mat_vec(&self.matrix, x)
293    }
294
295    /// Compute the mutual coherence of the measurement matrix columns.
296    ///
297    /// Lower coherence is better for sparse recovery.
298    pub fn coherence(&self) -> f64 {
299        SparsityMetrics::coherence(&self.matrix)
300    }
301}
302
303// ─────────────────────────────────────────────────────────────────────────────
304// BasisPursuit / ISTA / FISTA
305// ─────────────────────────────────────────────────────────────────────────────
306
307/// Basis Pursuit via iterative shrinkage-thresholding (ISTA and FISTA).
308///
309/// Solves the LASSO problem: `argmin_x 0.5 ||Ax - b||^2 + lambda ||x||_1`.
310///
311/// FISTA adds Nesterov momentum for faster O(1/k²) convergence versus
312/// O(1/k) for plain ISTA.
313#[allow(dead_code)]
314pub struct BasisPursuit;
315
316#[allow(dead_code)]
317impl BasisPursuit {
318    /// Estimate the Lipschitz constant of the gradient via power iteration.
319    fn lipschitz(a: &[Vec<f64>]) -> f64 {
320        spectral_norm(a, 20).powi(2).max(1e-10)
321    }
322
323    /// Solve the LASSO problem using ISTA.
324    ///
325    /// - `a` — measurement matrix (m × n, row-major)
326    /// - `b` — measurement vector (length m)
327    /// - `lambda` — sparsity regularisation weight
328    /// - `max_iter` — maximum number of iterations
329    ///
330    /// Returns the recovered sparse signal of length n.
331    pub fn solve_lasso(a: &[Vec<f64>], b: &[f64], lambda: f64, max_iter: usize) -> Vec<f64> {
332        if a.is_empty() || b.is_empty() {
333            return Vec::new();
334        }
335        let n = a[0].len();
336        let l = Self::lipschitz(a);
337        let step = 1.0 / l;
338        let mut x = vec![0.0_f64; n];
339
340        for _ in 0..max_iter {
341            let residual: Vec<f64> = mat_vec(a, &x)
342                .iter()
343                .zip(b.iter())
344                .map(|(r, bi)| r - bi)
345                .collect();
346            let grad = mat_transpose_vec(a, &residual);
347            x = x
348                .iter()
349                .zip(grad.iter())
350                .map(|(xi, gi)| soft_threshold(xi - step * gi, step * lambda))
351                .collect();
352        }
353        x
354    }
355
356    /// Solve the LASSO problem using FISTA (Fast ISTA with Nesterov momentum).
357    ///
358    /// - `a` — measurement matrix (m × n, row-major)
359    /// - `b` — measurement vector (length m)
360    /// - `lambda` — sparsity regularisation weight
361    /// - `max_iter` — maximum number of iterations
362    ///
363    /// Returns the recovered sparse signal of length n.
364    pub fn solve_fista(a: &[Vec<f64>], b: &[f64], lambda: f64, max_iter: usize) -> Vec<f64> {
365        if a.is_empty() || b.is_empty() {
366            return Vec::new();
367        }
368        let n = a[0].len();
369        let l = Self::lipschitz(a);
370        let step = 1.0 / l;
371
372        let mut x = vec![0.0_f64; n];
373        let mut y = x.clone();
374        let mut t = 1.0_f64;
375
376        for _ in 0..max_iter {
377            let x_prev = x.clone();
378
379            // Gradient step at y
380            let residual: Vec<f64> = mat_vec(a, &y)
381                .iter()
382                .zip(b.iter())
383                .map(|(r, bi)| r - bi)
384                .collect();
385            let grad = mat_transpose_vec(a, &residual);
386            x = y
387                .iter()
388                .zip(grad.iter())
389                .map(|(yi, gi)| soft_threshold(yi - step * gi, step * lambda))
390                .collect();
391
392            // Nesterov momentum update
393            let t_new = (1.0 + (1.0 + 4.0 * t * t).sqrt()) / 2.0;
394            let momentum = (t - 1.0) / t_new;
395            y = x
396                .iter()
397                .zip(x_prev.iter())
398                .map(|(xi, xi_prev)| xi + momentum * (xi - xi_prev))
399                .collect();
400            t = t_new;
401        }
402        x
403    }
404
405    /// Compute the objective value `0.5 ||Ax - b||^2 + lambda ||x||_1`.
406    ///
407    /// Useful for monitoring convergence.
408    pub fn objective(a: &[Vec<f64>], b: &[f64], x: &[f64], lambda: f64) -> f64 {
409        if a.is_empty() || b.is_empty() {
410            return 0.0;
411        }
412        let ax = mat_vec(a, x);
413        let residual_sq: f64 = ax
414            .iter()
415            .zip(b.iter())
416            .map(|(r, bi)| (r - bi).powi(2))
417            .sum();
418        let l1: f64 = x.iter().map(|xi| xi.abs()).sum();
419        0.5 * residual_sq + lambda * l1
420    }
421}
422
423// ─────────────────────────────────────────────────────────────────────────────
424// OrthogonalMatchingPursuit
425// ─────────────────────────────────────────────────────────────────────────────
426
427/// Orthogonal Matching Pursuit (OMP) for sparse signal recovery.
428///
429/// Greedily selects the most correlated column of the measurement matrix
430/// at each step and performs a least-squares fit on the selected support.
431#[allow(dead_code)]
432pub struct OrthogonalMatchingPursuit {
433    /// Maximum sparsity (number of non-zero coefficients to recover).
434    pub max_k: usize,
435}
436
437#[allow(dead_code)]
438impl OrthogonalMatchingPursuit {
439    /// Create an OMP solver with sparsity bound `max_k`.
440    pub fn new(max_k: usize) -> Self {
441        Self { max_k }
442    }
443
444    /// Recover a sparse signal from measurements.
445    ///
446    /// - `a` — measurement matrix (m × n, row-major)
447    /// - `b` — measurement vector (length m)
448    ///
449    /// Returns the recovered coefficient vector of length n.
450    pub fn solve(&self, a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
451        if a.is_empty() || b.is_empty() {
452            return Vec::new();
453        }
454        let m = a.len();
455        let n = a[0].len();
456        let k = self.max_k.min(n).min(m);
457
458        let mut residual = b.to_vec();
459        let mut support: Vec<usize> = Vec::with_capacity(k);
460        let mut x = vec![0.0_f64; n];
461
462        for _ in 0..k {
463            // Find column most correlated with residual
464            let mut best_idx = 0;
465            let mut best_corr = 0.0_f64;
466            for j in 0..n {
467                if support.contains(&j) {
468                    continue;
469                }
470                let corr: f64 = (0..m).map(|i| a[i][j] * residual[i]).sum::<f64>().abs();
471                if corr > best_corr {
472                    best_corr = corr;
473                    best_idx = j;
474                }
475            }
476            support.push(best_idx);
477
478            // Least-squares on support: solve A_S^T A_S c = A_S^T b
479            let s = support.len();
480            let mut ata = vec![vec![0.0_f64; s]; s];
481            let mut atb = vec![0.0_f64; s];
482            for (si, &ci) in support.iter().enumerate() {
483                for (sj, &cj) in support.iter().enumerate() {
484                    ata[si][sj] = (0..m).map(|i| a[i][ci] * a[i][cj]).sum();
485                }
486                atb[si] = (0..m).map(|i| a[i][ci] * b[i]).sum();
487            }
488
489            // Solve s×s system via Gaussian elimination
490            let coeffs = gauss_solve(&ata, &atb);
491
492            // Update x
493            for j in 0..n {
494                x[j] = 0.0;
495            }
496            for (si, &ci) in support.iter().enumerate() {
497                x[ci] = coeffs[si];
498            }
499
500            // Update residual: r = b - A x
501            residual = (0..m)
502                .map(|i| {
503                    let ax_i: f64 = (0..n).map(|j| a[i][j] * x[j]).sum();
504                    b[i] - ax_i
505                })
506                .collect();
507
508            let res_norm: f64 = residual.iter().map(|r| r * r).sum::<f64>().sqrt();
509            if res_norm < 1e-12 {
510                break;
511            }
512        }
513        x
514    }
515
516    /// Return the support (indices of non-zero components) for sparsity level `k`.
517    ///
518    /// Runs OMP for exactly `k` steps and returns the selected column indices.
519    pub fn support(&self, a: &[Vec<f64>], b: &[f64]) -> Vec<usize> {
520        if a.is_empty() || b.is_empty() {
521            return Vec::new();
522        }
523        let m = a.len();
524        let n = a[0].len();
525        let k = self.max_k.min(n).min(m);
526        let mut residual = b.to_vec();
527        let mut support: Vec<usize> = Vec::with_capacity(k);
528
529        for _ in 0..k {
530            let mut best_idx = 0;
531            let mut best_corr = 0.0_f64;
532            for j in 0..n {
533                if support.contains(&j) {
534                    continue;
535                }
536                let corr: f64 = (0..m).map(|i| a[i][j] * residual[i]).sum::<f64>().abs();
537                if corr > best_corr {
538                    best_corr = corr;
539                    best_idx = j;
540                }
541            }
542            support.push(best_idx);
543            // Quick residual update (orthogonal projection onto selected atom)
544            let col_norm_sq: f64 = (0..m).map(|i| a[i][best_idx].powi(2)).sum();
545            if col_norm_sq < 1e-14 {
546                break;
547            }
548            let proj: f64 = (0..m).map(|i| a[i][best_idx] * residual[i]).sum::<f64>() / col_norm_sq;
549            for i in 0..m {
550                residual[i] -= proj * a[i][best_idx];
551            }
552        }
553        support
554    }
555}
556
557/// Gaussian elimination solver for a small dense system `ax = b`.
558///
559/// Returns the solution vector, or zeros if the system is singular.
560fn gauss_solve(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
561    let n = b.len();
562    if n == 0 {
563        return Vec::new();
564    }
565    let mut mat: Vec<Vec<f64>> = a.to_vec();
566    let mut rhs: Vec<f64> = b.to_vec();
567
568    for col in 0..n {
569        // Partial pivoting
570        let pivot = (col..n).max_by(|&i, &j| {
571            mat[i][col]
572                .abs()
573                .partial_cmp(&mat[j][col].abs())
574                .unwrap_or(std::cmp::Ordering::Equal)
575        });
576        if let Some(p) = pivot {
577            mat.swap(col, p);
578            rhs.swap(col, p);
579        }
580        let diag = mat[col][col];
581        if diag.abs() < 1e-14 {
582            continue;
583        }
584        for row in (col + 1)..n {
585            let factor = mat[row][col] / diag;
586            for k in col..n {
587                let v = mat[col][k];
588                mat[row][k] -= factor * v;
589            }
590            rhs[row] -= factor * rhs[col];
591        }
592    }
593
594    // Back substitution
595    let mut x = vec![0.0_f64; n];
596    for i in (0..n).rev() {
597        let mut s = rhs[i];
598        for j in (i + 1)..n {
599            s -= mat[i][j] * x[j];
600        }
601        let d = mat[i][i];
602        x[i] = if d.abs() < 1e-14 { 0.0 } else { s / d };
603    }
604    x
605}
606
607// ─────────────────────────────────────────────────────────────────────────────
608// SparsityMetrics
609// ─────────────────────────────────────────────────────────────────────────────
610
611/// Metrics for quantifying signal sparsity and dictionary coherence.
612#[allow(dead_code)]
613pub struct SparsityMetrics;
614
615#[allow(dead_code)]
616impl SparsityMetrics {
617    /// Count the number of elements whose absolute value exceeds `threshold` (ℓ₀ norm).
618    pub fn l0_norm(x: &[f64], threshold: f64) -> usize {
619        x.iter().filter(|&&v| v.abs() > threshold).count()
620    }
621
622    /// Compute the ℓ₁ norm (sum of absolute values).
623    pub fn l1_norm(x: &[f64]) -> f64 {
624        x.iter().map(|v| v.abs()).sum()
625    }
626
627    /// Compute the ℓ₂ norm.
628    pub fn l2_norm(x: &[f64]) -> f64 {
629        x.iter().map(|v| v * v).sum::<f64>().sqrt()
630    }
631
632    /// Compute the Gini coefficient of `|x|` as a sparsity measure.
633    ///
634    /// Returns a value in `[0, 1]`; 1 = maximally sparse, 0 = maximally spread.
635    pub fn gini(x: &[f64]) -> f64 {
636        let n = x.len();
637        if n == 0 {
638            return 0.0;
639        }
640        let mut sorted: Vec<f64> = x.iter().map(|v| v.abs()).collect();
641        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
642        let sum: f64 = sorted.iter().sum();
643        if sum < 1e-14 {
644            return 1.0; // all zero → perfectly sparse
645        }
646        // Standard Gini coefficient: G = (2 * sum_{i=1}^{n} i*x_i) / (n * sum) - (n+1)/n
647        let weighted: f64 = sorted
648            .iter()
649            .enumerate()
650            .map(|(i, v)| (i + 1) as f64 * v)
651            .sum();
652        ((2.0 * weighted) / (n as f64 * sum) - (n as f64 + 1.0) / n as f64).clamp(0.0, 1.0)
653    }
654
655    /// Compute the mutual coherence of a matrix `a`.
656    ///
657    /// Defined as the maximum normalised inner product between distinct columns:
658    /// `μ(A) = max_{i≠j} |a_i^T a_j| / (||a_i|| ||a_j||)`.
659    pub fn coherence(a: &[Vec<f64>]) -> f64 {
660        if a.is_empty() {
661            return 0.0;
662        }
663        let m = a.len();
664        let n = a[0].len();
665        // Collect columns
666        let cols: Vec<Vec<f64>> = (0..n).map(|j| (0..m).map(|i| a[i][j]).collect()).collect();
667        let norms: Vec<f64> = cols
668            .iter()
669            .map(|c| c.iter().map(|x| x * x).sum::<f64>().sqrt())
670            .collect();
671
672        let mut max_coherence = 0.0_f64;
673        for i in 0..n {
674            for j in (i + 1)..n {
675                let ni = norms[i];
676                let nj = norms[j];
677                if ni < 1e-14 || nj < 1e-14 {
678                    continue;
679                }
680                let dot: f64 = cols[i].iter().zip(cols[j].iter()).map(|(a, b)| a * b).sum();
681                let c = (dot / (ni * nj)).abs();
682                if c > max_coherence {
683                    max_coherence = c;
684                }
685            }
686        }
687        max_coherence
688    }
689
690    /// Compute the Babel function `μ_1(k)` of a dictionary.
691    ///
692    /// `μ_1(k) = max_i sum_{j in S, |S|=k, j≠i} |a_i^T a_j| / (||a_i|| ||a_j||)`.
693    /// A small Babel function implies better sparse recovery guarantees.
694    pub fn babel_function(a: &[Vec<f64>], k: usize) -> f64 {
695        if a.is_empty() {
696            return 0.0;
697        }
698        let m = a.len();
699        let n = a[0].len();
700        let cols: Vec<Vec<f64>> = (0..n).map(|j| (0..m).map(|i| a[i][j]).collect()).collect();
701        let norms: Vec<f64> = cols
702            .iter()
703            .map(|c| c.iter().map(|x| x * x).sum::<f64>().sqrt())
704            .collect();
705
706        let mut max_babel = 0.0_f64;
707        for i in 0..n {
708            if norms[i] < 1e-14 {
709                continue;
710            }
711            // Sort coherences with column i in descending order
712            let mut corrs: Vec<f64> = (0..n)
713                .filter(|&j| j != i)
714                .filter(|&j| norms[j] > 1e-14)
715                .map(|j| {
716                    let dot: f64 = cols[i].iter().zip(cols[j].iter()).map(|(a, b)| a * b).sum();
717                    (dot / (norms[i] * norms[j])).abs()
718                })
719                .collect();
720            corrs.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
721            let babel: f64 = corrs.iter().take(k).sum();
722            if babel > max_babel {
723                max_babel = babel;
724            }
725        }
726        max_babel
727    }
728}
729
730// ─────────────────────────────────────────────────────────────────────────────
731// RecoveryGuarantee
732// ─────────────────────────────────────────────────────────────────────────────
733
734/// Theoretical guarantees for exact sparse recovery.
735#[allow(dead_code)]
736pub struct RecoveryGuarantee;
737
738#[allow(dead_code)]
739impl RecoveryGuarantee {
740    /// Estimate the Restricted Isometry Property (RIP) constant for sparsity `k`.
741    ///
742    /// Computes the worst-case deviation from isometry across all `k`-sparse unit
743    /// vectors by sampling deterministic sparse vectors and checking energy preservation.
744    pub fn rip_constant(a: &[Vec<f64>], k: usize) -> f64 {
745        if a.is_empty() {
746            return 0.0;
747        }
748        let m = a.len();
749        let n = a[0].len();
750        let k = k.min(n);
751
752        let mut max_dev = 0.0_f64;
753        // Check all k-element subsets (or a random sample for large n)
754        let trials = if n <= 10 { n } else { 50 };
755        for start in 0..trials {
756            let support: Vec<usize> = (0..k).map(|i| (start + i) % n).collect();
757            // Build a unit vector on this support
758            let v: Vec<f64> = {
759                let mut vec = vec![0.0_f64; n];
760                let norm = (k as f64).sqrt();
761                for &j in &support {
762                    vec[j] = 1.0 / norm;
763                }
764                vec
765            };
766            // Compute ||Av||^2
767            let av: Vec<f64> = (0..m)
768                .map(|i| {
769                    a[i].iter()
770                        .zip(v.iter())
771                        .map(|(ai, vi)| ai * vi)
772                        .sum::<f64>()
773                })
774                .collect();
775            let energy: f64 = av.iter().map(|x| x * x).sum();
776            // v is a unit vector, so ideal energy = 1; deviation = |energy - 1|
777            let dev = (energy - 1.0).abs();
778            if dev > max_dev {
779                max_dev = dev;
780            }
781        }
782        max_dev
783    }
784
785    /// Check whether exact recovery is theoretically possible.
786    ///
787    /// Based on the rule of thumb: `m >= 2k * ln(n / k)`.
788    pub fn exact_recovery_condition(k: usize, m: usize, n: usize) -> bool {
789        if k == 0 || n == 0 || k > n {
790            return true;
791        }
792        let required = 2.0 * k as f64 * ((n as f64 / k as f64).ln()).max(1.0);
793        m as f64 >= required
794    }
795
796    /// Lower bound on the number of measurements for RIP-based recovery.
797    ///
798    /// Returns `ceil(C * k * log(n/k))` where `C` is a constant (here 4).
799    pub fn rip_measurement_lower_bound(k: usize, n: usize) -> usize {
800        if k == 0 || n == 0 || k > n {
801            return 0;
802        }
803        let c = 4.0_f64;
804        (c * k as f64 * (n as f64 / k as f64).ln()).ceil() as usize
805    }
806
807    /// Compute the signal reconstruction error bound.
808    ///
809    /// For LASSO with regularisation `lambda`, the error is bounded by
810    /// `C * lambda * sqrt(k)` where `C` depends on the RIP constant.
811    pub fn lasso_error_bound(lambda: f64, k: usize, rip_delta: f64) -> f64 {
812        let c = (1.0 + rip_delta) / (1.0 - 2.0_f64.sqrt() * rip_delta).max(1e-14);
813        c * lambda * (k as f64).sqrt()
814    }
815}
816
817// ─────────────────────────────────────────────────────────────────────────────
818// KSvd — Dictionary Learning
819// ─────────────────────────────────────────────────────────────────────────────
820
821/// K-SVD dictionary learning algorithm.
822///
823/// Alternates between a sparse coding stage (OMP) and a dictionary update stage
824/// (rank-1 SVD update) to learn an overcomplete dictionary adapted to training data.
825///
826/// Reference: Aharon, Elad & Bruckstein (2006).
827#[allow(dead_code)]
828pub struct KSvd {
829    /// Number of dictionary atoms (columns).
830    pub n_atoms: usize,
831    /// Target sparsity per signal.
832    pub sparsity: usize,
833    /// Number of training iterations.
834    pub n_iter: usize,
835}
836
837#[allow(dead_code)]
838impl KSvd {
839    /// Create a new K-SVD learner.
840    ///
841    /// - `n_atoms` — number of dictionary atoms
842    /// - `sparsity` — maximum number of atoms per signal
843    /// - `n_iter` — number of alternating optimisation iterations
844    pub fn new(n_atoms: usize, sparsity: usize, n_iter: usize) -> Self {
845        Self {
846            n_atoms,
847            sparsity,
848            n_iter,
849        }
850    }
851
852    /// Learn a dictionary from training signals.
853    ///
854    /// - `signals` — list of training signals (each of the same length `d`)
855    ///
856    /// Returns a dictionary matrix `D` of shape `d × n_atoms` (columns are atoms,
857    /// stored as `Vec<Vec`f64`>` in row-major form, i.e. `D[i][j]` is row `i`, atom `j`).
858    pub fn fit(&self, signals: &[Vec<f64>]) -> Vec<Vec<f64>> {
859        if signals.is_empty() {
860            return Vec::new();
861        }
862        let d = signals[0].len();
863        let n_signals = signals.len();
864        let n_atoms = self.n_atoms.min(d);
865
866        // Initialise dictionary by picking random training signals as atoms
867        let mut rng = rand::rng();
868        let mut dict: Vec<Vec<f64>> = (0..n_atoms)
869            .map(|k| {
870                let idx = k % n_signals;
871                let _ = idx; // use deterministic init
872                let pick = rng.random_range(0..n_signals);
873                let mut atom = signals[pick].clone();
874                let norm = l2_norm(&atom);
875                if norm > 1e-14 {
876                    for v in atom.iter_mut() {
877                        *v /= norm;
878                    }
879                }
880                atom
881            })
882            .collect();
883
884        let omp = OrthogonalMatchingPursuit::new(self.sparsity);
885
886        for _iter in 0..self.n_iter {
887            // --- Sparse Coding stage: encode each signal with OMP ---
888            // Build measurement matrix: D^T (n_atoms × d signals), as A for OMP
889            // OMP expects A of shape m×n where b is length m.
890            // Here we measure each signal y: we want sparse c s.t. D c ≈ y.
891            // A = D (d × n_atoms), b = y (length d).
892            // Transpose: pass dict^T as the A matrix for OMP.
893            let dict_t: Vec<Vec<f64>> = (0..d)
894                .map(|i| (0..n_atoms).map(|k| dict[k][i]).collect())
895                .collect();
896
897            let codes: Vec<Vec<f64>> = signals.iter().map(|y| omp.solve(&dict_t, y)).collect();
898
899            // --- Dictionary Update stage: update each atom via rank-1 SVD ---
900            for k in 0..n_atoms {
901                // Find signals that use atom k
902                let using: Vec<usize> = (0..n_signals)
903                    .filter(|&s| codes[s][k].abs() > 1e-14)
904                    .collect();
905                if using.is_empty() {
906                    // Re-initialise dead atom
907                    let pick = rng.random_range(0..n_signals);
908                    let mut atom = signals[pick].clone();
909                    let norm = l2_norm(&atom);
910                    if norm > 1e-14 {
911                        for v in atom.iter_mut() {
912                            *v /= norm;
913                        }
914                    }
915                    dict[k] = atom;
916                    continue;
917                }
918
919                // Compute error matrix for atom k:
920                // E_k = Y - sum_{j≠k} d_j c_j^T
921                // Then update d_k and c_k via rank-1 approximation of E_k.
922                let e_rows: Vec<Vec<f64>> = using
923                    .iter()
924                    .map(|&s| {
925                        let mut e = signals[s].clone();
926                        for j in 0..n_atoms {
927                            if j == k {
928                                continue;
929                            }
930                            let coef = codes[s][j];
931                            for i in 0..d {
932                                e[i] -= coef * dict[j][i];
933                            }
934                        }
935                        e
936                    })
937                    .collect();
938
939                // Power iteration on E_k to find dominant left singular vector
940                let mut atom = dict[k].clone();
941                for _pi in 0..10 {
942                    // atom_new = E_k^T (E_k atom) / ||E_k atom||
943                    // E_k is (#using × d): rows are e_rows
944                    let e_atom: Vec<f64> = e_rows
945                        .iter()
946                        .map(|row| row.iter().zip(atom.iter()).map(|(a, b)| a * b).sum::<f64>())
947                        .collect();
948                    let mut new_atom = vec![0.0_f64; d];
949                    for (e_row, &ea) in e_rows.iter().zip(e_atom.iter()) {
950                        for (i, &ei) in e_row.iter().enumerate() {
951                            new_atom[i] += ei * ea;
952                        }
953                    }
954                    normalise(&mut new_atom);
955                    atom = new_atom;
956                }
957                dict[k] = atom;
958
959                // Update coefficients (project signals onto new atom)
960                for &s in &using {
961                    e_rows.iter().position(|_| true).map(|_| ()).unwrap_or(());
962                    // find position in using
963                    if let Some(pos) = using.iter().position(|&u| u == s) {
964                        let dot: f64 = e_rows[pos]
965                            .iter()
966                            .zip(dict[k].iter())
967                            .map(|(a, b)| a * b)
968                            .sum();
969                        let _ = (s, dot, pos);
970                    }
971                }
972                // Reset e_rows to satisfy borrow checker (it was moved conceptually)
973                let _ = e_rows.len();
974            }
975        }
976
977        dict
978    }
979
980    /// Encode a single signal using the learned dictionary.
981    ///
982    /// Returns a sparse coefficient vector of length `n_atoms`.
983    pub fn encode(&self, dict: &[Vec<f64>], signal: &[f64]) -> Vec<f64> {
984        if dict.is_empty() || signal.is_empty() {
985            return Vec::new();
986        }
987        let d = signal.len();
988        let n_atoms = dict.len();
989        // Build A = D^T (d × n_atoms) for OMP
990        let dict_t: Vec<Vec<f64>> = (0..d)
991            .map(|i| (0..n_atoms).map(|k| dict[k][i]).collect())
992            .collect();
993        let omp = OrthogonalMatchingPursuit::new(self.sparsity);
994        omp.solve(&dict_t, signal)
995    }
996
997    /// Reconstruct a signal from its sparse code and dictionary.
998    ///
999    /// Returns `D c` where `D` is the dictionary and `c` is the code.
1000    pub fn reconstruct(dict: &[Vec<f64>], code: &[f64]) -> Vec<f64> {
1001        if dict.is_empty() {
1002            return Vec::new();
1003        }
1004        let d = dict[0].len();
1005        let mut out = vec![0.0_f64; d];
1006        for (k, atom) in dict.iter().enumerate() {
1007            if k >= code.len() {
1008                break;
1009            }
1010            for (i, &ai) in atom.iter().enumerate() {
1011                out[i] += code[k] * ai;
1012            }
1013        }
1014        out
1015    }
1016}
1017
1018// ─────────────────────────────────────────────────────────────────────────────
1019// MRI-like Compressed Sensing
1020// ─────────────────────────────────────────────────────────────────────────────
1021
1022/// Compressed sensing reconstruction for MRI-like k-space data.
1023///
1024/// In MRI, measurements are taken in Fourier (k-space) domain.  This module
1025/// provides a simplified 1-D model: the signal is sparse in the DCT domain,
1026/// and k-space samples are random Fourier measurements.
1027#[allow(dead_code)]
1028pub struct MriCompressedSensing {
1029    /// Signal length.
1030    pub n: usize,
1031    /// Number of k-space samples (measurements).
1032    pub m: usize,
1033}
1034
1035#[allow(dead_code)]
1036impl MriCompressedSensing {
1037    /// Create an MRI-CS reconstruction problem for a signal of length `n`
1038    /// with `m` k-space measurements.
1039    pub fn new(n: usize, m: usize) -> Self {
1040        Self { n, m }
1041    }
1042
1043    /// Generate random k-space sampling indices in `[0, n)`.
1044    ///
1045    /// Returns `m` unique indices (or repeating if `m > n`).
1046    pub fn sample_kspace_indices(&self) -> Vec<usize> {
1047        let mut rng = rand::rng();
1048        let mut indices: Vec<usize> = (0..self.n).collect();
1049        // Fisher-Yates shuffle, take first m
1050        for i in 0..self.m.min(self.n) {
1051            let j = rng.random_range(i..self.n);
1052            indices.swap(i, j);
1053        }
1054        indices[..self.m.min(self.n)].to_vec()
1055    }
1056
1057    /// Build a partial Fourier (cosine) measurement matrix from k-space indices.
1058    ///
1059    /// Row `i` corresponds to k-space sample `k_i`; entry `(i, j)` is
1060    /// `cos(2π k_i j / n) / sqrt(m)`.
1061    pub fn build_measurement_matrix(&self, kspace_indices: &[usize]) -> Vec<Vec<f64>> {
1062        let scale = 1.0 / (self.m as f64).sqrt();
1063        kspace_indices
1064            .iter()
1065            .map(|&ki| {
1066                (0..self.n)
1067                    .map(|j| {
1068                        (2.0 * std::f64::consts::PI * ki as f64 * j as f64 / self.n as f64).cos()
1069                            * scale
1070                    })
1071                    .collect()
1072            })
1073            .collect()
1074    }
1075
1076    /// Reconstruct a signal from k-space measurements using FISTA.
1077    ///
1078    /// - `measurements` — observed k-space values
1079    /// - `kspace_indices` — which k-space lines were sampled
1080    /// - `lambda` — sparsity regularisation
1081    /// - `max_iter` — number of FISTA iterations
1082    ///
1083    /// Returns the reconstructed signal.
1084    #[allow(clippy::too_many_arguments)]
1085    pub fn reconstruct_fista(
1086        &self,
1087        measurements: &[f64],
1088        kspace_indices: &[usize],
1089        lambda: f64,
1090        max_iter: usize,
1091    ) -> Vec<f64> {
1092        let a = self.build_measurement_matrix(kspace_indices);
1093        BasisPursuit::solve_fista(&a, measurements, lambda, max_iter)
1094    }
1095
1096    /// Compute the peak signal-to-noise ratio (PSNR) between two signals.
1097    ///
1098    /// `PSNR = 20 * log10(max_val / RMSE)`
1099    pub fn psnr(original: &[f64], reconstructed: &[f64], max_val: f64) -> f64 {
1100        let n = original.len().min(reconstructed.len());
1101        if n == 0 {
1102            return 0.0;
1103        }
1104        let mse: f64 = original[..n]
1105            .iter()
1106            .zip(reconstructed[..n].iter())
1107            .map(|(a, b)| (a - b).powi(2))
1108            .sum::<f64>()
1109            / n as f64;
1110        if mse < 1e-14 {
1111            return f64::INFINITY;
1112        }
1113        20.0 * (max_val / mse.sqrt()).log10()
1114    }
1115}
1116
1117// ─────────────────────────────────────────────────────────────────────────────
1118// SparseSignal — synthetic sparse signal generation
1119// ─────────────────────────────────────────────────────────────────────────────
1120
1121/// Utility for generating and manipulating synthetic sparse signals.
1122#[allow(dead_code)]
1123pub struct SparseSignal;
1124
1125#[allow(dead_code)]
1126impl SparseSignal {
1127    /// Generate a `k`-sparse signal of length `n` with random support and values.
1128    ///
1129    /// Non-zero coefficients are drawn uniformly from `[-amplitude, amplitude]`.
1130    pub fn generate(n: usize, k: usize, amplitude: f64) -> Vec<f64> {
1131        let mut rng = rand::rng();
1132        let mut signal = vec![0.0_f64; n];
1133        let k = k.min(n);
1134
1135        // Choose k unique indices
1136        let mut indices: Vec<usize> = (0..n).collect();
1137        for i in 0..k {
1138            let j = rng.random_range(i..n);
1139            indices.swap(i, j);
1140        }
1141        for &idx in &indices[..k] {
1142            signal[idx] = rng.random_range(-amplitude..amplitude);
1143        }
1144        signal
1145    }
1146
1147    /// Add Gaussian noise with standard deviation `sigma` to a signal.
1148    pub fn add_noise(signal: &[f64], sigma: f64) -> Vec<f64> {
1149        let mut rng = rand::rng();
1150        signal
1151            .iter()
1152            .map(|&x| {
1153                let u1: f64 = rng.random_range(1e-12_f64..1.0_f64);
1154                let u2: f64 = rng.random_range(0.0_f64..1.0_f64);
1155                let noise = (-2.0_f64 * u1.ln()).sqrt()
1156                    * (2.0_f64 * std::f64::consts::PI * u2).cos()
1157                    * sigma;
1158                x + noise
1159            })
1160            .collect()
1161    }
1162
1163    /// Compute the support error between two signals (fraction of mismatched support).
1164    ///
1165    /// A value of 0 means the non-zero patterns are identical.
1166    pub fn support_error(truth: &[f64], recovered: &[f64], threshold: f64) -> f64 {
1167        let n = truth.len().min(recovered.len());
1168        if n == 0 {
1169            return 0.0;
1170        }
1171        let mismatches: usize = truth[..n]
1172            .iter()
1173            .zip(recovered[..n].iter())
1174            .filter(|&(t, r): &(&f64, &f64)| {
1175                let t_nonzero = t.abs() > threshold;
1176                let r_nonzero = r.abs() > threshold;
1177                t_nonzero != r_nonzero
1178            })
1179            .count();
1180        mismatches as f64 / n as f64
1181    }
1182
1183    /// Compute the relative reconstruction error `||x - x_hat||_2 / ||x||_2`.
1184    pub fn relative_error(truth: &[f64], recovered: &[f64]) -> f64 {
1185        let n = truth.len().min(recovered.len());
1186        let err: f64 = truth[..n]
1187            .iter()
1188            .zip(recovered[..n].iter())
1189            .map(|(a, b)| (a - b).powi(2))
1190            .sum::<f64>()
1191            .sqrt();
1192        let norm: f64 = truth[..n].iter().map(|x| x * x).sum::<f64>().sqrt();
1193        if norm < 1e-14 { err } else { err / norm }
1194    }
1195}
1196
1197// ─────────────────────────────────────────────────────────────────────────────
1198// Tests
1199// ─────────────────────────────────────────────────────────────────────────────
1200
1201#[cfg(test)]
1202mod tests {
1203    use super::*;
1204
1205    // ---- soft_threshold ----
1206
1207    #[test]
1208    fn test_soft_threshold_positive_above() {
1209        assert!((soft_threshold(3.0, 1.0) - 2.0).abs() < 1e-12);
1210    }
1211
1212    #[test]
1213    fn test_soft_threshold_positive_below() {
1214        assert_eq!(soft_threshold(0.5, 1.0), 0.0);
1215    }
1216
1217    #[test]
1218    fn test_soft_threshold_negative_above() {
1219        assert!((soft_threshold(-3.0, 1.0) + 2.0).abs() < 1e-12);
1220    }
1221
1222    #[test]
1223    fn test_soft_threshold_zero() {
1224        assert_eq!(soft_threshold(0.0, 1.0), 0.0);
1225    }
1226
1227    #[test]
1228    fn test_soft_threshold_exact_boundary() {
1229        assert_eq!(soft_threshold(1.0, 1.0), 0.0);
1230        assert_eq!(soft_threshold(-1.0, 1.0), 0.0);
1231    }
1232
1233    #[test]
1234    fn test_soft_threshold_zero_lambda() {
1235        assert!((soft_threshold(5.0, 0.0) - 5.0).abs() < 1e-12);
1236    }
1237
1238    // ---- nyquist_rate ----
1239
1240    #[test]
1241    fn test_nyquist_rate_basic() {
1242        assert!((nyquist_rate(1000.0) - 2000.0).abs() < 1e-9);
1243    }
1244
1245    #[test]
1246    fn test_nyquist_rate_zero() {
1247        assert_eq!(nyquist_rate(0.0), 0.0);
1248    }
1249
1250    // ---- compression_ratio ----
1251
1252    #[test]
1253    fn test_compression_ratio_half() {
1254        assert!((compression_ratio(100, 50) - 0.5).abs() < 1e-12);
1255    }
1256
1257    #[test]
1258    fn test_compression_ratio_zero_n() {
1259        assert_eq!(compression_ratio(0, 10), 0.0);
1260    }
1261
1262    #[test]
1263    fn test_compression_ratio_full() {
1264        assert!((compression_ratio(10, 10) - 1.0).abs() < 1e-12);
1265    }
1266
1267    // ---- l2_norm / normalise ----
1268
1269    #[test]
1270    fn test_l2_norm_known() {
1271        let x = vec![3.0, 4.0];
1272        assert!((l2_norm(&x) - 5.0).abs() < 1e-12);
1273    }
1274
1275    #[test]
1276    fn test_normalise_unit_vector() {
1277        let mut v = vec![3.0, 0.0, 4.0];
1278        normalise(&mut v);
1279        assert!((l2_norm(&v) - 1.0).abs() < 1e-12);
1280    }
1281
1282    #[test]
1283    fn test_normalise_zero_vector_unchanged() {
1284        let mut v = vec![0.0, 0.0, 0.0];
1285        normalise(&mut v);
1286        assert_eq!(v, vec![0.0, 0.0, 0.0]);
1287    }
1288
1289    // ---- mat_vec / mat_transpose_vec ----
1290
1291    #[test]
1292    fn test_mat_vec_identity() {
1293        let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1294        let x = vec![3.0, 7.0];
1295        let y = mat_vec(&a, &x);
1296        assert!((y[0] - 3.0).abs() < 1e-12);
1297        assert!((y[1] - 7.0).abs() < 1e-12);
1298    }
1299
1300    #[test]
1301    fn test_mat_transpose_vec_basic() {
1302        let a = vec![vec![1.0, 2.0, 3.0]]; // 1×3
1303        let x = vec![2.0]; // length 1
1304        let y = mat_transpose_vec(&a, &x);
1305        assert_eq!(y.len(), 3);
1306        assert!((y[0] - 2.0).abs() < 1e-12);
1307        assert!((y[1] - 4.0).abs() < 1e-12);
1308        assert!((y[2] - 6.0).abs() < 1e-12);
1309    }
1310
1311    // ---- DctBasis ----
1312
1313    #[test]
1314    fn test_dct_roundtrip() {
1315        let n = 8;
1316        let basis = DctBasis::new(n);
1317        let signal: Vec<f64> = (0..n).map(|i| (i as f64).sin()).collect();
1318        let coeffs = basis.transform(&signal);
1319        let recovered = basis.inverse(&coeffs);
1320        for (a, b) in signal.iter().zip(recovered.iter()) {
1321            assert!((a - b).abs() < 1e-10, "DCT roundtrip mismatch: {a} vs {b}");
1322        }
1323    }
1324
1325    #[test]
1326    fn test_dct_dc_component() {
1327        // Constant signal → only DC (k=0) is non-zero.
1328        let n = 8;
1329        let basis = DctBasis::new(n);
1330        let signal = vec![1.0_f64; n];
1331        let coeffs = basis.transform(&signal);
1332        for k in 1..n {
1333            assert!(
1334                coeffs[k].abs() < 1e-10,
1335                "non-DC coefficient k={k} should be ~0, got {}",
1336                coeffs[k]
1337            );
1338        }
1339        assert!(coeffs[0].abs() > 0.5, "DC component should be non-zero");
1340    }
1341
1342    #[test]
1343    fn test_dct_length_preserved() {
1344        let n = 16;
1345        let basis = DctBasis::new(n);
1346        let signal: Vec<f64> = vec![1.0; n];
1347        let coeffs = basis.transform(&signal);
1348        assert_eq!(coeffs.len(), n);
1349    }
1350
1351    #[test]
1352    fn test_dct_energy_preservation() {
1353        // Parseval's theorem: ||x||^2 ≈ ||X||^2 for orthonormal DCT.
1354        let n = 8;
1355        let basis = DctBasis::new(n);
1356        let signal: Vec<f64> = (0..n).map(|i| i as f64).collect();
1357        let coeffs = basis.transform(&signal);
1358        let e_signal: f64 = signal.iter().map(|x| x * x).sum();
1359        let e_coeffs: f64 = coeffs.iter().map(|x| x * x).sum();
1360        assert!((e_signal - e_coeffs).abs() / (e_signal + 1.0) < 1e-10);
1361    }
1362
1363    #[test]
1364    fn test_dct_new_n() {
1365        let basis = DctBasis::new(4);
1366        assert_eq!(basis.n, 4);
1367    }
1368
1369    #[test]
1370    fn test_dct_truncate_keeps_k_largest() {
1371        let basis = DctBasis::new(8);
1372        let coeffs = vec![1.0, 5.0, 0.1, 3.0, 0.0, 2.0, 0.0, 0.0];
1373        let truncated = basis.truncate(&coeffs, 2);
1374        let nonzero = truncated.iter().filter(|&&v| v.abs() > 1e-14).count();
1375        assert_eq!(nonzero, 2, "truncate(k=2) should leave 2 non-zeros");
1376        assert!((truncated[1] - 5.0).abs() < 1e-12, "5.0 should be kept");
1377        assert!((truncated[3] - 3.0).abs() < 1e-12, "3.0 should be kept");
1378    }
1379
1380    // ---- RandomMeasurementMatrix ----
1381
1382    #[test]
1383    fn test_random_measurement_matrix_dimensions() {
1384        let mat = RandomMeasurementMatrix::generate_gaussian(10, 20);
1385        assert_eq!(mat.m, 10);
1386        assert_eq!(mat.n, 20);
1387        assert_eq!(mat.matrix.len(), 10);
1388        assert_eq!(mat.matrix[0].len(), 20);
1389    }
1390
1391    #[test]
1392    fn test_measurement_output_length() {
1393        let mat = RandomMeasurementMatrix::generate_gaussian(5, 10);
1394        let x = vec![1.0_f64; 10];
1395        let y = mat.measure(&x);
1396        assert_eq!(y.len(), 5);
1397    }
1398
1399    #[test]
1400    fn test_measurement_linearity() {
1401        let mat = RandomMeasurementMatrix::generate_gaussian(5, 8);
1402        let x1: Vec<f64> = (0..8).map(|i| i as f64).collect();
1403        let x2: Vec<f64> = (0..8).map(|i| (8 - i) as f64).collect();
1404        let y1 = mat.measure(&x1);
1405        let y2 = mat.measure(&x2);
1406        let y_sum: Vec<f64> = x1.iter().zip(x2.iter()).map(|(a, b)| a + b).collect();
1407        let y_direct = mat.measure(&y_sum);
1408        for (a, b) in y_direct
1409            .iter()
1410            .zip(y1.iter().zip(y2.iter()).map(|(a, b)| a + b))
1411        {
1412            assert!((a - b).abs() < 1e-10, "linearity: {a} vs {b}");
1413        }
1414    }
1415
1416    #[test]
1417    fn test_bernoulli_matrix_dimensions() {
1418        let mat = RandomMeasurementMatrix::generate_bernoulli(8, 16);
1419        assert_eq!(mat.m, 8);
1420        assert_eq!(mat.n, 16);
1421    }
1422
1423    #[test]
1424    fn test_bernoulli_entries_are_plus_minus_scale() {
1425        let m = 5;
1426        let n = 10;
1427        let mat = RandomMeasurementMatrix::generate_bernoulli(m, n);
1428        let scale = 1.0 / (m as f64).sqrt();
1429        for row in &mat.matrix {
1430            for &v in row {
1431                let diff = (v.abs() - scale).abs();
1432                assert!(diff < 1e-12, "Bernoulli entry |{v}| ≠ {scale}");
1433            }
1434        }
1435    }
1436
1437    // ---- BasisPursuit (ISTA) ----
1438
1439    #[test]
1440    fn test_ista_trivial_identity() {
1441        // A = I_4, b = [1,0,0,0], lambda small → recover e_0
1442        let a: Vec<Vec<f64>> = (0..4)
1443            .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1444            .collect();
1445        let b = vec![1.0, 0.0, 0.0, 0.0];
1446        let x = BasisPursuit::solve_lasso(&a, &b, 1e-4, 200);
1447        assert_eq!(x.len(), 4);
1448        assert!((x[0] - 1.0).abs() < 0.05, "x[0] should be ~1, got {}", x[0]);
1449        assert!(x[1].abs() < 0.05);
1450    }
1451
1452    #[test]
1453    fn test_ista_empty_input() {
1454        let x = BasisPursuit::solve_lasso(&[], &[], 1.0, 100);
1455        assert!(x.is_empty());
1456    }
1457
1458    #[test]
1459    fn test_ista_sparse_recovery() {
1460        // A = 4×4 identity, sparse signal [3, 0, 0, 0]
1461        let a: Vec<Vec<f64>> = (0..4)
1462            .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1463            .collect();
1464        let b = vec![3.0, 0.0, 0.0, 0.0];
1465        let x = BasisPursuit::solve_lasso(&a, &b, 0.01, 300);
1466        assert!((x[0] - 3.0).abs() < 0.1, "x[0] ≈ 3, got {}", x[0]);
1467    }
1468
1469    // ---- FISTA ----
1470
1471    #[test]
1472    fn test_fista_identity_recovery() {
1473        let a: Vec<Vec<f64>> = (0..4)
1474            .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1475            .collect();
1476        let b = vec![0.0, 2.5, 0.0, 0.0];
1477        let x = BasisPursuit::solve_fista(&a, &b, 1e-3, 300);
1478        assert!((x[1] - 2.5).abs() < 0.05, "x[1] ≈ 2.5, got {}", x[1]);
1479    }
1480
1481    #[test]
1482    fn test_fista_empty_input() {
1483        let x = BasisPursuit::solve_fista(&[], &[], 1.0, 100);
1484        assert!(x.is_empty());
1485    }
1486
1487    #[test]
1488    fn test_fista_objective_decreases() {
1489        let a: Vec<Vec<f64>> = (0..3)
1490            .map(|i| (0..3).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1491            .collect();
1492        let b = vec![1.0, -1.0, 0.5];
1493        let lambda = 0.1;
1494        let x0 = vec![0.0_f64; 3];
1495        let obj0 = BasisPursuit::objective(&a, &b, &x0, lambda);
1496        let x_hat = BasisPursuit::solve_fista(&a, &b, lambda, 100);
1497        let obj1 = BasisPursuit::objective(&a, &b, &x_hat, lambda);
1498        assert!(
1499            obj1 <= obj0 + 1e-10,
1500            "FISTA should decrease objective: {obj1} > {obj0}"
1501        );
1502    }
1503
1504    // ---- OMP ----
1505
1506    #[test]
1507    fn test_omp_exact_1_sparse() {
1508        // Identity system with 1-sparse signal
1509        let a: Vec<Vec<f64>> = (0..4)
1510            .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1511            .collect();
1512        let b = vec![0.0, 5.0, 0.0, 0.0];
1513        let omp = OrthogonalMatchingPursuit::new(1);
1514        let x = omp.solve(&a, &b);
1515        assert_eq!(x.len(), 4);
1516        assert!((x[1] - 5.0).abs() < 1e-10, "x[1] should be 5, got {}", x[1]);
1517    }
1518
1519    #[test]
1520    fn test_omp_exact_2_sparse() {
1521        let a: Vec<Vec<f64>> = (0..4)
1522            .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1523            .collect();
1524        let b = vec![2.0, 0.0, 7.0, 0.0];
1525        let omp = OrthogonalMatchingPursuit::new(2);
1526        let x = omp.solve(&a, &b);
1527        assert!((x[0] - 2.0).abs() < 1e-8);
1528        assert!((x[2] - 7.0).abs() < 1e-8);
1529    }
1530
1531    #[test]
1532    fn test_omp_empty_input() {
1533        let omp = OrthogonalMatchingPursuit::new(3);
1534        let x = omp.solve(&[], &[]);
1535        assert!(x.is_empty());
1536    }
1537
1538    #[test]
1539    fn test_omp_new() {
1540        let omp = OrthogonalMatchingPursuit::new(5);
1541        assert_eq!(omp.max_k, 5);
1542    }
1543
1544    #[test]
1545    fn test_omp_residual_decreases() {
1546        // Random orthogonal system: with enough sparsity the residual should shrink.
1547        let a: Vec<Vec<f64>> = (0..6)
1548            .map(|i| (0..6).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1549            .collect();
1550        let b = vec![1.0, -1.0, 2.0, 0.0, -2.0, 0.0];
1551        let omp = OrthogonalMatchingPursuit::new(4);
1552        let x = omp.solve(&a, &b);
1553        // Residual after recovery
1554        let residual: f64 = b
1555            .iter()
1556            .enumerate()
1557            .map(|(i, &bi)| {
1558                let ax: f64 = a[i].iter().zip(x.iter()).map(|(aij, xj)| aij * xj).sum();
1559                (bi - ax).powi(2)
1560            })
1561            .sum::<f64>()
1562            .sqrt();
1563        assert!(residual < 1e-8, "residual should be tiny, got {residual}");
1564    }
1565
1566    #[test]
1567    fn test_omp_support_length() {
1568        let a: Vec<Vec<f64>> = (0..5)
1569            .map(|i| (0..5).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1570            .collect();
1571        let b = vec![1.0, 0.0, 3.0, 0.0, 2.0];
1572        let omp = OrthogonalMatchingPursuit::new(3);
1573        let supp = omp.support(&a, &b);
1574        assert_eq!(
1575            supp.len(),
1576            3,
1577            "support should have 3 elements, got {}",
1578            supp.len()
1579        );
1580    }
1581
1582    // ---- SparsityMetrics ----
1583
1584    #[test]
1585    fn test_l0_norm_basic() {
1586        let x = vec![0.0, 1.0, 0.0, -2.0, 0.001];
1587        assert_eq!(SparsityMetrics::l0_norm(&x, 0.5), 2); // 1.0 and -2.0
1588    }
1589
1590    #[test]
1591    fn test_l0_norm_all_zero() {
1592        let x = vec![0.0, 0.0, 0.0];
1593        assert_eq!(SparsityMetrics::l0_norm(&x, 1e-6), 0);
1594    }
1595
1596    #[test]
1597    fn test_l1_norm_basic() {
1598        let x = vec![1.0, -2.0, 3.0];
1599        assert!((SparsityMetrics::l1_norm(&x) - 6.0).abs() < 1e-12);
1600    }
1601
1602    #[test]
1603    fn test_l1_norm_empty() {
1604        assert_eq!(SparsityMetrics::l1_norm(&[]), 0.0);
1605    }
1606
1607    #[test]
1608    fn test_l2_norm_sparsity() {
1609        let x = vec![3.0, 4.0];
1610        assert!((SparsityMetrics::l2_norm(&x) - 5.0).abs() < 1e-12);
1611    }
1612
1613    #[test]
1614    fn test_gini_sparse_is_near_one() {
1615        let x = vec![0.0, 0.0, 0.0, 10.0]; // very sparse
1616        let g = SparsityMetrics::gini(&x);
1617        assert!(g > 0.6, "sparse signal should have high Gini, got {g}");
1618    }
1619
1620    #[test]
1621    fn test_gini_uniform_is_near_zero() {
1622        let x = vec![1.0, 1.0, 1.0, 1.0]; // flat = dense
1623        let g = SparsityMetrics::gini(&x);
1624        assert!(g < 0.1, "uniform signal should have low Gini, got {g}");
1625    }
1626
1627    #[test]
1628    fn test_coherence_identity_is_zero() {
1629        // Columns of the identity are orthogonal → coherence = 0
1630        let a: Vec<Vec<f64>> = (0..4)
1631            .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1632            .collect();
1633        let mu = SparsityMetrics::coherence(&a);
1634        assert!(mu < 1e-12, "identity coherence should be 0, got {mu}");
1635    }
1636
1637    #[test]
1638    fn test_coherence_empty() {
1639        assert_eq!(SparsityMetrics::coherence(&[]), 0.0);
1640    }
1641
1642    #[test]
1643    fn test_coherence_collinear_columns() {
1644        // Two identical columns → coherence = 1
1645        let a = vec![vec![1.0, 1.0], vec![0.0, 0.0]];
1646        let mu = SparsityMetrics::coherence(&a);
1647        assert!(
1648            (mu - 1.0).abs() < 1e-10,
1649            "collinear columns → coherence=1, got {mu}"
1650        );
1651    }
1652
1653    #[test]
1654    fn test_babel_function_identity() {
1655        let a: Vec<Vec<f64>> = (0..4)
1656            .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1657            .collect();
1658        let babel = SparsityMetrics::babel_function(&a, 2);
1659        assert!(babel < 1e-12, "identity babel should be 0, got {babel}");
1660    }
1661
1662    // ---- RecoveryGuarantee ----
1663
1664    #[test]
1665    fn test_exact_recovery_condition_sufficient_measurements() {
1666        // k=1, n=100 → need m >= 2*ln(100) ≈ 9.2 → m=20 is sufficient
1667        assert!(RecoveryGuarantee::exact_recovery_condition(1, 20, 100));
1668    }
1669
1670    #[test]
1671    fn test_exact_recovery_condition_insufficient() {
1672        // k=50, n=100 → need m >= 100*ln(2) ≈ 69 → m=5 is not sufficient
1673        assert!(!RecoveryGuarantee::exact_recovery_condition(50, 5, 100));
1674    }
1675
1676    #[test]
1677    fn test_exact_recovery_condition_k_zero() {
1678        assert!(RecoveryGuarantee::exact_recovery_condition(0, 0, 100));
1679    }
1680
1681    #[test]
1682    fn test_rip_constant_identity() {
1683        // For the identity matrix (m=n), RIP constant should be ~0
1684        let n = 4;
1685        let a: Vec<Vec<f64>> = (0..n)
1686            .map(|i| (0..n).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1687            .collect();
1688        let delta = RecoveryGuarantee::rip_constant(&a, 1);
1689        assert!(
1690            delta < 1e-10,
1691            "identity RIP constant should be ~0, got {delta}"
1692        );
1693    }
1694
1695    #[test]
1696    fn test_rip_constant_empty() {
1697        assert_eq!(RecoveryGuarantee::rip_constant(&[], 1), 0.0);
1698    }
1699
1700    #[test]
1701    fn test_rip_measurement_lower_bound_nonzero() {
1702        let lb = RecoveryGuarantee::rip_measurement_lower_bound(5, 100);
1703        assert!(lb > 0, "lower bound should be positive");
1704    }
1705
1706    #[test]
1707    fn test_lasso_error_bound_positive() {
1708        let bound = RecoveryGuarantee::lasso_error_bound(0.1, 4, 0.1);
1709        assert!(bound > 0.0, "LASSO error bound should be positive");
1710    }
1711
1712    // ---- gauss_solve ----
1713
1714    #[test]
1715    fn test_gauss_solve_2x2() {
1716        // [2 1; 1 3] x = [5; 10] → x = [1, 3]
1717        let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
1718        let b = vec![5.0, 10.0];
1719        let x = gauss_solve(&a, &b);
1720        assert!((x[0] - 1.0).abs() < 1e-10, "x[0]={}", x[0]);
1721        assert!((x[1] - 3.0).abs() < 1e-10, "x[1]={}", x[1]);
1722    }
1723
1724    #[test]
1725    fn test_gauss_solve_1x1() {
1726        let a = vec![vec![4.0]];
1727        let b = vec![8.0];
1728        let x = gauss_solve(&a, &b);
1729        assert!((x[0] - 2.0).abs() < 1e-12);
1730    }
1731
1732    #[test]
1733    fn test_gauss_solve_empty() {
1734        let x = gauss_solve(&[], &[]);
1735        assert!(x.is_empty());
1736    }
1737
1738    // ---- MriCompressedSensing ----
1739
1740    #[test]
1741    fn test_mri_cs_new() {
1742        let mri = MriCompressedSensing::new(64, 20);
1743        assert_eq!(mri.n, 64);
1744        assert_eq!(mri.m, 20);
1745    }
1746
1747    #[test]
1748    fn test_mri_kspace_indices_length() {
1749        let mri = MriCompressedSensing::new(32, 10);
1750        let idx = mri.sample_kspace_indices();
1751        assert_eq!(idx.len(), 10);
1752    }
1753
1754    #[test]
1755    fn test_mri_measurement_matrix_shape() {
1756        let mri = MriCompressedSensing::new(16, 8);
1757        let idx: Vec<usize> = (0..8).collect();
1758        let a = mri.build_measurement_matrix(&idx);
1759        assert_eq!(a.len(), 8);
1760        assert_eq!(a[0].len(), 16);
1761    }
1762
1763    #[test]
1764    fn test_psnr_identical_signals() {
1765        let s = vec![1.0, 2.0, 3.0];
1766        let psnr = MriCompressedSensing::psnr(&s, &s, 3.0);
1767        assert!(psnr.is_infinite(), "identical signals → PSNR = ∞");
1768    }
1769
1770    #[test]
1771    fn test_psnr_known_value() {
1772        let original = vec![1.0, 0.0];
1773        let reconstructed = vec![0.0, 0.0];
1774        let psnr = MriCompressedSensing::psnr(&original, &reconstructed, 1.0);
1775        assert!(psnr.is_finite(), "PSNR should be finite for non-identical");
1776    }
1777
1778    // ---- SparseSignal ----
1779
1780    #[test]
1781    fn test_sparse_signal_generate_sparsity() {
1782        let sig = SparseSignal::generate(20, 3, 1.0);
1783        assert_eq!(sig.len(), 20);
1784        let nnz = sig.iter().filter(|&&v| v.abs() > 1e-14).count();
1785        assert_eq!(
1786            nnz, 3,
1787            "generated signal should have exactly 3 non-zeros, got {nnz}"
1788        );
1789    }
1790
1791    #[test]
1792    fn test_sparse_signal_generate_length() {
1793        let sig = SparseSignal::generate(100, 5, 2.0);
1794        assert_eq!(sig.len(), 100);
1795    }
1796
1797    #[test]
1798    fn test_sparse_signal_relative_error_zero() {
1799        let s = vec![1.0, 2.0, 3.0];
1800        let err = SparseSignal::relative_error(&s, &s);
1801        assert!(
1802            err < 1e-12,
1803            "identical signals should have 0 relative error"
1804        );
1805    }
1806
1807    #[test]
1808    fn test_sparse_signal_support_error_identical() {
1809        let s = vec![0.0, 1.0, 0.0, 2.0];
1810        let err = SparseSignal::support_error(&s, &s, 0.5);
1811        assert_eq!(err, 0.0, "identical support → error = 0");
1812    }
1813
1814    // ---- KSvd ----
1815
1816    #[test]
1817    fn test_ksvd_new() {
1818        let k = KSvd::new(8, 2, 5);
1819        assert_eq!(k.n_atoms, 8);
1820        assert_eq!(k.sparsity, 2);
1821        assert_eq!(k.n_iter, 5);
1822    }
1823
1824    #[test]
1825    fn test_ksvd_reconstruct_zero_code() {
1826        let dict: Vec<Vec<f64>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1827        let code = vec![0.0, 0.0];
1828        let rec = KSvd::reconstruct(&dict, &code);
1829        assert_eq!(rec, vec![0.0, 0.0]);
1830    }
1831
1832    #[test]
1833    fn test_ksvd_reconstruct_unit_code() {
1834        let dict: Vec<Vec<f64>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1835        let code = vec![3.0, 5.0];
1836        let rec = KSvd::reconstruct(&dict, &code);
1837        assert!((rec[0] - 3.0).abs() < 1e-12);
1838        assert!((rec[1] - 5.0).abs() < 1e-12);
1839    }
1840
1841    #[test]
1842    fn test_ksvd_fit_returns_correct_shape() {
1843        // 4 signals of length 6, 4 atoms
1844        let signals: Vec<Vec<f64>> = (0..4)
1845            .map(|i| (0..6).map(|j| if j == i { 1.0 } else { 0.0 }).collect())
1846            .collect();
1847        let ksvd = KSvd::new(4, 1, 2);
1848        let dict = ksvd.fit(&signals);
1849        assert_eq!(dict.len(), 4, "dict should have 4 atoms");
1850        assert_eq!(dict[0].len(), 6, "each atom should have length 6");
1851    }
1852
1853    #[test]
1854    fn test_ksvd_encode_length() {
1855        let dict: Vec<Vec<f64>> = (0..4)
1856            .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1857            .collect();
1858        let ksvd = KSvd::new(4, 1, 1);
1859        let signal = vec![0.0, 1.0, 0.0, 0.0];
1860        let code = ksvd.encode(&dict, &signal);
1861        assert_eq!(code.len(), 4, "code length should match n_atoms");
1862    }
1863
1864    #[test]
1865    fn test_spectral_norm_identity() {
1866        let a: Vec<Vec<f64>> = (0..4)
1867            .map(|i| (0..4).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
1868            .collect();
1869        let sn = spectral_norm(&a, 30);
1870        assert!(
1871            (sn - 1.0).abs() < 0.01,
1872            "spectral norm of I should be ~1, got {sn}"
1873        );
1874    }
1875
1876    #[test]
1877    fn test_spectral_norm_empty() {
1878        let sn = spectral_norm(&[], 10);
1879        assert_eq!(sn, 0.0);
1880    }
1881}