use super::types::{CsConfig, CsResult, Measurement};
use crate::error::{FFTError, FFTResult};
pub struct OmpSolver {
pub config: CsConfig,
}
impl OmpSolver {
pub fn new(config: CsConfig) -> Self {
Self { config }
}
fn build_matrix(indices: &[usize], n: usize) -> Vec<Vec<f64>> {
use std::f64::consts::TAU;
let m = indices.len();
let mut a = vec![vec![0.0_f64; n]; 2 * m];
for (row, &idx) in indices.iter().enumerate() {
for col in 0..n {
let angle = TAU * (idx as f64) * (col as f64) / (n as f64);
a[2 * row][col] = angle.cos(); a[2 * row + 1][col] = -angle.sin(); }
}
a
}
fn flatten_measurements(meas: &Measurement) -> Vec<f64> {
meas.values.clone()
}
fn residual(a: &[Vec<f64>], y: &[f64], x: &[f64]) -> Vec<f64> {
let rows = a.len();
let cols = x.len();
let mut r = vec![0.0_f64; rows];
for i in 0..rows {
let mut ax_i = 0.0;
for j in 0..cols {
ax_i += a[i][j] * x[j];
}
r[i] = y[i] - ax_i;
}
r
}
fn most_correlated_column(a: &[Vec<f64>], r: &[f64], support: &[usize]) -> usize {
let n = if a.is_empty() { 0 } else { a[0].len() };
let rows = a.len();
let mut best_col = 0;
let mut best_val = f64::NEG_INFINITY;
for j in 0..n {
if support.contains(&j) {
continue;
}
let mut dot = 0.0;
for i in 0..rows {
dot += a[i][j] * r[i];
}
let abs_dot = dot.abs();
if abs_dot > best_val {
best_val = abs_dot;
best_col = j;
}
}
best_col
}
fn least_squares(a: &[Vec<f64>], y: &[f64], support: &[usize]) -> FFTResult<Vec<f64>> {
let m = a.len(); let s = support.len();
if s == 0 {
return Ok(vec![]);
}
let mut a_s: Vec<Vec<f64>> = (0..s)
.map(|k| (0..m).map(|i| a[i][support[k]]).collect())
.collect();
let mut rhs = y.to_vec();
let mut r_mat = vec![vec![0.0_f64; s]; s];
for k in 0..s {
let norm_sq: f64 = (0..m).map(|i| a_s[k][i] * a_s[k][i]).sum();
let norm = norm_sq.sqrt();
if norm < 1e-14 {
return Err(FFTError::ComputationError(
"OMP: singular sub-matrix encountered in QR".into(),
));
}
r_mat[k][k] = norm;
for i in 0..m {
a_s[k][i] /= norm;
}
let proj_rhs: f64 = (0..m).map(|i| a_s[k][i] * rhs[i]).sum();
r_mat[k][k] = norm; for j in (k + 1)..s {
let dot: f64 = (0..m).map(|i| a_s[k][i] * a_s[j][i]).sum();
r_mat[k][j] = dot;
for i in 0..m {
let qk = a_s[k][i];
a_s[j][i] -= dot * qk;
}
}
for i in 0..m {
let qk = a_s[k][i];
rhs[i] -= proj_rhs * qk;
}
r_mat[k][k] = norm; let _ = proj_rhs; }
let mut a_s2: Vec<Vec<f64>> = (0..s)
.map(|k| (0..m).map(|i| a[i][support[k]]).collect())
.collect();
let mut qty = vec![0.0_f64; s]; let mut r2 = vec![vec![0.0_f64; s]; s];
let mut y2 = y.to_vec();
for k in 0..s {
let norm_sq: f64 = (0..m).map(|i| a_s2[k][i] * a_s2[k][i]).sum();
let norm = norm_sq.sqrt();
if norm < 1e-14 {
return Err(FFTError::ComputationError(
"OMP: near-zero pivot during QR".into(),
));
}
r2[k][k] = norm;
for i in 0..m {
a_s2[k][i] /= norm;
}
let qk_dot_y: f64 = (0..m).map(|i| a_s2[k][i] * y2[i]).sum();
qty[k] = qk_dot_y;
for i in 0..m {
let qk = a_s2[k][i];
y2[i] -= qk_dot_y * qk;
}
for j in (k + 1)..s {
let dot: f64 = (0..m).map(|i| a_s2[k][i] * a_s2[j][i]).sum();
r2[k][j] = dot;
for i in 0..m {
let qk = a_s2[k][i];
a_s2[j][i] -= dot * qk;
}
}
}
let mut c = vec![0.0_f64; s];
for k in (0..s).rev() {
let mut sum = qty[k];
for j in (k + 1)..s {
sum -= r2[k][j] * c[j];
}
if r2[k][k].abs() < 1e-14 {
c[k] = 0.0;
} else {
c[k] = sum / r2[k][k];
}
}
Ok(c)
}
pub fn recover(&self, measurements: &Measurement, n: usize) -> FFTResult<CsResult> {
if measurements.indices.is_empty() {
return Err(FFTError::ValueError(
"No measurement indices provided".into(),
));
}
if measurements.values.len() != 2 * measurements.indices.len() {
return Err(FFTError::DimensionError(
"values must have length 2·|indices| (re/im interleaved)".into(),
));
}
let a = Self::build_matrix(&measurements.indices, n);
let y = Self::flatten_measurements(measurements);
let sparsity = self.config.sparsity.min(n);
let mut x = vec![0.0_f64; n];
let mut support: Vec<usize> = Vec::with_capacity(sparsity);
let mut r = Self::residual(&a, &y, &x);
let mut iters = 0;
for iter in 0..self.config.max_iter {
iters = iter + 1;
let res_norm: f64 = r.iter().map(|v| v * v).sum::<f64>().sqrt();
if res_norm < self.config.tol {
break;
}
if support.len() >= sparsity {
break;
}
let k_star = Self::most_correlated_column(&a, &r, &support);
support.push(k_star);
support.sort_unstable();
support.dedup();
let c = Self::least_squares(&a, &y, &support)?;
x = vec![0.0_f64; n];
for (idx, &col) in support.iter().enumerate() {
if idx < c.len() {
x[col] = c[idx];
}
}
r = Self::residual(&a, &y, &x);
}
let residual: f64 = r.iter().map(|v| v * v).sum::<f64>().sqrt();
Ok(CsResult {
recovered: x,
iterations: iters,
residual,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::TAU;
fn make_measurement(signal: &[f64], indices: &[usize]) -> Measurement {
let n = signal.len();
let m = indices.len();
let mut values = vec![0.0_f64; 2 * m];
for (row, &idx) in indices.iter().enumerate() {
let mut re = 0.0_f64;
let mut im = 0.0_f64;
for (j, &s) in signal.iter().enumerate() {
let angle = TAU * (idx as f64) * (j as f64) / (n as f64);
re += s * angle.cos();
im += s * (-angle.sin());
}
values[2 * row] = re;
values[2 * row + 1] = im;
}
Measurement {
indices: indices.to_vec(),
values,
}
}
#[test]
fn test_omp_recovers_sparse() {
let n = 32;
let mut signal = vec![0.0_f64; n];
signal[0] = 3.0;
signal[5] = -2.0;
signal[12] = 1.5;
let indices: Vec<usize> = (0..10).collect();
let meas = make_measurement(&signal, &indices);
let cfg = CsConfig {
sparsity: 3,
max_iter: 50,
tol: 1e-6,
};
let solver = OmpSolver::new(cfg);
let result = solver.recover(&meas, n).expect("OMP should succeed");
assert!(result.residual < 1e-4, "residual={}", result.residual);
let err: f64 = signal
.iter()
.zip(result.recovered.iter())
.map(|(a, b)| (a - b).abs())
.fold(f64::NEG_INFINITY, f64::max);
assert!(err < 0.1, "max_err={}", err);
}
#[test]
fn test_omp_exact_measurements() {
let n = 16;
let mut signal = vec![0.0_f64; n];
signal[2] = 5.0;
signal[9] = -3.0;
let indices: Vec<usize> = (0..n).collect();
let meas = make_measurement(&signal, &indices);
let cfg = CsConfig {
sparsity: 2,
max_iter: 50,
tol: 1e-8,
};
let solver = OmpSolver::new(cfg);
let result = solver.recover(&meas, n).expect("OMP should succeed");
assert!(result.residual < 1e-6, "residual={}", result.residual);
}
#[test]
fn test_omp_residual_decreases() {
let n = 16;
let mut signal = vec![0.0_f64; n];
signal[3] = 4.0;
signal[7] = -1.0;
signal[11] = 2.5;
let indices: Vec<usize> = (0..8).collect();
let meas = make_measurement(&signal, &indices);
let cfg1 = CsConfig {
sparsity: 1,
max_iter: 10,
tol: 1e-12,
};
let res1 = OmpSolver::new(cfg1).recover(&meas, n).expect("ok");
let cfg2 = CsConfig {
sparsity: 2,
max_iter: 10,
tol: 1e-12,
};
let res2 = OmpSolver::new(cfg2).recover(&meas, n).expect("ok");
assert!(
res2.residual <= res1.residual + 1e-10,
"residual should decrease: {} vs {}",
res2.residual,
res1.residual
);
}
#[test]
fn test_omp_overcomplete_dict() {
let n = 32;
let mut signal = vec![0.0_f64; n];
signal[1] = 1.0;
let indices = vec![0usize, 1, 2, 3, 4]; let meas = make_measurement(&signal, &indices);
let cfg = CsConfig {
sparsity: 1,
max_iter: 20,
tol: 1e-6,
};
let result = OmpSolver::new(cfg).recover(&meas, n).expect("ok");
assert_eq!(result.recovered.len(), n);
}
}