1use axonml_data::Transform;
9use axonml_tensor::Tensor;
10use rand::Rng;
11use std::f32::consts::PI;
12
13pub struct Resample {
19 orig_freq: usize,
20 new_freq: usize,
21}
22
23impl Resample {
24 #[must_use] pub fn new(orig_freq: usize, new_freq: usize) -> Self {
26 Self {
27 orig_freq,
28 new_freq,
29 }
30 }
31}
32
33impl Transform for Resample {
34 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
35 if self.orig_freq == self.new_freq {
36 return input.clone();
37 }
38
39 let data = input.to_vec();
40 let orig_len = data.len();
41 let new_len = (orig_len as f64 * self.new_freq as f64 / self.orig_freq as f64) as usize;
42
43 if new_len == 0 {
44 return Tensor::from_vec(vec![], &[0]).unwrap();
45 }
46
47 let mut resampled = Vec::with_capacity(new_len);
48 let ratio = orig_len as f64 / new_len as f64;
49
50 for i in 0..new_len {
51 let src_idx = i as f64 * ratio;
52 let idx0 = src_idx.floor() as usize;
53 let idx1 = (idx0 + 1).min(orig_len - 1);
54 let frac = (src_idx - idx0 as f64) as f32;
55
56 let value = data[idx0] * (1.0 - frac) + data[idx1] * frac;
57 resampled.push(value);
58 }
59
60 Tensor::from_vec(resampled, &[new_len]).unwrap()
61 }
62}
63
64pub struct MelSpectrogram {
70 sample_rate: usize,
71 n_fft: usize,
72 hop_length: usize,
73 n_mels: usize,
74}
75
76impl MelSpectrogram {
77 #[must_use] pub fn new(sample_rate: usize) -> Self {
79 Self {
80 sample_rate,
81 n_fft: 2048,
82 hop_length: 512,
83 n_mels: 128,
84 }
85 }
86
87 #[must_use] pub fn with_params(sample_rate: usize, n_fft: usize, hop_length: usize, n_mels: usize) -> Self {
89 Self {
90 sample_rate,
91 n_fft,
92 hop_length,
93 n_mels,
94 }
95 }
96
97 fn hz_to_mel(hz: f32) -> f32 {
99 2595.0 * (1.0 + hz / 700.0).log10()
100 }
101
102 fn mel_to_hz(mel: f32) -> f32 {
104 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0)
105 }
106
107 fn mel_filterbank(&self) -> Vec<Vec<f32>> {
109 let fmax = self.sample_rate as f32 / 2.0;
110 let mel_min = Self::hz_to_mel(0.0);
111 let mel_max = Self::hz_to_mel(fmax);
112
113 let mel_points: Vec<f32> = (0..=self.n_mels + 1)
115 .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (self.n_mels + 1) as f32)
116 .collect();
117
118 let hz_points: Vec<f32> = mel_points.iter().map(|&m| Self::mel_to_hz(m)).collect();
120
121 let bin_points: Vec<usize> = hz_points
123 .iter()
124 .map(|&hz| ((self.n_fft + 1) as f32 * hz / self.sample_rate as f32).floor() as usize)
125 .collect();
126
127 let n_bins = self.n_fft / 2 + 1;
129 let mut filterbank = vec![vec![0.0f32; n_bins]; self.n_mels];
130
131 for m in 0..self.n_mels {
132 let left = bin_points[m];
133 let center = bin_points[m + 1];
134 let right = bin_points[m + 2];
135
136 for k in left..center {
137 if center > left && k < n_bins {
138 filterbank[m][k] = (k - left) as f32 / (center - left) as f32;
139 }
140 }
141 for k in center..right {
142 if right > center && k < n_bins {
143 filterbank[m][k] = (right - k) as f32 / (right - center) as f32;
144 }
145 }
146 }
147
148 filterbank
149 }
150
151 fn hann_window(size: usize) -> Vec<f32> {
153 (0..size)
154 .map(|n| 0.5 * (1.0 - (2.0 * PI * n as f32 / (size - 1) as f32).cos()))
155 .collect()
156 }
157
158 fn dft(signal: &[f32]) -> Vec<f32> {
160 let n = signal.len();
161 let n_out = n / 2 + 1;
162 let mut magnitude = vec![0.0f32; n_out];
163
164 for k in 0..n_out {
165 let mut real = 0.0;
166 let mut imag = 0.0;
167
168 for (t, &x) in signal.iter().enumerate() {
169 let angle = 2.0 * PI * k as f32 * t as f32 / n as f32;
170 real += x * angle.cos();
171 imag -= x * angle.sin();
172 }
173
174 magnitude[k] = (real * real + imag * imag).sqrt();
175 }
176
177 magnitude
178 }
179}
180
181impl Transform for MelSpectrogram {
182 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
183 let data = input.to_vec();
184 let window = Self::hann_window(self.n_fft);
185 let filterbank = self.mel_filterbank();
186
187 let n_frames = if data.len() >= self.n_fft {
189 (data.len() - self.n_fft) / self.hop_length + 1
190 } else {
191 0
192 };
193
194 if n_frames == 0 {
195 return Tensor::from_vec(vec![0.0; self.n_mels], &[self.n_mels, 1]).unwrap();
196 }
197
198 let mut mel_spec = vec![0.0f32; self.n_mels * n_frames];
199
200 for frame_idx in 0..n_frames {
201 let start = frame_idx * self.hop_length;
202 let end = (start + self.n_fft).min(data.len());
203
204 let mut frame: Vec<f32> = data[start..end].to_vec();
206 frame.resize(self.n_fft, 0.0);
207
208 for (i, w) in window.iter().enumerate() {
209 frame[i] *= w;
210 }
211
212 let spectrum = Self::dft(&frame);
214
215 for (m, filter) in filterbank.iter().enumerate() {
217 let mut mel_energy = 0.0;
218 for (k, &mag) in spectrum.iter().enumerate() {
219 if k < filter.len() {
220 mel_energy += mag * mag * filter[k];
221 }
222 }
223 mel_spec[m * n_frames + frame_idx] = (mel_energy + 1e-10).ln();
225 }
226 }
227
228 Tensor::from_vec(mel_spec, &[self.n_mels, n_frames]).unwrap()
229 }
230}
231
232pub struct MFCC {
238 mel_spec: MelSpectrogram,
239 n_mfcc: usize,
240}
241
242impl MFCC {
243 #[must_use] pub fn new(sample_rate: usize, n_mfcc: usize) -> Self {
245 Self {
246 mel_spec: MelSpectrogram::new(sample_rate),
247 n_mfcc,
248 }
249 }
250
251 fn dct(input: &[f32], n_out: usize) -> Vec<f32> {
253 let n = input.len();
254 let mut output = vec![0.0f32; n_out];
255
256 for k in 0..n_out {
257 let mut sum = 0.0;
258 for (i, &x) in input.iter().enumerate() {
259 sum += x * (PI * k as f32 * (2 * i + 1) as f32 / (2 * n) as f32).cos();
260 }
261 output[k] = sum;
262 }
263
264 output
265 }
266}
267
268impl Transform for MFCC {
269 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
270 let mel = self.mel_spec.apply(input);
272 let mel_data = mel.to_vec();
273 let mel_shape = mel.shape();
274
275 if mel_shape.len() != 2 {
276 return input.clone();
277 }
278
279 let (n_mels, n_frames) = (mel_shape[0], mel_shape[1]);
280 let mut mfcc = vec![0.0f32; self.n_mfcc * n_frames];
281
282 for frame in 0..n_frames {
284 let frame_data: Vec<f32> = (0..n_mels)
285 .map(|m| mel_data[m * n_frames + frame])
286 .collect();
287
288 let coeffs = Self::dct(&frame_data, self.n_mfcc);
289
290 for (k, &c) in coeffs.iter().enumerate() {
291 mfcc[k * n_frames + frame] = c;
292 }
293 }
294
295 Tensor::from_vec(mfcc, &[self.n_mfcc, n_frames]).unwrap()
296 }
297}
298
299pub struct TimeStretch {
305 rate: f32,
306}
307
308impl TimeStretch {
309 #[must_use] pub fn new(rate: f32) -> Self {
312 Self {
313 rate: rate.max(0.1).min(10.0),
314 }
315 }
316}
317
318impl Transform for TimeStretch {
319 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
320 let data = input.to_vec();
321 let new_len = (data.len() as f32 / self.rate) as usize;
322
323 if new_len == 0 {
324 return Tensor::from_vec(vec![], &[0]).unwrap();
325 }
326
327 let mut stretched = Vec::with_capacity(new_len);
328
329 for i in 0..new_len {
330 let src_idx = i as f32 * self.rate;
331 let idx0 = src_idx.floor() as usize;
332 let idx1 = (idx0 + 1).min(data.len() - 1);
333 let frac = src_idx - idx0 as f32;
334
335 if idx0 < data.len() {
336 let value =
337 data[idx0] * (1.0 - frac) + data.get(idx1).copied().unwrap_or(0.0) * frac;
338 stretched.push(value);
339 }
340 }
341
342 let len = stretched.len();
343 Tensor::from_vec(stretched, &[len]).unwrap()
344 }
345}
346
347pub struct PitchShift {
353 semitones: f32,
354}
355
356impl PitchShift {
357 #[must_use] pub fn new(semitones: f32) -> Self {
360 Self { semitones }
361 }
362}
363
364impl Transform for PitchShift {
365 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
366 let rate = 2.0_f32.powf(self.semitones / 12.0);
369 let data = input.to_vec();
370 let orig_len = data.len();
371
372 let resampled_len = (orig_len as f32 / rate) as usize;
374 if resampled_len == 0 {
375 return input.clone();
376 }
377
378 let mut resampled = Vec::with_capacity(resampled_len);
379 for i in 0..resampled_len {
380 let src_idx = i as f32 * rate;
381 let idx0 = src_idx.floor() as usize;
382 let idx1 = (idx0 + 1).min(orig_len - 1);
383 let frac = src_idx - idx0 as f32;
384
385 if idx0 < orig_len {
386 let value =
387 data[idx0] * (1.0 - frac) + data.get(idx1).copied().unwrap_or(0.0) * frac;
388 resampled.push(value);
389 }
390 }
391
392 let mut result = Vec::with_capacity(orig_len);
394 for i in 0..orig_len {
395 let src_idx = i as f32 * resampled.len() as f32 / orig_len as f32;
396 let idx0 = src_idx.floor() as usize;
397 let idx1 = (idx0 + 1).min(resampled.len().saturating_sub(1));
398 let frac = src_idx - idx0 as f32;
399
400 if idx0 < resampled.len() {
401 let value = resampled[idx0] * (1.0 - frac)
402 + resampled.get(idx1).copied().unwrap_or(0.0) * frac;
403 result.push(value);
404 } else {
405 result.push(0.0);
406 }
407 }
408
409 Tensor::from_vec(result, &[orig_len]).unwrap()
410 }
411}
412
413pub struct AddNoise {
419 snr_db: f32,
420}
421
422impl AddNoise {
423 #[must_use] pub fn new(snr_db: f32) -> Self {
425 Self { snr_db }
426 }
427}
428
429impl Transform for AddNoise {
430 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
431 let data = input.to_vec();
432 let mut rng = rand::thread_rng();
433
434 let signal_power: f32 = data.iter().map(|&x| x * x).sum::<f32>() / data.len() as f32;
436
437 let noise_power = signal_power / 10.0_f32.powf(self.snr_db / 10.0);
439 let noise_std = noise_power.sqrt();
440
441 let noisy: Vec<f32> = data
443 .iter()
444 .map(|&x| {
445 let u1: f32 = rng.gen();
447 let u2: f32 = rng.gen();
448 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
449 x + z * noise_std
450 })
451 .collect();
452
453 Tensor::from_vec(noisy, input.shape()).unwrap()
454 }
455}
456
457pub struct NormalizeAudio;
463
464impl NormalizeAudio {
465 #[must_use] pub fn new() -> Self {
467 Self
468 }
469}
470
471impl Default for NormalizeAudio {
472 fn default() -> Self {
473 Self::new()
474 }
475}
476
477impl Transform for NormalizeAudio {
478 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
479 let data = input.to_vec();
480 let max_val = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
481
482 if max_val < 1e-10 {
483 return input.clone();
484 }
485
486 let normalized: Vec<f32> = data.iter().map(|&x| x / max_val).collect();
487 Tensor::from_vec(normalized, input.shape()).unwrap()
488 }
489}
490
491pub struct TrimSilence {
497 threshold_db: f32,
498}
499
500impl TrimSilence {
501 #[must_use] pub fn new(threshold_db: f32) -> Self {
503 Self { threshold_db }
504 }
505
506 #[must_use] pub fn default_threshold() -> Self {
508 Self::new(-60.0)
509 }
510}
511
512impl Transform for TrimSilence {
513 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
514 let data = input.to_vec();
515 let threshold = 10.0_f32.powf(self.threshold_db / 20.0);
516
517 let start = data.iter().position(|&x| x.abs() > threshold).unwrap_or(0);
519
520 let end = data
522 .iter()
523 .rposition(|&x| x.abs() > threshold)
524 .map_or(data.len(), |i| i + 1);
525
526 if start >= end {
527 return Tensor::from_vec(vec![], &[0]).unwrap();
528 }
529
530 let trimmed = data[start..end].to_vec();
531 let len = trimmed.len();
532 Tensor::from_vec(trimmed, &[len]).unwrap()
533 }
534}
535
536#[cfg(test)]
541mod tests {
542 use super::*;
543
544 fn create_sine_wave(freq: f32, sample_rate: usize, duration: f32) -> Tensor<f32> {
545 let n_samples = (sample_rate as f32 * duration) as usize;
546 let data: Vec<f32> = (0..n_samples)
547 .map(|i| (2.0 * PI * freq * i as f32 / sample_rate as f32).sin())
548 .collect();
549 Tensor::from_vec(data, &[n_samples]).unwrap()
550 }
551
552 #[test]
553 fn test_resample() {
554 let audio = create_sine_wave(440.0, 16000, 0.1);
555 let resample = Resample::new(16000, 8000);
556
557 let resampled = resample.apply(&audio);
558
559 assert_eq!(resampled.shape()[0], audio.shape()[0] / 2);
561 }
562
563 #[test]
564 fn test_resample_same_rate() {
565 let audio = create_sine_wave(440.0, 16000, 0.1);
566 let resample = Resample::new(16000, 16000);
567
568 let result = resample.apply(&audio);
569 assert_eq!(result.to_vec(), audio.to_vec());
570 }
571
572 #[test]
573 fn test_mel_spectrogram() {
574 let audio = create_sine_wave(440.0, 16000, 0.5);
575 let mel = MelSpectrogram::with_params(16000, 512, 256, 40);
576
577 let spec = mel.apply(&audio);
578
579 assert_eq!(spec.shape()[0], 40); assert!(spec.shape()[1] > 0); }
582
583 #[test]
584 fn test_mfcc() {
585 let audio = create_sine_wave(440.0, 16000, 0.5);
586 let mfcc = MFCC::new(16000, 13);
587
588 let coeffs = mfcc.apply(&audio);
589
590 assert_eq!(coeffs.shape()[0], 13); }
592
593 #[test]
594 fn test_time_stretch() {
595 let audio = create_sine_wave(440.0, 16000, 0.1);
596 let orig_len = audio.shape()[0];
597
598 let stretch = TimeStretch::new(2.0);
600 let stretched = stretch.apply(&audio);
601
602 assert!(stretched.shape()[0] < orig_len);
603 }
604
605 #[test]
606 fn test_pitch_shift() {
607 let audio = create_sine_wave(440.0, 16000, 0.1);
608 let orig_len = audio.shape()[0];
609
610 let shift = PitchShift::new(2.0); let shifted = shift.apply(&audio);
612
613 assert_eq!(shifted.shape()[0], orig_len);
615 }
616
617 #[test]
618 fn test_add_noise() {
619 let audio = create_sine_wave(440.0, 16000, 0.1);
620 let add_noise = AddNoise::new(20.0); let noisy = add_noise.apply(&audio);
623
624 assert_eq!(noisy.shape(), audio.shape());
625 assert_ne!(noisy.to_vec(), audio.to_vec());
627 }
628
629 #[test]
630 fn test_normalize_audio() {
631 let data = vec![0.1, -0.5, 0.3, -0.2];
632 let audio = Tensor::from_vec(data, &[4]).unwrap();
633
634 let normalize = NormalizeAudio::new();
635 let normalized = normalize.apply(&audio);
636
637 let max_val = normalized
638 .to_vec()
639 .iter()
640 .map(|x| x.abs())
641 .fold(0.0f32, f32::max);
642 assert!((max_val - 1.0).abs() < 0.001);
643 }
644
645 #[test]
646 fn test_trim_silence() {
647 let data = vec![0.0, 0.0, 0.5, 0.3, 0.0, 0.0];
648 let audio = Tensor::from_vec(data, &[6]).unwrap();
649
650 let trim = TrimSilence::new(-20.0);
651 let trimmed = trim.apply(&audio);
652
653 assert_eq!(trimmed.shape()[0], 2); }
655
656 #[test]
657 fn test_hz_to_mel_conversion() {
658 let hz = 1000.0;
659 let mel = MelSpectrogram::hz_to_mel(hz);
660 let back = MelSpectrogram::mel_to_hz(mel);
661
662 assert!((hz - back).abs() < 0.1);
663 }
664}