1use axonml_data::Transform;
19use axonml_tensor::Tensor;
20use rand::Rng;
21use rustfft::{FftPlanner, num_complex::Complex};
22use std::f32::consts::PI;
23
24pub struct Resample {
30 orig_freq: usize,
31 new_freq: usize,
32}
33
34impl Resample {
35 #[must_use]
37 pub fn new(orig_freq: usize, new_freq: usize) -> Self {
38 Self {
39 orig_freq,
40 new_freq,
41 }
42 }
43}
44
45impl Transform for Resample {
46 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
47 if self.orig_freq == self.new_freq {
48 return input.clone();
49 }
50
51 let data = input.to_vec();
52 let orig_len = data.len();
53 let new_len = (orig_len as f64 * self.new_freq as f64 / self.orig_freq as f64) as usize;
54
55 if new_len == 0 {
56 return Tensor::from_vec(vec![], &[0]).unwrap();
57 }
58
59 let mut resampled = Vec::with_capacity(new_len);
60 let ratio = orig_len as f64 / new_len as f64;
61
62 for i in 0..new_len {
63 let src_idx = i as f64 * ratio;
64 let idx0 = src_idx.floor() as usize;
65 let idx1 = (idx0 + 1).min(orig_len - 1);
66 let frac = (src_idx - idx0 as f64) as f32;
67
68 let value = data[idx0] * (1.0 - frac) + data[idx1] * frac;
69 resampled.push(value);
70 }
71
72 Tensor::from_vec(resampled, &[new_len]).unwrap()
73 }
74}
75
76pub struct MelSpectrogram {
82 sample_rate: usize,
83 n_fft: usize,
84 hop_length: usize,
85 n_mels: usize,
86}
87
88impl MelSpectrogram {
89 #[must_use]
91 pub fn new(sample_rate: usize) -> Self {
92 Self {
93 sample_rate,
94 n_fft: 2048,
95 hop_length: 512,
96 n_mels: 128,
97 }
98 }
99
100 #[must_use]
102 pub fn with_params(sample_rate: usize, n_fft: usize, hop_length: usize, n_mels: usize) -> Self {
103 Self {
104 sample_rate,
105 n_fft,
106 hop_length,
107 n_mels,
108 }
109 }
110
111 fn hz_to_mel(hz: f32) -> f32 {
113 2595.0 * (1.0 + hz / 700.0).log10()
114 }
115
116 fn mel_to_hz(mel: f32) -> f32 {
118 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0)
119 }
120
121 fn mel_filterbank(&self) -> Vec<Vec<f32>> {
123 let fmax = self.sample_rate as f32 / 2.0;
124 let mel_min = Self::hz_to_mel(0.0);
125 let mel_max = Self::hz_to_mel(fmax);
126
127 let mel_points: Vec<f32> = (0..=self.n_mels + 1)
129 .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (self.n_mels + 1) as f32)
130 .collect();
131
132 let hz_points: Vec<f32> = mel_points.iter().map(|&m| Self::mel_to_hz(m)).collect();
134
135 let bin_points: Vec<usize> = hz_points
137 .iter()
138 .map(|&hz| ((self.n_fft + 1) as f32 * hz / self.sample_rate as f32).floor() as usize)
139 .collect();
140
141 let n_bins = self.n_fft / 2 + 1;
143 let mut filterbank = vec![vec![0.0f32; n_bins]; self.n_mels];
144
145 for m in 0..self.n_mels {
146 let left = bin_points[m];
147 let center = bin_points[m + 1];
148 let right = bin_points[m + 2];
149
150 for k in left..center {
151 if center > left && k < n_bins {
152 filterbank[m][k] = (k - left) as f32 / (center - left) as f32;
153 }
154 }
155 for k in center..right {
156 if right > center && k < n_bins {
157 filterbank[m][k] = (right - k) as f32 / (right - center) as f32;
158 }
159 }
160 }
161
162 filterbank
163 }
164
165 fn hann_window(size: usize) -> Vec<f32> {
167 (0..size)
168 .map(|n| 0.5 * (1.0 - (2.0 * PI * n as f32 / (size - 1) as f32).cos()))
169 .collect()
170 }
171
172 fn fft_magnitude(signal: &[f32]) -> Vec<f32> {
174 let n = signal.len();
175 let n_out = n / 2 + 1;
176
177 let mut buffer: Vec<Complex<f32>> = signal.iter().map(|&x| Complex::new(x, 0.0)).collect();
179
180 let mut planner = FftPlanner::new();
182 let fft = planner.plan_fft_forward(n);
183 fft.process(&mut buffer);
184
185 buffer[..n_out].iter().map(|c| c.norm()).collect()
187 }
188}
189
190impl Transform for MelSpectrogram {
191 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
192 let data = input.to_vec();
193 let window = Self::hann_window(self.n_fft);
194 let filterbank = self.mel_filterbank();
195
196 let n_frames = if data.len() >= self.n_fft {
198 (data.len() - self.n_fft) / self.hop_length + 1
199 } else {
200 0
201 };
202
203 if n_frames == 0 {
204 return Tensor::from_vec(vec![0.0; self.n_mels], &[self.n_mels, 1]).unwrap();
205 }
206
207 let mut mel_spec = vec![0.0f32; self.n_mels * n_frames];
208
209 for frame_idx in 0..n_frames {
210 let start = frame_idx * self.hop_length;
211 let end = (start + self.n_fft).min(data.len());
212
213 let mut frame: Vec<f32> = data[start..end].to_vec();
215 frame.resize(self.n_fft, 0.0);
216
217 for (i, w) in window.iter().enumerate() {
218 frame[i] *= w;
219 }
220
221 let spectrum = Self::fft_magnitude(&frame);
223
224 for (m, filter) in filterbank.iter().enumerate() {
226 let mut mel_energy = 0.0;
227 for (k, &mag) in spectrum.iter().enumerate() {
228 if k < filter.len() {
229 mel_energy += mag * mag * filter[k];
230 }
231 }
232 mel_spec[m * n_frames + frame_idx] = (mel_energy + 1e-10).ln();
234 }
235 }
236
237 Tensor::from_vec(mel_spec, &[self.n_mels, n_frames]).unwrap()
238 }
239}
240
241pub struct MFCC {
247 mel_spec: MelSpectrogram,
248 n_mfcc: usize,
249}
250
251impl MFCC {
252 #[must_use]
254 pub fn new(sample_rate: usize, n_mfcc: usize) -> Self {
255 Self {
256 mel_spec: MelSpectrogram::new(sample_rate),
257 n_mfcc,
258 }
259 }
260
261 fn dct(input: &[f32], n_out: usize) -> Vec<f32> {
263 let n = input.len();
264 let mut output = vec![0.0f32; n_out];
265
266 for k in 0..n_out {
267 let mut sum = 0.0;
268 for (i, &x) in input.iter().enumerate() {
269 sum += x * (PI * k as f32 * (2 * i + 1) as f32 / (2 * n) as f32).cos();
270 }
271 output[k] = sum;
272 }
273
274 output
275 }
276}
277
278impl Transform for MFCC {
279 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
280 let mel = self.mel_spec.apply(input);
282 let mel_data = mel.to_vec();
283 let mel_shape = mel.shape();
284
285 if mel_shape.len() != 2 {
286 return input.clone();
287 }
288
289 let (n_mels, n_frames) = (mel_shape[0], mel_shape[1]);
290 let mut mfcc = vec![0.0f32; self.n_mfcc * n_frames];
291
292 for frame in 0..n_frames {
294 let frame_data: Vec<f32> = (0..n_mels)
295 .map(|m| mel_data[m * n_frames + frame])
296 .collect();
297
298 let coeffs = Self::dct(&frame_data, self.n_mfcc);
299
300 for (k, &c) in coeffs.iter().enumerate() {
301 mfcc[k * n_frames + frame] = c;
302 }
303 }
304
305 Tensor::from_vec(mfcc, &[self.n_mfcc, n_frames]).unwrap()
306 }
307}
308
309pub struct TimeStretch {
315 rate: f32,
316}
317
318impl TimeStretch {
319 #[must_use]
322 pub fn new(rate: f32) -> Self {
323 Self {
324 rate: rate.max(0.1).min(10.0),
325 }
326 }
327}
328
329impl Transform for TimeStretch {
330 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
331 let data = input.to_vec();
332 let new_len = (data.len() as f32 / self.rate) as usize;
333
334 if new_len == 0 {
335 return Tensor::from_vec(vec![], &[0]).unwrap();
336 }
337
338 let mut stretched = Vec::with_capacity(new_len);
339
340 for i in 0..new_len {
341 let src_idx = i as f32 * self.rate;
342 let idx0 = src_idx.floor() as usize;
343 let idx1 = (idx0 + 1).min(data.len() - 1);
344 let frac = src_idx - idx0 as f32;
345
346 if idx0 < data.len() {
347 let value =
348 data[idx0] * (1.0 - frac) + data.get(idx1).copied().unwrap_or(0.0) * frac;
349 stretched.push(value);
350 }
351 }
352
353 let len = stretched.len();
354 Tensor::from_vec(stretched, &[len]).unwrap()
355 }
356}
357
358pub struct PitchShift {
364 semitones: f32,
365}
366
367impl PitchShift {
368 #[must_use]
371 pub fn new(semitones: f32) -> Self {
372 Self { semitones }
373 }
374}
375
376impl Transform for PitchShift {
377 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
378 let rate = 2.0_f32.powf(self.semitones / 12.0);
381 let data = input.to_vec();
382 let orig_len = data.len();
383
384 let resampled_len = (orig_len as f32 / rate) as usize;
386 if resampled_len == 0 {
387 return input.clone();
388 }
389
390 let mut resampled = Vec::with_capacity(resampled_len);
391 for i in 0..resampled_len {
392 let src_idx = i as f32 * rate;
393 let idx0 = src_idx.floor() as usize;
394 let idx1 = (idx0 + 1).min(orig_len - 1);
395 let frac = src_idx - idx0 as f32;
396
397 if idx0 < orig_len {
398 let value =
399 data[idx0] * (1.0 - frac) + data.get(idx1).copied().unwrap_or(0.0) * frac;
400 resampled.push(value);
401 }
402 }
403
404 let mut result = Vec::with_capacity(orig_len);
406 for i in 0..orig_len {
407 let src_idx = i as f32 * resampled.len() as f32 / orig_len as f32;
408 let idx0 = src_idx.floor() as usize;
409 let idx1 = (idx0 + 1).min(resampled.len().saturating_sub(1));
410 let frac = src_idx - idx0 as f32;
411
412 if idx0 < resampled.len() {
413 let value = resampled[idx0] * (1.0 - frac)
414 + resampled.get(idx1).copied().unwrap_or(0.0) * frac;
415 result.push(value);
416 } else {
417 result.push(0.0);
418 }
419 }
420
421 Tensor::from_vec(result, &[orig_len]).unwrap()
422 }
423}
424
425pub struct AddNoise {
431 snr_db: f32,
432}
433
434impl AddNoise {
435 #[must_use]
437 pub fn new(snr_db: f32) -> Self {
438 Self { snr_db }
439 }
440}
441
442impl Transform for AddNoise {
443 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
444 let data = input.to_vec();
445 let mut rng = rand::thread_rng();
446
447 let signal_power: f32 = data.iter().map(|&x| x * x).sum::<f32>() / data.len() as f32;
449
450 let noise_power = signal_power / 10.0_f32.powf(self.snr_db / 10.0);
452 let noise_std = noise_power.sqrt();
453
454 let noisy: Vec<f32> = data
456 .iter()
457 .map(|&x| {
458 let u1: f32 = rng.r#gen();
460 let u2: f32 = rng.r#gen();
461 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
462 x + z * noise_std
463 })
464 .collect();
465
466 Tensor::from_vec(noisy, input.shape()).unwrap()
467 }
468}
469
470pub struct NormalizeAudio;
476
477impl NormalizeAudio {
478 #[must_use]
480 pub fn new() -> Self {
481 Self
482 }
483}
484
485impl Default for NormalizeAudio {
486 fn default() -> Self {
487 Self::new()
488 }
489}
490
491impl Transform for NormalizeAudio {
492 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
493 let data = input.to_vec();
494 let max_val = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
495
496 if max_val < 1e-10 {
497 return input.clone();
498 }
499
500 let normalized: Vec<f32> = data.iter().map(|&x| x / max_val).collect();
501 Tensor::from_vec(normalized, input.shape()).unwrap()
502 }
503}
504
505pub struct TrimSilence {
511 threshold_db: f32,
512}
513
514impl TrimSilence {
515 #[must_use]
517 pub fn new(threshold_db: f32) -> Self {
518 Self { threshold_db }
519 }
520
521 #[must_use]
523 pub fn default_threshold() -> Self {
524 Self::new(-60.0)
525 }
526}
527
528impl Transform for TrimSilence {
529 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
530 let data = input.to_vec();
531 let threshold = 10.0_f32.powf(self.threshold_db / 20.0);
532
533 let start = data.iter().position(|&x| x.abs() > threshold).unwrap_or(0);
535
536 let end = data
538 .iter()
539 .rposition(|&x| x.abs() > threshold)
540 .map_or(data.len(), |i| i + 1);
541
542 if start >= end {
543 return Tensor::from_vec(vec![], &[0]).unwrap();
544 }
545
546 let trimmed = data[start..end].to_vec();
547 let len = trimmed.len();
548 Tensor::from_vec(trimmed, &[len]).unwrap()
549 }
550}
551
552#[cfg(test)]
557mod tests {
558 use super::*;
559
560 fn create_sine_wave(freq: f32, sample_rate: usize, duration: f32) -> Tensor<f32> {
561 let n_samples = (sample_rate as f32 * duration) as usize;
562 let data: Vec<f32> = (0..n_samples)
563 .map(|i| (2.0 * PI * freq * i as f32 / sample_rate as f32).sin())
564 .collect();
565 Tensor::from_vec(data, &[n_samples]).unwrap()
566 }
567
568 #[test]
569 fn test_resample() {
570 let audio = create_sine_wave(440.0, 16000, 0.1);
571 let resample = Resample::new(16000, 8000);
572
573 let resampled = resample.apply(&audio);
574
575 assert_eq!(resampled.shape()[0], audio.shape()[0] / 2);
577 }
578
579 #[test]
580 fn test_resample_same_rate() {
581 let audio = create_sine_wave(440.0, 16000, 0.1);
582 let resample = Resample::new(16000, 16000);
583
584 let result = resample.apply(&audio);
585 assert_eq!(result.to_vec(), audio.to_vec());
586 }
587
588 #[test]
589 fn test_mel_spectrogram() {
590 let audio = create_sine_wave(440.0, 16000, 0.5);
591 let mel = MelSpectrogram::with_params(16000, 512, 256, 40);
592
593 let spec = mel.apply(&audio);
594
595 assert_eq!(spec.shape()[0], 40); assert!(spec.shape()[1] > 0); }
598
599 #[test]
600 fn test_mfcc() {
601 let audio = create_sine_wave(440.0, 16000, 0.5);
602 let mfcc = MFCC::new(16000, 13);
603
604 let coeffs = mfcc.apply(&audio);
605
606 assert_eq!(coeffs.shape()[0], 13); }
608
609 #[test]
610 fn test_time_stretch() {
611 let audio = create_sine_wave(440.0, 16000, 0.1);
612 let orig_len = audio.shape()[0];
613
614 let stretch = TimeStretch::new(2.0);
616 let stretched = stretch.apply(&audio);
617
618 assert!(stretched.shape()[0] < orig_len);
619 }
620
621 #[test]
622 fn test_pitch_shift() {
623 let audio = create_sine_wave(440.0, 16000, 0.1);
624 let orig_len = audio.shape()[0];
625
626 let shift = PitchShift::new(2.0); let shifted = shift.apply(&audio);
628
629 assert_eq!(shifted.shape()[0], orig_len);
631 }
632
633 #[test]
634 fn test_add_noise() {
635 let audio = create_sine_wave(440.0, 16000, 0.1);
636 let add_noise = AddNoise::new(20.0); let noisy = add_noise.apply(&audio);
639
640 assert_eq!(noisy.shape(), audio.shape());
641 assert_ne!(noisy.to_vec(), audio.to_vec());
643 }
644
645 #[test]
646 fn test_normalize_audio() {
647 let data = vec![0.1, -0.5, 0.3, -0.2];
648 let audio = Tensor::from_vec(data, &[4]).unwrap();
649
650 let normalize = NormalizeAudio::new();
651 let normalized = normalize.apply(&audio);
652
653 let max_val = normalized
654 .to_vec()
655 .iter()
656 .map(|x| x.abs())
657 .fold(0.0f32, f32::max);
658 assert!((max_val - 1.0).abs() < 0.001);
659 }
660
661 #[test]
662 fn test_trim_silence() {
663 let data = vec![0.0, 0.0, 0.5, 0.3, 0.0, 0.0];
664 let audio = Tensor::from_vec(data, &[6]).unwrap();
665
666 let trim = TrimSilence::new(-20.0);
667 let trimmed = trim.apply(&audio);
668
669 assert_eq!(trimmed.shape()[0], 2); }
671
672 #[test]
673 fn test_hz_to_mel_conversion() {
674 let hz = 1000.0;
675 let mel = MelSpectrogram::hz_to_mel(hz);
676 let back = MelSpectrogram::mel_to_hz(mel);
677
678 assert!((hz - back).abs() < 0.1);
679 }
680}