1use ndarray::{Array2, Axis};
7use rayon::prelude::*;
8use realfft::RealFftPlanner;
9use std::f32::consts::PI;
10use tracing::{debug, instrument};
11
12use crate::AudioError;
13
14#[derive(Debug, Clone)]
16pub struct SpectrogramConfig {
17 pub n_mels: usize,
19 pub n_fft: usize,
21 pub hop_length: usize,
23 pub sample_rate: u32,
25 pub f_min: f32,
27 pub f_max: f32,
29 pub log_scale: bool,
31 pub ref_db: f32,
33 pub min_value: f32,
35}
36
37impl Default for SpectrogramConfig {
38 fn default() -> Self {
39 Self {
40 n_mels: 128,
41 n_fft: 2048,
42 hop_length: 512,
43 sample_rate: 32_000,
44 f_min: 0.0,
45 f_max: 16_000.0, log_scale: true,
47 ref_db: 1.0,
48 min_value: 1e-10,
49 }
50 }
51}
52
53impl SpectrogramConfig {
54 #[must_use]
60 pub fn for_5s_segment() -> Self {
61 Self {
62 n_mels: 128,
63 n_fft: 2048,
64 hop_length: 320, sample_rate: 32_000,
66 f_min: 500.0, f_max: 15_000.0, log_scale: true,
69 ref_db: 1.0,
70 min_value: 1e-10,
71 }
72 }
73
74 #[must_use]
76 pub fn with_target_frames(target_frames: usize, duration_ms: u64, sample_rate: u32) -> Self {
77 let total_samples = (duration_ms as usize * sample_rate as usize) / 1000;
78 let hop_length = total_samples / target_frames;
79
80 Self {
81 hop_length: hop_length.max(1),
82 sample_rate,
83 ..Self::default()
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct MelSpectrogram {
91 pub data: Array2<f32>,
93 pub config: SpectrogramConfig,
95 pub duration_ms: u64,
97}
98
99impl MelSpectrogram {
100 #[instrument(skip(samples), fields(samples_len = samples.len()))]
109 pub fn compute(samples: &[f32], config: SpectrogramConfig) -> Result<Self, AudioError> {
110 if samples.is_empty() {
111 return Err(AudioError::invalid_data("Cannot compute spectrogram of empty audio"));
112 }
113
114 let duration_ms = (samples.len() as u64 * 1000) / u64::from(config.sample_rate);
115
116 let stft = Self::stft(samples, config.n_fft, config.hop_length)?;
118
119 let mel_filterbank = Self::create_mel_filterbank(
121 config.n_mels,
122 config.n_fft,
123 config.sample_rate,
124 config.f_min,
125 config.f_max,
126 );
127
128 let n_frames = stft.ncols();
130 let mut mel_spec = Array2::zeros((config.n_mels, n_frames));
131
132 for (frame_idx, frame) in stft.axis_iter(Axis(1)).enumerate() {
133 for (mel_idx, filter) in mel_filterbank.axis_iter(Axis(0)).enumerate() {
134 let energy: f32 = frame
135 .iter()
136 .zip(filter.iter())
137 .map(|(s, f)| s * f)
138 .sum();
139 mel_spec[[mel_idx, frame_idx]] = energy.max(config.min_value);
140 }
141 }
142
143 if config.log_scale {
145 mel_spec.mapv_inplace(|x| 10.0 * (x / config.ref_db).log10());
146 }
147
148 debug!(
149 n_mels = config.n_mels,
150 n_frames = n_frames,
151 duration_ms = duration_ms,
152 "Spectrogram computed"
153 );
154
155 Ok(Self {
156 data: mel_spec,
157 config,
158 duration_ms,
159 })
160 }
161
162 #[must_use]
164 pub fn shape(&self) -> (usize, usize) {
165 (self.data.nrows(), self.data.ncols())
166 }
167
168 #[must_use]
170 pub fn n_mels(&self) -> usize {
171 self.data.nrows()
172 }
173
174 #[must_use]
176 pub fn n_frames(&self) -> usize {
177 self.data.ncols()
178 }
179
180 #[must_use]
182 pub fn slice_frames(&self, start: usize, end: usize) -> Array2<f32> {
183 let end = end.min(self.n_frames());
184 let start = start.min(end);
185 self.data.slice(ndarray::s![.., start..end]).to_owned()
186 }
187
188 pub fn normalize(&mut self) {
190 for mut row in self.data.axis_iter_mut(Axis(0)) {
191 let mean = row.mean().unwrap_or(0.0);
192 let std = row.std(0.0);
193 if std > 1e-6 {
194 row.mapv_inplace(|x| (x - mean) / std);
195 } else {
196 row.mapv_inplace(|x| x - mean);
197 }
198 }
199 }
200
201 #[must_use]
203 pub fn to_vec(&self) -> Vec<f32> {
204 self.data.iter().copied().collect()
205 }
206
207 fn stft(
209 samples: &[f32],
210 n_fft: usize,
211 hop_length: usize,
212 ) -> Result<Array2<f32>, AudioError> {
213 let n_frames = (samples.len().saturating_sub(n_fft)) / hop_length + 1;
214 if n_frames == 0 {
215 return Err(AudioError::invalid_data(
216 "Audio too short for FFT window size",
217 ));
218 }
219
220 let n_bins = n_fft / 2 + 1;
221 let mut planner = RealFftPlanner::<f32>::new();
222 let fft = planner.plan_fft_forward(n_fft);
223
224 let window: Vec<f32> = (0..n_fft)
226 .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n_fft as f32).cos()))
227 .collect();
228
229 let frames: Vec<Vec<f32>> = (0..n_frames)
231 .into_par_iter()
232 .map(|frame_idx| {
233 let start = frame_idx * hop_length;
234 let mut input = vec![0.0f32; n_fft];
235
236 for (i, &w) in window.iter().enumerate() {
238 if start + i < samples.len() {
239 input[i] = samples[start + i] * w;
240 }
241 }
242
243 let mut spectrum = fft.make_output_vec();
245 let mut scratch = fft.make_scratch_vec();
246
247 let fft = RealFftPlanner::<f32>::new().plan_fft_forward(n_fft);
249 fft.process_with_scratch(&mut input, &mut spectrum, &mut scratch)
250 .ok();
251
252 spectrum
254 .iter()
255 .take(n_bins)
256 .map(|c| (c.re * c.re + c.im * c.im).sqrt())
257 .collect()
258 })
259 .collect();
260
261 let mut stft = Array2::zeros((n_bins, n_frames));
263 for (frame_idx, frame) in frames.into_iter().enumerate() {
264 for (bin_idx, &value) in frame.iter().enumerate() {
265 stft[[bin_idx, frame_idx]] = value;
266 }
267 }
268
269 Ok(stft)
270 }
271
272 fn create_mel_filterbank(
274 n_mels: usize,
275 n_fft: usize,
276 sample_rate: u32,
277 f_min: f32,
278 f_max: f32,
279 ) -> Array2<f32> {
280 let n_bins = n_fft / 2 + 1;
281
282 let mel_min = Self::hz_to_mel(f_min);
284 let mel_max = Self::hz_to_mel(f_max);
285
286 let mel_points: Vec<f32> = (0..=n_mels + 1)
288 .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
289 .collect();
290
291 let hz_points: Vec<f32> = mel_points.iter().map(|&m| Self::mel_to_hz(m)).collect();
293
294 let bin_points: Vec<usize> = hz_points
296 .iter()
297 .map(|&f| {
298 let bin = (f * n_fft as f32 / sample_rate as f32).round() as usize;
299 bin.min(n_bins - 1)
300 })
301 .collect();
302
303 let mut filterbank = Array2::zeros((n_mels, n_bins));
305
306 for m in 0..n_mels {
307 let left = bin_points[m];
308 let center = bin_points[m + 1];
309 let right = bin_points[m + 2];
310
311 for k in left..center {
313 if center != left {
314 filterbank[[m, k]] = (k - left) as f32 / (center - left) as f32;
315 }
316 }
317
318 for k in center..=right {
320 if right != center {
321 filterbank[[m, k]] = (right - k) as f32 / (right - center) as f32;
322 }
323 }
324 }
325
326 filterbank
327 }
328
329 fn hz_to_mel(hz: f32) -> f32 {
331 2595.0 * (1.0 + hz / 700.0).log10()
332 }
333
334 fn mel_to_hz(mel: f32) -> f32 {
336 700.0 * (10.0f32.powf(mel / 2595.0) - 1.0)
337 }
338}
339
340pub struct SpectrogramBatch;
342
343impl SpectrogramBatch {
344 pub fn compute_batch(
346 segments: &[Vec<f32>],
347 config: &SpectrogramConfig,
348 ) -> Result<Vec<MelSpectrogram>, AudioError> {
349 segments
350 .par_iter()
351 .map(|samples| MelSpectrogram::compute(samples, config.clone()))
352 .collect()
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 fn generate_sine_wave(freq: f32, duration_s: f32, sample_rate: u32) -> Vec<f32> {
361 let num_samples = (duration_s * sample_rate as f32) as usize;
362 (0..num_samples)
363 .map(|i| {
364 let t = i as f32 / sample_rate as f32;
365 (2.0 * PI * freq * t).sin()
366 })
367 .collect()
368 }
369
370 #[test]
371 fn test_spectrogram_config_default() {
372 let config = SpectrogramConfig::default();
373 assert_eq!(config.n_mels, 128);
374 assert_eq!(config.n_fft, 2048);
375 }
376
377 #[test]
378 fn test_spectrogram_5s_config() {
379 let config = SpectrogramConfig::for_5s_segment();
380 assert_eq!(config.hop_length, 320);
381 }
382
383 #[test]
384 fn test_mel_conversion() {
385 let hz = 1000.0;
386 let mel = MelSpectrogram::hz_to_mel(hz);
387 let hz_back = MelSpectrogram::mel_to_hz(mel);
388 assert!((hz - hz_back).abs() < 0.01);
389 }
390
391 #[test]
392 fn test_spectrogram_computation() {
393 let samples = generate_sine_wave(1000.0, 1.0, 32000);
394 let config = SpectrogramConfig::default();
395
396 let spec = MelSpectrogram::compute(&samples, config).unwrap();
397
398 assert_eq!(spec.n_mels(), 128);
399 assert!(spec.n_frames() > 0);
400 }
401
402 #[test]
403 fn test_spectrogram_5s_segment() {
404 let samples = generate_sine_wave(2000.0, 5.0, 32000);
406 let config = SpectrogramConfig::for_5s_segment();
407
408 let spec = MelSpectrogram::compute(&samples, config).unwrap();
409
410 assert_eq!(spec.n_mels(), 128);
411 assert!((spec.n_frames() as i32 - 500).abs() < 10);
413 }
414
415 #[test]
416 fn test_spectrogram_normalization() {
417 let samples = generate_sine_wave(1000.0, 1.0, 32000);
418 let config = SpectrogramConfig::default();
419
420 let mut spec = MelSpectrogram::compute(&samples, config).unwrap();
421 spec.normalize();
422
423 let first_row = spec.data.row(0);
425 let mean = first_row.mean().unwrap_or(1.0);
426 assert!(mean.abs() < 0.1);
427 }
428
429 #[test]
430 fn test_spectrogram_slice() {
431 let samples = generate_sine_wave(1000.0, 2.0, 32000);
432 let config = SpectrogramConfig::default();
433
434 let spec = MelSpectrogram::compute(&samples, config).unwrap();
435 let slice = spec.slice_frames(0, 10);
436
437 assert_eq!(slice.ncols(), 10);
438 assert_eq!(slice.nrows(), spec.n_mels());
439 }
440
441 #[test]
442 fn test_empty_input_error() {
443 let config = SpectrogramConfig::default();
444 let result = MelSpectrogram::compute(&[], config);
445 assert!(result.is_err());
446 }
447
448 #[test]
449 fn test_batch_computation() {
450 let segment1 = generate_sine_wave(1000.0, 1.0, 32000);
451 let segment2 = generate_sine_wave(2000.0, 1.0, 32000);
452 let segments = vec![segment1, segment2];
453 let config = SpectrogramConfig::default();
454
455 let specs = SpectrogramBatch::compute_batch(&segments, &config).unwrap();
456
457 assert_eq!(specs.len(), 2);
458 }
459}