sevensense_audio/
spectrogram.rs

1//! Mel spectrogram computation for audio feature extraction.
2//!
3//! This module provides efficient spectrogram computation using FFT
4//! and mel-scale filterbanks, producing features suitable for ML models.
5
6use ndarray::{Array2, Axis};
7use rayon::prelude::*;
8use realfft::RealFftPlanner;
9use std::f32::consts::PI;
10use tracing::{debug, instrument};
11
12use crate::AudioError;
13
14/// Configuration for spectrogram computation.
15#[derive(Debug, Clone)]
16pub struct SpectrogramConfig {
17    /// Number of mel frequency bands.
18    pub n_mels: usize,
19    /// FFT window size in samples.
20    pub n_fft: usize,
21    /// Hop size between frames in samples.
22    pub hop_length: usize,
23    /// Sample rate of the input audio.
24    pub sample_rate: u32,
25    /// Minimum frequency for mel filterbank (Hz).
26    pub f_min: f32,
27    /// Maximum frequency for mel filterbank (Hz).
28    pub f_max: f32,
29    /// Whether to apply log scaling.
30    pub log_scale: bool,
31    /// Reference value for dB conversion.
32    pub ref_db: f32,
33    /// Minimum value for log scaling (avoids log(0)).
34    pub min_value: f32,
35}
36
37impl Default for SpectrogramConfig {
38    fn default() -> Self {
39        Self {
40            n_mels: 128,
41            n_fft: 2048,
42            hop_length: 512,
43            sample_rate: 32_000,
44            f_min: 0.0,
45            f_max: 16_000.0, // Nyquist for 32kHz
46            log_scale: true,
47            ref_db: 1.0,
48            min_value: 1e-10,
49        }
50    }
51}
52
53impl SpectrogramConfig {
54    /// Creates a config optimized for 5-second segments producing 500 frames.
55    ///
56    /// For 32kHz audio:
57    /// - 5s = 160,000 samples
58    /// - hop_length = 320 gives ~500 frames
59    #[must_use]
60    pub fn for_5s_segment() -> Self {
61        Self {
62            n_mels: 128,
63            n_fft: 2048,
64            hop_length: 320, // 160000 / 320 = 500 frames
65            sample_rate: 32_000,
66            f_min: 500.0,    // Filter out very low frequencies
67            f_max: 15_000.0, // Most bird calls below 15kHz
68            log_scale: true,
69            ref_db: 1.0,
70            min_value: 1e-10,
71        }
72    }
73
74    /// Creates a config for variable-length audio.
75    #[must_use]
76    pub fn with_target_frames(target_frames: usize, duration_ms: u64, sample_rate: u32) -> Self {
77        let total_samples = (duration_ms as usize * sample_rate as usize) / 1000;
78        let hop_length = total_samples / target_frames;
79
80        Self {
81            hop_length: hop_length.max(1),
82            sample_rate,
83            ..Self::default()
84        }
85    }
86}
87
88/// A computed mel spectrogram.
89#[derive(Debug, Clone)]
90pub struct MelSpectrogram {
91    /// Spectrogram data (n_mels x n_frames).
92    pub data: Array2<f32>,
93    /// Configuration used to compute this spectrogram.
94    pub config: SpectrogramConfig,
95    /// Duration of the source audio in milliseconds.
96    pub duration_ms: u64,
97}
98
99impl MelSpectrogram {
100    /// Computes a mel spectrogram from audio samples.
101    ///
102    /// # Arguments
103    /// * `samples` - Mono audio samples
104    /// * `config` - Spectrogram configuration
105    ///
106    /// # Returns
107    /// A MelSpectrogram with shape (n_mels, n_frames).
108    #[instrument(skip(samples), fields(samples_len = samples.len()))]
109    pub fn compute(samples: &[f32], config: SpectrogramConfig) -> Result<Self, AudioError> {
110        if samples.is_empty() {
111            return Err(AudioError::invalid_data("Cannot compute spectrogram of empty audio"));
112        }
113
114        let duration_ms = (samples.len() as u64 * 1000) / u64::from(config.sample_rate);
115
116        // Compute STFT
117        let stft = Self::stft(samples, config.n_fft, config.hop_length)?;
118
119        // Compute mel filterbank
120        let mel_filterbank = Self::create_mel_filterbank(
121            config.n_mels,
122            config.n_fft,
123            config.sample_rate,
124            config.f_min,
125            config.f_max,
126        );
127
128        // Apply mel filterbank
129        let n_frames = stft.ncols();
130        let mut mel_spec = Array2::zeros((config.n_mels, n_frames));
131
132        for (frame_idx, frame) in stft.axis_iter(Axis(1)).enumerate() {
133            for (mel_idx, filter) in mel_filterbank.axis_iter(Axis(0)).enumerate() {
134                let energy: f32 = frame
135                    .iter()
136                    .zip(filter.iter())
137                    .map(|(s, f)| s * f)
138                    .sum();
139                mel_spec[[mel_idx, frame_idx]] = energy.max(config.min_value);
140            }
141        }
142
143        // Apply log scaling if requested
144        if config.log_scale {
145            mel_spec.mapv_inplace(|x| 10.0 * (x / config.ref_db).log10());
146        }
147
148        debug!(
149            n_mels = config.n_mels,
150            n_frames = n_frames,
151            duration_ms = duration_ms,
152            "Spectrogram computed"
153        );
154
155        Ok(Self {
156            data: mel_spec,
157            config,
158            duration_ms,
159        })
160    }
161
162    /// Returns the shape as (n_mels, n_frames).
163    #[must_use]
164    pub fn shape(&self) -> (usize, usize) {
165        (self.data.nrows(), self.data.ncols())
166    }
167
168    /// Returns the number of mel bands.
169    #[must_use]
170    pub fn n_mels(&self) -> usize {
171        self.data.nrows()
172    }
173
174    /// Returns the number of time frames.
175    #[must_use]
176    pub fn n_frames(&self) -> usize {
177        self.data.ncols()
178    }
179
180    /// Extracts a time slice of the spectrogram.
181    #[must_use]
182    pub fn slice_frames(&self, start: usize, end: usize) -> Array2<f32> {
183        let end = end.min(self.n_frames());
184        let start = start.min(end);
185        self.data.slice(ndarray::s![.., start..end]).to_owned()
186    }
187
188    /// Normalizes the spectrogram to zero mean and unit variance per mel band.
189    pub fn normalize(&mut self) {
190        for mut row in self.data.axis_iter_mut(Axis(0)) {
191            let mean = row.mean().unwrap_or(0.0);
192            let std = row.std(0.0);
193            if std > 1e-6 {
194                row.mapv_inplace(|x| (x - mean) / std);
195            } else {
196                row.mapv_inplace(|x| x - mean);
197            }
198        }
199    }
200
201    /// Returns the raw data as a flat vector (row-major order).
202    #[must_use]
203    pub fn to_vec(&self) -> Vec<f32> {
204        self.data.iter().copied().collect()
205    }
206
207    /// Computes Short-Time Fourier Transform.
208    fn stft(
209        samples: &[f32],
210        n_fft: usize,
211        hop_length: usize,
212    ) -> Result<Array2<f32>, AudioError> {
213        let n_frames = (samples.len().saturating_sub(n_fft)) / hop_length + 1;
214        if n_frames == 0 {
215            return Err(AudioError::invalid_data(
216                "Audio too short for FFT window size",
217            ));
218        }
219
220        let n_bins = n_fft / 2 + 1;
221        let mut planner = RealFftPlanner::<f32>::new();
222        let fft = planner.plan_fft_forward(n_fft);
223
224        // Pre-compute Hann window
225        let window: Vec<f32> = (0..n_fft)
226            .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n_fft as f32).cos()))
227            .collect();
228
229        // Compute STFT frames in parallel
230        let frames: Vec<Vec<f32>> = (0..n_frames)
231            .into_par_iter()
232            .map(|frame_idx| {
233                let start = frame_idx * hop_length;
234                let mut input = vec![0.0f32; n_fft];
235
236                // Copy and window the input
237                for (i, &w) in window.iter().enumerate() {
238                    if start + i < samples.len() {
239                        input[i] = samples[start + i] * w;
240                    }
241                }
242
243                // Perform FFT
244                let mut spectrum = fft.make_output_vec();
245                let mut scratch = fft.make_scratch_vec();
246
247                // Clone fft for thread safety
248                let fft = RealFftPlanner::<f32>::new().plan_fft_forward(n_fft);
249                fft.process_with_scratch(&mut input, &mut spectrum, &mut scratch)
250                    .ok();
251
252                // Compute magnitude spectrum
253                spectrum
254                    .iter()
255                    .take(n_bins)
256                    .map(|c| (c.re * c.re + c.im * c.im).sqrt())
257                    .collect()
258            })
259            .collect();
260
261        // Assemble into 2D array
262        let mut stft = Array2::zeros((n_bins, n_frames));
263        for (frame_idx, frame) in frames.into_iter().enumerate() {
264            for (bin_idx, &value) in frame.iter().enumerate() {
265                stft[[bin_idx, frame_idx]] = value;
266            }
267        }
268
269        Ok(stft)
270    }
271
272    /// Creates a mel filterbank matrix.
273    fn create_mel_filterbank(
274        n_mels: usize,
275        n_fft: usize,
276        sample_rate: u32,
277        f_min: f32,
278        f_max: f32,
279    ) -> Array2<f32> {
280        let n_bins = n_fft / 2 + 1;
281
282        // Convert frequency to mel scale
283        let mel_min = Self::hz_to_mel(f_min);
284        let mel_max = Self::hz_to_mel(f_max);
285
286        // Create mel points equally spaced in mel scale
287        let mel_points: Vec<f32> = (0..=n_mels + 1)
288            .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
289            .collect();
290
291        // Convert back to Hz
292        let hz_points: Vec<f32> = mel_points.iter().map(|&m| Self::mel_to_hz(m)).collect();
293
294        // Convert to FFT bin indices
295        let bin_points: Vec<usize> = hz_points
296            .iter()
297            .map(|&f| {
298                let bin = (f * n_fft as f32 / sample_rate as f32).round() as usize;
299                bin.min(n_bins - 1)
300            })
301            .collect();
302
303        // Create filterbank matrix
304        let mut filterbank = Array2::zeros((n_mels, n_bins));
305
306        for m in 0..n_mels {
307            let left = bin_points[m];
308            let center = bin_points[m + 1];
309            let right = bin_points[m + 2];
310
311            // Rising slope
312            for k in left..center {
313                if center != left {
314                    filterbank[[m, k]] = (k - left) as f32 / (center - left) as f32;
315                }
316            }
317
318            // Falling slope
319            for k in center..=right {
320                if right != center {
321                    filterbank[[m, k]] = (right - k) as f32 / (right - center) as f32;
322                }
323            }
324        }
325
326        filterbank
327    }
328
329    /// Converts frequency from Hz to mel scale.
330    fn hz_to_mel(hz: f32) -> f32 {
331        2595.0 * (1.0 + hz / 700.0).log10()
332    }
333
334    /// Converts frequency from mel scale to Hz.
335    fn mel_to_hz(mel: f32) -> f32 {
336        700.0 * (10.0f32.powf(mel / 2595.0) - 1.0)
337    }
338}
339
340/// Batch spectrogram computation for multiple segments.
341pub struct SpectrogramBatch;
342
343impl SpectrogramBatch {
344    /// Computes spectrograms for multiple audio segments in parallel.
345    pub fn compute_batch(
346        segments: &[Vec<f32>],
347        config: &SpectrogramConfig,
348    ) -> Result<Vec<MelSpectrogram>, AudioError> {
349        segments
350            .par_iter()
351            .map(|samples| MelSpectrogram::compute(samples, config.clone()))
352            .collect()
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    fn generate_sine_wave(freq: f32, duration_s: f32, sample_rate: u32) -> Vec<f32> {
361        let num_samples = (duration_s * sample_rate as f32) as usize;
362        (0..num_samples)
363            .map(|i| {
364                let t = i as f32 / sample_rate as f32;
365                (2.0 * PI * freq * t).sin()
366            })
367            .collect()
368    }
369
370    #[test]
371    fn test_spectrogram_config_default() {
372        let config = SpectrogramConfig::default();
373        assert_eq!(config.n_mels, 128);
374        assert_eq!(config.n_fft, 2048);
375    }
376
377    #[test]
378    fn test_spectrogram_5s_config() {
379        let config = SpectrogramConfig::for_5s_segment();
380        assert_eq!(config.hop_length, 320);
381    }
382
383    #[test]
384    fn test_mel_conversion() {
385        let hz = 1000.0;
386        let mel = MelSpectrogram::hz_to_mel(hz);
387        let hz_back = MelSpectrogram::mel_to_hz(mel);
388        assert!((hz - hz_back).abs() < 0.01);
389    }
390
391    #[test]
392    fn test_spectrogram_computation() {
393        let samples = generate_sine_wave(1000.0, 1.0, 32000);
394        let config = SpectrogramConfig::default();
395
396        let spec = MelSpectrogram::compute(&samples, config).unwrap();
397
398        assert_eq!(spec.n_mels(), 128);
399        assert!(spec.n_frames() > 0);
400    }
401
402    #[test]
403    fn test_spectrogram_5s_segment() {
404        // 5 seconds at 32kHz = 160,000 samples
405        let samples = generate_sine_wave(2000.0, 5.0, 32000);
406        let config = SpectrogramConfig::for_5s_segment();
407
408        let spec = MelSpectrogram::compute(&samples, config).unwrap();
409
410        assert_eq!(spec.n_mels(), 128);
411        // Should be approximately 500 frames
412        assert!((spec.n_frames() as i32 - 500).abs() < 10);
413    }
414
415    #[test]
416    fn test_spectrogram_normalization() {
417        let samples = generate_sine_wave(1000.0, 1.0, 32000);
418        let config = SpectrogramConfig::default();
419
420        let mut spec = MelSpectrogram::compute(&samples, config).unwrap();
421        spec.normalize();
422
423        // Check that at least one row has roughly zero mean
424        let first_row = spec.data.row(0);
425        let mean = first_row.mean().unwrap_or(1.0);
426        assert!(mean.abs() < 0.1);
427    }
428
429    #[test]
430    fn test_spectrogram_slice() {
431        let samples = generate_sine_wave(1000.0, 2.0, 32000);
432        let config = SpectrogramConfig::default();
433
434        let spec = MelSpectrogram::compute(&samples, config).unwrap();
435        let slice = spec.slice_frames(0, 10);
436
437        assert_eq!(slice.ncols(), 10);
438        assert_eq!(slice.nrows(), spec.n_mels());
439    }
440
441    #[test]
442    fn test_empty_input_error() {
443        let config = SpectrogramConfig::default();
444        let result = MelSpectrogram::compute(&[], config);
445        assert!(result.is_err());
446    }
447
448    #[test]
449    fn test_batch_computation() {
450        let segment1 = generate_sine_wave(1000.0, 1.0, 32000);
451        let segment2 = generate_sine_wave(2000.0, 1.0, 32000);
452        let segments = vec![segment1, segment2];
453        let config = SpectrogramConfig::default();
454
455        let specs = SpectrogramBatch::compute_batch(&segments, &config).unwrap();
456
457        assert_eq!(specs.len(), 2);
458    }
459}