Skip to main content

axonml_audio/
transforms.rs

1//! Audio Transforms - Signal Processing and Augmentation
2//!
3//! # File
4//! `crates/axonml-audio/src/transforms.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_data::Transform;
18use axonml_tensor::Tensor;
19use rand::Rng;
20use std::f32::consts::PI;
21
22// =============================================================================
23// Resample
24// =============================================================================
25
26/// Resamples audio to a target sample rate using linear interpolation.
27pub struct Resample {
28    orig_freq: usize,
29    new_freq: usize,
30}
31
32impl Resample {
33    /// Creates a new Resample transform.
34    #[must_use]
35    pub fn new(orig_freq: usize, new_freq: usize) -> Self {
36        Self {
37            orig_freq,
38            new_freq,
39        }
40    }
41}
42
43impl Transform for Resample {
44    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
45        if self.orig_freq == self.new_freq {
46            return input.clone();
47        }
48
49        let data = input.to_vec();
50        let orig_len = data.len();
51        let new_len = (orig_len as f64 * self.new_freq as f64 / self.orig_freq as f64) as usize;
52
53        if new_len == 0 {
54            return Tensor::from_vec(vec![], &[0]).unwrap();
55        }
56
57        let mut resampled = Vec::with_capacity(new_len);
58        let ratio = orig_len as f64 / new_len as f64;
59
60        for i in 0..new_len {
61            let src_idx = i as f64 * ratio;
62            let idx0 = src_idx.floor() as usize;
63            let idx1 = (idx0 + 1).min(orig_len - 1);
64            let frac = (src_idx - idx0 as f64) as f32;
65
66            let value = data[idx0] * (1.0 - frac) + data[idx1] * frac;
67            resampled.push(value);
68        }
69
70        Tensor::from_vec(resampled, &[new_len]).unwrap()
71    }
72}
73
74// =============================================================================
75// MelSpectrogram
76// =============================================================================
77
78/// Computes a mel spectrogram from audio waveform.
79pub struct MelSpectrogram {
80    sample_rate: usize,
81    n_fft: usize,
82    hop_length: usize,
83    n_mels: usize,
84}
85
86impl MelSpectrogram {
87    /// Creates a new `MelSpectrogram` transform with default parameters.
88    #[must_use]
89    pub fn new(sample_rate: usize) -> Self {
90        Self {
91            sample_rate,
92            n_fft: 2048,
93            hop_length: 512,
94            n_mels: 128,
95        }
96    }
97
98    /// Creates a `MelSpectrogram` with custom parameters.
99    #[must_use]
100    pub fn with_params(sample_rate: usize, n_fft: usize, hop_length: usize, n_mels: usize) -> Self {
101        Self {
102            sample_rate,
103            n_fft,
104            hop_length,
105            n_mels,
106        }
107    }
108
109    /// Converts frequency to mel scale.
110    fn hz_to_mel(hz: f32) -> f32 {
111        2595.0 * (1.0 + hz / 700.0).log10()
112    }
113
114    /// Converts mel to frequency.
115    fn mel_to_hz(mel: f32) -> f32 {
116        700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0)
117    }
118
119    /// Creates mel filterbank.
120    fn mel_filterbank(&self) -> Vec<Vec<f32>> {
121        let fmax = self.sample_rate as f32 / 2.0;
122        let mel_min = Self::hz_to_mel(0.0);
123        let mel_max = Self::hz_to_mel(fmax);
124
125        // Create mel points
126        let mel_points: Vec<f32> = (0..=self.n_mels + 1)
127            .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (self.n_mels + 1) as f32)
128            .collect();
129
130        // Convert to Hz
131        let hz_points: Vec<f32> = mel_points.iter().map(|&m| Self::mel_to_hz(m)).collect();
132
133        // Convert to FFT bins
134        let bin_points: Vec<usize> = hz_points
135            .iter()
136            .map(|&hz| ((self.n_fft + 1) as f32 * hz / self.sample_rate as f32).floor() as usize)
137            .collect();
138
139        // Create filterbank
140        let n_bins = self.n_fft / 2 + 1;
141        let mut filterbank = vec![vec![0.0f32; n_bins]; self.n_mels];
142
143        for m in 0..self.n_mels {
144            let left = bin_points[m];
145            let center = bin_points[m + 1];
146            let right = bin_points[m + 2];
147
148            for k in left..center {
149                if center > left && k < n_bins {
150                    filterbank[m][k] = (k - left) as f32 / (center - left) as f32;
151                }
152            }
153            for k in center..right {
154                if right > center && k < n_bins {
155                    filterbank[m][k] = (right - k) as f32 / (right - center) as f32;
156                }
157            }
158        }
159
160        filterbank
161    }
162
163    /// Applies Hann window.
164    fn hann_window(size: usize) -> Vec<f32> {
165        (0..size)
166            .map(|n| 0.5 * (1.0 - (2.0 * PI * n as f32 / (size - 1) as f32).cos()))
167            .collect()
168    }
169
170    /// Simple DFT (not FFT, but works for demonstration).
171    fn dft(signal: &[f32]) -> Vec<f32> {
172        let n = signal.len();
173        let n_out = n / 2 + 1;
174        let mut magnitude = vec![0.0f32; n_out];
175
176        for k in 0..n_out {
177            let mut real = 0.0;
178            let mut imag = 0.0;
179
180            for (t, &x) in signal.iter().enumerate() {
181                let angle = 2.0 * PI * k as f32 * t as f32 / n as f32;
182                real += x * angle.cos();
183                imag -= x * angle.sin();
184            }
185
186            magnitude[k] = (real * real + imag * imag).sqrt();
187        }
188
189        magnitude
190    }
191}
192
193impl Transform for MelSpectrogram {
194    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
195        let data = input.to_vec();
196        let window = Self::hann_window(self.n_fft);
197        let filterbank = self.mel_filterbank();
198
199        // Calculate number of frames
200        let n_frames = if data.len() >= self.n_fft {
201            (data.len() - self.n_fft) / self.hop_length + 1
202        } else {
203            0
204        };
205
206        if n_frames == 0 {
207            return Tensor::from_vec(vec![0.0; self.n_mels], &[self.n_mels, 1]).unwrap();
208        }
209
210        let mut mel_spec = vec![0.0f32; self.n_mels * n_frames];
211
212        for frame_idx in 0..n_frames {
213            let start = frame_idx * self.hop_length;
214            let end = (start + self.n_fft).min(data.len());
215
216            // Extract and window frame
217            let mut frame: Vec<f32> = data[start..end].to_vec();
218            frame.resize(self.n_fft, 0.0);
219
220            for (i, w) in window.iter().enumerate() {
221                frame[i] *= w;
222            }
223
224            // Compute magnitude spectrum
225            let spectrum = Self::dft(&frame);
226
227            // Apply mel filterbank
228            for (m, filter) in filterbank.iter().enumerate() {
229                let mut mel_energy = 0.0;
230                for (k, &mag) in spectrum.iter().enumerate() {
231                    if k < filter.len() {
232                        mel_energy += mag * mag * filter[k];
233                    }
234                }
235                // Convert to log scale
236                mel_spec[m * n_frames + frame_idx] = (mel_energy + 1e-10).ln();
237            }
238        }
239
240        Tensor::from_vec(mel_spec, &[self.n_mels, n_frames]).unwrap()
241    }
242}
243
244// =============================================================================
245// MFCC
246// =============================================================================
247
248/// Computes Mel-frequency cepstral coefficients.
249pub struct MFCC {
250    mel_spec: MelSpectrogram,
251    n_mfcc: usize,
252}
253
254impl MFCC {
255    /// Creates a new MFCC transform.
256    #[must_use]
257    pub fn new(sample_rate: usize, n_mfcc: usize) -> Self {
258        Self {
259            mel_spec: MelSpectrogram::new(sample_rate),
260            n_mfcc,
261        }
262    }
263
264    /// DCT-II for MFCC computation.
265    fn dct(input: &[f32], n_out: usize) -> Vec<f32> {
266        let n = input.len();
267        let mut output = vec![0.0f32; n_out];
268
269        for k in 0..n_out {
270            let mut sum = 0.0;
271            for (i, &x) in input.iter().enumerate() {
272                sum += x * (PI * k as f32 * (2 * i + 1) as f32 / (2 * n) as f32).cos();
273            }
274            output[k] = sum;
275        }
276
277        output
278    }
279}
280
281impl Transform for MFCC {
282    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
283        // First compute mel spectrogram
284        let mel = self.mel_spec.apply(input);
285        let mel_data = mel.to_vec();
286        let mel_shape = mel.shape();
287
288        if mel_shape.len() != 2 {
289            return input.clone();
290        }
291
292        let (n_mels, n_frames) = (mel_shape[0], mel_shape[1]);
293        let mut mfcc = vec![0.0f32; self.n_mfcc * n_frames];
294
295        // Apply DCT to each frame
296        for frame in 0..n_frames {
297            let frame_data: Vec<f32> = (0..n_mels)
298                .map(|m| mel_data[m * n_frames + frame])
299                .collect();
300
301            let coeffs = Self::dct(&frame_data, self.n_mfcc);
302
303            for (k, &c) in coeffs.iter().enumerate() {
304                mfcc[k * n_frames + frame] = c;
305            }
306        }
307
308        Tensor::from_vec(mfcc, &[self.n_mfcc, n_frames]).unwrap()
309    }
310}
311
312// =============================================================================
313// TimeStretch
314// =============================================================================
315
316/// Time stretches audio without changing pitch.
317pub struct TimeStretch {
318    rate: f32,
319}
320
321impl TimeStretch {
322    /// Creates a new `TimeStretch` transform.
323    /// rate > 1.0 speeds up, rate < 1.0 slows down.
324    #[must_use]
325    pub fn new(rate: f32) -> Self {
326        Self {
327            rate: rate.max(0.1).min(10.0),
328        }
329    }
330}
331
332impl Transform for TimeStretch {
333    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
334        let data = input.to_vec();
335        let new_len = (data.len() as f32 / self.rate) as usize;
336
337        if new_len == 0 {
338            return Tensor::from_vec(vec![], &[0]).unwrap();
339        }
340
341        let mut stretched = Vec::with_capacity(new_len);
342
343        for i in 0..new_len {
344            let src_idx = i as f32 * self.rate;
345            let idx0 = src_idx.floor() as usize;
346            let idx1 = (idx0 + 1).min(data.len() - 1);
347            let frac = src_idx - idx0 as f32;
348
349            if idx0 < data.len() {
350                let value =
351                    data[idx0] * (1.0 - frac) + data.get(idx1).copied().unwrap_or(0.0) * frac;
352                stretched.push(value);
353            }
354        }
355
356        let len = stretched.len();
357        Tensor::from_vec(stretched, &[len]).unwrap()
358    }
359}
360
361// =============================================================================
362// PitchShift
363// =============================================================================
364
365/// Shifts the pitch of audio.
366pub struct PitchShift {
367    semitones: f32,
368}
369
370impl PitchShift {
371    /// Creates a new `PitchShift` transform.
372    /// Positive semitones shift up, negative shift down.
373    #[must_use]
374    pub fn new(semitones: f32) -> Self {
375        Self { semitones }
376    }
377}
378
379impl Transform for PitchShift {
380    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
381        // Simplified pitch shift using resampling
382        // Real implementation would use phase vocoder
383        let rate = 2.0_f32.powf(self.semitones / 12.0);
384        let data = input.to_vec();
385        let orig_len = data.len();
386
387        // Resample to change pitch
388        let resampled_len = (orig_len as f32 / rate) as usize;
389        if resampled_len == 0 {
390            return input.clone();
391        }
392
393        let mut resampled = Vec::with_capacity(resampled_len);
394        for i in 0..resampled_len {
395            let src_idx = i as f32 * rate;
396            let idx0 = src_idx.floor() as usize;
397            let idx1 = (idx0 + 1).min(orig_len - 1);
398            let frac = src_idx - idx0 as f32;
399
400            if idx0 < orig_len {
401                let value =
402                    data[idx0] * (1.0 - frac) + data.get(idx1).copied().unwrap_or(0.0) * frac;
403                resampled.push(value);
404            }
405        }
406
407        // Time stretch back to original length
408        let mut result = Vec::with_capacity(orig_len);
409        for i in 0..orig_len {
410            let src_idx = i as f32 * resampled.len() as f32 / orig_len as f32;
411            let idx0 = src_idx.floor() as usize;
412            let idx1 = (idx0 + 1).min(resampled.len().saturating_sub(1));
413            let frac = src_idx - idx0 as f32;
414
415            if idx0 < resampled.len() {
416                let value = resampled[idx0] * (1.0 - frac)
417                    + resampled.get(idx1).copied().unwrap_or(0.0) * frac;
418                result.push(value);
419            } else {
420                result.push(0.0);
421            }
422        }
423
424        Tensor::from_vec(result, &[orig_len]).unwrap()
425    }
426}
427
428// =============================================================================
429// AddNoise
430// =============================================================================
431
432/// Adds random noise to audio.
433pub struct AddNoise {
434    snr_db: f32,
435}
436
437impl AddNoise {
438    /// Creates a new `AddNoise` transform with specified SNR in dB.
439    #[must_use]
440    pub fn new(snr_db: f32) -> Self {
441        Self { snr_db }
442    }
443}
444
445impl Transform for AddNoise {
446    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
447        let data = input.to_vec();
448        let mut rng = rand::thread_rng();
449
450        // Calculate signal power
451        let signal_power: f32 = data.iter().map(|&x| x * x).sum::<f32>() / data.len() as f32;
452
453        // Calculate noise power from SNR
454        let noise_power = signal_power / 10.0_f32.powf(self.snr_db / 10.0);
455        let noise_std = noise_power.sqrt();
456
457        // Add Gaussian noise
458        let noisy: Vec<f32> = data
459            .iter()
460            .map(|&x| {
461                // Box-Muller transform for Gaussian noise
462                let u1: f32 = rng.r#gen();
463                let u2: f32 = rng.r#gen();
464                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
465                x + z * noise_std
466            })
467            .collect();
468
469        Tensor::from_vec(noisy, input.shape()).unwrap()
470    }
471}
472
473// =============================================================================
474// Normalize Audio
475// =============================================================================
476
477/// Normalizes audio to have maximum amplitude of 1.0.
478pub struct NormalizeAudio;
479
480impl NormalizeAudio {
481    /// Creates a new `NormalizeAudio` transform.
482    #[must_use]
483    pub fn new() -> Self {
484        Self
485    }
486}
487
488impl Default for NormalizeAudio {
489    fn default() -> Self {
490        Self::new()
491    }
492}
493
494impl Transform for NormalizeAudio {
495    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
496        let data = input.to_vec();
497        let max_val = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
498
499        if max_val < 1e-10 {
500            return input.clone();
501        }
502
503        let normalized: Vec<f32> = data.iter().map(|&x| x / max_val).collect();
504        Tensor::from_vec(normalized, input.shape()).unwrap()
505    }
506}
507
508// =============================================================================
509// Trim Silence
510// =============================================================================
511
512/// Trims silence from the beginning and end of audio.
513pub struct TrimSilence {
514    threshold_db: f32,
515}
516
517impl TrimSilence {
518    /// Creates a `TrimSilence` transform with specified threshold in dB.
519    #[must_use]
520    pub fn new(threshold_db: f32) -> Self {
521        Self { threshold_db }
522    }
523
524    /// Creates with default -60dB threshold.
525    #[must_use]
526    pub fn default_threshold() -> Self {
527        Self::new(-60.0)
528    }
529}
530
531impl Transform for TrimSilence {
532    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
533        let data = input.to_vec();
534        let threshold = 10.0_f32.powf(self.threshold_db / 20.0);
535
536        // Find first non-silent sample
537        let start = data.iter().position(|&x| x.abs() > threshold).unwrap_or(0);
538
539        // Find last non-silent sample
540        let end = data
541            .iter()
542            .rposition(|&x| x.abs() > threshold)
543            .map_or(data.len(), |i| i + 1);
544
545        if start >= end {
546            return Tensor::from_vec(vec![], &[0]).unwrap();
547        }
548
549        let trimmed = data[start..end].to_vec();
550        let len = trimmed.len();
551        Tensor::from_vec(trimmed, &[len]).unwrap()
552    }
553}
554
555// =============================================================================
556// Tests
557// =============================================================================
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    fn create_sine_wave(freq: f32, sample_rate: usize, duration: f32) -> Tensor<f32> {
564        let n_samples = (sample_rate as f32 * duration) as usize;
565        let data: Vec<f32> = (0..n_samples)
566            .map(|i| (2.0 * PI * freq * i as f32 / sample_rate as f32).sin())
567            .collect();
568        Tensor::from_vec(data, &[n_samples]).unwrap()
569    }
570
571    #[test]
572    fn test_resample() {
573        let audio = create_sine_wave(440.0, 16000, 0.1);
574        let resample = Resample::new(16000, 8000);
575
576        let resampled = resample.apply(&audio);
577
578        // Should be half the length
579        assert_eq!(resampled.shape()[0], audio.shape()[0] / 2);
580    }
581
582    #[test]
583    fn test_resample_same_rate() {
584        let audio = create_sine_wave(440.0, 16000, 0.1);
585        let resample = Resample::new(16000, 16000);
586
587        let result = resample.apply(&audio);
588        assert_eq!(result.to_vec(), audio.to_vec());
589    }
590
591    #[test]
592    fn test_mel_spectrogram() {
593        let audio = create_sine_wave(440.0, 16000, 0.5);
594        let mel = MelSpectrogram::with_params(16000, 512, 256, 40);
595
596        let spec = mel.apply(&audio);
597
598        assert_eq!(spec.shape()[0], 40); // n_mels
599        assert!(spec.shape()[1] > 0); // n_frames
600    }
601
602    #[test]
603    fn test_mfcc() {
604        let audio = create_sine_wave(440.0, 16000, 0.5);
605        let mfcc = MFCC::new(16000, 13);
606
607        let coeffs = mfcc.apply(&audio);
608
609        assert_eq!(coeffs.shape()[0], 13); // n_mfcc
610    }
611
612    #[test]
613    fn test_time_stretch() {
614        let audio = create_sine_wave(440.0, 16000, 0.1);
615        let orig_len = audio.shape()[0];
616
617        // Speed up 2x
618        let stretch = TimeStretch::new(2.0);
619        let stretched = stretch.apply(&audio);
620
621        assert!(stretched.shape()[0] < orig_len);
622    }
623
624    #[test]
625    fn test_pitch_shift() {
626        let audio = create_sine_wave(440.0, 16000, 0.1);
627        let orig_len = audio.shape()[0];
628
629        let shift = PitchShift::new(2.0); // Shift up 2 semitones
630        let shifted = shift.apply(&audio);
631
632        // Length should remain the same
633        assert_eq!(shifted.shape()[0], orig_len);
634    }
635
636    #[test]
637    fn test_add_noise() {
638        let audio = create_sine_wave(440.0, 16000, 0.1);
639        let add_noise = AddNoise::new(20.0); // 20dB SNR
640
641        let noisy = add_noise.apply(&audio);
642
643        assert_eq!(noisy.shape(), audio.shape());
644        // Values should be different (noise added)
645        assert_ne!(noisy.to_vec(), audio.to_vec());
646    }
647
648    #[test]
649    fn test_normalize_audio() {
650        let data = vec![0.1, -0.5, 0.3, -0.2];
651        let audio = Tensor::from_vec(data, &[4]).unwrap();
652
653        let normalize = NormalizeAudio::new();
654        let normalized = normalize.apply(&audio);
655
656        let max_val = normalized
657            .to_vec()
658            .iter()
659            .map(|x| x.abs())
660            .fold(0.0f32, f32::max);
661        assert!((max_val - 1.0).abs() < 0.001);
662    }
663
664    #[test]
665    fn test_trim_silence() {
666        let data = vec![0.0, 0.0, 0.5, 0.3, 0.0, 0.0];
667        let audio = Tensor::from_vec(data, &[6]).unwrap();
668
669        let trim = TrimSilence::new(-20.0);
670        let trimmed = trim.apply(&audio);
671
672        assert_eq!(trimmed.shape()[0], 2); // Only [0.5, 0.3]
673    }
674
675    #[test]
676    fn test_hz_to_mel_conversion() {
677        let hz = 1000.0;
678        let mel = MelSpectrogram::hz_to_mel(hz);
679        let back = MelSpectrogram::mel_to_hz(mel);
680
681        assert!((hz - back).abs() < 0.1);
682    }
683}