1use std::f32::consts::PI;
6
7use std::sync::Arc;
8
9use ndarray::{Array2, ArrayViewMutD};
10use realfft::{RealFftPlanner, RealToComplex};
11
12pub struct MelConfig {
14 pub sample_rate: usize,
15 pub n_fft: usize,
16 pub hop_length: usize,
17 pub win_length: usize,
18 pub n_mels: usize,
19 pub center: bool,
20}
21
22pub struct MelSpectrogram {
24 r2c: Arc<dyn RealToComplex<f32>>,
25 mel_fb: Array2<f32>,
26 window: Vec<f32>,
27 n_fft: usize,
28 hop_length: usize,
29 center: bool,
30}
31
32impl MelSpectrogram {
33 pub fn new(config: &MelConfig) -> Self {
34 let n_fft = config.n_fft;
35
36 let mut planner = RealFftPlanner::<f32>::new();
37 let r2c = planner.plan_fft_forward(n_fft);
38
39 let window = hann_window(config.n_fft, config.win_length);
40
41 let mel_fb = build_mel_filterbank(config.n_mels, n_fft, config.sample_rate as f32);
42
43 Self { r2c, mel_fb, window, n_fft, hop_length: config.hop_length, center: config.center }
44 }
45
46 pub fn n_mels(&self) -> usize {
47 self.mel_fb.nrows()
48 }
49
50 pub fn num_frames(&self, waveform_len: usize) -> usize {
51 let signal_len = if self.center { waveform_len + self.n_fft } else { waveform_len };
52 if signal_len >= self.n_fft { (signal_len - self.n_fft) / self.hop_length + 1 } else { 0 }
53 }
54
55 pub fn forward_into(&self, waveform: &[f32], out: &mut ArrayViewMutD<'_, f32>) {
56 let n_fft = self.n_fft;
57 let signal: &[f32];
58 let signal_owned: Vec<f32>;
59
60 if self.center {
61 let pad = n_fft / 2;
62 signal_owned = reflect_pad(waveform, pad);
63 signal = &signal_owned;
64 } else {
65 signal = waveform;
66 }
67
68 let n_frames = if signal.len() >= n_fft { (signal.len() - n_fft) / self.hop_length + 1 } else { 0 };
69 let n_bins = n_fft / 2 + 1;
70 let n_mels = self.mel_fb.nrows();
71
72 debug_assert!(
73 {
74 let shape = out.shape();
75 shape.len() >= 2
76 && shape[shape.len() - 2] == n_mels
77 && shape[shape.len() - 1] == n_frames
78 && shape[..shape.len() - 2].iter().all(|&d| d == 1)
79 },
80 "forward_into: expected output trailing dims [.., {n_mels}, {n_frames}] with leading 1s, got {:?}",
81 out.shape(),
82 );
83
84 let out_slice = out.as_slice_mut().expect("output must be contiguous");
85
86 out_slice[..n_mels * n_frames].fill(0.0);
87
88 let mut indata = self.r2c.make_input_vec();
89 let mut outdata = self.r2c.make_output_vec();
90 let mut power = vec![0.0f32; n_bins];
91
92 for frame_idx in 0..n_frames {
93 let start = frame_idx * self.hop_length;
94 for i in 0..n_fft {
95 indata[i] = signal[start + i] * self.window[i];
96 }
97 self.r2c.process(&mut indata, &mut outdata).expect("FFT failed");
98
99 for (i, c) in outdata.iter().enumerate() {
100 power[i] = c.re * c.re + c.im * c.im;
101 }
102
103 for mel_idx in 0..n_mels {
104 let mut sum = 0.0f32;
105 for (bin, &p) in power.iter().enumerate() {
106 sum += self.mel_fb[[mel_idx, bin]] * p;
107 }
108 out_slice[mel_idx * n_frames + frame_idx] = sum.clamp(1e-9, 1e9).ln();
109 }
110 }
111 }
112}
113
114pub(crate) fn hann_window(n_fft: usize, win_length: usize) -> Vec<f32> {
118 let mut window = vec![0.0f32; n_fft];
119 for (i, w) in window.iter_mut().enumerate().take(win_length) {
120 *w = 0.5 * (1.0 - (2.0 * PI * i as f32 / win_length as f32).cos());
121 }
122 window
123}
124
125fn build_mel_filterbank(n_mels: usize, n_fft: usize, sample_rate: f32) -> Array2<f32> {
127 let n_bins = n_fft / 2 + 1;
128 let f_max = sample_rate / 2.0;
129
130 let hz_to_mel = |f: f32| 2595.0 * (1.0 + f / 700.0).log10();
131 let mel_to_hz = |m: f32| 700.0 * (10.0f32.powf(m / 2595.0) - 1.0);
132
133 let mel_min = hz_to_mel(0.0);
134 let mel_max = hz_to_mel(f_max);
135
136 let mel_points: Vec<f32> =
137 (0..n_mels + 2).map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32).collect();
138 let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
139 let bin_points: Vec<f32> = hz_points.iter().map(|&f| f * n_fft as f32 / sample_rate).collect();
140
141 let mut fb = Array2::zeros((n_mels, n_bins));
142 for i in 0..n_mels {
143 let left = bin_points[i];
144 let center = bin_points[i + 1];
145 let right = bin_points[i + 2];
146
147 for j in 0..n_bins {
148 let freq = j as f32;
149 if freq >= left && freq <= center && center > left {
150 fb[[i, j]] = (freq - left) / (center - left);
151 } else if freq > center && freq <= right && right > center {
152 fb[[i, j]] = (right - freq) / (right - center);
153 }
154 }
155 }
156 fb
157}
158
159pub(crate) fn reflect_pad(signal: &[f32], pad: usize) -> Vec<f32> {
163 let len = signal.len();
164 assert!(
165 pad < len,
166 "reflect_pad requires pad ({pad}) < signal length ({len}); multi-bounce reflection is not supported",
167 );
168
169 let mut padded = Vec::with_capacity(len + 2 * pad);
170 for i in (1..=pad).rev() {
171 padded.push(signal[i]);
172 }
173 padded.extend_from_slice(signal);
174 for i in 1..=pad {
175 padded.push(signal[len - 1 - i]);
176 }
177 padded
178}