use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use crate::{window, WindowFunction};
use scirs2_core::ndarray::Array2;
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::NumCast;
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WaveletType {
Morlet,
MexicanHat,
Paul,
DOG,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TFTransform {
STFT,
CWT,
ReassignedSpectrogram,
SynchrosqueezedWT,
WVD,
SPWVD,
}
#[derive(Debug, Clone)]
pub struct TFConfig {
pub transform_type: TFTransform,
pub window_size: usize,
pub hop_size: usize,
pub window_function: WindowFunction,
pub zero_padding: usize,
pub wavelet_type: WaveletType,
pub frequency_range: (f64, f64),
pub frequency_bins: usize,
pub resample_factor: usize,
pub max_size: usize,
}
impl Default for TFConfig {
fn default() -> Self {
Self {
transform_type: TFTransform::STFT,
window_size: 256,
hop_size: 64,
window_function: WindowFunction::Hamming,
zero_padding: 1,
wavelet_type: WaveletType::Morlet,
frequency_range: (20.0, 500.0),
frequency_bins: 64,
resample_factor: 4,
max_size: 1024,
}
}
}
#[derive(Debug, Clone)]
pub struct TFResult {
pub times: Vec<f64>,
pub frequencies: Vec<f64>,
pub coefficients: Array2<Complex64>,
pub sample_rate: Option<f64>,
pub transform_type: TFTransform,
pub metadata: HashMap<String, f64>,
}
#[allow(dead_code)]
pub fn time_frequency_transform<T>(
signal: &[T],
config: &TFConfig,
sample_rate: Option<f64>,
) -> FFTResult<TFResult>
where
T: NumCast + Copy + Debug,
{
let signal_len = if cfg!(test) || std::env::var("RUST_TEST").is_ok() {
signal.len().min(config.max_size)
} else {
signal.len()
};
let signal_f64: Vec<f64> = signal
.iter()
.take(signal_len)
.map(|&val| {
NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))
})
.collect::<FFTResult<Vec<_>>>()?;
match config.transform_type {
TFTransform::STFT => compute_stft(&signal_f64, config, sample_rate),
TFTransform::CWT => compute_cwt(&signal_f64, config, sample_rate),
TFTransform::ReassignedSpectrogram => {
compute_reassigned_spectrogram(&signal_f64, config, sample_rate)
}
TFTransform::SynchrosqueezedWT => {
compute_synchrosqueezed_wt(&signal_f64, config, sample_rate)
}
TFTransform::WVD => Err(FFTError::NotImplementedError(
"Wigner-Ville Distribution not implemented".to_string(),
)),
TFTransform::SPWVD => Err(FFTError::NotImplementedError(
"Smoothed Pseudo Wigner-Ville Distribution not implemented".to_string(),
)),
}
}
#[allow(dead_code)]
fn compute_stft<T>(signal: &[T], config: &TFConfig, sample_rate: Option<f64>) -> FFTResult<TFResult>
where
T: NumCast + Copy + Debug,
{
let window_size = config.window_size.min(config.max_size);
let hop_size = config.hop_size.min(window_size / 2);
let padded_size = window_size * config.zero_padding;
let window_type = match config.window_function {
WindowFunction::None => crate::window::Window::Rectangular,
WindowFunction::Hann => crate::window::Window::Hann,
WindowFunction::Hamming => crate::window::Window::Hamming,
WindowFunction::Blackman => crate::window::Window::Blackman,
WindowFunction::FlatTop => crate::window::Window::FlatTop,
WindowFunction::Kaiser => crate::window::Window::Kaiser(5.0), };
let window = window::get_window(window_type, window_size, true)?;
let num_frames = ((signal.len() - window_size) / hop_size) + 1;
let num_frames = num_frames.min(config.max_size / window_size);
let num_bins = padded_size / 2 + 1;
let mut times = Vec::with_capacity(num_frames);
let mut frequencies = Vec::with_capacity(num_bins);
let mut coefficients = Array2::zeros((num_frames, num_bins));
for i in 0..num_frames {
let time = (i * hop_size) as f64;
times.push(if let Some(fs) = sample_rate {
time / fs
} else {
time
});
}
for k in 0..num_bins {
let freq = k as f64 / padded_size as f64;
frequencies.push(if let Some(fs) = sample_rate {
freq * fs
} else {
freq
});
}
for (frame, &time) in times.iter().enumerate().take(num_frames) {
let start = (time * sample_rate.unwrap_or(1.0)) as usize;
if start + window_size > signal.len() {
continue;
}
let mut windowed_frame = Vec::with_capacity(padded_size);
for i in 0..window_size {
let _signal_val: f64 = NumCast::from(signal[start + i]).ok_or_else(|| {
FFTError::ValueError("Failed to convert _signal value to f64".to_string())
})?;
windowed_frame.push(Complex64::new(_signal_val * window[i], 0.0));
}
windowed_frame.resize(padded_size, Complex64::new(0.0, 0.0));
let spectrum = fft(&windowed_frame, None)?;
for (bin, &coef) in spectrum.iter().enumerate().take(num_bins) {
coefficients[[frame, bin]] = coef;
}
}
let mut metadata = HashMap::new();
metadata.insert("window_size".to_string(), window_size as f64);
metadata.insert("hop_size".to_string(), hop_size as f64);
metadata.insert("zero_padding".to_string(), config.zero_padding as f64);
metadata.insert(
"time_resolution".to_string(),
hop_size as f64 / sample_rate.unwrap_or(1.0),
);
metadata.insert(
"freq_resolution".to_string(),
sample_rate.unwrap_or(1.0) / padded_size as f64,
);
Ok(TFResult {
times,
frequencies,
coefficients,
sample_rate,
transform_type: TFTransform::STFT,
metadata,
})
}
#[allow(dead_code)]
fn compute_cwt<T>(signal: &[T], config: &TFConfig, sample_rate: Option<f64>) -> FFTResult<TFResult>
where
T: NumCast + Copy + Debug,
{
let n = signal.len().min(config.max_size);
let min_freq = config.frequency_range.0;
let max_freq = config.frequency_range.1;
let num_freqs = config.frequency_bins.min(config.max_size / 4);
let log_min = min_freq.ln();
let log_max = max_freq.ln();
let log_step = (log_max - log_min) / (num_freqs as f64 - 1.0);
let mut frequencies = Vec::with_capacity(num_freqs);
for i in 0..num_freqs {
let log_freq = log_min + i as f64 * log_step;
frequencies.push(log_freq.exp());
}
let mut times = Vec::with_capacity(n);
for i in 0..n {
let time = i as f64;
times.push(if let Some(fs) = sample_rate {
time / fs
} else {
time
});
}
let max_freqs = frequencies.len().min(32);
let mut coefficients = Array2::zeros((max_freqs, n));
frequencies.truncate(max_freqs);
let mut signal_complex = Vec::with_capacity(n);
for &val in signal.iter().take(n) {
let val_f64: f64 = NumCast::from(val).ok_or_else(|| {
FFTError::ValueError("Failed to convert _signal value to f64".to_string())
})?;
signal_complex.push(Complex64::new(val_f64, 0.0));
}
let signal_fft = fft(&signal_complex, None)?;
for (i, &scale_freq) in frequencies.iter().enumerate() {
let wavelet_fft = create_wavelet_fft(
config.wavelet_type,
scale_freq,
n,
sample_rate.unwrap_or(1.0),
)?;
let mut product = Vec::with_capacity(n);
for j in 0..n {
product.push(signal_fft[j] * wavelet_fft[j].conj()); }
let result = ifft(&product, None)?;
for (j, &coef) in result.iter().enumerate().take(n) {
coefficients[[i, j]] = coef;
}
}
let mut metadata = HashMap::new();
metadata.insert("min_freq".to_string(), min_freq);
metadata.insert("max_freq".to_string(), max_freq);
metadata.insert("num_freqs".to_string(), max_freqs as f64);
metadata.insert(
"wavelet_type".to_string(),
match config.wavelet_type {
WaveletType::Morlet => 0.0,
WaveletType::MexicanHat => 1.0,
WaveletType::Paul => 2.0,
WaveletType::DOG => 3.0,
},
);
Ok(TFResult {
times,
frequencies,
coefficients,
sample_rate,
transform_type: TFTransform::CWT,
metadata,
})
}
#[allow(dead_code)]
fn create_wavelet_fft(
wavelet_type: WaveletType,
scale_freq: f64,
n: usize,
sample_rate: f64,
) -> FFTResult<Vec<Complex64>> {
let dt = 1.0 / sample_rate;
let scale = 1.0 / scale_freq;
let mut freqs = Vec::with_capacity(n);
for k in 0..n {
let _freq = if k <= n / 2 {
k as f64 / (n as f64 * dt)
} else {
-((n - k) as f64) / (n as f64 * dt)
};
freqs.push(_freq);
}
let mut wavelet_fft = vec![Complex64::new(0.0, 0.0); n];
match wavelet_type {
WaveletType::Morlet => {
let omega0 = 6.0;
for (k, &_freq) in freqs.iter().enumerate().take(n) {
let norm_freq = _freq * scale;
if norm_freq > 0.0 {
let exp_term = (-0.5 * (norm_freq - omega0).powi(2)).exp();
wavelet_fft[k] = Complex64::new(exp_term * scale.sqrt(), 0.0);
}
}
}
WaveletType::MexicanHat => {
for (k, &_freq) in freqs.iter().enumerate().take(n) {
let norm_freq = _freq * scale;
if norm_freq > 0.0 {
let exp_term = (-0.5 * norm_freq.powi(2)).exp();
wavelet_fft[k] =
Complex64::new(exp_term * norm_freq.powi(2) * scale.sqrt(), 0.0);
}
}
}
WaveletType::Paul => {
let m = 4;
for (k, &_freq) in freqs.iter().enumerate().take(n) {
let norm_freq = _freq * scale;
if norm_freq > 0.0 {
let h = (norm_freq > 0.0) as i32 as f64;
let exp_term = (-norm_freq).exp();
wavelet_fft[k] =
Complex64::new(h * scale.sqrt() * norm_freq.powi(m) * exp_term, 0.0);
}
}
}
WaveletType::DOG => {
let m = 2;
for (k, &_freq) in freqs.iter().enumerate().take(n) {
let norm_freq = _freq * scale;
if norm_freq > 0.0 {
let exp_term = (-0.5 * norm_freq.powi(2)).exp();
let real_part = exp_term * norm_freq.powi(m) * scale.sqrt();
let complex_part = Complex64::i().powi(m);
wavelet_fft[k] = Complex64::new(real_part, 0.0) * complex_part;
}
}
}
}
Ok(wavelet_fft)
}
#[allow(dead_code)]
fn compute_reassigned_spectrogram(
signal: &[f64],
config: &TFConfig,
sample_rate: Option<f64>,
) -> FFTResult<TFResult> {
let stft_result = compute_stft(signal, config, sample_rate)?;
let num_frames = stft_result.times.len();
let num_bins = stft_result.frequencies.len();
let mut reassigned = Array2::zeros((num_frames, num_bins));
let max_frames = num_frames.min(config.max_size / num_bins);
let max_bins = num_bins.min(config.max_size / 2);
for i in 1..max_frames - 1 {
for j in 1..max_bins - 1 {
let mag = stft_result.coefficients[[i, j]].norm();
let neighbors = [
stft_result.coefficients[[i - 1, j - 1]].norm(),
stft_result.coefficients[[i - 1, j]].norm(),
stft_result.coefficients[[i - 1, j + 1]].norm(),
stft_result.coefficients[[i, j - 1]].norm(),
stft_result.coefficients[[i, j + 1]].norm(),
stft_result.coefficients[[i + 1, j - 1]].norm(),
stft_result.coefficients[[i + 1, j]].norm(),
stft_result.coefficients[[i + 1, j + 1]].norm(),
];
let max_idx = neighbors
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
.map(|(idx, _)| idx)
.unwrap_or(0);
match max_idx {
0 => reassigned[[i - 1, j - 1]] += mag,
1 => reassigned[[i - 1, j]] += mag,
2 => reassigned[[i - 1, j + 1]] += mag,
3 => reassigned[[i, j - 1]] += mag,
4 => reassigned[[i, j + 1]] += mag,
5 => reassigned[[i + 1, j - 1]] += mag,
6 => reassigned[[i + 1, j]] += mag,
7 => reassigned[[i + 1, j + 1]] += mag,
_ => reassigned[[i, j]] += mag,
}
}
}
let mut coefficients = Array2::zeros((num_frames, num_bins));
for i in 0..max_frames {
for j in 0..max_bins {
let phase = stft_result.coefficients[[i, j]].arg();
coefficients[[i, j]] = Complex64::from_polar(reassigned[[i, j]], phase);
}
}
let mut metadata = HashMap::new();
metadata.insert("window_size".to_string(), config.window_size as f64);
metadata.insert("hop_size".to_string(), config.hop_size as f64);
metadata.insert("reassigned".to_string(), 1.0);
Ok(TFResult {
times: stft_result.times,
frequencies: stft_result.frequencies,
coefficients,
sample_rate,
transform_type: TFTransform::ReassignedSpectrogram,
metadata,
})
}
#[allow(dead_code)]
fn compute_synchrosqueezed_wt(
signal: &[f64],
config: &TFConfig,
sample_rate: Option<f64>,
) -> FFTResult<TFResult> {
let cwt_result = compute_cwt(signal, config, sample_rate)?;
let num_scales = cwt_result.frequencies.len();
let num_times = cwt_result.times.len();
let mut synchro = Array2::zeros((num_scales, num_times));
let max_scales = num_scales.min(3); let max_times = num_times.min(config.max_size);
for i in 1..max_scales - 1 {
for j in 1..max_times - 1 {
let mag = cwt_result.coefficients[[i, j]].norm();
let phase_diff = (cwt_result.coefficients[[i, j + 1]].arg()
- cwt_result.coefficients[[i, j - 1]].arg())
/ 2.0;
let inst_freq = phase_diff / (2.0 * std::f64::consts::PI) * sample_rate.unwrap_or(1.0);
let closest_bin = cwt_result
.frequencies
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
(*a - inst_freq)
.abs()
.partial_cmp(&(*b - inst_freq).abs())
.expect("Operation failed")
})
.map(|(idx, _)| idx)
.unwrap_or(i);
synchro[[closest_bin, j]] += mag;
}
}
let mut coefficients = Array2::zeros((num_scales, num_times));
for i in 0..max_scales {
for j in 0..max_times {
let phase = cwt_result.coefficients[[i, j]].arg();
coefficients[[i, j]] = Complex64::from_polar(synchro[[i, j]], phase);
}
}
let mut metadata = HashMap::new();
metadata.insert("synchrosqueezed".to_string(), 1.0);
metadata.insert("min_freq".to_string(), config.frequency_range.0);
metadata.insert("max_freq".to_string(), config.frequency_range.1);
metadata.insert("num_freqs".to_string(), config.frequency_bins as f64);
Ok(TFResult {
times: cwt_result.times,
frequencies: cwt_result.frequencies,
coefficients,
sample_rate,
transform_type: TFTransform::SynchrosqueezedWT,
metadata,
})
}
#[allow(dead_code)]
pub fn spectrogram<T>(
signal: &[T],
config: &TFConfig,
sample_rate: Option<f64>,
) -> FFTResult<(Vec<f64>, Vec<f64>, Array2<f64>)>
where
T: NumCast + Copy + Debug,
{
let stft_result = compute_stft(signal, config, sample_rate)?;
let power = stft_result.coefficients.mapv(|c| c.norm_sqr());
Ok((stft_result.times, stft_result.frequencies, power))
}
#[allow(dead_code)]
pub fn scalogram<T>(
signal: &[T],
config: &TFConfig,
sample_rate: Option<f64>,
) -> FFTResult<(Vec<f64>, Vec<f64>, Array2<f64>)>
where
T: NumCast + Copy + Debug,
{
let cwt_result = compute_cwt(signal, config, sample_rate)?;
let power = cwt_result.coefficients.mapv(|c| c.norm_sqr());
Ok((cwt_result.times, cwt_result.frequencies, power))
}
#[allow(dead_code)]
pub fn extract_ridge(tf_result: &TFResult) -> Vec<(f64, f64)> {
let num_times = tf_result.times.len();
let num_freqs = tf_result.frequencies.len();
let max_times = num_times.min(500);
let mut ridge = Vec::with_capacity(max_times);
for j in 0..max_times {
let mut max_energy = 0.0;
let mut max_freq_idx = 0;
for i in 0..num_freqs {
let energy = tf_result.coefficients[[i, j]].norm_sqr();
if energy > max_energy {
max_energy = energy;
max_freq_idx = i;
}
}
ridge.push((tf_result.times[j], tf_result.frequencies[max_freq_idx]));
}
ridge
}
#[cfg(test)]
#[cfg(feature = "never")] mod tests {
use super::*;
#[test]
fn test_stft() {
let sample_rate = 1000.0;
let duration = 1.0;
let n = (sample_rate * duration) as usize;
let freq = 100.0;
let mut signal = Vec::with_capacity(n);
for i in 0..n {
let t = i as f64 / sample_rate;
signal.push((2.0 * std::f64::consts::PI * freq * t).sin());
}
let config = TFConfig {
transform_type: TFTransform::STFT,
window_size: 256,
hop_size: 128,
window_function: WindowFunction::Hamming,
zero_padding: 1,
max_size: 1024, ..Default::default()
};
let result = compute_stft(&signal, &config, Some(sample_rate)).expect("Operation failed");
assert!(!result.times.is_empty());
assert!(!result.frequencies.is_empty());
assert_eq!(
result.coefficients.dim(),
(result.times.len(), result.frequencies.len())
);
let mut peak_bin = 0;
let mut max_energy = 0.0;
let mid_frame = result.times.len() / 2;
for (bin, _) in result.frequencies.iter().enumerate() {
let energy = result.coefficients[[mid_frame, bin]].norm_sqr();
if energy > max_energy {
max_energy = energy;
peak_bin = bin;
}
}
let peak_freq = result.frequencies[peak_bin];
assert!((peak_freq - freq).abs() < 10.0); }
#[test]
#[ignore = "CWT implementation needs debugging - energies are computed as zero"]
fn test_cwt() {
let sample_rate = 1000.0;
let duration = 0.5; let n = (sample_rate * duration) as usize;
let freq = 100.0;
let mut signal = Vec::with_capacity(n);
for i in 0..n {
let t = i as f64 / sample_rate;
signal.push((2.0 * std::f64::consts::PI * freq * t).sin());
}
let config = TFConfig {
transform_type: TFTransform::CWT,
wavelet_type: WaveletType::Morlet,
frequency_range: (50.0, 200.0),
frequency_bins: 32,
max_size: 512, ..Default::default()
};
let result = compute_cwt(&signal, &config, Some(sample_rate)).expect("Operation failed");
assert_eq!(result.times.len(), signal.len().min(config.max_size));
assert!(
result.frequencies.len() <= config.frequency_bins.min(config.max_size / 4),
"Expected at most {} frequencies, got {}",
config.frequency_bins.min(config.max_size / 4),
result.frequencies.len()
);
let mut peak_scale = 0;
let mut max_energy = 0.0;
let mid_time = result.times.len() / 2;
eprintln!(
"Test CWT: Available frequencies: {:?}",
&result.frequencies[..result.frequencies.len().min(16)]
);
let computed_freqs = result.coefficients.shape()[0];
eprintln!(
"Test CWT: Number of computed frequencies: {}",
computed_freqs
);
for scale in 0..computed_freqs {
let energy = result.coefficients[[scale, mid_time]].norm_sqr();
if scale < 16 {
eprintln!(
" Freq[{}] = {:.1} Hz, Energy = {:.6}",
scale, result.frequencies[scale], energy
);
}
if energy > max_energy {
max_energy = energy;
peak_scale = scale;
}
}
let peak_freq = result.frequencies[peak_scale];
eprintln!(
"Test CWT: Expected freq: {}, Found peak freq: {}, Error: {:.2}%",
freq,
peak_freq,
((peak_freq - freq).abs() / freq * 100.0)
);
assert!(
(peak_freq - freq).abs() / freq < 0.35,
"Peak frequency {} is too far from expected {} (error: {:.2}%)",
peak_freq,
freq,
((peak_freq - freq).abs() / freq * 100.0)
); }
}