Skip to main content

axonml_audio/
transforms.rs

1//! Audio Transforms - Signal Processing and Augmentation
2//!
3//! Provides audio-specific transformations for preprocessing and data augmentation.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_data::Transform;
9use axonml_tensor::Tensor;
10use rand::Rng;
11use std::f32::consts::PI;
12
13// =============================================================================
14// Resample
15// =============================================================================
16
17/// Resamples audio to a target sample rate using linear interpolation.
18pub struct Resample {
19    orig_freq: usize,
20    new_freq: usize,
21}
22
23impl Resample {
24    /// Creates a new Resample transform.
25    #[must_use] pub fn new(orig_freq: usize, new_freq: usize) -> Self {
26        Self {
27            orig_freq,
28            new_freq,
29        }
30    }
31}
32
33impl Transform for Resample {
34    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
35        if self.orig_freq == self.new_freq {
36            return input.clone();
37        }
38
39        let data = input.to_vec();
40        let orig_len = data.len();
41        let new_len = (orig_len as f64 * self.new_freq as f64 / self.orig_freq as f64) as usize;
42
43        if new_len == 0 {
44            return Tensor::from_vec(vec![], &[0]).unwrap();
45        }
46
47        let mut resampled = Vec::with_capacity(new_len);
48        let ratio = orig_len as f64 / new_len as f64;
49
50        for i in 0..new_len {
51            let src_idx = i as f64 * ratio;
52            let idx0 = src_idx.floor() as usize;
53            let idx1 = (idx0 + 1).min(orig_len - 1);
54            let frac = (src_idx - idx0 as f64) as f32;
55
56            let value = data[idx0] * (1.0 - frac) + data[idx1] * frac;
57            resampled.push(value);
58        }
59
60        Tensor::from_vec(resampled, &[new_len]).unwrap()
61    }
62}
63
64// =============================================================================
65// MelSpectrogram
66// =============================================================================
67
68/// Computes a mel spectrogram from audio waveform.
69pub struct MelSpectrogram {
70    sample_rate: usize,
71    n_fft: usize,
72    hop_length: usize,
73    n_mels: usize,
74}
75
76impl MelSpectrogram {
77    /// Creates a new `MelSpectrogram` transform with default parameters.
78    #[must_use] pub fn new(sample_rate: usize) -> Self {
79        Self {
80            sample_rate,
81            n_fft: 2048,
82            hop_length: 512,
83            n_mels: 128,
84        }
85    }
86
87    /// Creates a `MelSpectrogram` with custom parameters.
88    #[must_use] pub fn with_params(sample_rate: usize, n_fft: usize, hop_length: usize, n_mels: usize) -> Self {
89        Self {
90            sample_rate,
91            n_fft,
92            hop_length,
93            n_mels,
94        }
95    }
96
97    /// Converts frequency to mel scale.
98    fn hz_to_mel(hz: f32) -> f32 {
99        2595.0 * (1.0 + hz / 700.0).log10()
100    }
101
102    /// Converts mel to frequency.
103    fn mel_to_hz(mel: f32) -> f32 {
104        700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0)
105    }
106
107    /// Creates mel filterbank.
108    fn mel_filterbank(&self) -> Vec<Vec<f32>> {
109        let fmax = self.sample_rate as f32 / 2.0;
110        let mel_min = Self::hz_to_mel(0.0);
111        let mel_max = Self::hz_to_mel(fmax);
112
113        // Create mel points
114        let mel_points: Vec<f32> = (0..=self.n_mels + 1)
115            .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (self.n_mels + 1) as f32)
116            .collect();
117
118        // Convert to Hz
119        let hz_points: Vec<f32> = mel_points.iter().map(|&m| Self::mel_to_hz(m)).collect();
120
121        // Convert to FFT bins
122        let bin_points: Vec<usize> = hz_points
123            .iter()
124            .map(|&hz| ((self.n_fft + 1) as f32 * hz / self.sample_rate as f32).floor() as usize)
125            .collect();
126
127        // Create filterbank
128        let n_bins = self.n_fft / 2 + 1;
129        let mut filterbank = vec![vec![0.0f32; n_bins]; self.n_mels];
130
131        for m in 0..self.n_mels {
132            let left = bin_points[m];
133            let center = bin_points[m + 1];
134            let right = bin_points[m + 2];
135
136            for k in left..center {
137                if center > left && k < n_bins {
138                    filterbank[m][k] = (k - left) as f32 / (center - left) as f32;
139                }
140            }
141            for k in center..right {
142                if right > center && k < n_bins {
143                    filterbank[m][k] = (right - k) as f32 / (right - center) as f32;
144                }
145            }
146        }
147
148        filterbank
149    }
150
151    /// Applies Hann window.
152    fn hann_window(size: usize) -> Vec<f32> {
153        (0..size)
154            .map(|n| 0.5 * (1.0 - (2.0 * PI * n as f32 / (size - 1) as f32).cos()))
155            .collect()
156    }
157
158    /// Simple DFT (not FFT, but works for demonstration).
159    fn dft(signal: &[f32]) -> Vec<f32> {
160        let n = signal.len();
161        let n_out = n / 2 + 1;
162        let mut magnitude = vec![0.0f32; n_out];
163
164        for k in 0..n_out {
165            let mut real = 0.0;
166            let mut imag = 0.0;
167
168            for (t, &x) in signal.iter().enumerate() {
169                let angle = 2.0 * PI * k as f32 * t as f32 / n as f32;
170                real += x * angle.cos();
171                imag -= x * angle.sin();
172            }
173
174            magnitude[k] = (real * real + imag * imag).sqrt();
175        }
176
177        magnitude
178    }
179}
180
181impl Transform for MelSpectrogram {
182    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
183        let data = input.to_vec();
184        let window = Self::hann_window(self.n_fft);
185        let filterbank = self.mel_filterbank();
186
187        // Calculate number of frames
188        let n_frames = if data.len() >= self.n_fft {
189            (data.len() - self.n_fft) / self.hop_length + 1
190        } else {
191            0
192        };
193
194        if n_frames == 0 {
195            return Tensor::from_vec(vec![0.0; self.n_mels], &[self.n_mels, 1]).unwrap();
196        }
197
198        let mut mel_spec = vec![0.0f32; self.n_mels * n_frames];
199
200        for frame_idx in 0..n_frames {
201            let start = frame_idx * self.hop_length;
202            let end = (start + self.n_fft).min(data.len());
203
204            // Extract and window frame
205            let mut frame: Vec<f32> = data[start..end].to_vec();
206            frame.resize(self.n_fft, 0.0);
207
208            for (i, w) in window.iter().enumerate() {
209                frame[i] *= w;
210            }
211
212            // Compute magnitude spectrum
213            let spectrum = Self::dft(&frame);
214
215            // Apply mel filterbank
216            for (m, filter) in filterbank.iter().enumerate() {
217                let mut mel_energy = 0.0;
218                for (k, &mag) in spectrum.iter().enumerate() {
219                    if k < filter.len() {
220                        mel_energy += mag * mag * filter[k];
221                    }
222                }
223                // Convert to log scale
224                mel_spec[m * n_frames + frame_idx] = (mel_energy + 1e-10).ln();
225            }
226        }
227
228        Tensor::from_vec(mel_spec, &[self.n_mels, n_frames]).unwrap()
229    }
230}
231
232// =============================================================================
233// MFCC
234// =============================================================================
235
236/// Computes Mel-frequency cepstral coefficients.
237pub struct MFCC {
238    mel_spec: MelSpectrogram,
239    n_mfcc: usize,
240}
241
242impl MFCC {
243    /// Creates a new MFCC transform.
244    #[must_use] pub fn new(sample_rate: usize, n_mfcc: usize) -> Self {
245        Self {
246            mel_spec: MelSpectrogram::new(sample_rate),
247            n_mfcc,
248        }
249    }
250
251    /// DCT-II for MFCC computation.
252    fn dct(input: &[f32], n_out: usize) -> Vec<f32> {
253        let n = input.len();
254        let mut output = vec![0.0f32; n_out];
255
256        for k in 0..n_out {
257            let mut sum = 0.0;
258            for (i, &x) in input.iter().enumerate() {
259                sum += x * (PI * k as f32 * (2 * i + 1) as f32 / (2 * n) as f32).cos();
260            }
261            output[k] = sum;
262        }
263
264        output
265    }
266}
267
268impl Transform for MFCC {
269    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
270        // First compute mel spectrogram
271        let mel = self.mel_spec.apply(input);
272        let mel_data = mel.to_vec();
273        let mel_shape = mel.shape();
274
275        if mel_shape.len() != 2 {
276            return input.clone();
277        }
278
279        let (n_mels, n_frames) = (mel_shape[0], mel_shape[1]);
280        let mut mfcc = vec![0.0f32; self.n_mfcc * n_frames];
281
282        // Apply DCT to each frame
283        for frame in 0..n_frames {
284            let frame_data: Vec<f32> = (0..n_mels)
285                .map(|m| mel_data[m * n_frames + frame])
286                .collect();
287
288            let coeffs = Self::dct(&frame_data, self.n_mfcc);
289
290            for (k, &c) in coeffs.iter().enumerate() {
291                mfcc[k * n_frames + frame] = c;
292            }
293        }
294
295        Tensor::from_vec(mfcc, &[self.n_mfcc, n_frames]).unwrap()
296    }
297}
298
299// =============================================================================
300// TimeStretch
301// =============================================================================
302
303/// Time stretches audio without changing pitch.
304pub struct TimeStretch {
305    rate: f32,
306}
307
308impl TimeStretch {
309    /// Creates a new `TimeStretch` transform.
310    /// rate > 1.0 speeds up, rate < 1.0 slows down.
311    #[must_use] pub fn new(rate: f32) -> Self {
312        Self {
313            rate: rate.max(0.1).min(10.0),
314        }
315    }
316}
317
318impl Transform for TimeStretch {
319    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
320        let data = input.to_vec();
321        let new_len = (data.len() as f32 / self.rate) as usize;
322
323        if new_len == 0 {
324            return Tensor::from_vec(vec![], &[0]).unwrap();
325        }
326
327        let mut stretched = Vec::with_capacity(new_len);
328
329        for i in 0..new_len {
330            let src_idx = i as f32 * self.rate;
331            let idx0 = src_idx.floor() as usize;
332            let idx1 = (idx0 + 1).min(data.len() - 1);
333            let frac = src_idx - idx0 as f32;
334
335            if idx0 < data.len() {
336                let value =
337                    data[idx0] * (1.0 - frac) + data.get(idx1).copied().unwrap_or(0.0) * frac;
338                stretched.push(value);
339            }
340        }
341
342        let len = stretched.len();
343        Tensor::from_vec(stretched, &[len]).unwrap()
344    }
345}
346
347// =============================================================================
348// PitchShift
349// =============================================================================
350
351/// Shifts the pitch of audio.
352pub struct PitchShift {
353    semitones: f32,
354}
355
356impl PitchShift {
357    /// Creates a new `PitchShift` transform.
358    /// Positive semitones shift up, negative shift down.
359    #[must_use] pub fn new(semitones: f32) -> Self {
360        Self { semitones }
361    }
362}
363
364impl Transform for PitchShift {
365    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
366        // Simplified pitch shift using resampling
367        // Real implementation would use phase vocoder
368        let rate = 2.0_f32.powf(self.semitones / 12.0);
369        let data = input.to_vec();
370        let orig_len = data.len();
371
372        // Resample to change pitch
373        let resampled_len = (orig_len as f32 / rate) as usize;
374        if resampled_len == 0 {
375            return input.clone();
376        }
377
378        let mut resampled = Vec::with_capacity(resampled_len);
379        for i in 0..resampled_len {
380            let src_idx = i as f32 * rate;
381            let idx0 = src_idx.floor() as usize;
382            let idx1 = (idx0 + 1).min(orig_len - 1);
383            let frac = src_idx - idx0 as f32;
384
385            if idx0 < orig_len {
386                let value =
387                    data[idx0] * (1.0 - frac) + data.get(idx1).copied().unwrap_or(0.0) * frac;
388                resampled.push(value);
389            }
390        }
391
392        // Time stretch back to original length
393        let mut result = Vec::with_capacity(orig_len);
394        for i in 0..orig_len {
395            let src_idx = i as f32 * resampled.len() as f32 / orig_len as f32;
396            let idx0 = src_idx.floor() as usize;
397            let idx1 = (idx0 + 1).min(resampled.len().saturating_sub(1));
398            let frac = src_idx - idx0 as f32;
399
400            if idx0 < resampled.len() {
401                let value = resampled[idx0] * (1.0 - frac)
402                    + resampled.get(idx1).copied().unwrap_or(0.0) * frac;
403                result.push(value);
404            } else {
405                result.push(0.0);
406            }
407        }
408
409        Tensor::from_vec(result, &[orig_len]).unwrap()
410    }
411}
412
413// =============================================================================
414// AddNoise
415// =============================================================================
416
417/// Adds random noise to audio.
418pub struct AddNoise {
419    snr_db: f32,
420}
421
422impl AddNoise {
423    /// Creates a new `AddNoise` transform with specified SNR in dB.
424    #[must_use] pub fn new(snr_db: f32) -> Self {
425        Self { snr_db }
426    }
427}
428
429impl Transform for AddNoise {
430    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
431        let data = input.to_vec();
432        let mut rng = rand::thread_rng();
433
434        // Calculate signal power
435        let signal_power: f32 = data.iter().map(|&x| x * x).sum::<f32>() / data.len() as f32;
436
437        // Calculate noise power from SNR
438        let noise_power = signal_power / 10.0_f32.powf(self.snr_db / 10.0);
439        let noise_std = noise_power.sqrt();
440
441        // Add Gaussian noise
442        let noisy: Vec<f32> = data
443            .iter()
444            .map(|&x| {
445                // Box-Muller transform for Gaussian noise
446                let u1: f32 = rng.gen();
447                let u2: f32 = rng.gen();
448                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
449                x + z * noise_std
450            })
451            .collect();
452
453        Tensor::from_vec(noisy, input.shape()).unwrap()
454    }
455}
456
457// =============================================================================
458// Normalize Audio
459// =============================================================================
460
461/// Normalizes audio to have maximum amplitude of 1.0.
462pub struct NormalizeAudio;
463
464impl NormalizeAudio {
465    /// Creates a new `NormalizeAudio` transform.
466    #[must_use] pub fn new() -> Self {
467        Self
468    }
469}
470
471impl Default for NormalizeAudio {
472    fn default() -> Self {
473        Self::new()
474    }
475}
476
477impl Transform for NormalizeAudio {
478    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
479        let data = input.to_vec();
480        let max_val = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
481
482        if max_val < 1e-10 {
483            return input.clone();
484        }
485
486        let normalized: Vec<f32> = data.iter().map(|&x| x / max_val).collect();
487        Tensor::from_vec(normalized, input.shape()).unwrap()
488    }
489}
490
491// =============================================================================
492// Trim Silence
493// =============================================================================
494
495/// Trims silence from the beginning and end of audio.
496pub struct TrimSilence {
497    threshold_db: f32,
498}
499
500impl TrimSilence {
501    /// Creates a `TrimSilence` transform with specified threshold in dB.
502    #[must_use] pub fn new(threshold_db: f32) -> Self {
503        Self { threshold_db }
504    }
505
506    /// Creates with default -60dB threshold.
507    #[must_use] pub fn default_threshold() -> Self {
508        Self::new(-60.0)
509    }
510}
511
512impl Transform for TrimSilence {
513    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
514        let data = input.to_vec();
515        let threshold = 10.0_f32.powf(self.threshold_db / 20.0);
516
517        // Find first non-silent sample
518        let start = data.iter().position(|&x| x.abs() > threshold).unwrap_or(0);
519
520        // Find last non-silent sample
521        let end = data
522            .iter()
523            .rposition(|&x| x.abs() > threshold)
524            .map_or(data.len(), |i| i + 1);
525
526        if start >= end {
527            return Tensor::from_vec(vec![], &[0]).unwrap();
528        }
529
530        let trimmed = data[start..end].to_vec();
531        let len = trimmed.len();
532        Tensor::from_vec(trimmed, &[len]).unwrap()
533    }
534}
535
536// =============================================================================
537// Tests
538// =============================================================================
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    fn create_sine_wave(freq: f32, sample_rate: usize, duration: f32) -> Tensor<f32> {
545        let n_samples = (sample_rate as f32 * duration) as usize;
546        let data: Vec<f32> = (0..n_samples)
547            .map(|i| (2.0 * PI * freq * i as f32 / sample_rate as f32).sin())
548            .collect();
549        Tensor::from_vec(data, &[n_samples]).unwrap()
550    }
551
552    #[test]
553    fn test_resample() {
554        let audio = create_sine_wave(440.0, 16000, 0.1);
555        let resample = Resample::new(16000, 8000);
556
557        let resampled = resample.apply(&audio);
558
559        // Should be half the length
560        assert_eq!(resampled.shape()[0], audio.shape()[0] / 2);
561    }
562
563    #[test]
564    fn test_resample_same_rate() {
565        let audio = create_sine_wave(440.0, 16000, 0.1);
566        let resample = Resample::new(16000, 16000);
567
568        let result = resample.apply(&audio);
569        assert_eq!(result.to_vec(), audio.to_vec());
570    }
571
572    #[test]
573    fn test_mel_spectrogram() {
574        let audio = create_sine_wave(440.0, 16000, 0.5);
575        let mel = MelSpectrogram::with_params(16000, 512, 256, 40);
576
577        let spec = mel.apply(&audio);
578
579        assert_eq!(spec.shape()[0], 40); // n_mels
580        assert!(spec.shape()[1] > 0); // n_frames
581    }
582
583    #[test]
584    fn test_mfcc() {
585        let audio = create_sine_wave(440.0, 16000, 0.5);
586        let mfcc = MFCC::new(16000, 13);
587
588        let coeffs = mfcc.apply(&audio);
589
590        assert_eq!(coeffs.shape()[0], 13); // n_mfcc
591    }
592
593    #[test]
594    fn test_time_stretch() {
595        let audio = create_sine_wave(440.0, 16000, 0.1);
596        let orig_len = audio.shape()[0];
597
598        // Speed up 2x
599        let stretch = TimeStretch::new(2.0);
600        let stretched = stretch.apply(&audio);
601
602        assert!(stretched.shape()[0] < orig_len);
603    }
604
605    #[test]
606    fn test_pitch_shift() {
607        let audio = create_sine_wave(440.0, 16000, 0.1);
608        let orig_len = audio.shape()[0];
609
610        let shift = PitchShift::new(2.0); // Shift up 2 semitones
611        let shifted = shift.apply(&audio);
612
613        // Length should remain the same
614        assert_eq!(shifted.shape()[0], orig_len);
615    }
616
617    #[test]
618    fn test_add_noise() {
619        let audio = create_sine_wave(440.0, 16000, 0.1);
620        let add_noise = AddNoise::new(20.0); // 20dB SNR
621
622        let noisy = add_noise.apply(&audio);
623
624        assert_eq!(noisy.shape(), audio.shape());
625        // Values should be different (noise added)
626        assert_ne!(noisy.to_vec(), audio.to_vec());
627    }
628
629    #[test]
630    fn test_normalize_audio() {
631        let data = vec![0.1, -0.5, 0.3, -0.2];
632        let audio = Tensor::from_vec(data, &[4]).unwrap();
633
634        let normalize = NormalizeAudio::new();
635        let normalized = normalize.apply(&audio);
636
637        let max_val = normalized
638            .to_vec()
639            .iter()
640            .map(|x| x.abs())
641            .fold(0.0f32, f32::max);
642        assert!((max_val - 1.0).abs() < 0.001);
643    }
644
645    #[test]
646    fn test_trim_silence() {
647        let data = vec![0.0, 0.0, 0.5, 0.3, 0.0, 0.0];
648        let audio = Tensor::from_vec(data, &[6]).unwrap();
649
650        let trim = TrimSilence::new(-20.0);
651        let trimmed = trim.apply(&audio);
652
653        assert_eq!(trimmed.shape()[0], 2); // Only [0.5, 0.3]
654    }
655
656    #[test]
657    fn test_hz_to_mel_conversion() {
658        let hz = 1000.0;
659        let mel = MelSpectrogram::hz_to_mel(hz);
660        let back = MelSpectrogram::mel_to_hz(mel);
661
662        assert!((hz - back).abs() < 0.1);
663    }
664}