1use rustfft::FftPlanner;
2use num_complex::Complex;
3use ndarray::{Array1, Array2, s};
4use crate::{utils::frequency::fft_frequencies, AudioData, AudioError};
5use std::f32::consts::{PI, SQRT_2};
6
7pub fn stft(
25 y: &[f32],
26 n_fft: Option<usize>,
27 hop_length: Option<usize>,
28 win_length: Option<usize>,
29) -> Result<Array2<Complex<f32>>, AudioError> {
30 let n = n_fft.unwrap_or(2048);
31 let hop = hop_length.unwrap_or(n / 4).max(1);
32 let win = win_length.unwrap_or(n);
33 let mut planner = FftPlanner::new();
34 let fft = planner.plan_fft_forward(n);
35 let mut buffer = vec![Complex::new(0.0, 0.0); n];
36 let mut spectrogram = Vec::new();
37
38 if y.len() < n {
39 let mut padded = vec![0.0; n];
40 padded[..y.len()].copy_from_slice(y);
41 buffer[..n].copy_from_slice(&padded.iter().map(|&x| Complex::new(x * hamming(0, win), 0.0)).collect::<Vec<_>>());
42 fft.process(&mut buffer);
43 spectrogram.push(buffer.clone());
44 } else {
45 for i in (0..y.len()).step_by(hop) {
46 let end = std::cmp::min(i + n, y.len());
47 buffer.fill(Complex::new(0.0, 0.0));
48 for (j, &sample) in y[i..end].iter().enumerate() {
49 buffer[j] = Complex::new(sample * hamming(j, win), 0.0);
50 }
51 fft.process(&mut buffer);
52 spectrogram.push(buffer.clone());
53 }
54 }
55
56 let n_frames = spectrogram.len();
57 Ok(Array2::from_shape_vec((n / 2 + 1, n_frames), spectrogram.into_iter().flat_map(|v| v.into_iter().take(n / 2 + 1)).collect())?)
58}
59
60pub fn istft(
78 stft_matrix: &Array2<Complex<f32>>,
79 hop_length: Option<usize>,
80 win_length: Option<usize>,
81 length: Option<usize>,
82) -> Vec<f32> {
83 let n_fft = (stft_matrix.shape()[0] - 1) * 2;
84 let hop = hop_length.unwrap_or(n_fft / 4).max(1);
85 let win = win_length.unwrap_or(n_fft);
86 let n_frames = stft_matrix.shape()[1];
87 let mut planner = FftPlanner::new();
88 let fft = planner.plan_fft_inverse(n_fft);
89
90 let max_len = hop * (n_frames - 1) + n_fft;
91 let target_len = length.unwrap_or(max_len);
92 let mut signal = vec![0.0; max_len];
93 let mut window_sum = vec![0.0; max_len];
94 let window = hamming_vec(win);
95
96 for (frame_idx, frame) in stft_matrix.axis_iter(ndarray::Axis(1)).enumerate() {
97 let mut buffer: Vec<Complex<f32>> = frame.to_vec();
98 buffer.extend(vec![Complex::new(0.0, 0.0); n_fft - buffer.len()]);
99 fft.process(&mut buffer);
100 let start = frame_idx * hop;
101 for (i, &val) in buffer.iter().enumerate().take(win) {
102 if start + i < signal.len() {
103 signal[start + i] += val.re * window[i];
104 window_sum[start + i] += window[i];
105 }
106 }
107 }
108
109 for (i, &sum) in window_sum.iter().enumerate() {
110 if sum > 1e-6 {
111 signal[i] /= sum;
112 }
113 }
114
115 signal.resize(target_len, 0.0);
116 signal
117}
118
119fn hamming(n: usize, win_length: usize) -> f32 {
134 0.54 - 0.46 * (2.0 * std::f32::consts::PI * n as f32 / (win_length - 1) as f32).cos()
135}
136
137fn hamming_vec(win_length: usize) -> Vec<f32> {
151 (0..win_length).map(|n| hamming(n, win_length)).collect()
152}
153
154pub fn magphase(d: &Array2<Complex<f32>>, power: Option<f32>) -> (Array2<f32>, Array2<Complex<f32>>) {
173 let power_val = power.unwrap_or(1.0);
174 let magnitude = d.mapv(|x| x.norm().powf(power_val));
175 let phase = d.mapv(|x| x / x.norm());
176 (magnitude, phase)
177}
178
179pub fn reassigned_spectrogram(
200 y: &[f32],
201 sr: Option<u32>,
202 n_fft: Option<usize>,
203) -> Result<Array2<f32>, AudioError> {
204 let sr = sr.unwrap_or(44100);
205 let n_fft = n_fft.unwrap_or(2048);
206 let hop_length = n_fft / 4;
207
208 if y.len() < n_fft {
209 return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), n_fft)));
210 }
211
212 let s = stft(y, Some(n_fft), Some(hop_length), None)
213 .map_err(|e| AudioError::ComputationFailed(format!("STFT failed: {}", e)))?;
214 let s_time = stft_with_derivative(y, Some(n_fft), Some(hop_length), true)?;
215 let s_freq = stft_with_derivative(y, Some(n_fft), Some(hop_length), false)?;
216
217 let mut reassigned = Array2::zeros(s.dim());
218 let freqs = fft_frequencies(Some(sr), Some(n_fft));
219 let times = Array1::linspace(0.0, (y.len() as f32 - 1.0) / sr as f32, s.shape()[1]);
220
221 for t in 0..s.shape()[1] {
222 for f in 0..s.shape()[0] {
223 let mag = s[[f, t]].norm();
224 if mag > 1e-6 {
225 let dphi_dt = s_time[[f, t]].im / mag;
226 let t_reassigned = times[t] - dphi_dt * hop_length as f32 / sr as f32;
227 let dphi_df = s_freq[[f, t]].im / mag;
228 let f_reassigned = freqs[f] + dphi_df * sr as f32 / n_fft as f32;
229
230 let t_idx = ((t_reassigned * sr as f32 / hop_length as f32).round() as usize).min(s.shape()[1] - 1);
231 let f_idx = freqs.iter().position(|&x| x >= f_reassigned).unwrap_or(f).min(s.shape()[0] - 1);
232 reassigned[[f_idx, t_idx]] += mag;
233 }
234 }
235 }
236
237 Ok(reassigned)
238}
239
240pub fn cqt(
264 signal: &AudioData,
265 hop_length: Option<usize>,
266 fmin: Option<f32>,
267 n_bins: Option<usize>,
268) -> Result<Array2<Complex<f32>>, AudioError> {
269 let sr = signal.sample_rate;
270 let y = &signal.samples;
271 let hop_length = hop_length.unwrap_or(512);
272 let fmin = fmin.unwrap_or(32.70);
273 let n_bins = n_bins.unwrap_or(84);
274 let bins_per_octave = 12;
275
276 if y.len() < hop_length {
277 return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), hop_length)));
278 }
279 if fmin <= 0.0 {
280 return Err(AudioError::InvalidInput("fmin must be positive".to_string()));
281 }
282
283 let n_fft = ((sr as f32 / fmin * 2.0) as u32).next_power_of_two() as usize;
284 let s_stft = stft(y, Some(n_fft), Some(hop_length), None)
285 .map_err(|e| AudioError::ComputationFailed(format!("STFT failed: {}", e)))?;
286 let n_frames = s_stft.shape()[1];
287 let mut s_cqt = Array2::zeros((n_bins, n_frames));
288
289 let mut planner = FftPlanner::new();
290 let fft = planner.plan_fft_forward(n_fft);
291 for k in 0..n_bins {
292 let fk = fmin * 2.0f32.powf(k as f32 / bins_per_octave as f32);
293 let n = (sr as f32 / fk).round() as usize;
294 let mut kernel = Array1::zeros(n_fft);
295 let window = hann_window(n);
296 for i in 0..n {
297 let phase = 2.0 * PI * fk * i as f32 / sr as f32;
298 kernel[i] = Complex::new(window[i] * phase.cos(), window[i] * phase.sin()) / n as f32;
299 }
300 fft.process(&mut kernel.to_vec());
301
302 for t in 0..n_frames {
303 let stft_frame = s_stft.slice(s![.., t]);
304 s_cqt[[k, t]] = stft_frame.iter().zip(kernel.iter()).map(|(&s, &k)| s * k.conj()).sum();
305 }
306 }
307
308 Ok(s_cqt)
309}
310
311pub fn icqt(
333 c: &Array2<Complex<f32>>,
334 sr: Option<u32>,
335 hop_length: Option<usize>,
336 fmin: Option<f32>,
337) -> Result<Vec<f32>, AudioError> {
338 let sr = sr.unwrap_or(44100);
339 let hop_length = hop_length.unwrap_or(512);
340 let fmin = fmin.unwrap_or(32.70);
341 let n_bins = c.shape()[0];
342 let n_frames = c.shape()[1];
343 let bins_per_octave = 12;
344
345 if fmin <= 0.0 {
346 return Err(AudioError::InvalidInput("fmin must be positive".to_string()));
347 }
348
349 let n_fft = ((sr as f32 / fmin * 2.0) as u32).next_power_of_two() as usize;
350 let n_samples = n_frames * hop_length;
351 let mut y = vec![0.0; n_samples];
352 let mut planner = FftPlanner::new();
353 let ifft = planner.plan_fft_inverse(n_fft);
354
355 for k in 0..n_bins {
356 let fk = fmin * 2.0f32.powf(k as f32 / bins_per_octave as f32);
357 let n = (sr as f32 / fk).round() as usize;
358 let window = hann_window(n);
359 let mut kernel = Array1::zeros(n_fft);
360 for i in 0..n {
361 let phase = 2.0 * PI * fk * i as f32 / sr as f32;
362 kernel[i] = Complex::new(window[i] * phase.cos(), window[i] * phase.sin()) / n as f32;
363 }
364 ifft.process(&mut kernel.to_vec());
365
366 for t in 0..n_frames {
367 let mut frame = vec![Complex::new(c[[k, t]].re, c[[k, t]].im) * Complex::conj(&kernel[0]); n_fft];
368 ifft.process(&mut frame);
369 let start = t * hop_length;
370 for i in 0..n.min(n_samples - start) {
371 y[start + i] += frame[i].re * window[i];
372 }
373 }
374 }
375
376 let mut overlap = vec![0.0; n_samples];
377 for t in 0..n_frames {
378 let start = t * hop_length;
379 for i in 0..n_fft.min(n_samples - start) {
380 overlap[start + i] += hann_window(n_fft)[i].powi(2);
381 }
382 }
383 for i in 0..n_samples {
384 if overlap[i] > 1e-6 {
385 y[i] /= overlap[i];
386 }
387 }
388
389 Ok(y)
390}
391
392pub fn hybrid_cqt(
415 y: &[f32],
416 sr: Option<u32>,
417 hop_length: Option<usize>,
418 fmin: Option<f32>,
419) -> Result<Array2<Complex<f32>>, AudioError> {
420 let sr = sr.unwrap_or(44100);
421 let hop_length = hop_length.unwrap_or(512);
422 let fmin = fmin.unwrap_or(32.70);
423 let n_fft = ((sr as f32 / fmin * 2.0) as u32).next_power_of_two() as usize;
424 let n_bins = 84;
425
426 if y.len() < n_fft {
427 return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), n_fft)));
428 }
429 if fmin <= 0.0 {
430 return Err(AudioError::InvalidInput("fmin must be positive".to_string()));
431 }
432
433 let s_stft = stft(y, Some(n_fft), Some(hop_length), None)
434 .map_err(|e| AudioError::ComputationFailed(format!("STFT failed: {}", e)))?;
435 let mut s_hybrid = Array2::zeros((n_bins, s_stft.shape()[1]));
436 let mut planner = FftPlanner::new();
437 let fft = planner.plan_fft_forward(n_fft);
438
439 for k in 0..n_bins {
440 let fk = fmin * 2.0f32.powf(k as f32 / 12.0);
441 let n = (sr as f32 / fk).round() as usize;
442 let mut kernel = Array1::zeros(n_fft);
443 let window = hann_window(n);
444 for i in 0..n {
445 let phase = 2.0 * PI * fk * i as f32 / sr as f32;
446 kernel[i] = Complex::new(window[i] * phase.cos(), window[i] * phase.sin()) / n as f32;
447 }
448 fft.process(&mut kernel.to_vec());
449
450 for t in 0..s_stft.shape()[1] {
451 s_hybrid[[k, t]] = s_stft.slice(s![.., t]).iter().zip(kernel.iter()).map(|(&s, &k)| s * k.conj()).sum();
452 }
453 }
454
455 Ok(s_hybrid)
456}
457
458pub fn pseudo_cqt(
481 y: &[f32],
482 sr: Option<u32>,
483 hop_length: Option<usize>,
484 fmin: Option<f32>,
485) -> Result<Array2<Complex<f32>>, AudioError> {
486 let sr = sr.unwrap_or(44100);
487 let hop_length = hop_length.unwrap_or(512);
488 let fmin = fmin.unwrap_or(32.70);
489 let n_fft = ((sr as f32 / fmin * 2.0) as u32).next_power_of_two() as usize;
490 let n_bins = 84;
491
492 if y.len() < n_fft {
493 return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), n_fft)));
494 }
495 if fmin <= 0.0 {
496 return Err(AudioError::InvalidInput("fmin must be positive".to_string()));
497 }
498
499 let s_stft = stft(y, Some(n_fft), Some(hop_length), None)
500 .map_err(|e| AudioError::ComputationFailed(format!("STFT failed: {}", e)))?;
501 let mut s_pseudo = Array2::zeros((n_bins, s_stft.shape()[1]));
502 let freqs = fft_frequencies(Some(sr), Some(n_fft));
503
504 for t in 0..s_stft.shape()[1] {
505 for k in 0..n_bins {
506 let fk = fmin * 2.0f32.powf(k as f32 / 12.0);
507 let idx = freqs.iter().position(|&f| f >= fk).unwrap_or(0);
508 s_pseudo[[k, t]] = s_stft[[idx.min(s_stft.shape()[0] - 1), t]];
509 }
510 }
511
512 Ok(s_pseudo)
513}
514
515pub fn vqt(
539 y: &[f32],
540 sr: Option<u32>,
541 hop_length: Option<usize>,
542 fmin: Option<f32>,
543 n_bins: Option<usize>,
544) -> Result<Array2<Complex<f32>>, AudioError> {
545 let sr = sr.unwrap_or(44100);
546 let hop_length = hop_length.unwrap_or(512);
547 let fmin = fmin.unwrap_or(32.70);
548 let n_bins = n_bins.unwrap_or(84);
549 let gamma = 24.0;
550
551 if y.len() < hop_length {
552 return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), hop_length)));
553 }
554 if fmin <= 0.0 {
555 return Err(AudioError::InvalidInput("fmin must be positive".to_string()));
556 }
557
558 let n_fft = ((sr as f32 / fmin * 2.0) as u32).next_power_of_two() as usize;
559 let s_stft = stft(y, Some(n_fft), Some(hop_length), None)
560 .map_err(|e| AudioError::ComputationFailed(format!("STFT failed: {}", e)))?;
561 let mut s_vqt = Array2::zeros((n_bins, s_stft.shape()[1]));
562 let mut planner = FftPlanner::new();
563 let fft = planner.plan_fft_forward(n_fft);
564
565 for k in 0..n_bins {
566 let fk = fmin * 2.0f32.powf(k as f32 / 12.0);
567 let q = gamma / (2.0f32.powf(1.0 / 12.0) - 1.0);
568 let n = (sr as f32 * q / fk).round() as usize;
569 let mut kernel = Array1::zeros(n_fft);
570 let window = hann_window(n);
571 for i in 0..n {
572 let phase = 2.0 * PI * fk * i as f32 / sr as f32;
573 kernel[i] = Complex::new(window[i] * phase.cos(), window[i] * phase.sin()) / n as f32;
574 }
575 fft.process(&mut kernel.to_vec());
576
577 for t in 0..s_stft.shape()[1] {
578 s_vqt[[k, t]] = s_stft.slice(s![.., t]).iter().zip(kernel.iter()).map(|(&s, &k)| s * k.conj()).sum();
579 }
580 }
581
582 Ok(s_vqt)
583}
584
585pub fn fmt(
608 y: &[f32],
609 t_min: Option<f32>,
610 n_fmt: Option<usize>,
611 kind: Option<&str>,
612 beta: Option<f32>,
613) -> Result<Array2<f32>, AudioError> {
614 let sr = 44100;
615 let t_min = t_min.unwrap_or(0.005);
616 let n_fmt = n_fmt.unwrap_or(5);
617 let _kind = kind.unwrap_or("cos");
618 let beta = beta.unwrap_or(2.0);
619 let hop_length = (sr as f32 * t_min).round() as usize;
620
621 if y.len() < hop_length {
622 return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), hop_length)));
623 }
624 if t_min <= 0.0 {
625 return Err(AudioError::InvalidInput("t_min must be positive".to_string()));
626 }
627
628 let n_frames = (y.len() - hop_length) / hop_length + 1;
629 let mut s = Array2::zeros((n_fmt, n_frames));
630 let window = hann_window(hop_length);
631
632 for t in 0..n_frames {
633 let start = t * hop_length;
634 let frame = &y[start..(start + hop_length).min(y.len())];
635 for k in 0..n_fmt {
636 let freq = (k + 1) as f32 / t_min;
637 let mut sum_re = 0.0;
638 let mut sum_im = 0.0;
639 for (i, &sample) in frame.iter().enumerate() {
640 let phase = 2.0 * PI * freq * i as f32 / sr as f32;
641 let w = window[i];
642 sum_re += sample * w * phase.cos();
643 sum_im += sample * w * phase.sin();
644 }
645 let mag = Complex::new(sum_re, sum_im).norm() / hop_length as f32;
646 s[[k, t]] = mag.powf(beta);
647 }
648 }
649
650 Ok(s)
651}
652
653fn hann_window(n: usize) -> Vec<f32> {
667 (0..n).map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / (n - 1) as f32).cos())).collect()
668}
669
670fn stft_with_derivative(
688 y: &[f32],
689 n_fft: Option<usize>,
690 hop_length: Option<usize>,
691 time_derivative: bool,
692) -> Result<Array2<Complex<f32>>, AudioError> {
693 let n_fft = n_fft.unwrap_or(2048);
694 let hop_length = hop_length.unwrap_or(n_fft / 4);
695 let n_frames = (y.len() - n_fft) / hop_length + 1;
696 let mut planner = FftPlanner::new();
697 let fft = planner.plan_fft_forward(n_fft);
698 let mut s = Array2::zeros((n_fft / 2 + 1, n_frames));
699 let window = hann_window(n_fft);
700 let deriv_window = if time_derivative {
701 (0..n_fft).map(|i| i as f32 * window[i]).collect::<Vec<_>>()
702 } else {
703 (0..n_fft).map(|i| window[i] * (2.0 * PI * i as f32 / n_fft as f32).sin()).collect::<Vec<_>>()
704 };
705
706 for t in 0..n_frames {
707 let start = t * hop_length;
708 let frame = &y[start..(start + n_fft).min(y.len())];
709 let mut buffer = frame.iter().zip(deriv_window.iter()).map(|(&x, &w)| Complex::new(x * w, 0.0)).collect::<Vec<_>>();
710 buffer.resize(n_fft, Complex::new(0.0, 0.0));
711 fft.process(&mut buffer);
712 for f in 0..n_fft / 2 + 1 {
713 s[[f, t]] = buffer[f];
714 }
715 }
716 Ok(s)
717}
718
719fn butterworth_bandpass(lowcut: f32, highcut: f32, fs: f32, order: Option<usize>) -> Result<(Vec<f32>, Vec<f32>), AudioError> {
739 if lowcut <= 0.0 || highcut <= lowcut || highcut >= fs / 2.0 {
740 return Err(AudioError::InvalidInput(format!(
741 "Invalid frequencies: lowcut={} must be > 0, highcut={} must be > lowcut and < fs/2={}",
742 lowcut, highcut, fs / 2.0
743 )));
744 }
745
746 let order = order.unwrap_or(2);
747 let n = order as i32;
748
749 let w_low = 2.0 * fs * (lowcut * PI / fs).tan();
750 let w_high = 2.0 * fs * (highcut * PI / fs).tan();
751 let w0 = (w_high * w_low).sqrt();
752 let bw = w_high - w_low;
753
754 let mut poles = Vec::new();
755 for k in 0..n {
756 let theta = PI * (2.0 * k as f32 + 1.0 + n as f32) / (2.0 * n as f32);
757 let real = -bw / 2.0 * theta.sin();
758 let imag = w0 * theta.cos();
759 poles.push(Complex::new(real, imag));
760 poles.push(Complex::new(real, -imag));
761 }
762
763 let mut z_poles = Vec::new();
764 let fs2 = 2.0 * fs;
765 for p in poles {
766 let pz = (fs2 + p) / (fs2 - p);
767 z_poles.push(pz);
768 }
769
770 let mut b = vec![1.0];
771 let mut a = vec![1.0];
772 for p in z_poles.iter() {
773 b = convolve(&b, &[1.0, -p.re]);
774 a = convolve(&a, &[1.0, -p.re]);
775 }
776 for _ in 0..n {
777 b = convolve(&b, &[1.0, 0.0]);
778 }
779
780 let w_center = 2.0 * PI * (lowcut + highcut) / 2.0 / fs;
781 let gain = evaluate_filter(&b, &a, w_center).norm();
782 for b_k in b.iter_mut() {
783 *b_k /= gain;
784 }
785
786 Ok((b, a))
787}
788
789fn convolve(a: &[f32], b: &[f32]) -> Vec<f32> {
804 let mut result = vec![0.0; a.len() + b.len() - 1];
805 for i in 0..a.len() {
806 for j in 0..b.len() {
807 result[i + j] += a[i] * b[j];
808 }
809 }
810 result
811}
812
813fn evaluate_filter(b: &[f32], a: &[f32], w: f32) -> Complex<f32> {
828 let mut num = Complex::new(0.0, 0.0);
829 let mut den = Complex::new(0.0, 0.0);
830 for (k, &bk) in b.iter().enumerate() {
831 let phase = -w * k as f32;
832 num += Complex::new(bk * phase.cos(), bk * phase.sin());
833 }
834 for (k, &ak) in a.iter().enumerate() {
835 let phase = -w * k as f32;
836 den += Complex::new(ak * phase.cos(), ak * phase.sin());
837 }
838 num / den
839}
840
841pub fn iirt(
863 y: &[f32],
864 sr: Option<u32>,
865 win_length: Option<usize>,
866 hop_length: Option<usize>,
867) -> Result<Array2<f32>, AudioError> {
868 let sr = sr.unwrap_or(44100);
869 let win_length = win_length.unwrap_or(2048);
870 let hop_length = hop_length.unwrap_or(win_length / 4);
871 let n_bands = 12;
872
873 if y.len() < win_length {
874 return Err(AudioError::InsufficientData(format!("Signal too short: {} < {}", y.len(), win_length)));
875 }
876
877 let n_frames = (y.len() - win_length) / hop_length + 1;
878 let mut s = Array2::zeros((n_bands, n_frames));
879 let fmin = 32.70;
880
881 for b in 0..n_bands {
882 let fc = fmin * 2.0f32.powf(b as f32);
883 let bw = fc / SQRT_2;
884 let (b_coeffs, a_coeffs) = butterworth_bandpass(fc - bw / 2.0, fc + bw / 2.0, sr as f32, Some(4))?;
885
886 for t in 0..n_frames {
887 let start = t * hop_length;
888 let frame = &y[start..(start + win_length).min(y.len())];
889 let filtered = filter(frame, &b_coeffs, &a_coeffs);
890 s[[b, t]] = filtered.iter().map(|&x| x.powi(2)).sum::<f32>().sqrt() / win_length as f32;
891 }
892 }
893
894 Ok(s)
895}
896
897fn filter(x: &[f32], b: &[f32], a: &[f32]) -> Vec<f32> {
913 let mut y = vec![0.0; x.len()];
914 for n in 0..x.len() {
915 y[n] = b[0] * x[n] + b[1] * x.get(n - 1).unwrap_or(&0.0) + b[2] * x.get(n - 2).unwrap_or(&0.0)
916 - a[1] * y.get(n - 1).unwrap_or(&0.0) - a[2] * y.get(n - 2).unwrap_or(&0.0);
917 }
918 y
919}