1use axonml_data::Transform;
18use axonml_tensor::Tensor;
19use rand::Rng;
20use std::f32::consts::PI;
21
22pub struct Resample {
28 orig_freq: usize,
29 new_freq: usize,
30}
31
32impl Resample {
33 #[must_use]
35 pub fn new(orig_freq: usize, new_freq: usize) -> Self {
36 Self {
37 orig_freq,
38 new_freq,
39 }
40 }
41}
42
43impl Transform for Resample {
44 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
45 if self.orig_freq == self.new_freq {
46 return input.clone();
47 }
48
49 let data = input.to_vec();
50 let orig_len = data.len();
51 let new_len = (orig_len as f64 * self.new_freq as f64 / self.orig_freq as f64) as usize;
52
53 if new_len == 0 {
54 return Tensor::from_vec(vec![], &[0]).unwrap();
55 }
56
57 let mut resampled = Vec::with_capacity(new_len);
58 let ratio = orig_len as f64 / new_len as f64;
59
60 for i in 0..new_len {
61 let src_idx = i as f64 * ratio;
62 let idx0 = src_idx.floor() as usize;
63 let idx1 = (idx0 + 1).min(orig_len - 1);
64 let frac = (src_idx - idx0 as f64) as f32;
65
66 let value = data[idx0] * (1.0 - frac) + data[idx1] * frac;
67 resampled.push(value);
68 }
69
70 Tensor::from_vec(resampled, &[new_len]).unwrap()
71 }
72}
73
74pub struct MelSpectrogram {
80 sample_rate: usize,
81 n_fft: usize,
82 hop_length: usize,
83 n_mels: usize,
84}
85
86impl MelSpectrogram {
87 #[must_use]
89 pub fn new(sample_rate: usize) -> Self {
90 Self {
91 sample_rate,
92 n_fft: 2048,
93 hop_length: 512,
94 n_mels: 128,
95 }
96 }
97
98 #[must_use]
100 pub fn with_params(sample_rate: usize, n_fft: usize, hop_length: usize, n_mels: usize) -> Self {
101 Self {
102 sample_rate,
103 n_fft,
104 hop_length,
105 n_mels,
106 }
107 }
108
109 fn hz_to_mel(hz: f32) -> f32 {
111 2595.0 * (1.0 + hz / 700.0).log10()
112 }
113
114 fn mel_to_hz(mel: f32) -> f32 {
116 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0)
117 }
118
119 fn mel_filterbank(&self) -> Vec<Vec<f32>> {
121 let fmax = self.sample_rate as f32 / 2.0;
122 let mel_min = Self::hz_to_mel(0.0);
123 let mel_max = Self::hz_to_mel(fmax);
124
125 let mel_points: Vec<f32> = (0..=self.n_mels + 1)
127 .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (self.n_mels + 1) as f32)
128 .collect();
129
130 let hz_points: Vec<f32> = mel_points.iter().map(|&m| Self::mel_to_hz(m)).collect();
132
133 let bin_points: Vec<usize> = hz_points
135 .iter()
136 .map(|&hz| ((self.n_fft + 1) as f32 * hz / self.sample_rate as f32).floor() as usize)
137 .collect();
138
139 let n_bins = self.n_fft / 2 + 1;
141 let mut filterbank = vec![vec![0.0f32; n_bins]; self.n_mels];
142
143 for m in 0..self.n_mels {
144 let left = bin_points[m];
145 let center = bin_points[m + 1];
146 let right = bin_points[m + 2];
147
148 for k in left..center {
149 if center > left && k < n_bins {
150 filterbank[m][k] = (k - left) as f32 / (center - left) as f32;
151 }
152 }
153 for k in center..right {
154 if right > center && k < n_bins {
155 filterbank[m][k] = (right - k) as f32 / (right - center) as f32;
156 }
157 }
158 }
159
160 filterbank
161 }
162
163 fn hann_window(size: usize) -> Vec<f32> {
165 (0..size)
166 .map(|n| 0.5 * (1.0 - (2.0 * PI * n as f32 / (size - 1) as f32).cos()))
167 .collect()
168 }
169
170 fn dft(signal: &[f32]) -> Vec<f32> {
172 let n = signal.len();
173 let n_out = n / 2 + 1;
174 let mut magnitude = vec![0.0f32; n_out];
175
176 for k in 0..n_out {
177 let mut real = 0.0;
178 let mut imag = 0.0;
179
180 for (t, &x) in signal.iter().enumerate() {
181 let angle = 2.0 * PI * k as f32 * t as f32 / n as f32;
182 real += x * angle.cos();
183 imag -= x * angle.sin();
184 }
185
186 magnitude[k] = (real * real + imag * imag).sqrt();
187 }
188
189 magnitude
190 }
191}
192
193impl Transform for MelSpectrogram {
194 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
195 let data = input.to_vec();
196 let window = Self::hann_window(self.n_fft);
197 let filterbank = self.mel_filterbank();
198
199 let n_frames = if data.len() >= self.n_fft {
201 (data.len() - self.n_fft) / self.hop_length + 1
202 } else {
203 0
204 };
205
206 if n_frames == 0 {
207 return Tensor::from_vec(vec![0.0; self.n_mels], &[self.n_mels, 1]).unwrap();
208 }
209
210 let mut mel_spec = vec![0.0f32; self.n_mels * n_frames];
211
212 for frame_idx in 0..n_frames {
213 let start = frame_idx * self.hop_length;
214 let end = (start + self.n_fft).min(data.len());
215
216 let mut frame: Vec<f32> = data[start..end].to_vec();
218 frame.resize(self.n_fft, 0.0);
219
220 for (i, w) in window.iter().enumerate() {
221 frame[i] *= w;
222 }
223
224 let spectrum = Self::dft(&frame);
226
227 for (m, filter) in filterbank.iter().enumerate() {
229 let mut mel_energy = 0.0;
230 for (k, &mag) in spectrum.iter().enumerate() {
231 if k < filter.len() {
232 mel_energy += mag * mag * filter[k];
233 }
234 }
235 mel_spec[m * n_frames + frame_idx] = (mel_energy + 1e-10).ln();
237 }
238 }
239
240 Tensor::from_vec(mel_spec, &[self.n_mels, n_frames]).unwrap()
241 }
242}
243
244pub struct MFCC {
250 mel_spec: MelSpectrogram,
251 n_mfcc: usize,
252}
253
254impl MFCC {
255 #[must_use]
257 pub fn new(sample_rate: usize, n_mfcc: usize) -> Self {
258 Self {
259 mel_spec: MelSpectrogram::new(sample_rate),
260 n_mfcc,
261 }
262 }
263
264 fn dct(input: &[f32], n_out: usize) -> Vec<f32> {
266 let n = input.len();
267 let mut output = vec![0.0f32; n_out];
268
269 for k in 0..n_out {
270 let mut sum = 0.0;
271 for (i, &x) in input.iter().enumerate() {
272 sum += x * (PI * k as f32 * (2 * i + 1) as f32 / (2 * n) as f32).cos();
273 }
274 output[k] = sum;
275 }
276
277 output
278 }
279}
280
281impl Transform for MFCC {
282 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
283 let mel = self.mel_spec.apply(input);
285 let mel_data = mel.to_vec();
286 let mel_shape = mel.shape();
287
288 if mel_shape.len() != 2 {
289 return input.clone();
290 }
291
292 let (n_mels, n_frames) = (mel_shape[0], mel_shape[1]);
293 let mut mfcc = vec![0.0f32; self.n_mfcc * n_frames];
294
295 for frame in 0..n_frames {
297 let frame_data: Vec<f32> = (0..n_mels)
298 .map(|m| mel_data[m * n_frames + frame])
299 .collect();
300
301 let coeffs = Self::dct(&frame_data, self.n_mfcc);
302
303 for (k, &c) in coeffs.iter().enumerate() {
304 mfcc[k * n_frames + frame] = c;
305 }
306 }
307
308 Tensor::from_vec(mfcc, &[self.n_mfcc, n_frames]).unwrap()
309 }
310}
311
312pub struct TimeStretch {
318 rate: f32,
319}
320
321impl TimeStretch {
322 #[must_use]
325 pub fn new(rate: f32) -> Self {
326 Self {
327 rate: rate.max(0.1).min(10.0),
328 }
329 }
330}
331
332impl Transform for TimeStretch {
333 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
334 let data = input.to_vec();
335 let new_len = (data.len() as f32 / self.rate) as usize;
336
337 if new_len == 0 {
338 return Tensor::from_vec(vec![], &[0]).unwrap();
339 }
340
341 let mut stretched = Vec::with_capacity(new_len);
342
343 for i in 0..new_len {
344 let src_idx = i as f32 * self.rate;
345 let idx0 = src_idx.floor() as usize;
346 let idx1 = (idx0 + 1).min(data.len() - 1);
347 let frac = src_idx - idx0 as f32;
348
349 if idx0 < data.len() {
350 let value =
351 data[idx0] * (1.0 - frac) + data.get(idx1).copied().unwrap_or(0.0) * frac;
352 stretched.push(value);
353 }
354 }
355
356 let len = stretched.len();
357 Tensor::from_vec(stretched, &[len]).unwrap()
358 }
359}
360
361pub struct PitchShift {
367 semitones: f32,
368}
369
370impl PitchShift {
371 #[must_use]
374 pub fn new(semitones: f32) -> Self {
375 Self { semitones }
376 }
377}
378
379impl Transform for PitchShift {
380 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
381 let rate = 2.0_f32.powf(self.semitones / 12.0);
384 let data = input.to_vec();
385 let orig_len = data.len();
386
387 let resampled_len = (orig_len as f32 / rate) as usize;
389 if resampled_len == 0 {
390 return input.clone();
391 }
392
393 let mut resampled = Vec::with_capacity(resampled_len);
394 for i in 0..resampled_len {
395 let src_idx = i as f32 * rate;
396 let idx0 = src_idx.floor() as usize;
397 let idx1 = (idx0 + 1).min(orig_len - 1);
398 let frac = src_idx - idx0 as f32;
399
400 if idx0 < orig_len {
401 let value =
402 data[idx0] * (1.0 - frac) + data.get(idx1).copied().unwrap_or(0.0) * frac;
403 resampled.push(value);
404 }
405 }
406
407 let mut result = Vec::with_capacity(orig_len);
409 for i in 0..orig_len {
410 let src_idx = i as f32 * resampled.len() as f32 / orig_len as f32;
411 let idx0 = src_idx.floor() as usize;
412 let idx1 = (idx0 + 1).min(resampled.len().saturating_sub(1));
413 let frac = src_idx - idx0 as f32;
414
415 if idx0 < resampled.len() {
416 let value = resampled[idx0] * (1.0 - frac)
417 + resampled.get(idx1).copied().unwrap_or(0.0) * frac;
418 result.push(value);
419 } else {
420 result.push(0.0);
421 }
422 }
423
424 Tensor::from_vec(result, &[orig_len]).unwrap()
425 }
426}
427
428pub struct AddNoise {
434 snr_db: f32,
435}
436
437impl AddNoise {
438 #[must_use]
440 pub fn new(snr_db: f32) -> Self {
441 Self { snr_db }
442 }
443}
444
445impl Transform for AddNoise {
446 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
447 let data = input.to_vec();
448 let mut rng = rand::thread_rng();
449
450 let signal_power: f32 = data.iter().map(|&x| x * x).sum::<f32>() / data.len() as f32;
452
453 let noise_power = signal_power / 10.0_f32.powf(self.snr_db / 10.0);
455 let noise_std = noise_power.sqrt();
456
457 let noisy: Vec<f32> = data
459 .iter()
460 .map(|&x| {
461 let u1: f32 = rng.r#gen();
463 let u2: f32 = rng.r#gen();
464 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
465 x + z * noise_std
466 })
467 .collect();
468
469 Tensor::from_vec(noisy, input.shape()).unwrap()
470 }
471}
472
473pub struct NormalizeAudio;
479
480impl NormalizeAudio {
481 #[must_use]
483 pub fn new() -> Self {
484 Self
485 }
486}
487
488impl Default for NormalizeAudio {
489 fn default() -> Self {
490 Self::new()
491 }
492}
493
494impl Transform for NormalizeAudio {
495 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
496 let data = input.to_vec();
497 let max_val = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
498
499 if max_val < 1e-10 {
500 return input.clone();
501 }
502
503 let normalized: Vec<f32> = data.iter().map(|&x| x / max_val).collect();
504 Tensor::from_vec(normalized, input.shape()).unwrap()
505 }
506}
507
508pub struct TrimSilence {
514 threshold_db: f32,
515}
516
517impl TrimSilence {
518 #[must_use]
520 pub fn new(threshold_db: f32) -> Self {
521 Self { threshold_db }
522 }
523
524 #[must_use]
526 pub fn default_threshold() -> Self {
527 Self::new(-60.0)
528 }
529}
530
531impl Transform for TrimSilence {
532 fn apply(&self, input: &Tensor<f32>) -> Tensor<f32> {
533 let data = input.to_vec();
534 let threshold = 10.0_f32.powf(self.threshold_db / 20.0);
535
536 let start = data.iter().position(|&x| x.abs() > threshold).unwrap_or(0);
538
539 let end = data
541 .iter()
542 .rposition(|&x| x.abs() > threshold)
543 .map_or(data.len(), |i| i + 1);
544
545 if start >= end {
546 return Tensor::from_vec(vec![], &[0]).unwrap();
547 }
548
549 let trimmed = data[start..end].to_vec();
550 let len = trimmed.len();
551 Tensor::from_vec(trimmed, &[len]).unwrap()
552 }
553}
554
555#[cfg(test)]
560mod tests {
561 use super::*;
562
563 fn create_sine_wave(freq: f32, sample_rate: usize, duration: f32) -> Tensor<f32> {
564 let n_samples = (sample_rate as f32 * duration) as usize;
565 let data: Vec<f32> = (0..n_samples)
566 .map(|i| (2.0 * PI * freq * i as f32 / sample_rate as f32).sin())
567 .collect();
568 Tensor::from_vec(data, &[n_samples]).unwrap()
569 }
570
571 #[test]
572 fn test_resample() {
573 let audio = create_sine_wave(440.0, 16000, 0.1);
574 let resample = Resample::new(16000, 8000);
575
576 let resampled = resample.apply(&audio);
577
578 assert_eq!(resampled.shape()[0], audio.shape()[0] / 2);
580 }
581
582 #[test]
583 fn test_resample_same_rate() {
584 let audio = create_sine_wave(440.0, 16000, 0.1);
585 let resample = Resample::new(16000, 16000);
586
587 let result = resample.apply(&audio);
588 assert_eq!(result.to_vec(), audio.to_vec());
589 }
590
591 #[test]
592 fn test_mel_spectrogram() {
593 let audio = create_sine_wave(440.0, 16000, 0.5);
594 let mel = MelSpectrogram::with_params(16000, 512, 256, 40);
595
596 let spec = mel.apply(&audio);
597
598 assert_eq!(spec.shape()[0], 40); assert!(spec.shape()[1] > 0); }
601
602 #[test]
603 fn test_mfcc() {
604 let audio = create_sine_wave(440.0, 16000, 0.5);
605 let mfcc = MFCC::new(16000, 13);
606
607 let coeffs = mfcc.apply(&audio);
608
609 assert_eq!(coeffs.shape()[0], 13); }
611
612 #[test]
613 fn test_time_stretch() {
614 let audio = create_sine_wave(440.0, 16000, 0.1);
615 let orig_len = audio.shape()[0];
616
617 let stretch = TimeStretch::new(2.0);
619 let stretched = stretch.apply(&audio);
620
621 assert!(stretched.shape()[0] < orig_len);
622 }
623
624 #[test]
625 fn test_pitch_shift() {
626 let audio = create_sine_wave(440.0, 16000, 0.1);
627 let orig_len = audio.shape()[0];
628
629 let shift = PitchShift::new(2.0); let shifted = shift.apply(&audio);
631
632 assert_eq!(shifted.shape()[0], orig_len);
634 }
635
636 #[test]
637 fn test_add_noise() {
638 let audio = create_sine_wave(440.0, 16000, 0.1);
639 let add_noise = AddNoise::new(20.0); let noisy = add_noise.apply(&audio);
642
643 assert_eq!(noisy.shape(), audio.shape());
644 assert_ne!(noisy.to_vec(), audio.to_vec());
646 }
647
648 #[test]
649 fn test_normalize_audio() {
650 let data = vec![0.1, -0.5, 0.3, -0.2];
651 let audio = Tensor::from_vec(data, &[4]).unwrap();
652
653 let normalize = NormalizeAudio::new();
654 let normalized = normalize.apply(&audio);
655
656 let max_val = normalized
657 .to_vec()
658 .iter()
659 .map(|x| x.abs())
660 .fold(0.0f32, f32::max);
661 assert!((max_val - 1.0).abs() < 0.001);
662 }
663
664 #[test]
665 fn test_trim_silence() {
666 let data = vec![0.0, 0.0, 0.5, 0.3, 0.0, 0.0];
667 let audio = Tensor::from_vec(data, &[6]).unwrap();
668
669 let trim = TrimSilence::new(-20.0);
670 let trimmed = trim.apply(&audio);
671
672 assert_eq!(trimmed.shape()[0], 2); }
674
675 #[test]
676 fn test_hz_to_mel_conversion() {
677 let hz = 1000.0;
678 let mel = MelSpectrogram::hz_to_mel(hz);
679 let back = MelSpectrogram::mel_to_hz(mel);
680
681 assert!((hz - back).abs() < 0.1);
682 }
683}