dasp_rs/signal_processing/
time_frequency.rs

1use rustfft::FftPlanner;
2use num_complex::Complex;
3use ndarray::{Array1, Array2, s};
4use crate::{utils::frequency::fft_frequencies, AudioData, AudioError};
5use std::f32::consts::{PI, SQRT_2};
6
7/// Computes the Short-Time Fourier Transform (STFT) of a signal.
8///
9/// # Arguments
10/// * `y` - Input signal as a slice of `f32`
11/// * `n_fft` - Optional FFT window size (defaults to 2048)
12/// * `hop_length` - Optional hop length in samples (defaults to n_fft/4, minimum 1)
13/// * `win_length` - Optional window length in samples (defaults to n_fft)
14///
15/// # Returns
16/// Returns a `Result` containing an `Array2<Complex<f32>>` representing the STFT spectrogram,
17/// with shape `(n_fft/2 + 1, n_frames)`, or an `AudioError` if array shaping fails.
18///
19/// # Examples
20/// ```
21/// let signal = vec![1.0, 2.0, 3.0, 4.0];
22/// let spectrogram = stft(&signal, None, None, None).unwrap();
23/// ```
24pub fn stft(
25    y: &[f32],
26    n_fft: Option<usize>,
27    hop_length: Option<usize>,
28    win_length: Option<usize>,
29) -> Result<Array2<Complex<f32>>, AudioError> {
30    let n = n_fft.unwrap_or(2048);
31    let hop = hop_length.unwrap_or(n / 4).max(1);
32    let win = win_length.unwrap_or(n);
33    let mut planner = FftPlanner::new();
34    let fft = planner.plan_fft_forward(n);
35    let mut buffer = vec![Complex::new(0.0, 0.0); n];
36    let mut spectrogram = Vec::new();
37
38    if y.len() < n {
39        let mut padded = vec![0.0; n];
40        padded[..y.len()].copy_from_slice(y);
41        buffer[..n].copy_from_slice(&padded.iter().map(|&x| Complex::new(x * hamming(0, win), 0.0)).collect::<Vec<_>>());
42        fft.process(&mut buffer);
43        spectrogram.push(buffer.clone());
44    } else {
45        for i in (0..y.len()).step_by(hop) {
46            let end = std::cmp::min(i + n, y.len());
47            buffer.fill(Complex::new(0.0, 0.0));
48            for (j, &sample) in y[i..end].iter().enumerate() {
49                buffer[j] = Complex::new(sample * hamming(j, win), 0.0);
50            }
51            fft.process(&mut buffer);
52            spectrogram.push(buffer.clone());
53        }
54    }
55
56    let n_frames = spectrogram.len();
57    Ok(Array2::from_shape_vec((n / 2 + 1, n_frames), spectrogram.into_iter().flat_map(|v| v.into_iter().take(n / 2 + 1)).collect())?)
58}
59
60/// Computes the inverse Short-Time Fourier Transform (iSTFT) to reconstruct a signal.
61///
62/// # Arguments
63/// * `stft_matrix` - STFT spectrogram as an `Array2<Complex<f32>>`
64/// * `hop_length` - Optional hop length in samples (defaults to n_fft/4, minimum 1)
65/// * `win_length` - Optional window length in samples (defaults to n_fft)
66/// * `length` - Optional output signal length in samples (defaults to maximum possible length)
67///
68/// # Returns
69/// Returns a `Vec<f32>` containing the reconstructed time-domain signal.
70///
71/// # Examples
72/// ```
73/// use ndarray::arr2;
74/// let stft_data = arr2(&[[Complex::new(1.0, 0.0)], [Complex::new(0.5, 0.0)]]);
75/// let signal = istft(&stft_data, None, None, None);
76/// ```
77pub fn istft(
78    stft_matrix: &Array2<Complex<f32>>,
79    hop_length: Option<usize>,
80    win_length: Option<usize>,
81    length: Option<usize>,
82) -> Vec<f32> {
83    let n_fft = (stft_matrix.shape()[0] - 1) * 2;
84    let hop = hop_length.unwrap_or(n_fft / 4).max(1);
85    let win = win_length.unwrap_or(n_fft);
86    let n_frames = stft_matrix.shape()[1];
87    let mut planner = FftPlanner::new();
88    let fft = planner.plan_fft_inverse(n_fft);
89
90    let max_len = hop * (n_frames - 1) + n_fft;
91    let target_len = length.unwrap_or(max_len);
92    let mut signal = vec![0.0; max_len];
93    let mut window_sum = vec![0.0; max_len];
94    let window = hamming_vec(win);
95
96    for (frame_idx, frame) in stft_matrix.axis_iter(ndarray::Axis(1)).enumerate() {
97        let mut buffer: Vec<Complex<f32>> = frame.to_vec();
98        buffer.extend(vec![Complex::new(0.0, 0.0); n_fft - buffer.len()]);
99        fft.process(&mut buffer);
100        let start = frame_idx * hop;
101        for (i, &val) in buffer.iter().enumerate().take(win) {
102            if start + i < signal.len() {
103                signal[start + i] += val.re * window[i];
104                window_sum[start + i] += window[i];
105            }
106        }
107    }
108
109    for (i, &sum) in window_sum.iter().enumerate() {
110        if sum > 1e-6 {
111            signal[i] /= sum;
112        }
113    }
114
115    signal.resize(target_len, 0.0);
116    signal
117}
118
119/// Computes the Hamming window value at a given sample index.
120///
121/// # Arguments
122/// * `n` - Sample index
123/// * `win_length` - Total window length
124///
125/// # Returns
126/// Returns a `f32` representing the Hamming window coefficient.
127///
128/// # Examples
129/// ```
130/// let value = hamming(0, 10);
131/// assert!(value > 0.0 && value <= 1.0);
132/// ```
133fn hamming(n: usize, win_length: usize) -> f32 {
134    0.54 - 0.46 * (2.0 * std::f32::consts::PI * n as f32 / (win_length - 1) as f32).cos()
135}
136
137/// Generates a Hamming window vector.
138///
139/// # Arguments
140/// * `win_length` - Length of the window
141///
142/// # Returns
143/// Returns a `Vec<f32>` containing the Hamming window coefficients.
144///
145/// # Examples
146/// ```
147/// let window = hamming_vec(5);
148/// assert_eq!(window.len(), 5);
149/// ```
150fn hamming_vec(win_length: usize) -> Vec<f32> {
151    (0..win_length).map(|n| hamming(n, win_length)).collect()
152}
153
154/// Separates magnitude and phase from a complex spectrogram.
155///
156/// # Arguments
157/// * `D` - Input spectrogram as an `Array2<Complex<f32>>`
158/// * `power` - Optional power to raise the magnitude (defaults to 1.0)
159///
160/// # Returns
161/// Returns a tuple `(magnitude, phase)` where:
162/// - `magnitude` is an `Array2<f32>` of magnitude values
163/// - `phase` is an `Array2<Complex<f32>>` of unit-magnitude phase values
164///
165/// # Examples
166/// ```
167/// use ndarray::arr2;
168/// let spectrogram = arr2(&[[Complex::new(3.0, 4.0)]]);
169/// let (mag, phase) = magphase(&spectrogram, None);
170/// assert_eq!(mag[[0, 0]], 5.0); // sqrt(3^2 + 4^2)
171/// ```
172pub fn magphase(d: &Array2<Complex<f32>>, power: Option<f32>) -> (Array2<f32>, Array2<Complex<f32>>) {
173    let power_val = power.unwrap_or(1.0);
174    let magnitude = d.mapv(|x| x.norm().powf(power_val));
175    let phase = d.mapv(|x| x / x.norm());
176    (magnitude, phase)
177}
178
179/// Computes a reassigned spectrogram for improved time-frequency resolution.
180///
181/// # Arguments
182/// * `y` - Input signal as a slice of `f32`
183/// * `sr` - Optional sample rate in Hz (defaults to 44100)
184/// * `n_fft` - Optional FFT window size (defaults to 2048)
185///
186/// # Returns
187/// Returns a `Result` containing an `Array2<f32>` representing the reassigned spectrogram,
188/// or an `AudioError` if computation fails.
189///
190/// # Errors
191/// * `AudioError::InsufficientData` - If signal length is less than `n_fft`.
192/// * `AudioError::ComputationFailed` - If STFT computation fails.
193///
194/// # Examples
195/// ```
196/// let signal = vec![1.0; 4096];
197/// let reassigned = reassigned_spectrogram(&signal, None, None).unwrap();
198/// ```
199pub fn reassigned_spectrogram(
200    y: &[f32],
201    sr: Option<u32>,
202    n_fft: Option<usize>,
203) -> Result<Array2<f32>, AudioError> {
204    let sr = sr.unwrap_or(44100);
205    let n_fft = n_fft.unwrap_or(2048);
206    let hop_length = n_fft / 4;
207
208    if y.len() < n_fft {
209        return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), n_fft)));
210    }
211
212    let s = stft(y, Some(n_fft), Some(hop_length), None)
213        .map_err(|e| AudioError::ComputationFailed(format!("STFT failed: {}", e)))?;
214    let s_time = stft_with_derivative(y, Some(n_fft), Some(hop_length), true)?;
215    let s_freq = stft_with_derivative(y, Some(n_fft), Some(hop_length), false)?;
216
217    let mut reassigned = Array2::zeros(s.dim());
218    let freqs = fft_frequencies(Some(sr), Some(n_fft));
219    let times = Array1::linspace(0.0, (y.len() as f32 - 1.0) / sr as f32, s.shape()[1]);
220
221    for t in 0..s.shape()[1] {
222        for f in 0..s.shape()[0] {
223            let mag = s[[f, t]].norm();
224            if mag > 1e-6 {
225                let dphi_dt = s_time[[f, t]].im / mag;
226                let t_reassigned = times[t] - dphi_dt * hop_length as f32 / sr as f32;
227                let dphi_df = s_freq[[f, t]].im / mag;
228                let f_reassigned = freqs[f] + dphi_df * sr as f32 / n_fft as f32;
229
230                let t_idx = ((t_reassigned * sr as f32 / hop_length as f32).round() as usize).min(s.shape()[1] - 1);
231                let f_idx = freqs.iter().position(|&x| x >= f_reassigned).unwrap_or(f).min(s.shape()[0] - 1);
232                reassigned[[f_idx, t_idx]] += mag;
233            }
234        }
235    }
236
237    Ok(reassigned)
238}
239
240/// Computes the Constant-Q Transform (CQT) of a signal.
241///
242/// # Arguments
243/// * `y` - Input signal as a slice of `f32`
244/// * `sr` - Optional sample rate in Hz (defaults to 44100)
245/// * `hop_length` - Optional hop length in samples (defaults to 512)
246/// * `fmin` - Optional minimum frequency in Hz (defaults to 32.70, C1)
247/// * `n_bins` - Optional number of frequency bins (defaults to 84)
248///
249/// # Returns
250/// Returns a `Result` containing an `Array2<Complex<f32>>` representing the CQT spectrogram,
251/// or an `AudioError` if computation fails.
252///
253/// # Errors
254/// * `AudioError::InsufficientData` - If signal length is less than `hop_length`.
255/// * `AudioError::InvalidInput` - If `fmin` is not positive.
256/// * `AudioError::ComputationFailed` - If STFT computation fails.
257///
258/// # Examples
259/// ```
260/// let signal = vec![1.0; 1024];
261/// let cqt_result = cqt(&signal, None, None, None, None).unwrap();
262/// ```
263pub fn cqt(
264    signal: &AudioData,
265    hop_length: Option<usize>,
266    fmin: Option<f32>,
267    n_bins: Option<usize>,
268) -> Result<Array2<Complex<f32>>, AudioError> {
269    let sr = signal.sample_rate;
270    let y = &signal.samples;
271    let hop_length = hop_length.unwrap_or(512);
272    let fmin = fmin.unwrap_or(32.70);
273    let n_bins = n_bins.unwrap_or(84);
274    let bins_per_octave = 12;
275
276    if y.len() < hop_length {
277        return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), hop_length)));
278    }
279    if fmin <= 0.0 {
280        return Err(AudioError::InvalidInput("fmin must be positive".to_string()));
281    }
282
283    let n_fft = ((sr as f32 / fmin * 2.0) as u32).next_power_of_two() as usize;
284    let s_stft = stft(y, Some(n_fft), Some(hop_length), None)
285        .map_err(|e| AudioError::ComputationFailed(format!("STFT failed: {}", e)))?;
286    let n_frames = s_stft.shape()[1];
287    let mut s_cqt = Array2::zeros((n_bins, n_frames));
288
289    let mut planner = FftPlanner::new();
290    let fft = planner.plan_fft_forward(n_fft);
291    for k in 0..n_bins {
292        let fk = fmin * 2.0f32.powf(k as f32 / bins_per_octave as f32);
293        let n = (sr as f32 / fk).round() as usize;
294        let mut kernel = Array1::zeros(n_fft);
295        let window = hann_window(n);
296        for i in 0..n {
297            let phase = 2.0 * PI * fk * i as f32 / sr as f32;
298            kernel[i] = Complex::new(window[i] * phase.cos(), window[i] * phase.sin()) / n as f32;
299        }
300        fft.process(&mut kernel.to_vec());
301
302        for t in 0..n_frames {
303            let stft_frame = s_stft.slice(s![.., t]);
304            s_cqt[[k, t]] = stft_frame.iter().zip(kernel.iter()).map(|(&s, &k)| s * k.conj()).sum();
305        }
306    }
307
308    Ok(s_cqt)
309}
310
311/// Computes the inverse Constant-Q Transform (iCQT) to reconstruct a signal.
312///
313/// # Arguments
314/// * `C` - CQT spectrogram as an `Array2<Complex<f32>>`
315/// * `sr` - Optional sample rate in Hz (defaults to 44100)
316/// * `hop_length` - Optional hop length in samples (defaults to 512)
317/// * `fmin` - Optional minimum frequency in Hz (defaults to 32.70, C1)
318///
319/// # Returns
320/// Returns a `Result` containing a `Vec<f32>` of the reconstructed signal,
321/// or an `AudioError` if computation fails.
322///
323/// # Errors
324/// * `AudioError::InvalidInput` - If `fmin` is not positive.
325///
326/// # Examples
327/// ```
328/// use ndarray::arr2;
329/// let cqt_data = arr2(&[[Complex::new(1.0, 0.0)]]);
330/// let signal = icqt(&cqt_data, None, None, None).unwrap();
331/// ```
332pub fn icqt(
333    c: &Array2<Complex<f32>>,
334    sr: Option<u32>,
335    hop_length: Option<usize>,
336    fmin: Option<f32>,
337) -> Result<Vec<f32>, AudioError> {
338    let sr = sr.unwrap_or(44100);
339    let hop_length = hop_length.unwrap_or(512);
340    let fmin = fmin.unwrap_or(32.70);
341    let n_bins = c.shape()[0];
342    let n_frames = c.shape()[1];
343    let bins_per_octave = 12;
344
345    if fmin <= 0.0 {
346        return Err(AudioError::InvalidInput("fmin must be positive".to_string()));
347    }
348
349    let n_fft = ((sr as f32 / fmin * 2.0) as u32).next_power_of_two() as usize;
350    let n_samples = n_frames * hop_length;
351    let mut y = vec![0.0; n_samples];
352    let mut planner = FftPlanner::new();
353    let ifft = planner.plan_fft_inverse(n_fft);
354
355    for k in 0..n_bins {
356        let fk = fmin * 2.0f32.powf(k as f32 / bins_per_octave as f32);
357        let n = (sr as f32 / fk).round() as usize;
358        let window = hann_window(n);
359        let mut kernel = Array1::zeros(n_fft);
360        for i in 0..n {
361            let phase = 2.0 * PI * fk * i as f32 / sr as f32;
362            kernel[i] = Complex::new(window[i] * phase.cos(), window[i] * phase.sin()) / n as f32;
363        }
364        ifft.process(&mut kernel.to_vec());
365
366        for t in 0..n_frames {
367            let mut frame = vec![Complex::new(c[[k, t]].re, c[[k, t]].im) * Complex::conj(&kernel[0]); n_fft];
368            ifft.process(&mut frame);
369            let start = t * hop_length;
370            for i in 0..n.min(n_samples - start) {
371                y[start + i] += frame[i].re * window[i];
372            }
373        }
374    }
375
376    let mut overlap = vec![0.0; n_samples];
377    for t in 0..n_frames {
378        let start = t * hop_length;
379        for i in 0..n_fft.min(n_samples - start) {
380            overlap[start + i] += hann_window(n_fft)[i].powi(2);
381        }
382    }
383    for i in 0..n_samples {
384        if overlap[i] > 1e-6 {
385            y[i] /= overlap[i];
386        }
387    }
388
389    Ok(y)
390}
391
392/// Computes a hybrid Constant-Q Transform (CQT) combining STFT and CQT properties.
393///
394/// # Arguments
395/// * `y` - Input signal as a slice of `f32`
396/// * `sr` - Optional sample rate in Hz (defaults to 44100)
397/// * `hop_length` - Optional hop length in samples (defaults to 512)
398/// * `fmin` - Optional minimum frequency in Hz (defaults to 32.70, C1)
399///
400/// # Returns
401/// Returns a `Result` containing an `Array2<Complex<f32>>` representing the hybrid CQT,
402/// or an `AudioError` if computation fails.
403///
404/// # Errors
405/// * `AudioError::InsufficientData` - If signal length is less than `n_fft`.
406/// * `AudioError::InvalidInput` - If `fmin` is not positive.
407/// * `AudioError::ComputationFailed` - If STFT computation fails.
408///
409/// # Examples
410/// ```
411/// let signal = vec![1.0; 1024];
412/// let hybrid = hybrid_cqt(&signal, None, None, None).unwrap();
413/// ```
414pub fn hybrid_cqt(
415    y: &[f32],
416    sr: Option<u32>,
417    hop_length: Option<usize>,
418    fmin: Option<f32>,
419) -> Result<Array2<Complex<f32>>, AudioError> {
420    let sr = sr.unwrap_or(44100);
421    let hop_length = hop_length.unwrap_or(512);
422    let fmin = fmin.unwrap_or(32.70);
423    let n_fft = ((sr as f32 / fmin * 2.0) as u32).next_power_of_two() as usize;
424    let n_bins = 84;
425
426    if y.len() < n_fft {
427        return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), n_fft)));
428    }
429    if fmin <= 0.0 {
430        return Err(AudioError::InvalidInput("fmin must be positive".to_string()));
431    }
432
433    let s_stft = stft(y, Some(n_fft), Some(hop_length), None)
434        .map_err(|e| AudioError::ComputationFailed(format!("STFT failed: {}", e)))?;
435    let mut s_hybrid = Array2::zeros((n_bins, s_stft.shape()[1]));
436    let mut planner = FftPlanner::new();
437    let fft = planner.plan_fft_forward(n_fft);
438
439    for k in 0..n_bins {
440        let fk = fmin * 2.0f32.powf(k as f32 / 12.0);
441        let n = (sr as f32 / fk).round() as usize;
442        let mut kernel = Array1::zeros(n_fft);
443        let window = hann_window(n);
444        for i in 0..n {
445            let phase = 2.0 * PI * fk * i as f32 / sr as f32;
446            kernel[i] = Complex::new(window[i] * phase.cos(), window[i] * phase.sin()) / n as f32;
447        }
448        fft.process(&mut kernel.to_vec());
449
450        for t in 0..s_stft.shape()[1] {
451            s_hybrid[[k, t]] = s_stft.slice(s![.., t]).iter().zip(kernel.iter()).map(|(&s, &k)| s * k.conj()).sum();
452        }
453    }
454
455    Ok(s_hybrid)
456}
457
458/// Computes a pseudo Constant-Q Transform (CQT) using STFT bin mapping.
459///
460/// # Arguments
461/// * `y` - Input signal as a slice of `f32`
462/// * `sr` - Optional sample rate in Hz (defaults to 44100)
463/// * `hop_length` - Optional hop length in samples (defaults to 512)
464/// * `fmin` - Optional minimum frequency in Hz (defaults to 32.70, C1)
465///
466/// # Returns
467/// Returns a `Result` containing an `Array2<Complex<f32>>` representing the pseudo CQT,
468/// or an `AudioError` if computation fails.
469///
470/// # Errors
471/// * `AudioError::InsufficientData` - If signal length is less than `n_fft`.
472/// * `AudioError::InvalidInput` - If `fmin` is not positive.
473/// * `AudioError::ComputationFailed` - If STFT computation fails.
474///
475/// # Examples
476/// ```
477/// let signal = vec![1.0; 1024];
478/// let pseudo = pseudo_cqt(&signal, None, None, None).unwrap();
479/// ```
480pub fn pseudo_cqt(
481    y: &[f32],
482    sr: Option<u32>,
483    hop_length: Option<usize>,
484    fmin: Option<f32>,
485) -> Result<Array2<Complex<f32>>, AudioError> {
486    let sr = sr.unwrap_or(44100);
487    let hop_length = hop_length.unwrap_or(512);
488    let fmin = fmin.unwrap_or(32.70);
489    let n_fft = ((sr as f32 / fmin * 2.0) as u32).next_power_of_two() as usize;
490    let n_bins = 84;
491
492    if y.len() < n_fft {
493        return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), n_fft)));
494    }
495    if fmin <= 0.0 {
496        return Err(AudioError::InvalidInput("fmin must be positive".to_string()));
497    }
498
499    let s_stft = stft(y, Some(n_fft), Some(hop_length), None)
500        .map_err(|e| AudioError::ComputationFailed(format!("STFT failed: {}", e)))?;
501    let mut s_pseudo = Array2::zeros((n_bins, s_stft.shape()[1]));
502    let freqs = fft_frequencies(Some(sr), Some(n_fft));
503
504    for t in 0..s_stft.shape()[1] {
505        for k in 0..n_bins {
506            let fk = fmin * 2.0f32.powf(k as f32 / 12.0);
507            let idx = freqs.iter().position(|&f| f >= fk).unwrap_or(0);
508            s_pseudo[[k, t]] = s_stft[[idx.min(s_stft.shape()[0] - 1), t]];
509        }
510    }
511
512    Ok(s_pseudo)
513}
514
515/// Computes the Variable-Q Transform (VQT) of a signal.
516///
517/// # Arguments
518/// * `y` - Input signal as a slice of `f32`
519/// * `sr` - Optional sample rate in Hz (defaults to 44100)
520/// * `hop_length` - Optional hop length in samples (defaults to 512)
521/// * `fmin` - Optional minimum frequency in Hz (defaults to 32.70, C1)
522/// * `n_bins` - Optional number of frequency bins (defaults to 84)
523///
524/// # Returns
525/// Returns a `Result` containing an `Array2<Complex<f32>>` representing the VQT,
526/// or an `AudioError` if computation fails.
527///
528/// # Errors
529/// * `AudioError::InsufficientData` - If signal length is less than `hop_length`.
530/// * `AudioError::InvalidInput` - If `fmin` is not positive.
531/// * `AudioError::ComputationFailed` - If STFT computation fails.
532///
533/// # Examples
534/// ```
535/// let signal = vec![1.0; 1024];
536/// let vqt_result = vqt(&signal, None, None, None, None).unwrap();
537/// ```
538pub fn vqt(
539    y: &[f32],
540    sr: Option<u32>,
541    hop_length: Option<usize>,
542    fmin: Option<f32>,
543    n_bins: Option<usize>,
544) -> Result<Array2<Complex<f32>>, AudioError> {
545    let sr = sr.unwrap_or(44100);
546    let hop_length = hop_length.unwrap_or(512);
547    let fmin = fmin.unwrap_or(32.70);
548    let n_bins = n_bins.unwrap_or(84);
549    let gamma = 24.0;
550
551    if y.len() < hop_length {
552        return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), hop_length)));
553    }
554    if fmin <= 0.0 {
555        return Err(AudioError::InvalidInput("fmin must be positive".to_string()));
556    }
557
558    let n_fft = ((sr as f32 / fmin * 2.0) as u32).next_power_of_two() as usize;
559    let s_stft = stft(y, Some(n_fft), Some(hop_length), None)
560        .map_err(|e| AudioError::ComputationFailed(format!("STFT failed: {}", e)))?;
561    let mut s_vqt = Array2::zeros((n_bins, s_stft.shape()[1]));
562    let mut planner = FftPlanner::new();
563    let fft = planner.plan_fft_forward(n_fft);
564
565    for k in 0..n_bins {
566        let fk = fmin * 2.0f32.powf(k as f32 / 12.0);
567        let q = gamma / (2.0f32.powf(1.0 / 12.0) - 1.0);
568        let n = (sr as f32 * q / fk).round() as usize;
569        let mut kernel = Array1::zeros(n_fft);
570        let window = hann_window(n);
571        for i in 0..n {
572            let phase = 2.0 * PI * fk * i as f32 / sr as f32;
573            kernel[i] = Complex::new(window[i] * phase.cos(), window[i] * phase.sin()) / n as f32;
574        }
575        fft.process(&mut kernel.to_vec());
576
577        for t in 0..s_stft.shape()[1] {
578            s_vqt[[k, t]] = s_stft.slice(s![.., t]).iter().zip(kernel.iter()).map(|(&s, &k)| s * k.conj()).sum();
579        }
580    }
581
582    Ok(s_vqt)
583}
584
585/// Computes the Fourier Modulation Transform (FMT) of a signal.
586///
587/// # Arguments
588/// * `y` - Input signal as a slice of `f32`
589/// * `t_min` - Optional minimum time period in seconds (defaults to 0.005)
590/// * `n_fmt` - Optional number of modulation frequencies (defaults to 5)
591/// * `kind` - Optional transform kind ("cos" or others, defaults to "cos")
592/// * `beta` - Optional power for magnitude scaling (defaults to 2.0)
593///
594/// # Returns
595/// Returns a `Result` containing an `Array2<f32>` representing the FMT spectrogram,
596/// or an `AudioError` if computation fails.
597///
598/// # Errors
599/// * `AudioError::InsufficientData` - If signal length is less than `hop_length`.
600/// * `AudioError::InvalidInput` - If `t_min` is not positive.
601///
602/// # Examples
603/// ```
604/// let signal = vec![1.0; 1024];
605/// let fmt_result = fmt(&signal, None, None, None, None).unwrap();
606/// ```
607pub fn fmt(
608    y: &[f32],
609    t_min: Option<f32>,
610    n_fmt: Option<usize>,
611    kind: Option<&str>,
612    beta: Option<f32>,
613) -> Result<Array2<f32>, AudioError> {
614    let sr = 44100;
615    let t_min = t_min.unwrap_or(0.005);
616    let n_fmt = n_fmt.unwrap_or(5);
617    let _kind = kind.unwrap_or("cos");
618    let beta = beta.unwrap_or(2.0);
619    let hop_length = (sr as f32 * t_min).round() as usize;
620
621    if y.len() < hop_length {
622        return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), hop_length)));
623    }
624    if t_min <= 0.0 {
625        return Err(AudioError::InvalidInput("t_min must be positive".to_string()));
626    }
627
628    let n_frames = (y.len() - hop_length) / hop_length + 1;
629    let mut s = Array2::zeros((n_fmt, n_frames));
630    let window = hann_window(hop_length);
631
632    for t in 0..n_frames {
633        let start = t * hop_length;
634        let frame = &y[start..(start + hop_length).min(y.len())];
635        for k in 0..n_fmt {
636            let freq = (k + 1) as f32 / t_min;
637            let mut sum_re = 0.0;
638            let mut sum_im = 0.0;
639            for (i, &sample) in frame.iter().enumerate() {
640                let phase = 2.0 * PI * freq * i as f32 / sr as f32;
641                let w = window[i];
642                sum_re += sample * w * phase.cos();
643                sum_im += sample * w * phase.sin();
644            }
645            let mag = Complex::new(sum_re, sum_im).norm() / hop_length as f32;
646            s[[k, t]] = mag.powf(beta);
647        }
648    }
649
650    Ok(s)
651}
652
653/// Generates a Hann window vector.
654///
655/// # Arguments
656/// * `n` - Length of the window
657///
658/// # Returns
659/// Returns a `Vec<f32>` containing the Hann window coefficients.
660///
661/// # Examples
662/// ```
663/// let window = hann_window(5);
664/// assert_eq!(window.len(), 5);
665/// ```
666fn hann_window(n: usize) -> Vec<f32> {
667    (0..n).map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / (n - 1) as f32).cos())).collect()
668}
669
670/// Computes STFT with time or frequency derivative for reassignment.
671///
672/// # Arguments
673/// * `y` - Input signal as a slice of `f32`
674/// * `n_fft` - Optional FFT window size (defaults to 2048)
675/// * `hop_length` - Optional hop length in samples (defaults to n_fft/4)
676/// * `time_derivative` - If true, computes time derivative; if false, frequency derivative
677///
678/// # Returns
679/// Returns a `Result` containing an `Array2<Complex<f32>>` with derivative information,
680/// or an `AudioError` if computation fails.
681///
682/// # Examples
683/// ```
684/// let signal = vec![1.0; 2048];
685/// let deriv = stft_with_derivative(&signal, None, None, true).unwrap();
686/// ```
687fn stft_with_derivative(
688    y: &[f32],
689    n_fft: Option<usize>,
690    hop_length: Option<usize>,
691    time_derivative: bool,
692) -> Result<Array2<Complex<f32>>, AudioError> {
693    let n_fft = n_fft.unwrap_or(2048);
694    let hop_length = hop_length.unwrap_or(n_fft / 4);
695    let n_frames = (y.len() - n_fft) / hop_length + 1;
696    let mut planner = FftPlanner::new();
697    let fft = planner.plan_fft_forward(n_fft);
698    let mut s = Array2::zeros((n_fft / 2 + 1, n_frames));
699    let window = hann_window(n_fft);
700    let deriv_window = if time_derivative {
701        (0..n_fft).map(|i| i as f32 * window[i]).collect::<Vec<_>>()
702    } else {
703        (0..n_fft).map(|i| window[i] * (2.0 * PI * i as f32 / n_fft as f32).sin()).collect::<Vec<_>>()
704    };
705
706    for t in 0..n_frames {
707        let start = t * hop_length;
708        let frame = &y[start..(start + n_fft).min(y.len())];
709        let mut buffer = frame.iter().zip(deriv_window.iter()).map(|(&x, &w)| Complex::new(x * w, 0.0)).collect::<Vec<_>>();
710        buffer.resize(n_fft, Complex::new(0.0, 0.0));
711        fft.process(&mut buffer);
712        for f in 0..n_fft / 2 + 1 {
713            s[[f, t]] = buffer[f];
714        }
715    }
716    Ok(s)
717}
718
719/// Designs a Butterworth bandpass filter.
720///
721/// # Arguments
722/// * `lowcut` - Lower cutoff frequency in Hz
723/// * `highcut` - Upper cutoff frequency in Hz
724/// * `fs` - Sampling frequency in Hz
725/// * `order` - Optional filter order (defaults to 2)
726///
727/// # Returns
728/// Returns a `Result` containing a tuple `(b, a)` of numerator and denominator coefficients,
729/// or an `AudioError` if frequencies are invalid.
730///
731/// # Errors
732/// * `AudioError::InvalidInput` - If `lowcut` <= 0, `highcut` <= `lowcut`, or `highcut` >= `fs/2`.
733///
734/// # Examples
735/// ```
736/// let (b, a) = butterworth_bandpass(100.0, 1000.0, 44100.0, None).unwrap();
737/// ```
738fn butterworth_bandpass(lowcut: f32, highcut: f32, fs: f32, order: Option<usize>) -> Result<(Vec<f32>, Vec<f32>), AudioError> {
739    if lowcut <= 0.0 || highcut <= lowcut || highcut >= fs / 2.0 {
740        return Err(AudioError::InvalidInput(format!(
741            "Invalid frequencies: lowcut={} must be > 0, highcut={} must be > lowcut and < fs/2={}",
742            lowcut, highcut, fs / 2.0
743        )));
744    }
745
746    let order = order.unwrap_or(2);
747    let n = order as i32;
748
749    let w_low = 2.0 * fs * (lowcut * PI / fs).tan();
750    let w_high = 2.0 * fs * (highcut * PI / fs).tan();
751    let w0 = (w_high * w_low).sqrt();
752    let bw = w_high - w_low;
753
754    let mut poles = Vec::new();
755    for k in 0..n {
756        let theta = PI * (2.0 * k as f32 + 1.0 + n as f32) / (2.0 * n as f32);
757        let real = -bw / 2.0 * theta.sin();
758        let imag = w0 * theta.cos();
759        poles.push(Complex::new(real, imag));
760        poles.push(Complex::new(real, -imag));
761    }
762
763    let mut z_poles = Vec::new();
764    let fs2 = 2.0 * fs;
765    for p in poles {
766        let pz = (fs2 + p) / (fs2 - p);
767        z_poles.push(pz);
768    }
769
770    let mut b = vec![1.0];
771    let mut a = vec![1.0];
772    for p in z_poles.iter() {
773        b = convolve(&b, &[1.0, -p.re]);
774        a = convolve(&a, &[1.0, -p.re]);
775    }
776    for _ in 0..n {
777        b = convolve(&b, &[1.0, 0.0]);
778    }
779
780    let w_center = 2.0 * PI * (lowcut + highcut) / 2.0 / fs;
781    let gain = evaluate_filter(&b, &a, w_center).norm();
782    for b_k in b.iter_mut() {
783        *b_k /= gain;
784    }
785
786    Ok((b, a))
787}
788
789/// Convolves two vectors.
790///
791/// # Arguments
792/// * `a` - First input vector
793/// * `b` - Second input vector
794///
795/// # Returns
796/// Returns a `Vec<f32>` containing the convolution result.
797///
798/// # Examples
799/// ```
800/// let result = convolve(&[1.0, 2.0], &[3.0, 4.0]);
801/// assert_eq!(result, vec![3.0, 10.0, 8.0]);
802/// ```
803fn convolve(a: &[f32], b: &[f32]) -> Vec<f32> {
804    let mut result = vec![0.0; a.len() + b.len() - 1];
805    for i in 0..a.len() {
806        for j in 0..b.len() {
807            result[i + j] += a[i] * b[j];
808        }
809    }
810    result
811}
812
813/// Evaluates a digital filter's frequency response at a given frequency.
814///
815/// # Arguments
816/// * `b` - Numerator coefficients
817/// * `a` - Denominator coefficients
818/// * `w` - Frequency in radians/sample
819///
820/// # Returns
821/// Returns a `Complex<f32>` representing the filter's response.
822///
823/// # Examples
824/// ```
825/// let response = evaluate_filter(&[1.0], &[1.0, -0.5], 0.1);
826/// ```
827fn evaluate_filter(b: &[f32], a: &[f32], w: f32) -> Complex<f32> {
828    let mut num = Complex::new(0.0, 0.0);
829    let mut den = Complex::new(0.0, 0.0);
830    for (k, &bk) in b.iter().enumerate() {
831        let phase = -w * k as f32;
832        num += Complex::new(bk * phase.cos(), bk * phase.sin());
833    }
834    for (k, &ak) in a.iter().enumerate() {
835        let phase = -w * k as f32;
836        den += Complex::new(ak * phase.cos(), ak * phase.sin());
837    }
838    num / den
839}
840
841/// Computes the Instantaneous Impulse Response Transform (IIRT) using bandpass filtering.
842///
843/// # Arguments
844/// * `y` - Input signal as a slice of `f32`
845/// * `sr` - Optional sample rate in Hz (defaults to 44100)
846/// * `win_length` - Optional window length in samples (defaults to 2048)
847/// * `hop_length` - Optional hop length in samples (defaults to win_length/4)
848///
849/// # Returns
850/// Returns a `Result` containing an `Array2<f32>` representing the IIRT spectrogram,
851/// or an `AudioError` if computation fails.
852///
853/// # Errors
854/// * `AudioError::InsufficientData` - If signal length is less than `win_length`.
855/// * `AudioError::InvalidInput` - If bandpass filter frequencies are invalid.
856///
857/// # Examples
858/// ```
859/// let signal = vec![1.0; 4096];
860/// let iirt_result = iirt(&signal, None, None, None).unwrap();
861/// ```
862pub fn iirt(
863    y: &[f32],
864    sr: Option<u32>,
865    win_length: Option<usize>,
866    hop_length: Option<usize>,
867) -> Result<Array2<f32>, AudioError> {
868    let sr = sr.unwrap_or(44100);
869    let win_length = win_length.unwrap_or(2048);
870    let hop_length = hop_length.unwrap_or(win_length / 4);
871    let n_bands = 12;
872
873    if y.len() < win_length {
874        return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), win_length)));
875    }
876
877    let n_frames = (y.len() - win_length) / hop_length + 1;
878    let mut s = Array2::zeros((n_bands, n_frames));
879    let fmin = 32.70;
880
881    for b in 0..n_bands {
882        let fc = fmin * 2.0f32.powf(b as f32);
883        let bw = fc / SQRT_2;
884        let (b_coeffs, a_coeffs) = butterworth_bandpass(fc - bw / 2.0, fc + bw / 2.0, sr as f32, Some(4))?;
885        
886        for t in 0..n_frames {
887            let start = t * hop_length;
888            let frame = &y[start..(start + win_length).min(y.len())];
889            let filtered = filter(frame, &b_coeffs, &a_coeffs);
890            s[[b, t]] = filtered.iter().map(|&x| x.powi(2)).sum::<f32>().sqrt() / win_length as f32;
891        }
892    }
893
894    Ok(s)
895}
896
897/// Applies an IIR filter to a signal.
898///
899/// # Arguments
900/// * `x` - Input signal as a slice of `f32`
901/// * `b` - Numerator coefficients
902/// * `a` - Denominator coefficients
903///
904/// # Returns
905/// Returns a `Vec<f32>` containing the filtered signal.
906///
907/// # Examples
908/// ```
909/// let signal = vec![1.0, 2.0, 3.0];
910/// let filtered = filter(&signal, &[1.0, 0.0, 0.0], &[1.0, -0.5, 0.0]);
911/// ```
912fn filter(x: &[f32], b: &[f32], a: &[f32]) -> Vec<f32> {
913    let mut y = vec![0.0; x.len()];
914    for n in 0..x.len() {
915        y[n] = b[0] * x[n] + b[1] * x.get(n - 1).unwrap_or(&0.0) + b[2] * x.get(n - 2).unwrap_or(&0.0)
916            - a[1] * y.get(n - 1).unwrap_or(&0.0) - a[2] * y.get(n - 2).unwrap_or(&0.0);
917    }
918    y
919}