Skip to main content

dasp_rs/features/
inverse.rs

1use crate::core::io::AudioData;
2use crate::features::phase_recovery::griffinlim;
3use crate::utils::frequency::{fft_frequencies, mel_frequencies};
4use ndarray::{Array2, Axis};
5use rayon::prelude::*;
6use thiserror::Error;
7
8/// Error conditions for MFCC processing and reconstruction.
9///
10/// Enumerates specific failure modes in MFCC delta computation and spectrogram/audio
11/// reconstruction, tailored for DSP pipeline diagnostics.
12#[derive(Error, Debug)]
13pub enum MfccError {
14    /// Input dimensions are invalid (e.g., empty matrix).
15    #[error("Invalid dimensions: {0}")]
16    InvalidDimensions(String),
17
18    /// Input parameters are invalid (e.g., negative width, even width).
19    #[error("Invalid parameter: {0}")]
20    InvalidInput(String),
21
22    /// Numerical computation failure (e.g., overflow in reconstruction).
23    #[error("Computation failed: {0}")]
24    ComputationFailed(String),
25}
26
27/// Computes the first-order delta coefficients of MFCCs.
28///
29/// # Parameters
30/// - `mfcc`: Input MFCC matrix, shape `(n_mfcc, n_frames)`.
31/// - `width`: Optional window width for delta computation; defaults to 9.
32/// - `axis`: Optional time axis; -1 (frames) or 0 (mfcc), defaults to -1.
33///
34/// # Returns
35/// - `Ok(Array2<f32>)`: Delta coefficients, same shape as input MFCCs.
36/// - `Err(MfccError)`: Failure due to invalid input or dimensions.
37///
38/// # Constraints
39/// - `width` must be a positive odd integer.
40/// - `mfcc` must have at least `width` elements along the time axis.
41pub fn compute_delta(
42    mfcc: &Array2<f32>,
43    width: Option<usize>,
44    axis: Option<isize>,
45) -> Result<Array2<f32>, MfccError> {
46    let width = width.unwrap_or(9);
47    let axis = axis.unwrap_or(-1);
48
49    if width == 0 || width % 2 == 0 {
50        return Err(MfccError::InvalidInput(
51            "Width must be a positive odd integer".to_string(),
52        ));
53    }
54    let ax = if axis < 0 { 1 } else { 0 };
55    let (n_mfcc, n_frames) = if ax == 1 {
56        mfcc.dim()
57    } else {
58        (mfcc.shape()[1], mfcc.shape()[0])
59    };
60    if n_frames == 0 || n_mfcc == 0 {
61        return Err(MfccError::InvalidDimensions(
62            "MFCC matrix is empty".to_string(),
63        ));
64    }
65    if n_frames < width {
66        return Err(MfccError::InvalidDimensions(format!(
67            "Time axis length {} less than width {}",
68            n_frames, width
69        )));
70    }
71
72    let half_width = width / 2;
73    let weights: Vec<f32> = (-(half_width as isize)..=half_width as isize)
74        .map(|i| i as f32)
75        .collect();
76    let norm = weights.iter().map(|x| x.powi(2)).sum::<f32>();
77    if norm == 0.0 {
78        return Err(MfccError::ComputationFailed(
79            "Normalization factor is zero".to_string(),
80        ));
81    }
82
83    let mut delta = Array2::zeros(mfcc.dim());
84    delta
85        .axis_iter_mut(Axis(ax))
86        .into_par_iter()
87        .enumerate()
88        .for_each(|(i, mut slice)| {
89            let row = mfcc.index_axis(Axis(ax), i);
90            for j in 0..row.len() {
91                let mut sum = 0.0;
92                for (w_idx, &w) in weights.iter().enumerate() {
93                    let offset = w_idx as isize - half_width as isize;
94                    let idx = (j as isize + offset).clamp(0, row.len() as isize - 1) as usize;
95                    sum += w * row[idx];
96                }
97                slice[j] = sum / norm;
98            }
99        });
100
101    Ok(delta)
102}
103
104/// Converts mel spectrogram to STFT magnitude spectrogram.
105///
106/// Reconstructs an STFT magnitude spectrogram from a mel spectrogram using inverse mel
107/// filterbank weighting, ensuring energy preservation.
108///
109/// # Parameters
110/// - `m`: Mel spectrogram, shape `(n_mels, n_frames)`.
111/// - `sr`: Optional sample rate in Hz; defaults to 44100.
112/// - `n_fft`: Optional FFT size; defaults to 2048.
113/// - `power`: Optional power of input spectrogram; defaults to 2.0.
114///
115/// # Returns
116/// - `Ok(Array2<f32>)`: STFT magnitude spectrogram, shape `(n_fft/2 + 1, n_frames)`.
117/// - `Err(MfccError)`: Failure due to invalid dimensions or parameters.
118pub fn mel_to_stft(
119    m: &Array2<f32>,
120    sr: Option<u32>,
121    n_fft: Option<usize>,
122    power: Option<f32>,
123) -> Result<Array2<f32>, MfccError> {
124    let sr = sr.unwrap_or(44100);
125    let n_fft = n_fft.unwrap_or(2048);
126    let power = power.unwrap_or(2.0);
127    if m.is_empty() {
128        return Err(MfccError::InvalidDimensions(
129            "Mel spectrogram is empty".to_string(),
130        ));
131    }
132    if n_fft < 2 {
133        return Err(MfccError::InvalidInput(
134            "n_fft must be at least 2".to_string(),
135        ));
136    }
137    if power <= 0.0 {
138        return Err(MfccError::InvalidInput(
139            "Power must be positive".to_string(),
140        ));
141    }
142
143    let n_mels = m.shape()[0];
144    let n_frames = m.shape()[1];
145    let mel_f = mel_frequencies(Some(n_mels + 2), None, Some(sr as f32 / 2.0), None);
146    let fft_f = fft_frequencies(Some(sr), Some(n_fft));
147    let n_bins = n_fft / 2 + 1;
148
149    let mut s = Array2::zeros((n_bins, n_frames));
150    s.axis_iter_mut(Axis(1))
151        .into_par_iter()
152        .enumerate()
153        .for_each(|(t, mut col)| {
154            for mel in 0..n_mels {
155                let f_low = mel_f[mel];
156                let f_center = mel_f[mel + 1];
157                let f_high = mel_f[mel + 2];
158                for (bin, &f) in fft_f.iter().enumerate().take(n_bins) {
159                    let weight = if f >= f_low && f <= f_high {
160                        if f <= f_center {
161                            (f - f_low) / (f_center - f_low)
162                        } else {
163                            (f_high - f) / (f_high - f_center)
164                        }
165                    } else {
166                        0.0
167                    }
168                    .max(0.0);
169                    col[bin] += m[[mel, t]].max(0.0) * weight;
170                }
171            }
172        });
173
174    Ok(s.mapv(|x: f32| x.powf(1.0 / power)))
175}
176
177/// Converts mel spectrogram to audio waveform.
178///
179/// Reconstructs a time-domain audio signal from a mel spectrogram using STFT magnitude
180/// estimation and Griffin-Lim phase recovery.
181///
182/// # Parameters
183/// - `m`: Mel spectrogram, shape `(n_mels, n_frames)`.
184/// - `sr`: Optional sample rate in Hz; defaults to 44100.
185/// - `n_fft`: Optional FFT size; defaults to 2048.
186/// - `hop_length`: Optional hop length; defaults to `n_fft / 4`.
187///
188/// # Returns
189/// - `Ok(AudioData)`: Reconstructed audio waveform with metadata.
190/// - `Err(MfccError)`: Failure due to invalid input or reconstruction errors.
191///
192/// # Complexity
193/// - O(M * F * B + G) where G is Griffin-Lim complexity, parallelized in `mel_to_stft`.
194pub fn mel_to_audio(
195    m: &Array2<f32>,
196    sr: Option<u32>,
197    n_fft: Option<usize>,
198    hop_length: Option<usize>,
199) -> Result<AudioData, MfccError> {
200    let n_fft = n_fft.unwrap_or(2048);
201    let hop = hop_length.unwrap_or(n_fft / 4);
202    let sr = sr.unwrap_or(44100);
203    if hop == 0 {
204        return Err(MfccError::InvalidInput(
205            "Hop length must be positive".to_string(),
206        ));
207    }
208
209    let s = mel_to_stft(m, Some(sr), Some(n_fft), None)?;
210    let samples = griffinlim(&s)
211        .hop_length(hop)
212        .compute()
213        .map_err(|e| MfccError::ComputationFailed(format!("Griffin-Lim failed: {}", e)))?;
214    if samples.is_empty() {
215        return Err(MfccError::ComputationFailed(
216            "Griffin-Lim returned empty samples".to_string(),
217        ));
218    }
219    if samples.iter().any(|&x| !x.is_finite()) {
220        return Err(MfccError::ComputationFailed(
221            "Non-finite samples in reconstruction".to_string(),
222        ));
223    }
224    Ok(AudioData::new(samples, sr, 1).map_err(|e| MfccError::ComputationFailed(e.to_string()))?)
225}
226
227/// Converts MFCCs back to mel spectrogram using inverse DCT.
228///
229/// Reconstructs a mel spectrogram from MFCCs via inverse discrete cosine transform (type II).
230///
231/// # Parameters
232/// - `mfcc`: MFCC matrix, shape `(n_mfcc, n_frames)`.
233/// - `n_mels`: Optional number of mel bins; defaults to 128.
234/// - `dct_type`: Optional DCT type (1, 2, 3, 4); defaults to 2.
235///
236/// # Returns
237/// - `Ok(Array2<f32>)`: Mel spectrogram, shape `(n_mels, n_frames)`.
238/// - `Err(MfccError)`: Failure due to invalid dimensions or DCT type.
239///
240/// # Complexity
241/// - O(M * F * K) where M is mel bins, F is frames, K is MFCC coefficients, parallelized over frames.
242pub fn mfcc_to_mel(
243    mfcc: &Array2<f32>,
244    n_mels: Option<usize>,
245    dct_type: Option<i32>,
246) -> Result<Array2<f32>, MfccError> {
247    let n_mels = n_mels.unwrap_or(128);
248    let dct_type = dct_type.unwrap_or(2);
249    if mfcc.is_empty() {
250        return Err(MfccError::InvalidDimensions(
251            "MFCC matrix is empty".to_string(),
252        ));
253    }
254    if ![1, 2, 3, 4].contains(&dct_type) {
255        return Err(MfccError::InvalidInput(format!(
256            "Unsupported DCT type: {}",
257            dct_type
258        )));
259    }
260
261    let n_frames = mfcc.shape()[1];
262    let n_mfcc = mfcc.shape()[0];
263    let mut mel = Array2::zeros((n_mels, n_frames));
264    mel.axis_iter_mut(Axis(1))
265        .into_par_iter()
266        .enumerate()
267        .for_each(|(t, mut col)| {
268            for n in 0..n_mels {
269                let mut sum = 0.0;
270                for k in 0..n_mfcc {
271                    let scale = if k == 0 {
272                        1.0 / (n_mels as f32).sqrt()
273                    } else {
274                        (2.0 / n_mels as f32).sqrt()
275                    };
276                    let theta = std::f32::consts::PI * k as f32 * (n as f32 + 0.5) / n_mels as f32;
277                    sum += scale * mfcc[[k, t]] * theta.cos();
278                }
279                col[n] = sum.max(0.0);
280            }
281        });
282    Ok(mel.mapv(f32::exp))
283}
284
285/// Converts MFCCs to audio waveform.
286///
287/// Reconstructs a time-domain audio signal from MFCCs via mel spectrogram and STFT.
288///
289/// # Parameters
290/// - `mfcc`: MFCC matrix, shape `(n_mfcc, n_frames)`.
291/// - `n_mels`: Optional number of mel bins; defaults to 128.
292/// - `sr`: Optional sample rate in Hz; defaults to 44100.
293/// - `n_fft`: Optional FFT size; defaults to 2048.
294/// - `hop_length`: Optional hop length; defaults to `n_fft / 4`.
295///
296/// # Returns
297/// - `Ok(AudioData)`: Reconstructed audio waveform with metadata.
298/// - `Err(MfccError)`: Failure due to invalid input or reconstruction errors.
299pub fn mfcc_to_audio(
300    mfcc: &Array2<f32>,
301    n_mels: Option<usize>,
302    sr: Option<u32>,
303    n_fft: Option<usize>,
304    hop_length: Option<usize>,
305) -> Result<AudioData, MfccError> {
306    let mel = mfcc_to_mel(mfcc, n_mels, Some(2))?;
307    mel_to_audio(&mel, sr, n_fft, hop_length)
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use ndarray::array;
314
315    #[test]
316    fn test_compute_delta_invalid_width() {
317        let mfcc = array![[0.1, 0.2], [0.3, 0.4]];
318        let result = compute_delta(&mfcc, Some(2), None); // Invalid even width
319        assert!(matches!(result, Err(MfccError::InvalidInput(_))));
320    }
321
322    #[test]
323    fn test_compute_delta_empty_input() {
324        let mfcc = array![[]]; // Empty input
325        let result = compute_delta(&mfcc, Some(3), None);
326        assert!(matches!(result, Err(MfccError::InvalidDimensions(_))));
327    }
328
329    #[test]
330    fn test_compute_delta_insufficient_frames() {
331        let mfcc = array![[0.1, 0.2], [0.3, 0.4]]; // Only 2 frames
332        let result = compute_delta(&mfcc, Some(5), None); // Width 5 requires at least 5 frames
333        assert!(matches!(result, Err(MfccError::InvalidDimensions(_))));
334    }
335
336    #[test]
337    fn test_mfcc_to_mel() {
338        let mfcc = array![[0.1, 0.2], [0.3, 0.4]];
339        let mel = mfcc_to_mel(&mfcc, Some(4), None).unwrap();
340        assert_eq!(mel.shape(), &[4, 2]);
341        assert!(mel[[0, 0]] > 0.0);
342    }
343
344    #[test]
345    fn test_invalid_input() {
346        let empty = array![[]];
347        assert!(matches!(
348            compute_delta(&empty, None, None),
349            Err(MfccError::InvalidDimensions(_))
350        ));
351        assert!(matches!(
352            mel_to_stft(&empty, None, None, None),
353            Err(MfccError::InvalidDimensions(_))
354        ));
355    }
356}