use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use scirs2_core::numeric::Complex64;
use std::collections::HashMap;
use super::estimation::estimate_sparsity;
use super::types::{AdaptiveSfftConfig, AdaptiveSfftResult};
pub struct AdaptiveSparseFft {
seed: u64,
}
impl AdaptiveSparseFft {
pub fn new() -> Self {
Self {
seed: 0x517c1e6f3a5b9d2e,
}
}
pub fn with_seed(seed: u64) -> Self {
Self { seed }
}
pub fn compute(
&self,
signal: &[f64],
config: &AdaptiveSfftConfig,
) -> FFTResult<AdaptiveSfftResult> {
let n = signal.len();
if n < 4 {
return Err(FFTError::ValueError(
"Signal must have at least 4 samples".to_string(),
));
}
let initial_k = estimate_sparsity(signal)?.max(1).min(config.max_sparsity);
let full_spectrum = fft(signal, None)?;
let n_spectrum = full_spectrum.len();
let bin_energies: Vec<f64> = full_spectrum
.iter()
.map(|c| c.re * c.re + c.im * c.im)
.collect();
let total_energy: f64 = bin_energies.iter().sum();
if total_energy < f64::EPSILON {
return Ok(AdaptiveSfftResult::empty());
}
let max_bin_energy = bin_energies.iter().cloned().fold(0.0_f64, f64::max);
let abs_threshold = max_bin_energy * config.energy_threshold;
let mut candidate_votes: HashMap<usize, usize> = HashMap::new();
let mut rng_state = self.seed;
let mut current_k = initial_k;
let mut actual_iterations = 0;
for _iter in 0..config.max_iterations {
actual_iterations += 1;
let perm = generate_permutation(n, &mut rng_state);
let inv_perm = invert_permutation(&perm);
let permuted: Vec<f64> = (0..n).map(|i| signal[perm[i]]).collect();
let b = compute_bucket_count(n, current_k);
let sub_signal = subsample_avg(&permuted, b);
let sub_spectrum = fft(&sub_signal, None)?;
let b_actual = sub_spectrum.len();
let peaks = find_top_peaks(&sub_spectrum, current_k);
for (sub_bin, _) in &peaks {
let num_aliases = n.div_ceil(b_actual);
for alias in 0..num_aliases {
let perm_bin = (sub_bin + alias * b_actual) % n_spectrum;
let orig_bin = inv_perm[perm_bin];
if orig_bin < n_spectrum && bin_energies[orig_bin] >= abs_threshold {
*candidate_votes.entry(orig_bin).or_insert(0) += 1;
}
}
}
let captured: f64 = candidate_votes
.keys()
.map(|&b| bin_energies[b])
.sum::<f64>();
if captured / total_energy >= config.confidence {
break;
}
if actual_iterations < config.max_iterations {
let found: HashMap<usize, Complex64> = candidate_votes
.keys()
.filter(|&&b| b < n_spectrum)
.map(|&b| (b, full_spectrum[b]))
.collect();
let residual = subtract_components(signal, &found, n)?;
let residual_k = estimate_sparsity(&residual).unwrap_or(1).max(1);
current_k = residual_k.min(config.max_sparsity);
if current_k == 0 {
break;
}
}
}
if candidate_votes.is_empty() {
let peaks = find_top_peaks(&full_spectrum, initial_k.min(config.max_sparsity));
for (bin, _) in peaks {
if bin_energies[bin] >= abs_threshold {
candidate_votes.insert(bin, 1);
}
}
if candidate_votes.is_empty() {
if let Some((max_bin, _)) = bin_energies
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
{
candidate_votes.insert(max_bin, 1);
}
}
}
let mut candidate_list: Vec<(usize, Complex64)> = candidate_votes
.keys()
.filter(|&&b| b < n_spectrum)
.map(|&b| (b, full_spectrum[b]))
.collect();
candidate_list.sort_by(|a, b| {
let ea = a.1.re * a.1.re + a.1.im * a.1.im;
let eb = b.1.re * b.1.re + b.1.im * b.1.im;
eb.partial_cmp(&ea).unwrap_or(std::cmp::Ordering::Equal)
});
candidate_list.truncate(config.max_sparsity);
candidate_list.sort_by_key(|(idx, _)| *idx);
let frequencies: Vec<usize> = candidate_list.iter().map(|(i, _)| *i).collect();
let coefficients: Vec<Complex64> = candidate_list.iter().map(|(_, c)| *c).collect();
let captured_energy: f64 = coefficients.iter().map(|c| c.re * c.re + c.im * c.im).sum();
let captured_fraction = (captured_energy / total_energy).min(1.0);
Ok(AdaptiveSfftResult {
estimated_sparsity: frequencies.len(),
iterations: actual_iterations,
total_energy,
captured_energy_fraction: captured_fraction,
frequencies,
coefficients,
})
}
}
impl Default for AdaptiveSparseFft {
fn default() -> Self {
Self::new()
}
}
fn generate_permutation(n: usize, state: &mut u64) -> Vec<usize> {
let mut perm: Vec<usize> = (0..n).collect();
for i in (1..n).rev() {
*state ^= *state << 13;
*state ^= *state >> 7;
*state ^= *state << 17;
let j = (*state as usize) % (i + 1);
perm.swap(i, j);
}
perm
}
fn invert_permutation(perm: &[usize]) -> Vec<usize> {
let mut inv = vec![0usize; perm.len()];
for (i, &p) in perm.iter().enumerate() {
inv[p] = i;
}
inv
}
fn compute_bucket_count(n: usize, k: usize) -> usize {
let b = (n / (4 * k.max(1))).max(1);
b.next_power_of_two().min(n)
}
fn subsample_avg(signal: &[f64], target_len: usize) -> Vec<f64> {
let n = signal.len();
if target_len >= n {
return signal.to_vec();
}
let block = n / target_len;
(0..target_len)
.map(|i| {
let start = i * block;
let end = (start + block).min(n);
let count = (end - start) as f64;
signal[start..end].iter().sum::<f64>() / count
})
.collect()
}
fn find_top_peaks(spectrum: &[Complex64], k: usize) -> Vec<(usize, f64)> {
let mut magnitudes: Vec<(usize, f64)> = spectrum
.iter()
.enumerate()
.map(|(i, c)| (i, (c.re * c.re + c.im * c.im).sqrt()))
.collect();
magnitudes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
magnitudes.truncate(k);
magnitudes
}
fn subtract_components(
signal: &[f64],
candidates: &HashMap<usize, Complex64>,
n: usize,
) -> FFTResult<Vec<f64>> {
if candidates.is_empty() {
return Ok(signal.to_vec());
}
let mut candidate_spectrum: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); n];
for (&bin, &coeff) in candidates {
if bin < n {
candidate_spectrum[bin] = coeff;
}
}
let approx = ifft(&candidate_spectrum, None)?;
let residual: Vec<f64> = signal
.iter()
.zip(approx.iter())
.map(|(&s, a)| s - a.re)
.collect();
Ok(residual)
}
pub fn adaptive_sparse_fft_auto(signal: &[f64]) -> FFTResult<AdaptiveSfftResult> {
let config = AdaptiveSfftConfig::default();
let solver = AdaptiveSparseFft::new();
solver.compute(signal, &config)
}