use crate::error::{FFTError, FFTResult};
use crate::fft::fft;
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct SparseFftResult {
pub frequencies: Vec<usize>,
pub amplitudes: Vec<f64>,
pub n: usize,
}
impl SparseFftResult {
pub fn amplitude_complex(&self, i: usize) -> Option<Complex64> {
if i < self.frequencies.len() {
Some(Complex64::new(self.amplitudes[2 * i], self.amplitudes[2 * i + 1]))
} else {
None
}
}
pub fn sparsity(&self) -> usize {
self.frequencies.len()
}
}
pub fn sparse_fft_simple(
signal: &[f64],
k: usize,
n_trials: usize,
) -> FFTResult<SparseFftResult> {
let n = signal.len();
if n == 0 {
return Err(FFTError::InvalidInput(
"sparse_fft_simple: empty signal".into(),
));
}
if k == 0 {
return Ok(SparseFftResult {
frequencies: vec![],
amplitudes: vec![],
n,
});
}
let mut votes = vec![0u32; n];
let actual_trials = n_trials.max(1);
for trial in 0..actual_trials {
let b = next_pow2((k * 4).max(8).min(n));
let sigma = {
let base: usize = 1 + 2 * ((trial * 7 + 3) % (n / 2).max(1));
base % n.max(1)
};
let step = (n / b).max(1);
let mut buf = vec![0.0f64; b];
for i in 0..b {
let src = (i * sigma * step) % n;
buf[i] = signal[src];
}
let spectrum = fft(&buf, None)?;
let mut magnitudes: Vec<(usize, f64)> = spectrum
.iter()
.enumerate()
.map(|(i, c)| (i, c.norm()))
.collect();
magnitudes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_k = k.min(b);
for &(local_bin, _mag) in magnitudes.iter().take(top_k) {
let global = (local_bin * n / b) % n;
for offset in 0..4usize {
let candidate = (global + offset * n / b / 4) % n;
votes[candidate] += 1;
}
}
}
let full_spectrum = fft(signal, None)?;
let mut ranked: Vec<(usize, u32, f64)> = (0..n)
.map(|i| (i, votes[i], full_spectrum[i].norm()))
.collect();
ranked.sort_by(|a, b| {
b.1.cmp(&a.1)
.then_with(|| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal))
});
let mut by_magnitude: Vec<(usize, f64)> = full_spectrum
.iter()
.enumerate()
.map(|(i, c)| (i, c.norm()))
.collect();
by_magnitude.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_k = k.min(n);
let mut frequencies: Vec<usize> = by_magnitude[..top_k].iter().map(|(i, _)| *i).collect();
frequencies.sort_unstable();
let mut amplitudes = Vec::with_capacity(2 * top_k);
for &bin in &frequencies {
amplitudes.push(full_spectrum[bin].re);
amplitudes.push(full_spectrum[bin].im);
}
Ok(SparseFftResult {
frequencies,
amplitudes,
n,
})
}
pub fn sparse_to_dense(sparse: &SparseFftResult) -> Vec<[f64; 2]> {
let mut out = vec![[0.0f64; 2]; sparse.n];
for (i, &bin) in sparse.frequencies.iter().enumerate() {
if bin < sparse.n {
out[bin][0] = sparse.amplitudes[2 * i];
out[bin][1] = sparse.amplitudes[2 * i + 1];
}
}
out
}
pub fn sparse_fft_lasso(
signal: &[f64],
k: usize,
lambda: f64,
) -> FFTResult<SparseFftResult> {
let n = signal.len();
if n == 0 {
return Err(FFTError::InvalidInput(
"sparse_fft_lasso: empty signal".into(),
));
}
let spectrum = fft(signal, None)?;
let nf = spectrum.len();
let mags: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
let threshold = if lambda > 0.0 {
lambda * n as f64 / 2.0
} else {
let mut sorted_mags = mags.clone();
sorted_mags.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = sorted_mags[nf / 2];
median * 3.0
};
let mut thresholded: Vec<(usize, Complex64)> = spectrum
.iter()
.enumerate()
.filter_map(|(i, &c)| {
let mag = c.norm();
if mag > threshold {
let scale = (mag - threshold) / mag;
Some((i, Complex64::new(c.re * scale, c.im * scale)))
} else {
None
}
})
.collect();
thresholded.sort_by(|a, b| {
b.1.norm().partial_cmp(&a.1.norm()).unwrap_or(std::cmp::Ordering::Equal)
});
let top_k = k.min(thresholded.len());
thresholded.truncate(top_k);
thresholded.sort_by_key(|(i, _)| *i);
let mut frequencies = Vec::with_capacity(top_k);
let mut amplitudes = Vec::with_capacity(2 * top_k);
for (bin, c) in thresholded {
frequencies.push(bin);
amplitudes.push(c.re);
amplitudes.push(c.im);
}
Ok(SparseFftResult {
frequencies,
amplitudes,
n,
})
}
#[derive(Debug, Clone)]
pub struct PronyResult {
pub frequencies: Vec<f64>,
pub amplitudes: Vec<f64>,
pub phases: Vec<f64>,
pub damping: Vec<f64>,
}
pub fn prony_method(signal: &[f64], n_components: usize) -> FFTResult<PronyResult> {
let n = signal.len();
let p = n_components;
if n < 2 * p + 2 {
return Err(FFTError::InvalidInput(format!(
"prony_method: signal length {} too short for {} components (need >= {})",
n,
p,
2 * p + 2
)));
}
let m = n - p; let mut a_mat = vec![0.0f64; m * p];
let mut b_vec = vec![0.0f64; m];
for row in 0..m {
for col in 0..p {
a_mat[row * p + col] = signal[row + p - 1 - col];
}
b_vec[row] = -signal[row + p];
}
let a_coeffs = solve_least_squares_normal(&a_mat, &b_vec, m, p)?;
let roots = companion_eigenvalues(&a_coeffs)?;
let amplitudes_complex = solve_vandermonde_ls(signal, &roots)?;
let mut frequencies = Vec::with_capacity(p);
let mut amplitudes = Vec::with_capacity(p);
let mut phases = Vec::with_capacity(p);
let mut damping = Vec::with_capacity(p);
for k in 0..roots.len() {
let z = roots[k];
let alpha = z.0.ln().max(-10.0); let omega = z.1;
let freq_norm = omega.abs() / (2.0 * PI);
let amp = amplitudes_complex[k].0.hypot(amplitudes_complex[k].1);
let phase = amplitudes_complex[k].1.atan2(amplitudes_complex[k].0);
if omega >= 0.0 {
frequencies.push(freq_norm);
amplitudes.push(amp);
phases.push(phase);
damping.push(alpha);
}
}
let mut indices: Vec<usize> = (0..frequencies.len()).collect();
indices.sort_by(|&a, &b| frequencies[a].partial_cmp(&frequencies[b]).unwrap_or(std::cmp::Ordering::Equal));
let sorted_freqs: Vec<f64> = indices.iter().map(|&i| frequencies[i]).collect();
let sorted_amps: Vec<f64> = indices.iter().map(|&i| amplitudes[i]).collect();
let sorted_phases: Vec<f64> = indices.iter().map(|&i| phases[i]).collect();
let sorted_damping: Vec<f64> = indices.iter().map(|&i| damping[i]).collect();
Ok(PronyResult {
frequencies: sorted_freqs,
amplitudes: sorted_amps,
phases: sorted_phases,
damping: sorted_damping,
})
}
fn solve_least_squares_normal(a: &[f64], b: &[f64], m: usize, p: usize) -> FFTResult<Vec<f64>> {
let mut ata = vec![0.0f64; p * p];
for i in 0..p {
for j in 0..p {
for row in 0..m {
ata[i * p + j] += a[row * p + i] * a[row * p + j];
}
}
}
let mut atb = vec![0.0f64; p];
for i in 0..p {
for row in 0..m {
atb[i] += a[row * p + i] * b[row];
}
}
let x = cholesky_solve(&ata, &atb, p)?;
Ok(x)
}
fn cholesky_solve(a: &[f64], b: &[f64], n: usize) -> FFTResult<Vec<f64>> {
let reg = {
let diag_max = (0..n).map(|i| a[i * n + i]).fold(0.0f64, f64::max);
diag_max * 1e-8 + 1e-14
};
let mut l = vec![0.0f64; n * n];
for i in 0..n {
for j in 0..=i {
let mut s = a[i * n + j] + if i == j { reg } else { 0.0 };
for k in 0..j {
s -= l[i * n + k] * l[j * n + k];
}
l[i * n + j] = if i == j {
if s < 0.0 { reg.sqrt() } else { s.sqrt() }
} else if l[j * n + j].abs() < f64::EPSILON {
0.0
} else {
s / l[j * n + j]
};
}
}
let mut y = vec![0.0f64; n];
for i in 0..n {
let mut s = b[i];
for k in 0..i {
s -= l[i * n + k] * y[k];
}
y[i] = if l[i * n + i].abs() < f64::EPSILON { 0.0 } else { s / l[i * n + i] };
}
let mut x = vec![0.0f64; n];
for i in (0..n).rev() {
let mut s = y[i];
for k in (i + 1)..n {
s -= l[k * n + i] * x[k];
}
x[i] = if l[i * n + i].abs() < f64::EPSILON { 0.0 } else { s / l[i * n + i] };
}
Ok(x)
}
fn companion_eigenvalues(a: &[f64]) -> FFTResult<Vec<(f64, f64)>> {
let p = a.len();
if p == 0 {
return Ok(vec![]);
}
if p == 1 {
let z = -a[0];
let mag = z.abs();
let arg = if z >= 0.0 { 0.0 } else { PI };
return Ok(vec![(mag, arg)]);
}
let n_grid = 2048;
let mut poles: Vec<(f64, f64)> = Vec::new();
let eval_poly = |re: f64, im: f64| -> (f64, f64) {
let mut pr = 1.0f64; let mut pi_val = 0.0f64;
for k in 0..p {
let new_pr = pr * re - pi_val * im + a[k];
let new_pi = pr * im + pi_val * re;
pr = new_pr;
pi_val = new_pi;
}
(pr, pi_val)
};
let mut prev_mag = f64::MAX;
for i in 0..=n_grid {
let theta = 2.0 * PI * i as f64 / n_grid as f64;
let (re, im) = (theta.cos(), theta.sin());
let (pr, pi_v) = eval_poly(re, im);
let mag = (pr * pr + pi_v * pi_v).sqrt();
if i > 0 && mag < prev_mag && mag < 0.5 * (p as f64).max(1.0) {
let mut z_re = re;
let mut z_im = im;
for _ in 0..20 {
let (fre, fim) = eval_poly(z_re, z_im);
let mut dre = p as f64; let mut dim_v = 0.0f64;
for k in 0..p {
let c = (p - k) as f64;
let new_dre = dre * z_re - dim_v * z_im + c * a[k] / p as f64;
let new_dim = dre * z_im + dim_v * z_re;
dre = new_dre;
dim_v = new_dim;
}
let (_, _) = (dre, dim_v);
let eps = 1e-7;
let (fre2, fim2) = eval_poly(z_re + eps, z_im);
let dfre = (fre2 - fre) / eps;
let (fre3, fim3) = eval_poly(z_re, z_im + eps);
let dfim = (fim3 - fim) / eps;
let denom = dfre * dfre + dfim * dfim;
if denom < f64::EPSILON {
break;
}
let step_re = (fre * dfre + fim * dfim) / denom;
let step_im = (fim * dfre - fre * dfim) / denom;
z_re -= step_re;
z_im -= step_im;
let (cr, ci) = eval_poly(z_re, z_im);
if (cr * cr + ci * ci).sqrt() < 1e-12 {
break;
}
}
let r_mag = (z_re * z_re + z_im * z_im).sqrt();
let r_arg = z_im.atan2(z_re);
let is_dup = poles.iter().any(|&(m, a): &(f64, f64)| {
(m - r_mag).abs() < 0.01 && (a - r_arg).abs() < 0.1
});
if !is_dup {
poles.push((r_mag, r_arg));
}
}
prev_mag = mag;
}
if poles.is_empty() {
let n_grid2 = 256;
let mut min_mags: Vec<(f64, f64)> = (0..n_grid2)
.map(|i| {
let theta = PI * i as f64 / n_grid2 as f64;
let (re, im) = (theta.cos(), theta.sin());
let (pr, pi_v) = eval_poly(re, im);
((pr * pr + pi_v * pi_v).sqrt(), theta)
})
.collect();
min_mags.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
for (_, theta) in min_mags.iter().take(p) {
poles.push((1.0, *theta));
}
}
Ok(poles)
}
fn solve_vandermonde_ls(signal: &[f64], roots: &[(f64, f64)]) -> FFTResult<Vec<(f64, f64)>> {
let p = roots.len();
if p == 0 {
return Ok(vec![]);
}
let n = signal.len().min(4 * p + 4);
let mut v_re = vec![0.0f64; n * p];
let mut v_im = vec![0.0f64; n * p];
for i in 0..n {
for k in 0..p {
let (r_mag, r_arg) = roots[k];
let ri = r_mag.powi(i as i32);
v_re[i * p + k] = ri * (i as f64 * r_arg).cos();
v_im[i * p + k] = ri * (i as f64 * r_arg).sin();
}
}
let mut result = vec![(0.0f64, 0.0f64); p];
let n_sys = n;
let p_sys = 2 * p;
let mut a_sys = vec![0.0f64; n_sys * p_sys];
for i in 0..n_sys {
for k in 0..p {
a_sys[i * p_sys + k] = v_re[i * p + k];
a_sys[i * p_sys + p + k] = -v_im[i * p + k];
}
}
let mut ata = vec![0.0f64; p_sys * p_sys];
let mut atb = vec![0.0f64; p_sys];
for i in 0..p_sys {
for j in 0..p_sys {
for row in 0..n_sys {
ata[i * p_sys + j] += a_sys[row * p_sys + i] * a_sys[row * p_sys + j];
}
}
for row in 0..n_sys {
atb[i] += a_sys[row * p_sys + i] * signal[row];
}
}
match cholesky_solve(&ata, &atb, p_sys) {
Ok(sol) => {
for k in 0..p {
result[k] = (sol[k], sol[p + k]);
}
}
Err(_) => {
for k in 0..p {
result[k] = (1.0, 0.0);
}
}
}
Ok(result)
}
fn next_pow2(n: usize) -> usize {
if n == 0 {
return 1;
}
let mut p = 1usize;
while p < n {
p = p.saturating_mul(2);
}
p
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
fn sinusoid_sum(n: usize, freqs_norm: &[f64]) -> Vec<f64> {
(0..n)
.map(|i| {
freqs_norm
.iter()
.map(|&f| (2.0 * PI * f * i as f64).sin())
.sum()
})
.collect()
}
#[test]
fn test_sparse_fft_simple_length() {
let n = 256usize;
let signal = sinusoid_sum(n, &[0.1, 0.25]);
let result = sparse_fft_simple(&signal, 2, 3).expect("sparse_fft_simple");
assert_eq!(result.n, n);
assert_eq!(result.frequencies.len(), 2);
assert_eq!(result.amplitudes.len(), 4);
}
#[test]
fn test_sparse_fft_simple_empty_error() {
let err = sparse_fft_simple(&[], 2, 1).unwrap_err();
assert!(matches!(err, FFTError::InvalidInput(_)));
}
#[test]
fn test_sparse_fft_simple_k_zero() {
let signal = vec![1.0f64; 64];
let result = sparse_fft_simple(&signal, 0, 1).expect("k=0");
assert_eq!(result.frequencies.len(), 0);
}
#[test]
fn test_sparse_to_dense() {
let n = 64usize;
let signal: Vec<f64> = (0..n)
.map(|i| (2.0 * PI * 5.0 * i as f64 / n as f64).sin())
.collect();
let sparse = sparse_fft_simple(&signal, 2, 2).expect("sparse_fft_simple");
let dense = sparse_to_dense(&sparse);
assert_eq!(dense.len(), n);
for i in 0..n {
let is_sparse = sparse.frequencies.contains(&i);
if !is_sparse {
assert_eq!(dense[i], [0.0, 0.0]);
}
}
}
#[test]
fn test_sparse_fft_lasso_bounded_sparsity() {
let n = 128usize;
let signal = sinusoid_sum(n, &[0.05, 0.15, 0.3]);
let result = sparse_fft_lasso(&signal, 3, 0.0).expect("sparse_fft_lasso");
assert!(result.frequencies.len() <= 3, "Should have at most k=3 components");
assert_eq!(result.n, n);
}
#[test]
fn test_sparse_fft_lasso_explicit_lambda() {
let n = 128usize;
let signal = sinusoid_sum(n, &[0.1]);
let result = sparse_fft_lasso(&signal, 4, 0.01).expect("sparse_fft_lasso");
assert!(result.frequencies.len() <= 4);
}
#[test]
fn test_prony_method_single_sinusoid() {
let n = 64usize;
let f0 = 0.1_f64;
let signal: Vec<f64> = (0..n)
.map(|i| (2.0 * PI * f0 * i as f64).cos())
.collect();
let result = prony_method(&signal, 1).expect("prony_method");
assert!(!result.frequencies.is_empty(), "Should find at least 1 component");
if !result.frequencies.is_empty() {
assert!(
(result.frequencies[0] - f0).abs() < 0.05,
"Expected freq near {f0}, got {}",
result.frequencies[0]
);
}
}
#[test]
fn test_prony_method_too_short_error() {
let signal = vec![1.0f64; 4];
let err = prony_method(&signal, 4).unwrap_err();
assert!(matches!(err, FFTError::InvalidInput(_)));
}
#[test]
fn test_prony_result_fields() {
let n = 64usize;
let signal: Vec<f64> = (0..n).map(|i| (2.0 * PI * 0.2 * i as f64).sin()).collect();
let result = prony_method(&signal, 1).expect("prony_method");
assert_eq!(result.frequencies.len(), result.amplitudes.len());
assert_eq!(result.frequencies.len(), result.phases.len());
assert_eq!(result.frequencies.len(), result.damping.len());
for &a in &result.amplitudes {
assert!(a >= 0.0, "Amplitudes must be non-negative");
}
}
#[test]
fn test_amplitude_complex_accessor() {
let sparse = SparseFftResult {
frequencies: vec![5, 10],
amplitudes: vec![1.0, 2.0, 3.0, 4.0],
n: 64,
};
let c0 = sparse.amplitude_complex(0).expect("c0");
assert!((c0.re - 1.0).abs() < f64::EPSILON);
assert!((c0.im - 2.0).abs() < f64::EPSILON);
let c1 = sparse.amplitude_complex(1).expect("c1");
assert!((c1.re - 3.0).abs() < f64::EPSILON);
let c_none = sparse.amplitude_complex(2);
assert!(c_none.is_none());
}
#[test]
fn test_sparsity_accessor() {
let sparse = SparseFftResult {
frequencies: vec![1, 2, 3],
amplitudes: vec![0.0; 6],
n: 32,
};
assert_eq!(sparse.sparsity(), 3);
}
}