Skip to main content

axonml_audio/
transforms.rs

1//! Audio Transforms - Signal Processing and Augmentation
2//!
3//! # File
4//! `crates/axonml-audio/src/transforms.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use axonml_data::Transform;
19use axonml_tensor::Tensor;
20use rand::Rng;
21use rustfft::{FftPlanner, num_complex::Complex};
22use std::f32::consts::PI;
23
24// =============================================================================
25// Resample
26// =============================================================================
27
28/// Resamples audio to a target sample rate using linear interpolation.
29pub struct Resample {
30    orig_freq: usize,
31    new_freq: usize,
32}
33
34impl Resample {
35    /// Creates a new Resample transform.
36    #[must_use]
37    pub fn new(orig_freq: usize, new_freq: usize) -> Self {
38        Self {
39            orig_freq,
40            new_freq,
41        }
42    }
43}
44
45impl Transform for Resample {
46    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
47        if self.orig_freq == self.new_freq {
48            return input.clone();
49        }
50
51        let data = input.to_vec();
52        let orig_len = data.len();
53        let new_len = (orig_len as f64 * self.new_freq as f64 / self.orig_freq as f64) as usize;
54
55        if new_len == 0 {
56            return Tensor::from_vec(vec![], &[0]).unwrap();
57        }
58
59        let mut resampled = Vec::with_capacity(new_len);
60        let ratio = orig_len as f64 / new_len as f64;
61
62        for i in 0..new_len {
63            let src_idx = i as f64 * ratio;
64            let idx0 = src_idx.floor() as usize;
65            let idx1 = (idx0 + 1).min(orig_len - 1);
66            let frac = (src_idx - idx0 as f64) as f32;
67
68            let value = data[idx0] * (1.0 - frac) + data[idx1] * frac;
69            resampled.push(value);
70        }
71
72        Tensor::from_vec(resampled, &[new_len]).unwrap()
73    }
74}
75
76// =============================================================================
77// MelSpectrogram
78// =============================================================================
79
80/// Computes a mel spectrogram from audio waveform.
81pub struct MelSpectrogram {
82    sample_rate: usize,
83    n_fft: usize,
84    hop_length: usize,
85    n_mels: usize,
86}
87
88impl MelSpectrogram {
89    /// Creates a new `MelSpectrogram` transform with default parameters.
90    #[must_use]
91    pub fn new(sample_rate: usize) -> Self {
92        Self {
93            sample_rate,
94            n_fft: 2048,
95            hop_length: 512,
96            n_mels: 128,
97        }
98    }
99
100    /// Creates a `MelSpectrogram` with custom parameters.
101    #[must_use]
102    pub fn with_params(sample_rate: usize, n_fft: usize, hop_length: usize, n_mels: usize) -> Self {
103        Self {
104            sample_rate,
105            n_fft,
106            hop_length,
107            n_mels,
108        }
109    }
110
111    /// Converts frequency to mel scale.
112    fn hz_to_mel(hz: f32) -> f32 {
113        2595.0 * (1.0 + hz / 700.0).log10()
114    }
115
116    /// Converts mel to frequency.
117    fn mel_to_hz(mel: f32) -> f32 {
118        700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0)
119    }
120
121    /// Creates mel filterbank.
122    fn mel_filterbank(&self) -> Vec<Vec<f32>> {
123        let fmax = self.sample_rate as f32 / 2.0;
124        let mel_min = Self::hz_to_mel(0.0);
125        let mel_max = Self::hz_to_mel(fmax);
126
127        // Create mel points
128        let mel_points: Vec<f32> = (0..=self.n_mels + 1)
129            .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (self.n_mels + 1) as f32)
130            .collect();
131
132        // Convert to Hz
133        let hz_points: Vec<f32> = mel_points.iter().map(|&m| Self::mel_to_hz(m)).collect();
134
135        // Convert to FFT bins
136        let bin_points: Vec<usize> = hz_points
137            .iter()
138            .map(|&hz| ((self.n_fft + 1) as f32 * hz / self.sample_rate as f32).floor() as usize)
139            .collect();
140
141        // Create filterbank
142        let n_bins = self.n_fft / 2 + 1;
143        let mut filterbank = vec![vec![0.0f32; n_bins]; self.n_mels];
144
145        for m in 0..self.n_mels {
146            let left = bin_points[m];
147            let center = bin_points[m + 1];
148            let right = bin_points[m + 2];
149
150            for k in left..center {
151                if center > left && k < n_bins {
152                    filterbank[m][k] = (k - left) as f32 / (center - left) as f32;
153                }
154            }
155            for k in center..right {
156                if right > center && k < n_bins {
157                    filterbank[m][k] = (right - k) as f32 / (right - center) as f32;
158                }
159            }
160        }
161
162        filterbank
163    }
164
165    /// Applies Hann window.
166    fn hann_window(size: usize) -> Vec<f32> {
167        (0..size)
168            .map(|n| 0.5 * (1.0 - (2.0 * PI * n as f32 / (size - 1) as f32).cos()))
169            .collect()
170    }
171
172    /// Computes magnitude spectrum using FFT (O(n log n) via rustfft).
173    fn fft_magnitude(signal: &[f32]) -> Vec<f32> {
174        let n = signal.len();
175        let n_out = n / 2 + 1;
176
177        // Convert real signal to complex
178        let mut buffer: Vec<Complex<f32>> = signal.iter().map(|&x| Complex::new(x, 0.0)).collect();
179
180        // Run FFT in-place
181        let mut planner = FftPlanner::new();
182        let fft = planner.plan_fft_forward(n);
183        fft.process(&mut buffer);
184
185        // Extract magnitudes of the first n/2+1 bins (positive frequencies)
186        buffer[..n_out].iter().map(|c| c.norm()).collect()
187    }
188}
189
190impl Transform for MelSpectrogram {
191    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
192        let data = input.to_vec();
193        let window = Self::hann_window(self.n_fft);
194        let filterbank = self.mel_filterbank();
195
196        // Calculate number of frames
197        let n_frames = if data.len() >= self.n_fft {
198            (data.len() - self.n_fft) / self.hop_length + 1
199        } else {
200            0
201        };
202
203        if n_frames == 0 {
204            return Tensor::from_vec(vec![0.0; self.n_mels], &[self.n_mels, 1]).unwrap();
205        }
206
207        let mut mel_spec = vec![0.0f32; self.n_mels * n_frames];
208
209        for frame_idx in 0..n_frames {
210            let start = frame_idx * self.hop_length;
211            let end = (start + self.n_fft).min(data.len());
212
213            // Extract and window frame
214            let mut frame: Vec<f32> = data[start..end].to_vec();
215            frame.resize(self.n_fft, 0.0);
216
217            for (i, w) in window.iter().enumerate() {
218                frame[i] *= w;
219            }
220
221            // Compute magnitude spectrum via FFT
222            let spectrum = Self::fft_magnitude(&frame);
223
224            // Apply mel filterbank
225            for (m, filter) in filterbank.iter().enumerate() {
226                let mut mel_energy = 0.0;
227                for (k, &mag) in spectrum.iter().enumerate() {
228                    if k < filter.len() {
229                        mel_energy += mag * mag * filter[k];
230                    }
231                }
232                // Convert to log scale
233                mel_spec[m * n_frames + frame_idx] = (mel_energy + 1e-10).ln();
234            }
235        }
236
237        Tensor::from_vec(mel_spec, &[self.n_mels, n_frames]).unwrap()
238    }
239}
240
241// =============================================================================
242// MFCC
243// =============================================================================
244
245/// Computes Mel-frequency cepstral coefficients.
246pub struct MFCC {
247    mel_spec: MelSpectrogram,
248    n_mfcc: usize,
249}
250
251impl MFCC {
252    /// Creates a new MFCC transform.
253    #[must_use]
254    pub fn new(sample_rate: usize, n_mfcc: usize) -> Self {
255        Self {
256            mel_spec: MelSpectrogram::new(sample_rate),
257            n_mfcc,
258        }
259    }
260
261    /// DCT-II for MFCC computation.
262    fn dct(input: &[f32], n_out: usize) -> Vec<f32> {
263        let n = input.len();
264        let mut output = vec![0.0f32; n_out];
265
266        for k in 0..n_out {
267            let mut sum = 0.0;
268            for (i, &x) in input.iter().enumerate() {
269                sum += x * (PI * k as f32 * (2 * i + 1) as f32 / (2 * n) as f32).cos();
270            }
271            output[k] = sum;
272        }
273
274        output
275    }
276}
277
278impl Transform for MFCC {
279    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
280        // First compute mel spectrogram
281        let mel = self.mel_spec.apply(input);
282        let mel_data = mel.to_vec();
283        let mel_shape = mel.shape();
284
285        if mel_shape.len() != 2 {
286            return input.clone();
287        }
288
289        let (n_mels, n_frames) = (mel_shape[0], mel_shape[1]);
290        let mut mfcc = vec![0.0f32; self.n_mfcc * n_frames];
291
292        // Apply DCT to each frame
293        for frame in 0..n_frames {
294            let frame_data: Vec<f32> = (0..n_mels)
295                .map(|m| mel_data[m * n_frames + frame])
296                .collect();
297
298            let coeffs = Self::dct(&frame_data, self.n_mfcc);
299
300            for (k, &c) in coeffs.iter().enumerate() {
301                mfcc[k * n_frames + frame] = c;
302            }
303        }
304
305        Tensor::from_vec(mfcc, &[self.n_mfcc, n_frames]).unwrap()
306    }
307}
308
309// =============================================================================
310// TimeStretch
311// =============================================================================
312
313/// Time stretches audio without changing pitch.
314pub struct TimeStretch {
315    rate: f32,
316}
317
318impl TimeStretch {
319    /// Creates a new `TimeStretch` transform.
320    /// rate > 1.0 speeds up, rate < 1.0 slows down.
321    #[must_use]
322    pub fn new(rate: f32) -> Self {
323        Self {
324            rate: rate.max(0.1).min(10.0),
325        }
326    }
327}
328
329impl Transform for TimeStretch {
330    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
331        let data = input.to_vec();
332        let new_len = (data.len() as f32 / self.rate) as usize;
333
334        if new_len == 0 {
335            return Tensor::from_vec(vec![], &[0]).unwrap();
336        }
337
338        let mut stretched = Vec::with_capacity(new_len);
339
340        for i in 0..new_len {
341            let src_idx = i as f32 * self.rate;
342            let idx0 = src_idx.floor() as usize;
343            let idx1 = (idx0 + 1).min(data.len() - 1);
344            let frac = src_idx - idx0 as f32;
345
346            if idx0 < data.len() {
347                let value =
348                    data[idx0] * (1.0 - frac) + data.get(idx1).copied().unwrap_or(0.0) * frac;
349                stretched.push(value);
350            }
351        }
352
353        let len = stretched.len();
354        Tensor::from_vec(stretched, &[len]).unwrap()
355    }
356}
357
358// =============================================================================
359// PitchShift
360// =============================================================================
361
362/// Shifts the pitch of audio.
363pub struct PitchShift {
364    semitones: f32,
365}
366
367impl PitchShift {
368    /// Creates a new `PitchShift` transform.
369    /// Positive semitones shift up, negative shift down.
370    #[must_use]
371    pub fn new(semitones: f32) -> Self {
372        Self { semitones }
373    }
374}
375
376impl Transform for PitchShift {
377    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
378        // Simplified pitch shift using resampling
379        // Real implementation would use phase vocoder
380        let rate = 2.0_f32.powf(self.semitones / 12.0);
381        let data = input.to_vec();
382        let orig_len = data.len();
383
384        // Resample to change pitch
385        let resampled_len = (orig_len as f32 / rate) as usize;
386        if resampled_len == 0 {
387            return input.clone();
388        }
389
390        let mut resampled = Vec::with_capacity(resampled_len);
391        for i in 0..resampled_len {
392            let src_idx = i as f32 * rate;
393            let idx0 = src_idx.floor() as usize;
394            let idx1 = (idx0 + 1).min(orig_len - 1);
395            let frac = src_idx - idx0 as f32;
396
397            if idx0 < orig_len {
398                let value =
399                    data[idx0] * (1.0 - frac) + data.get(idx1).copied().unwrap_or(0.0) * frac;
400                resampled.push(value);
401            }
402        }
403
404        // Time stretch back to original length
405        let mut result = Vec::with_capacity(orig_len);
406        for i in 0..orig_len {
407            let src_idx = i as f32 * resampled.len() as f32 / orig_len as f32;
408            let idx0 = src_idx.floor() as usize;
409            let idx1 = (idx0 + 1).min(resampled.len().saturating_sub(1));
410            let frac = src_idx - idx0 as f32;
411
412            if idx0 < resampled.len() {
413                let value = resampled[idx0] * (1.0 - frac)
414                    + resampled.get(idx1).copied().unwrap_or(0.0) * frac;
415                result.push(value);
416            } else {
417                result.push(0.0);
418            }
419        }
420
421        Tensor::from_vec(result, &[orig_len]).unwrap()
422    }
423}
424
425// =============================================================================
426// AddNoise
427// =============================================================================
428
429/// Adds random noise to audio.
430pub struct AddNoise {
431    snr_db: f32,
432}
433
434impl AddNoise {
435    /// Creates a new `AddNoise` transform with specified SNR in dB.
436    #[must_use]
437    pub fn new(snr_db: f32) -> Self {
438        Self { snr_db }
439    }
440}
441
442impl Transform for AddNoise {
443    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
444        let data = input.to_vec();
445        let mut rng = rand::thread_rng();
446
447        // Calculate signal power
448        let signal_power: f32 = data.iter().map(|&x| x * x).sum::<f32>() / data.len() as f32;
449
450        // Calculate noise power from SNR
451        let noise_power = signal_power / 10.0_f32.powf(self.snr_db / 10.0);
452        let noise_std = noise_power.sqrt();
453
454        // Add Gaussian noise
455        let noisy: Vec<f32> = data
456            .iter()
457            .map(|&x| {
458                // Box-Muller transform for Gaussian noise
459                let u1: f32 = rng.r#gen();
460                let u2: f32 = rng.r#gen();
461                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
462                x + z * noise_std
463            })
464            .collect();
465
466        Tensor::from_vec(noisy, input.shape()).unwrap()
467    }
468}
469
470// =============================================================================
471// Normalize Audio
472// =============================================================================
473
474/// Normalizes audio to have maximum amplitude of 1.0.
475pub struct NormalizeAudio;
476
477impl NormalizeAudio {
478    /// Creates a new `NormalizeAudio` transform.
479    #[must_use]
480    pub fn new() -> Self {
481        Self
482    }
483}
484
485impl Default for NormalizeAudio {
486    fn default() -> Self {
487        Self::new()
488    }
489}
490
491impl Transform for NormalizeAudio {
492    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
493        let data = input.to_vec();
494        let max_val = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
495
496        if max_val < 1e-10 {
497            return input.clone();
498        }
499
500        let normalized: Vec<f32> = data.iter().map(|&x| x / max_val).collect();
501        Tensor::from_vec(normalized, input.shape()).unwrap()
502    }
503}
504
505// =============================================================================
506// Trim Silence
507// =============================================================================
508
509/// Trims silence from the beginning and end of audio.
510pub struct TrimSilence {
511    threshold_db: f32,
512}
513
514impl TrimSilence {
515    /// Creates a `TrimSilence` transform with specified threshold in dB.
516    #[must_use]
517    pub fn new(threshold_db: f32) -> Self {
518        Self { threshold_db }
519    }
520
521    /// Creates with default -60dB threshold.
522    #[must_use]
523    pub fn default_threshold() -> Self {
524        Self::new(-60.0)
525    }
526}
527
528impl Transform for TrimSilence {
529    fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
530        let data = input.to_vec();
531        let threshold = 10.0_f32.powf(self.threshold_db / 20.0);
532
533        // Find first non-silent sample
534        let start = data.iter().position(|&x| x.abs() > threshold).unwrap_or(0);
535
536        // Find last non-silent sample
537        let end = data
538            .iter()
539            .rposition(|&x| x.abs() > threshold)
540            .map_or(data.len(), |i| i + 1);
541
542        if start >= end {
543            return Tensor::from_vec(vec![], &[0]).unwrap();
544        }
545
546        let trimmed = data[start..end].to_vec();
547        let len = trimmed.len();
548        Tensor::from_vec(trimmed, &[len]).unwrap()
549    }
550}
551
552// =============================================================================
553// Tests
554// =============================================================================
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    fn create_sine_wave(freq: f32, sample_rate: usize, duration: f32) -> Tensor<f32> {
561        let n_samples = (sample_rate as f32 * duration) as usize;
562        let data: Vec<f32> = (0..n_samples)
563            .map(|i| (2.0 * PI * freq * i as f32 / sample_rate as f32).sin())
564            .collect();
565        Tensor::from_vec(data, &[n_samples]).unwrap()
566    }
567
568    #[test]
569    fn test_resample() {
570        let audio = create_sine_wave(440.0, 16000, 0.1);
571        let resample = Resample::new(16000, 8000);
572
573        let resampled = resample.apply(&audio);
574
575        // Should be half the length
576        assert_eq!(resampled.shape()[0], audio.shape()[0] / 2);
577    }
578
579    #[test]
580    fn test_resample_same_rate() {
581        let audio = create_sine_wave(440.0, 16000, 0.1);
582        let resample = Resample::new(16000, 16000);
583
584        let result = resample.apply(&audio);
585        assert_eq!(result.to_vec(), audio.to_vec());
586    }
587
588    #[test]
589    fn test_mel_spectrogram() {
590        let audio = create_sine_wave(440.0, 16000, 0.5);
591        let mel = MelSpectrogram::with_params(16000, 512, 256, 40);
592
593        let spec = mel.apply(&audio);
594
595        assert_eq!(spec.shape()[0], 40); // n_mels
596        assert!(spec.shape()[1] > 0); // n_frames
597    }
598
599    #[test]
600    fn test_mfcc() {
601        let audio = create_sine_wave(440.0, 16000, 0.5);
602        let mfcc = MFCC::new(16000, 13);
603
604        let coeffs = mfcc.apply(&audio);
605
606        assert_eq!(coeffs.shape()[0], 13); // n_mfcc
607    }
608
609    #[test]
610    fn test_time_stretch() {
611        let audio = create_sine_wave(440.0, 16000, 0.1);
612        let orig_len = audio.shape()[0];
613
614        // Speed up 2x
615        let stretch = TimeStretch::new(2.0);
616        let stretched = stretch.apply(&audio);
617
618        assert!(stretched.shape()[0] < orig_len);
619    }
620
621    #[test]
622    fn test_pitch_shift() {
623        let audio = create_sine_wave(440.0, 16000, 0.1);
624        let orig_len = audio.shape()[0];
625
626        let shift = PitchShift::new(2.0); // Shift up 2 semitones
627        let shifted = shift.apply(&audio);
628
629        // Length should remain the same
630        assert_eq!(shifted.shape()[0], orig_len);
631    }
632
633    #[test]
634    fn test_add_noise() {
635        let audio = create_sine_wave(440.0, 16000, 0.1);
636        let add_noise = AddNoise::new(20.0); // 20dB SNR
637
638        let noisy = add_noise.apply(&audio);
639
640        assert_eq!(noisy.shape(), audio.shape());
641        // Values should be different (noise added)
642        assert_ne!(noisy.to_vec(), audio.to_vec());
643    }
644
645    #[test]
646    fn test_normalize_audio() {
647        let data = vec![0.1, -0.5, 0.3, -0.2];
648        let audio = Tensor::from_vec(data, &[4]).unwrap();
649
650        let normalize = NormalizeAudio::new();
651        let normalized = normalize.apply(&audio);
652
653        let max_val = normalized
654            .to_vec()
655            .iter()
656            .map(|x| x.abs())
657            .fold(0.0f32, f32::max);
658        assert!((max_val - 1.0).abs() < 0.001);
659    }
660
661    #[test]
662    fn test_trim_silence() {
663        let data = vec![0.0, 0.0, 0.5, 0.3, 0.0, 0.0];
664        let audio = Tensor::from_vec(data, &[6]).unwrap();
665
666        let trim = TrimSilence::new(-20.0);
667        let trimmed = trim.apply(&audio);
668
669        assert_eq!(trimmed.shape()[0], 2); // Only [0.5, 0.3]
670    }
671
672    #[test]
673    fn test_hz_to_mel_conversion() {
674        let hz = 1000.0;
675        let mel = MelSpectrogram::hz_to_mel(hz);
676        let back = MelSpectrogram::mel_to_hz(mel);
677
678        assert!((hz - back).abs() < 0.1);
679    }
680}