1use crate::core::io::AudioData;
2use crate::features::phase_recovery::griffinlim;
3use crate::utils::frequency::{fft_frequencies, mel_frequencies};
4use ndarray::{Array2, Axis};
5use rayon::prelude::*;
6use thiserror::Error;
7
8#[derive(Error, Debug)]
13pub enum MfccError {
14 #[error("Invalid dimensions: {0}")]
16 InvalidDimensions(String),
17
18 #[error("Invalid parameter: {0}")]
20 InvalidInput(String),
21
22 #[error("Computation failed: {0}")]
24 ComputationFailed(String),
25}
26
27pub fn compute_delta(
42 mfcc: &Array2<f32>,
43 width: Option<usize>,
44 axis: Option<isize>,
45) -> Result<Array2<f32>, MfccError> {
46 let width = width.unwrap_or(9);
47 let axis = axis.unwrap_or(-1);
48
49 if width == 0 || width % 2 == 0 {
50 return Err(MfccError::InvalidInput(
51 "Width must be a positive odd integer".to_string(),
52 ));
53 }
54 let ax = if axis < 0 { 1 } else { 0 };
55 let (n_mfcc, n_frames) = if ax == 1 {
56 mfcc.dim()
57 } else {
58 (mfcc.shape()[1], mfcc.shape()[0])
59 };
60 if n_frames == 0 || n_mfcc == 0 {
61 return Err(MfccError::InvalidDimensions(
62 "MFCC matrix is empty".to_string(),
63 ));
64 }
65 if n_frames < width {
66 return Err(MfccError::InvalidDimensions(format!(
67 "Time axis length {} less than width {}",
68 n_frames, width
69 )));
70 }
71
72 let half_width = width / 2;
73 let weights: Vec<f32> = (-(half_width as isize)..=half_width as isize)
74 .map(|i| i as f32)
75 .collect();
76 let norm = weights.iter().map(|x| x.powi(2)).sum::<f32>();
77 if norm == 0.0 {
78 return Err(MfccError::ComputationFailed(
79 "Normalization factor is zero".to_string(),
80 ));
81 }
82
83 let mut delta = Array2::zeros(mfcc.dim());
84 delta
85 .axis_iter_mut(Axis(ax))
86 .into_par_iter()
87 .enumerate()
88 .for_each(|(i, mut slice)| {
89 let row = mfcc.index_axis(Axis(ax), i);
90 for j in 0..row.len() {
91 let mut sum = 0.0;
92 for (w_idx, &w) in weights.iter().enumerate() {
93 let offset = w_idx as isize - half_width as isize;
94 let idx = (j as isize + offset).clamp(0, row.len() as isize - 1) as usize;
95 sum += w * row[idx];
96 }
97 slice[j] = sum / norm;
98 }
99 });
100
101 Ok(delta)
102}
103
104pub fn mel_to_stft(
119 m: &Array2<f32>,
120 sr: Option<u32>,
121 n_fft: Option<usize>,
122 power: Option<f32>,
123) -> Result<Array2<f32>, MfccError> {
124 let sr = sr.unwrap_or(44100);
125 let n_fft = n_fft.unwrap_or(2048);
126 let power = power.unwrap_or(2.0);
127 if m.is_empty() {
128 return Err(MfccError::InvalidDimensions(
129 "Mel spectrogram is empty".to_string(),
130 ));
131 }
132 if n_fft < 2 {
133 return Err(MfccError::InvalidInput(
134 "n_fft must be at least 2".to_string(),
135 ));
136 }
137 if power <= 0.0 {
138 return Err(MfccError::InvalidInput(
139 "Power must be positive".to_string(),
140 ));
141 }
142
143 let n_mels = m.shape()[0];
144 let n_frames = m.shape()[1];
145 let mel_f = mel_frequencies(Some(n_mels + 2), None, Some(sr as f32 / 2.0), None);
146 let fft_f = fft_frequencies(Some(sr), Some(n_fft));
147 let n_bins = n_fft / 2 + 1;
148
149 let mut s = Array2::zeros((n_bins, n_frames));
150 s.axis_iter_mut(Axis(1))
151 .into_par_iter()
152 .enumerate()
153 .for_each(|(t, mut col)| {
154 for mel in 0..n_mels {
155 let f_low = mel_f[mel];
156 let f_center = mel_f[mel + 1];
157 let f_high = mel_f[mel + 2];
158 for (bin, &f) in fft_f.iter().enumerate().take(n_bins) {
159 let weight = if f >= f_low && f <= f_high {
160 if f <= f_center {
161 (f - f_low) / (f_center - f_low)
162 } else {
163 (f_high - f) / (f_high - f_center)
164 }
165 } else {
166 0.0
167 }
168 .max(0.0);
169 col[bin] += m[[mel, t]].max(0.0) * weight;
170 }
171 }
172 });
173
174 Ok(s.mapv(|x: f32| x.powf(1.0 / power)))
175}
176
177pub fn mel_to_audio(
195 m: &Array2<f32>,
196 sr: Option<u32>,
197 n_fft: Option<usize>,
198 hop_length: Option<usize>,
199) -> Result<AudioData, MfccError> {
200 let n_fft = n_fft.unwrap_or(2048);
201 let hop = hop_length.unwrap_or(n_fft / 4);
202 let sr = sr.unwrap_or(44100);
203 if hop == 0 {
204 return Err(MfccError::InvalidInput(
205 "Hop length must be positive".to_string(),
206 ));
207 }
208
209 let s = mel_to_stft(m, Some(sr), Some(n_fft), None)?;
210 let samples = griffinlim(&s)
211 .hop_length(hop)
212 .compute()
213 .map_err(|e| MfccError::ComputationFailed(format!("Griffin-Lim failed: {}", e)))?;
214 if samples.is_empty() {
215 return Err(MfccError::ComputationFailed(
216 "Griffin-Lim returned empty samples".to_string(),
217 ));
218 }
219 if samples.iter().any(|&x| !x.is_finite()) {
220 return Err(MfccError::ComputationFailed(
221 "Non-finite samples in reconstruction".to_string(),
222 ));
223 }
224 Ok(AudioData::new(samples, sr, 1).map_err(|e| MfccError::ComputationFailed(e.to_string()))?)
225}
226
227pub fn mfcc_to_mel(
243 mfcc: &Array2<f32>,
244 n_mels: Option<usize>,
245 dct_type: Option<i32>,
246) -> Result<Array2<f32>, MfccError> {
247 let n_mels = n_mels.unwrap_or(128);
248 let dct_type = dct_type.unwrap_or(2);
249 if mfcc.is_empty() {
250 return Err(MfccError::InvalidDimensions(
251 "MFCC matrix is empty".to_string(),
252 ));
253 }
254 if ![1, 2, 3, 4].contains(&dct_type) {
255 return Err(MfccError::InvalidInput(format!(
256 "Unsupported DCT type: {}",
257 dct_type
258 )));
259 }
260
261 let n_frames = mfcc.shape()[1];
262 let n_mfcc = mfcc.shape()[0];
263 let mut mel = Array2::zeros((n_mels, n_frames));
264 mel.axis_iter_mut(Axis(1))
265 .into_par_iter()
266 .enumerate()
267 .for_each(|(t, mut col)| {
268 for n in 0..n_mels {
269 let mut sum = 0.0;
270 for k in 0..n_mfcc {
271 let scale = if k == 0 {
272 1.0 / (n_mels as f32).sqrt()
273 } else {
274 (2.0 / n_mels as f32).sqrt()
275 };
276 let theta = std::f32::consts::PI * k as f32 * (n as f32 + 0.5) / n_mels as f32;
277 sum += scale * mfcc[[k, t]] * theta.cos();
278 }
279 col[n] = sum.max(0.0);
280 }
281 });
282 Ok(mel.mapv(f32::exp))
283}
284
285pub fn mfcc_to_audio(
300 mfcc: &Array2<f32>,
301 n_mels: Option<usize>,
302 sr: Option<u32>,
303 n_fft: Option<usize>,
304 hop_length: Option<usize>,
305) -> Result<AudioData, MfccError> {
306 let mel = mfcc_to_mel(mfcc, n_mels, Some(2))?;
307 mel_to_audio(&mel, sr, n_fft, hop_length)
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use ndarray::array;
314
315 #[test]
316 fn test_compute_delta_invalid_width() {
317 let mfcc = array![[0.1, 0.2], [0.3, 0.4]];
318 let result = compute_delta(&mfcc, Some(2), None); assert!(matches!(result, Err(MfccError::InvalidInput(_))));
320 }
321
322 #[test]
323 fn test_compute_delta_empty_input() {
324 let mfcc = array![[]]; let result = compute_delta(&mfcc, Some(3), None);
326 assert!(matches!(result, Err(MfccError::InvalidDimensions(_))));
327 }
328
329 #[test]
330 fn test_compute_delta_insufficient_frames() {
331 let mfcc = array![[0.1, 0.2], [0.3, 0.4]]; let result = compute_delta(&mfcc, Some(5), None); assert!(matches!(result, Err(MfccError::InvalidDimensions(_))));
334 }
335
336 #[test]
337 fn test_mfcc_to_mel() {
338 let mfcc = array![[0.1, 0.2], [0.3, 0.4]];
339 let mel = mfcc_to_mel(&mfcc, Some(4), None).unwrap();
340 assert_eq!(mel.shape(), &[4, 2]);
341 assert!(mel[[0, 0]] > 0.0);
342 }
343
344 #[test]
345 fn test_invalid_input() {
346 let empty = array![[]];
347 assert!(matches!(
348 compute_delta(&empty, None, None),
349 Err(MfccError::InvalidDimensions(_))
350 ));
351 assert!(matches!(
352 mel_to_stft(&empty, None, None, None),
353 Err(MfccError::InvalidDimensions(_))
354 ));
355 }
356}