use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::NumCast;
use scirs2_core::random::{Rng, RngExt, SeedableRng};
use std::fmt::Debug;
use std::time::Instant;
use super::config::{SparseFFTAlgorithm, SparseFFTConfig};
use super::estimation::estimate_sparsity;
use super::windowing::apply_window;
#[derive(Debug, Clone)]
pub struct SparseFFTResult {
pub values: Vec<Complex64>,
pub indices: Vec<usize>,
pub estimated_sparsity: usize,
pub computation_time: std::time::Duration,
pub algorithm: SparseFFTAlgorithm,
}
pub struct SparseFFT {
config: SparseFFTConfig,
rng: scirs2_core::random::rngs::StdRng,
}
impl SparseFFT {
pub fn new(config: SparseFFTConfig) -> Self {
let seed = config.seed.unwrap_or_else(scirs2_core::random::random);
let rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
Self { config, rng }
}
pub fn with_default_config() -> Self {
Self::new(SparseFFTConfig::default())
}
pub fn estimate_sparsity<T>(&mut self, signal: &[T]) -> FFTResult<usize>
where
T: NumCast + Copy + Debug + 'static,
{
estimate_sparsity(signal, &self.config)
}
fn calculate_spectral_flatness(&self, magnitudes: &[f64]) -> f64 {
if magnitudes.is_empty() {
return 1.0; }
let epsilon = 1e-10;
let log_sum: f64 = magnitudes.iter().map(|&x| (x + epsilon).ln()).sum::<f64>();
let geometric_mean = (log_sum / magnitudes.len() as f64).exp();
let arithmetic_mean: f64 = magnitudes.iter().sum::<f64>() / magnitudes.len() as f64;
if arithmetic_mean < epsilon {
return 1.0; }
let flatness = geometric_mean / arithmetic_mean;
flatness.clamp(0.0, 1.0)
}
pub fn sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
where
T: NumCast + Copy + Debug + 'static,
{
let start = Instant::now();
let limit = signal.len().min(self.config.max_signal_size);
let limited_signal = &signal[..limit];
let windowed_signal = apply_window(
limited_signal,
self.config.window_function,
self.config.kaiser_beta,
)?;
let estimated_sparsity = self.estimate_sparsity(&windowed_signal)?;
let (values, indices) = match self.config.algorithm {
SparseFFTAlgorithm::Sublinear => {
self.sublinear_sfft(&windowed_signal, estimated_sparsity)?
}
SparseFFTAlgorithm::CompressedSensing => {
self.compressed_sensing_sfft(&windowed_signal, estimated_sparsity)?
}
SparseFFTAlgorithm::Iterative => {
self.iterative_sfft(&windowed_signal, estimated_sparsity)?
}
SparseFFTAlgorithm::Deterministic => {
self.deterministic_sfft(&windowed_signal, estimated_sparsity)?
}
SparseFFTAlgorithm::FrequencyPruning => {
self.frequency_pruning_sfft(&windowed_signal, estimated_sparsity)?
}
SparseFFTAlgorithm::SpectralFlatness => {
self.spectral_flatness_sfft(&windowed_signal, estimated_sparsity)?
}
};
let computation_time = start.elapsed();
Ok(SparseFFTResult {
values,
indices,
estimated_sparsity,
computation_time,
algorithm: self.config.algorithm,
})
}
pub fn sparse_fft_full<T>(&mut self, signal: &[T]) -> FFTResult<Vec<Complex64>>
where
T: NumCast + Copy + Debug + 'static,
{
let n = signal.len().min(self.config.max_signal_size);
let windowed_signal = apply_window(
&signal[..n],
self.config.window_function,
self.config.kaiser_beta,
)?;
let result = self.sparse_fft(&windowed_signal)?;
let mut spectrum = vec![Complex64::new(0.0, 0.0); n];
for (value, &index) in result.values.iter().zip(result.indices.iter()) {
spectrum[index] = *value;
}
Ok(spectrum)
}
pub fn reconstruct_signal(
&self,
sparse_result: &SparseFFTResult,
n: usize,
) -> FFTResult<Vec<Complex64>> {
let mut spectrum = vec![Complex64::new(0.0, 0.0); n];
for (value, &index) in sparse_result
.values
.iter()
.zip(sparse_result.indices.iter())
{
spectrum[index] = *value;
}
ifft(&spectrum, None)
}
fn sublinear_sfft<T>(
&mut self,
signal: &[T],
k: usize,
) -> FFTResult<(Vec<Complex64>, Vec<usize>)>
where
T: NumCast + Copy + Debug + 'static,
{
let signal_complex: Vec<Complex64> = signal
.iter()
.map(|&val| {
let val_f64 = NumCast::from(val).ok_or_else(|| {
FFTError::ValueError(format!("Could not convert {val:?} to f64"))
})?;
Ok(Complex64::new(val_f64, 0.0))
})
.collect::<FFTResult<Vec<_>>>()?;
let _n = signal_complex.len();
let spectrum = fft(&signal_complex, None)?;
let mut freq_with_magnitudes: Vec<(f64, usize, Complex64)> = spectrum
.iter()
.enumerate()
.map(|(i, &coef)| (coef.norm(), i, coef))
.collect();
freq_with_magnitudes
.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let mut selected_indices = Vec::new();
let mut selected_values = Vec::new();
for &(_, idx, val) in freq_with_magnitudes.iter().take(k) {
selected_indices.push(idx);
selected_values.push(val);
}
Ok((selected_values, selected_indices))
}
fn compressed_sensing_sfft<T>(
&mut self,
signal: &[T],
k: usize,
) -> FFTResult<(Vec<Complex64>, Vec<usize>)>
where
T: NumCast + Copy + Debug + 'static,
{
let signal_complex: Vec<Complex64> = signal
.iter()
.map(|&val| {
let val_f64 = NumCast::from(val).ok_or_else(|| {
FFTError::ValueError(format!("Could not convert {val:?} to f64"))
})?;
Ok(Complex64::new(val_f64, 0.0))
})
.collect::<FFTResult<Vec<_>>>()?;
let n = signal_complex.len();
let m = (4 * k * (self.config.iterations as f64).log2() as usize).min(n);
let mut measurements = Vec::with_capacity(m);
let mut sample_indices = Vec::with_capacity(m);
for _ in 0..m {
let idx = self.rng.random_range(0..n);
sample_indices.push(idx);
measurements.push(signal_complex[idx]);
}
let spectrum = fft(&signal_complex, None)?;
let mut freq_with_magnitudes: Vec<(f64, usize, Complex64)> = spectrum
.iter()
.enumerate()
.map(|(i, &coef)| (coef.norm(), i, coef))
.collect();
freq_with_magnitudes
.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let mut selected_indices = Vec::new();
let mut selected_values = Vec::new();
for &(_, idx, val) in freq_with_magnitudes.iter().take(k) {
selected_indices.push(idx);
selected_values.push(val);
}
Ok((selected_values, selected_indices))
}
fn iterative_sfft<T>(
&mut self,
signal: &[T],
k: usize,
) -> FFTResult<(Vec<Complex64>, Vec<usize>)>
where
T: NumCast + Copy + Debug + 'static,
{
let mut signal_complex: Vec<Complex64> = signal
.iter()
.map(|&val| {
let val_f64 = NumCast::from(val).ok_or_else(|| {
FFTError::ValueError(format!("Could not convert {val:?} to f64"))
})?;
Ok(Complex64::new(val_f64, 0.0))
})
.collect::<FFTResult<Vec<_>>>()?;
let mut selected_indices = Vec::new();
let mut selected_values = Vec::new();
for _ in 0..k.min(self.config.iterations) {
let spectrum = fft(&signal_complex, None)?;
let (best_idx, best_value) = spectrum
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.norm()
.partial_cmp(&b.norm())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, &val)| (i, val))
.ok_or_else(|| FFTError::ValueError("Empty spectrum".to_string()))?;
if best_value.norm() < 1e-10 {
break;
}
selected_indices.push(best_idx);
selected_values.push(best_value);
let n = signal_complex.len();
for (i, sample) in signal_complex.iter_mut().enumerate() {
let phase =
2.0 * std::f64::consts::PI * (best_idx as f64) * (i as f64) / (n as f64);
let component = best_value * Complex64::new(phase.cos(), phase.sin()) / (n as f64);
*sample -= component;
}
}
Ok((selected_values, selected_indices))
}
fn deterministic_sfft<T>(
&mut self,
signal: &[T],
k: usize,
) -> FFTResult<(Vec<Complex64>, Vec<usize>)>
where
T: NumCast + Copy + Debug + 'static,
{
self.sublinear_sfft(signal, k)
}
fn frequency_pruning_sfft<T>(
&mut self,
signal: &[T],
k: usize,
) -> FFTResult<(Vec<Complex64>, Vec<usize>)>
where
T: NumCast + Copy + Debug + 'static,
{
let signal_complex: Vec<Complex64> = signal
.iter()
.map(|&val| {
let val_f64 = NumCast::from(val).ok_or_else(|| {
FFTError::ValueError(format!("Could not convert {val:?} to f64"))
})?;
Ok(Complex64::new(val_f64, 0.0))
})
.collect::<FFTResult<Vec<_>>>()?;
let spectrum = fft(&signal_complex, None)?;
let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
let n = magnitudes.len();
let mean: f64 = magnitudes.iter().sum::<f64>() / n as f64;
let variance: f64 = magnitudes.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
let std_dev = variance.sqrt();
let threshold = mean + self.config.pruning_sensitivity * std_dev;
let mut candidates: Vec<(f64, usize, Complex64)> = spectrum
.iter()
.enumerate()
.filter(|(_, c)| c.norm() > threshold)
.map(|(i, &c)| (c.norm(), i, c))
.collect();
candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let selected_count = k.min(candidates.len());
let selected_indices: Vec<usize> = candidates[..selected_count]
.iter()
.map(|(_, i_, _)| *i_)
.collect();
let selected_values: Vec<Complex64> = candidates[..selected_count]
.iter()
.map(|(_, _, c)| *c)
.collect();
Ok((selected_values, selected_indices))
}
fn spectral_flatness_sfft<T>(
&mut self,
signal: &[T],
k: usize,
) -> FFTResult<(Vec<Complex64>, Vec<usize>)>
where
T: NumCast + Copy + Debug + 'static,
{
let signal_complex: Vec<Complex64> = signal
.iter()
.map(|&val| {
let val_f64 = NumCast::from(val).ok_or_else(|| {
FFTError::ValueError(format!("Could not convert {val:?} to f64"))
})?;
Ok(Complex64::new(val_f64, 0.0))
})
.collect::<FFTResult<Vec<_>>>()?;
let spectrum = fft(&signal_complex, None)?;
let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
let n = magnitudes.len();
let window_size = self.config.window_size.min(n);
let mut selected_indices = Vec::new();
let mut selected_values = Vec::new();
for start in (0..n).step_by(window_size / 2) {
let end = (start + window_size).min(n);
if start >= n {
break;
}
let window_mags = &magnitudes[start..end];
let flatness = self.calculate_spectral_flatness(window_mags);
if flatness < self.config.flatness_threshold {
if let Some((local_idx_, _)) = window_mags
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
{
let global_idx = start + local_idx_;
if !selected_indices.contains(&global_idx) {
selected_indices.push(global_idx);
selected_values.push(spectrum[global_idx]);
}
}
}
if selected_indices.len() >= k {
break;
}
}
if selected_indices.len() < k {
let mut remaining_candidates: Vec<(f64, usize, Complex64)> = spectrum
.iter()
.enumerate()
.filter(|(i_, _)| !selected_indices.contains(i_))
.map(|(i, &c)| (c.norm(), i, c))
.collect();
remaining_candidates
.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let needed = k - selected_indices.len();
for (_, idx, val) in remaining_candidates.into_iter().take(needed) {
selected_indices.push(idx);
selected_values.push(val);
}
}
Ok((selected_values, selected_indices))
}
}
#[allow(dead_code)]
pub fn sparse_fft<T>(
signal: &[T],
k: usize,
algorithm: Option<SparseFFTAlgorithm>,
seed: Option<u64>,
) -> FFTResult<SparseFFTResult>
where
T: NumCast + Copy + Debug + 'static,
{
let config = SparseFFTConfig {
sparsity: k,
algorithm: algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear),
seed,
..SparseFFTConfig::default()
};
let mut processor = SparseFFT::new(config);
processor.sparse_fft(signal)
}
#[allow(dead_code)]
pub fn adaptive_sparse_fft<T>(signal: &[T], threshold: f64) -> FFTResult<SparseFFTResult>
where
T: NumCast + Copy + Debug + 'static,
{
let config = SparseFFTConfig {
estimation_method: super::config::SparsityEstimationMethod::Adaptive,
threshold,
adaptivity_factor: threshold,
..SparseFFTConfig::default()
};
let mut processor = SparseFFT::new(config);
processor.sparse_fft(signal)
}
#[allow(dead_code)]
pub fn frequency_pruning_sparse_fft<T>(
_signal: &[T],
sensitivity: f64,
) -> FFTResult<SparseFFTResult>
where
T: NumCast + Copy + Debug + 'static,
{
let config = SparseFFTConfig {
estimation_method: super::config::SparsityEstimationMethod::FrequencyPruning,
algorithm: SparseFFTAlgorithm::FrequencyPruning,
pruning_sensitivity: sensitivity,
..SparseFFTConfig::default()
};
let mut processor = SparseFFT::new(config);
processor.sparse_fft(_signal)
}
#[allow(dead_code)]
pub fn spectral_flatness_sparse_fft<T>(
signal: &[T],
flatness_threshold: f64,
window_size: usize,
) -> FFTResult<SparseFFTResult>
where
T: NumCast + Copy + Debug + 'static,
{
let config = SparseFFTConfig {
estimation_method: super::config::SparsityEstimationMethod::SpectralFlatness,
algorithm: SparseFFTAlgorithm::SpectralFlatness,
flatness_threshold,
window_size,
..SparseFFTConfig::default()
};
let mut processor = SparseFFT::new(config);
processor.sparse_fft(signal)
}
#[allow(dead_code)]
pub fn sparse_fft2<T>(
_signal: &[Vec<T>],
_k: usize,
_algorithm: Option<SparseFFTAlgorithm>,
) -> FFTResult<SparseFFTResult>
where
T: NumCast + Copy + Debug + 'static,
{
Err(FFTError::ValueError(
"2D sparse FFT not yet implemented".to_string(),
))
}
#[allow(dead_code)]
pub fn sparse_fftn<T>(
_signal: &[T],
_shape: &[usize],
_k: usize,
_algorithm: Option<SparseFFTAlgorithm>,
) -> FFTResult<SparseFFTResult>
where
T: NumCast + Copy + Debug + 'static,
{
Err(FFTError::ValueError(
"N-dimensional sparse FFT not yet implemented".to_string(),
))
}