Skip to main content

axonml_audio/
datasets.rs

1//! Audio Datasets - Dataset implementations for audio processing tasks
2//!
3//! Provides datasets for audio classification, speech recognition, and music tasks.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_data::Dataset;
9use axonml_tensor::Tensor;
10use rand::{Rng, SeedableRng};
11use std::f32::consts::PI;
12
13// =============================================================================
14// Audio Classification Dataset
15// =============================================================================
16
17/// A dataset for audio classification tasks.
18pub struct AudioClassificationDataset {
19    waveforms: Vec<Tensor<f32>>,
20    labels: Vec<usize>,
21    sample_rate: usize,
22    num_classes: usize,
23}
24
25impl AudioClassificationDataset {
26    /// Creates a new audio classification dataset from waveforms and labels.
27    #[must_use] pub fn new(
28        waveforms: Vec<Tensor<f32>>,
29        labels: Vec<usize>,
30        sample_rate: usize,
31        num_classes: usize,
32    ) -> Self {
33        Self {
34            waveforms,
35            labels,
36            sample_rate,
37            num_classes,
38        }
39    }
40
41    /// Returns the sample rate.
42    #[must_use] pub fn sample_rate(&self) -> usize {
43        self.sample_rate
44    }
45
46    /// Returns the number of classes.
47    #[must_use] pub fn num_classes(&self) -> usize {
48        self.num_classes
49    }
50}
51
52impl Dataset for AudioClassificationDataset {
53    type Item = (Tensor<f32>, Tensor<f32>);
54
55    fn len(&self) -> usize {
56        self.waveforms.len()
57    }
58
59    fn get(&self, index: usize) -> Option<Self::Item> {
60        if index >= self.len() {
61            return None;
62        }
63
64        let waveform = self.waveforms[index].clone();
65
66        // One-hot encode label
67        let mut label_vec = vec![0.0f32; self.num_classes];
68        label_vec[self.labels[index]] = 1.0;
69        let label = Tensor::from_vec(label_vec, &[self.num_classes]).unwrap();
70
71        Some((waveform, label))
72    }
73}
74
75// =============================================================================
76// Synthetic Audio Command Dataset
77// =============================================================================
78
79/// A synthetic dataset simulating audio commands (like "yes", "no", "stop", etc.).
80/// Uses different frequency patterns to represent different commands.
81pub struct SyntheticCommandDataset {
82    num_samples: usize,
83    sample_rate: usize,
84    duration: f32,
85    num_classes: usize,
86}
87
88impl SyntheticCommandDataset {
89    /// Creates a new synthetic command dataset.
90    #[must_use] pub fn new(num_samples: usize, sample_rate: usize, duration: f32, num_classes: usize) -> Self {
91        Self {
92            num_samples,
93            sample_rate,
94            duration,
95            num_classes: num_classes.max(2),
96        }
97    }
98
99    /// Creates a small dataset with 100 samples.
100    #[must_use] pub fn small() -> Self {
101        Self::new(100, 16000, 0.5, 10)
102    }
103
104    /// Creates a medium dataset with 1000 samples.
105    #[must_use] pub fn medium() -> Self {
106        Self::new(1000, 16000, 0.5, 10)
107    }
108
109    /// Creates a large dataset with 10000 samples.
110    #[must_use] pub fn large() -> Self {
111        Self::new(10000, 16000, 0.5, 35)
112    }
113
114    /// Generates a synthetic waveform for a given class.
115    fn generate_waveform(&self, class: usize, seed: u64) -> Tensor<f32> {
116        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
117        let n_samples = (self.sample_rate as f32 * self.duration) as usize;
118
119        // Different classes have different frequency patterns
120        let base_freq = 200.0 + (class as f32 * 100.0);
121        let freq_variation = rng.gen_range(0.9..1.1);
122        let freq = base_freq * freq_variation;
123
124        // Add some harmonics based on class
125        let harmonic_weight = 0.3 + (class as f32 * 0.05);
126
127        let data: Vec<f32> = (0..n_samples)
128            .map(|i| {
129                let t = i as f32 / self.sample_rate as f32;
130                let fundamental = (2.0 * PI * freq * t).sin();
131                let harmonic1 = harmonic_weight * (2.0 * PI * freq * 2.0 * t).sin();
132                let harmonic2 = harmonic_weight * 0.5 * (2.0 * PI * freq * 3.0 * t).sin();
133
134                // Add envelope
135                let envelope = if t < 0.05 {
136                    t / 0.05
137                } else if t > self.duration - 0.1 {
138                    (self.duration - t) / 0.1
139                } else {
140                    1.0
141                };
142
143                // Add small noise
144                let noise: f32 = rng.gen_range(-0.05..0.05);
145
146                (fundamental + harmonic1 + harmonic2 + noise) * envelope * 0.5
147            })
148            .collect();
149
150        Tensor::from_vec(data, &[n_samples]).unwrap()
151    }
152
153    /// Returns the sample rate.
154    #[must_use] pub fn sample_rate(&self) -> usize {
155        self.sample_rate
156    }
157
158    /// Returns the number of classes.
159    #[must_use] pub fn num_classes(&self) -> usize {
160        self.num_classes
161    }
162}
163
164impl Dataset for SyntheticCommandDataset {
165    type Item = (Tensor<f32>, Tensor<f32>);
166
167    fn len(&self) -> usize {
168        self.num_samples
169    }
170
171    fn get(&self, index: usize) -> Option<Self::Item> {
172        if index >= self.num_samples {
173            return None;
174        }
175
176        let class = index % self.num_classes;
177        let waveform = self.generate_waveform(class, index as u64);
178
179        // One-hot encode label
180        let mut label_vec = vec![0.0f32; self.num_classes];
181        label_vec[class] = 1.0;
182        let label = Tensor::from_vec(label_vec, &[self.num_classes]).unwrap();
183
184        Some((waveform, label))
185    }
186}
187
188// =============================================================================
189// Synthetic Music Genre Dataset
190// =============================================================================
191
192/// A synthetic dataset for music genre classification.
193/// Simulates different genres with distinct rhythm and frequency patterns.
194pub struct SyntheticMusicDataset {
195    num_samples: usize,
196    sample_rate: usize,
197    duration: f32,
198    num_genres: usize,
199}
200
201impl SyntheticMusicDataset {
202    /// Creates a new synthetic music dataset.
203    #[must_use] pub fn new(num_samples: usize, sample_rate: usize, duration: f32, num_genres: usize) -> Self {
204        Self {
205            num_samples,
206            sample_rate,
207            duration,
208            num_genres: num_genres.max(2),
209        }
210    }
211
212    /// Creates a small dataset.
213    #[must_use] pub fn small() -> Self {
214        Self::new(100, 22050, 1.0, 5)
215    }
216
217    /// Creates a medium dataset.
218    #[must_use] pub fn medium() -> Self {
219        Self::new(500, 22050, 2.0, 10)
220    }
221
222    /// Generates a synthetic waveform for a music genre.
223    fn generate_waveform(&self, genre: usize, seed: u64) -> Tensor<f32> {
224        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
225        let n_samples = (self.sample_rate as f32 * self.duration) as usize;
226
227        // Different genres have different characteristics
228        let bpm = match genre % 5 {
229            0 => 60.0 + rng.gen_range(-5.0..5.0),    // Classical - slow
230            1 => 90.0 + rng.gen_range(-10.0..10.0),  // Jazz
231            2 => 120.0 + rng.gen_range(-10.0..10.0), // Pop
232            3 => 140.0 + rng.gen_range(-15.0..15.0), // Electronic
233            _ => 180.0 + rng.gen_range(-20.0..20.0), // Metal
234        };
235
236        let beat_duration = 60.0 / bpm;
237        let base_freq = 220.0 + (genre as f32 * 50.0);
238
239        let data: Vec<f32> = (0..n_samples)
240            .map(|i| {
241                let t = i as f32 / self.sample_rate as f32;
242                let beat_phase = (t / beat_duration).fract();
243
244                // Create rhythm pattern
245                let rhythm = if beat_phase < 0.1 {
246                    1.0 - beat_phase / 0.1
247                } else {
248                    0.0
249                };
250
251                // Melodic content
252                let melody_freq = base_freq * (1.0 + 0.2 * (t * 2.0 * PI / beat_duration).sin());
253                let melody = (2.0 * PI * melody_freq * t).sin();
254
255                // Bass
256                let bass = 0.5 * (2.0 * PI * base_freq * 0.5 * t).sin();
257
258                // Combine with genre-specific mixing
259                let mix = match genre % 5 {
260                    0 => melody * 0.8 + bass * 0.2,
261                    1 => melody * 0.6 + bass * 0.3 + rhythm * 0.1,
262                    2 => melody * 0.5 + bass * 0.3 + rhythm * 0.2,
263                    3 => melody * 0.3 + bass * 0.4 + rhythm * 0.3,
264                    _ => melody * 0.4 + bass * 0.5 + rhythm * 0.3,
265                };
266
267                // Add noise for texture
268                let noise: f32 = rng.gen_range(-0.02..0.02);
269
270                (mix + noise) * 0.5
271            })
272            .collect();
273
274        Tensor::from_vec(data, &[n_samples]).unwrap()
275    }
276
277    /// Returns the sample rate.
278    #[must_use] pub fn sample_rate(&self) -> usize {
279        self.sample_rate
280    }
281
282    /// Returns the number of genres.
283    #[must_use] pub fn num_genres(&self) -> usize {
284        self.num_genres
285    }
286}
287
288impl Dataset for SyntheticMusicDataset {
289    type Item = (Tensor<f32>, Tensor<f32>);
290
291    fn len(&self) -> usize {
292        self.num_samples
293    }
294
295    fn get(&self, index: usize) -> Option<Self::Item> {
296        if index >= self.num_samples {
297            return None;
298        }
299
300        let genre = index % self.num_genres;
301        let waveform = self.generate_waveform(genre, index as u64);
302
303        // One-hot encode label
304        let mut label_vec = vec![0.0f32; self.num_genres];
305        label_vec[genre] = 1.0;
306        let label = Tensor::from_vec(label_vec, &[self.num_genres]).unwrap();
307
308        Some((waveform, label))
309    }
310}
311
312// =============================================================================
313// Synthetic Speech Dataset
314// =============================================================================
315
316/// A synthetic dataset for speaker identification.
317/// Simulates different speakers with distinct vocal characteristics.
318pub struct SyntheticSpeakerDataset {
319    num_samples: usize,
320    sample_rate: usize,
321    duration: f32,
322    num_speakers: usize,
323}
324
325impl SyntheticSpeakerDataset {
326    /// Creates a new synthetic speaker dataset.
327    #[must_use] pub fn new(num_samples: usize, sample_rate: usize, duration: f32, num_speakers: usize) -> Self {
328        Self {
329            num_samples,
330            sample_rate,
331            duration,
332            num_speakers: num_speakers.max(2),
333        }
334    }
335
336    /// Creates a small dataset.
337    #[must_use] pub fn small() -> Self {
338        Self::new(100, 16000, 0.5, 5)
339    }
340
341    /// Creates a medium dataset.
342    #[must_use] pub fn medium() -> Self {
343        Self::new(500, 16000, 1.0, 20)
344    }
345
346    /// Generates a synthetic waveform for a speaker.
347    fn generate_waveform(&self, speaker: usize, seed: u64) -> Tensor<f32> {
348        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
349        let n_samples = (self.sample_rate as f32 * self.duration) as usize;
350
351        // Different speakers have different fundamental frequencies
352        let f0 = 80.0 + (speaker as f32 * 15.0) + rng.gen_range(-10.0..10.0);
353
354        // Formant frequencies (simplified vocal tract model)
355        let formants = [
356            f0 * 5.0 + (speaker as f32 * 20.0),
357            f0 * 10.0 + (speaker as f32 * 30.0),
358            f0 * 25.0 + (speaker as f32 * 10.0),
359        ];
360
361        let data: Vec<f32> = (0..n_samples)
362            .map(|i| {
363                let t = i as f32 / self.sample_rate as f32;
364
365                // Glottal pulse train
366                let pulse_phase = (t * f0).fract();
367                let glottal = if pulse_phase < 0.3 {
368                    (pulse_phase * PI / 0.3).sin()
369                } else {
370                    0.0
371                };
372
373                // Add formants
374                let mut signal = glottal;
375                for &formant in &formants {
376                    signal += 0.2 * glottal * (2.0 * PI * formant * t).sin();
377                }
378
379                // Add some variation
380                let variation = 1.0 + 0.1 * (t * 5.0 * PI).sin();
381
382                // Add noise for breathiness
383                let noise: f32 = rng.gen_range(-0.03..0.03);
384
385                signal * variation * 0.3 + noise
386            })
387            .collect();
388
389        Tensor::from_vec(data, &[n_samples]).unwrap()
390    }
391
392    /// Returns the sample rate.
393    #[must_use] pub fn sample_rate(&self) -> usize {
394        self.sample_rate
395    }
396
397    /// Returns the number of speakers.
398    #[must_use] pub fn num_speakers(&self) -> usize {
399        self.num_speakers
400    }
401}
402
403impl Dataset for SyntheticSpeakerDataset {
404    type Item = (Tensor<f32>, Tensor<f32>);
405
406    fn len(&self) -> usize {
407        self.num_samples
408    }
409
410    fn get(&self, index: usize) -> Option<Self::Item> {
411        if index >= self.num_samples {
412            return None;
413        }
414
415        let speaker = index % self.num_speakers;
416        let waveform = self.generate_waveform(speaker, index as u64);
417
418        // One-hot encode label
419        let mut label_vec = vec![0.0f32; self.num_speakers];
420        label_vec[speaker] = 1.0;
421        let label = Tensor::from_vec(label_vec, &[self.num_speakers]).unwrap();
422
423        Some((waveform, label))
424    }
425}
426
427// =============================================================================
428// Sequence-to-Sequence Audio Dataset
429// =============================================================================
430
431/// A dataset for audio sequence-to-sequence tasks.
432pub struct AudioSeq2SeqDataset {
433    sources: Vec<Tensor<f32>>,
434    targets: Vec<Tensor<f32>>,
435}
436
437impl AudioSeq2SeqDataset {
438    /// Creates a new audio seq2seq dataset.
439    #[must_use] pub fn new(sources: Vec<Tensor<f32>>, targets: Vec<Tensor<f32>>) -> Self {
440        Self { sources, targets }
441    }
442
443    /// Creates a synthetic noise reduction dataset.
444    #[must_use] pub fn noise_reduction_task(num_samples: usize, sample_rate: usize, duration: f32) -> Self {
445        let n_samples_per = (sample_rate as f32 * duration) as usize;
446        let mut sources = Vec::with_capacity(num_samples);
447        let mut targets = Vec::with_capacity(num_samples);
448
449        for i in 0..num_samples {
450            let mut rng = rand::rngs::StdRng::seed_from_u64(i as u64);
451            let freq = 200.0 + (i as f32 * 50.0) % 800.0;
452
453            // Clean signal
454            let clean: Vec<f32> = (0..n_samples_per)
455                .map(|j| {
456                    let t = j as f32 / sample_rate as f32;
457                    (2.0 * PI * freq * t).sin() * 0.5
458                })
459                .collect();
460
461            // Noisy signal
462            let noisy: Vec<f32> = clean
463                .iter()
464                .map(|&x| x + rng.gen_range(-0.2..0.2))
465                .collect();
466
467            sources.push(Tensor::from_vec(noisy, &[n_samples_per]).unwrap());
468            targets.push(Tensor::from_vec(clean, &[n_samples_per]).unwrap());
469        }
470
471        Self { sources, targets }
472    }
473}
474
475impl Dataset for AudioSeq2SeqDataset {
476    type Item = (Tensor<f32>, Tensor<f32>);
477
478    fn len(&self) -> usize {
479        self.sources.len()
480    }
481
482    fn get(&self, index: usize) -> Option<Self::Item> {
483        if index >= self.len() {
484            return None;
485        }
486
487        Some((self.sources[index].clone(), self.targets[index].clone()))
488    }
489}
490
491// =============================================================================
492// Tests
493// =============================================================================
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn test_audio_classification_dataset() {
501        let waveforms = vec![
502            Tensor::from_vec(vec![0.0; 16000], &[16000]).unwrap(),
503            Tensor::from_vec(vec![0.0; 16000], &[16000]).unwrap(),
504        ];
505        let labels = vec![0, 1];
506
507        let dataset = AudioClassificationDataset::new(waveforms, labels, 16000, 2);
508
509        assert_eq!(dataset.len(), 2);
510        assert_eq!(dataset.sample_rate(), 16000);
511        assert_eq!(dataset.num_classes(), 2);
512
513        let (wave, label) = dataset.get(0).unwrap();
514        assert_eq!(wave.shape(), &[16000]);
515        assert_eq!(label.shape(), &[2]);
516    }
517
518    #[test]
519    fn test_synthetic_command_dataset() {
520        let dataset = SyntheticCommandDataset::small();
521
522        assert_eq!(dataset.len(), 100);
523        assert_eq!(dataset.num_classes(), 10);
524        assert_eq!(dataset.sample_rate(), 16000);
525
526        let (wave, label) = dataset.get(0).unwrap();
527        assert_eq!(wave.shape()[0], 8000); // 0.5s at 16000Hz
528        assert_eq!(label.shape(), &[10]);
529
530        // Check label is one-hot
531        let label_sum: f32 = label.to_vec().iter().sum();
532        assert!((label_sum - 1.0).abs() < 0.001);
533    }
534
535    #[test]
536    fn test_synthetic_command_dataset_different_classes() {
537        let dataset = SyntheticCommandDataset::small();
538
539        // Different indices should produce different class labels
540        let (_, label0) = dataset.get(0).unwrap();
541        let (_, label1) = dataset.get(1).unwrap();
542
543        let label0_vec = label0.to_vec();
544        let label1_vec = label1.to_vec();
545
546        let class0 = label0_vec.iter().position(|&x| x > 0.5).unwrap();
547        let class1 = label1_vec.iter().position(|&x| x > 0.5).unwrap();
548
549        assert_eq!(class0, 0);
550        assert_eq!(class1, 1);
551    }
552
553    #[test]
554    fn test_synthetic_music_dataset() {
555        let dataset = SyntheticMusicDataset::small();
556
557        assert_eq!(dataset.len(), 100);
558        assert_eq!(dataset.num_genres(), 5);
559        assert_eq!(dataset.sample_rate(), 22050);
560
561        let (wave, label) = dataset.get(0).unwrap();
562        assert_eq!(wave.shape()[0], 22050); // 1.0s at 22050Hz
563        assert_eq!(label.shape(), &[5]);
564    }
565
566    #[test]
567    fn test_synthetic_speaker_dataset() {
568        let dataset = SyntheticSpeakerDataset::small();
569
570        assert_eq!(dataset.len(), 100);
571        assert_eq!(dataset.num_speakers(), 5);
572        assert_eq!(dataset.sample_rate(), 16000);
573
574        let (wave, label) = dataset.get(0).unwrap();
575        assert_eq!(wave.shape()[0], 8000); // 0.5s at 16000Hz
576        assert_eq!(label.shape(), &[5]);
577    }
578
579    #[test]
580    fn test_audio_seq2seq_dataset() {
581        let dataset = AudioSeq2SeqDataset::noise_reduction_task(10, 16000, 0.1);
582
583        assert_eq!(dataset.len(), 10);
584
585        let (source, target) = dataset.get(0).unwrap();
586        assert_eq!(source.shape(), target.shape());
587    }
588
589    #[test]
590    fn test_dataset_bounds() {
591        let dataset = SyntheticCommandDataset::small();
592
593        assert!(dataset.get(99).is_some());
594        assert!(dataset.get(100).is_none());
595    }
596
597    #[test]
598    fn test_waveform_values_in_range() {
599        let dataset = SyntheticCommandDataset::small();
600
601        let (wave, _) = dataset.get(0).unwrap();
602        let data = wave.to_vec();
603
604        // All values should be in reasonable range
605        for &val in &data {
606            assert!(val.abs() <= 1.0, "Waveform value {val} out of range");
607        }
608    }
609
610    #[test]
611    fn test_music_dataset_different_genres() {
612        let dataset = SyntheticMusicDataset::small();
613
614        // Get waveforms from two different genres
615        let (wave0, _) = dataset.get(0).unwrap();
616        let (wave1, _) = dataset.get(1).unwrap();
617
618        // The waveforms should be different
619        assert_ne!(wave0.to_vec(), wave1.to_vec());
620    }
621}