#![allow(clippy::needless_range_loop)]
use plotly::common::Title;
use plotly::{common::Mode, layout::Axis, Layout, Plot, Scatter};
use scirs2_core::Complex64;
use scirs2_fft::{
reconstruct_filtered, reconstruct_high_resolution, reconstruct_time_domain, sparse_fft,
sparse_fft::SparseFFTAlgorithm,
};
use std::f64::consts::PI;
#[allow(dead_code)]
fn main() {
println!("Sparse FFT Reconstruction Example");
println!("=================================\n");
let n = 1024;
println!("Creating a signal with n = {n} samples and 3 frequency components");
let frequencies = vec![(3, 1.0), (7, 0.5), (15, 0.25)];
let clean_signal = create_sparse_signal(n, &frequencies);
let mut noisy_signal = clean_signal.clone();
for i in 0..n {
noisy_signal[i] += 0.1 * scirs2_core::random::random::<f64>();
}
println!("\nPerforming sparse FFT on noisy signal...");
let sparse_result = sparse_fft(
&noisy_signal,
6,
Some(SparseFFTAlgorithm::SpectralFlatness),
Some(42), )
.expect("Operation failed");
println!(
"Found {} significant frequency components",
sparse_result.values.len()
);
println!(
"\nBasic reconstruction: Converting sparse frequency components back to time domain..."
);
let reconstructed = reconstruct_time_domain(&sparse_result, n).expect("Operation failed");
let clean_error = compute_error(&clean_signal, &reconstructed);
println!("Error between original clean signal and reconstruction: {clean_error:.6}");
let noisy_error = compute_error(&noisy_signal, &reconstructed);
println!("Error between noisy signal and reconstruction: {noisy_error:.6}");
println!("(Lower error with clean signal shows noise reduction effect)");
println!("\nHigh-resolution reconstruction: Enhancing frequency resolution 2x...");
let target_length = n * 2;
let high_res =
reconstruct_high_resolution(&sparse_result, n, target_length).expect("Operation failed");
println!("Original signal length: {n}");
println!("High-resolution signal length: {}", high_res.len());
println!("\nFiltered reconstruction: Applying low-pass filter to keep only lowest 10% frequencies...");
let lowpass = |idx: usize, n: usize| -> f64 {
let nyquist = n / 2;
let cutoff = nyquist / 10;
let freq_idx = if idx <= nyquist { idx } else { n - idx };
if freq_idx <= cutoff {
1.0 } else {
0.0 }
};
let lowpass_signal =
reconstruct_filtered(&sparse_result, n, lowpass).expect("Operation failed");
println!("\nFiltered reconstruction: Applying band-pass filter (30-70% of Nyquist)...");
let bandpass = |idx: usize, n: usize| -> f64 {
let nyquist = n / 2;
let low_cutoff = (nyquist as f64 * 0.3) as usize; let high_cutoff = (nyquist as f64 * 0.7) as usize;
let freq_idx = if idx <= nyquist { idx } else { n - idx };
if freq_idx >= low_cutoff && freq_idx <= high_cutoff {
1.0 } else {
0.0 }
};
let bandpass_signal =
reconstruct_filtered(&sparse_result, n, bandpass).expect("Operation failed");
println!("\nCreating visualization...");
create_plots(
&noisy_signal,
&reconstructed,
&high_res,
&lowpass_signal,
&bandpass_signal,
);
println!("\nExample completed successfully!");
}
#[allow(dead_code)]
fn create_sparse_signal(n: usize, frequencies: &[(usize, f64)]) -> Vec<f64> {
let mut signal = vec![0.0; n];
for i in 0..n {
let t = 2.0 * PI * (i as f64) / (n as f64);
for &(freq, amp) in frequencies {
signal[i] += amp * (freq as f64 * t).sin();
}
}
signal
}
#[allow(dead_code)]
fn compute_error(original: &[f64], reconstructed: &[Complex64]) -> f64 {
if original.len() != reconstructed.len() {
let min_len = original.len().min(reconstructed.len());
return compute_error(&original[..min_len], &reconstructed[..min_len]);
}
let original_complex: Vec<Complex64> =
original.iter().map(|&x| Complex64::new(x, 0.0)).collect();
let orig_energy: f64 = original_complex.iter().map(|&x| x.norm_sqr()).sum();
let recon_energy: f64 = reconstructed.iter().map(|&x| x.norm_sqr()).sum();
let orig_scale = 1.0 / orig_energy.sqrt();
let recon_scale = 1.0 / recon_energy.sqrt();
let mut error_sum = 0.0;
for i in 0..original.len() {
let orig = original_complex[i] * orig_scale;
let recon = reconstructed[i] * recon_scale;
error_sum += (orig - recon).norm_sqr();
}
(error_sum / (2.0 * original.len() as f64)).sqrt()
}
#[allow(dead_code)]
fn create_plots(
noisy_signal: &[f64],
basic_recon: &[Complex64],
high_res: &[Complex64],
lowpass: &[Complex64],
bandpass: &[Complex64],
) {
let basic_recon_real: Vec<f64> = basic_recon.iter().map(|c| c.re).collect();
let high_res_real: Vec<f64> = high_res.iter().map(|c| c.re).collect();
let lowpass_real: Vec<f64> = lowpass.iter().map(|c| c.re).collect();
let bandpass_real: Vec<f64> = bandpass.iter().map(|c| c.re).collect();
let slice_start = 0;
let slice_len = 200.min(noisy_signal.len());
let slice_end = slice_start + slice_len;
let mut time_plot = Plot::new();
let noisy_trace = Scatter::new(
(slice_start..slice_end).collect::<Vec<_>>(),
noisy_signal[slice_start..slice_end].to_vec(),
)
.mode(Mode::Lines)
.name("Noisy Signal");
let basic_trace = Scatter::new(
(slice_start..slice_end).collect::<Vec<_>>(),
basic_recon_real[slice_start..slice_end].to_vec(),
)
.mode(Mode::Lines)
.name("Basic Reconstruction");
let lowpass_trace = Scatter::new(
(slice_start..slice_end).collect::<Vec<_>>(),
lowpass_real[slice_start..slice_end].to_vec(),
)
.mode(Mode::Lines)
.name("Lowpass Filtered");
let bandpass_trace = Scatter::new(
(slice_start..slice_end).collect::<Vec<_>>(),
bandpass_real[slice_start..slice_end].to_vec(),
)
.mode(Mode::Lines)
.name("Bandpass Filtered");
time_plot.add_trace(noisy_trace);
time_plot.add_trace(basic_trace);
time_plot.add_trace(lowpass_trace);
time_plot.add_trace(bandpass_trace);
time_plot.set_layout(
Layout::new()
.title(Title::with_text("Time Domain Signal Comparison"))
.x_axis(Axis::new().title(Title::with_text("Sample Index")))
.y_axis(Axis::new().title(Title::with_text("Amplitude"))),
);
time_plot.write_html("sparse_recon_time_domain.html");
let mut highres_plot = Plot::new();
let hires_slice_start = 0;
let hires_slice_len = 100.min(noisy_signal.len());
let hires_slice_end = hires_slice_start + hires_slice_len;
let orig_times: Vec<f64> = (hires_slice_start..hires_slice_end)
.map(|i| i as f64)
.collect();
let hires_times: Vec<f64> = (0..(2 * hires_slice_len))
.map(|i| hires_slice_start as f64 + i as f64 / 2.0)
.collect();
let orig_trace = Scatter::new(
orig_times.clone(),
noisy_signal[hires_slice_start..hires_slice_end].to_vec(),
)
.mode(Mode::Lines)
.name("Original Signal");
let hires_trace = Scatter::new(
hires_times,
high_res_real[2 * hires_slice_start..2 * hires_slice_end].to_vec(),
)
.mode(Mode::Lines)
.name("High-Resolution");
highres_plot.add_trace(orig_trace);
highres_plot.add_trace(hires_trace);
highres_plot.set_layout(
Layout::new()
.title(Title::with_text("High-Resolution Reconstruction"))
.x_axis(Axis::new().title(Title::with_text("Sample Index")))
.y_axis(Axis::new().title(Title::with_text("Amplitude"))),
);
highres_plot.write_html("sparse_recon_high_res.html");
println!("Plots saved as 'sparse_recon_time_domain.html' and 'sparse_recon_high_res.html'");
}