1use axonml_data::Dataset;
9use axonml_tensor::Tensor;
10use rand::{Rng, SeedableRng};
11use std::f32::consts::PI;
12
13pub struct AudioClassificationDataset {
19 waveforms: Vec<Tensor<f32>>,
20 labels: Vec<usize>,
21 sample_rate: usize,
22 num_classes: usize,
23}
24
25impl AudioClassificationDataset {
26 #[must_use] pub fn new(
28 waveforms: Vec<Tensor<f32>>,
29 labels: Vec<usize>,
30 sample_rate: usize,
31 num_classes: usize,
32 ) -> Self {
33 Self {
34 waveforms,
35 labels,
36 sample_rate,
37 num_classes,
38 }
39 }
40
41 #[must_use] pub fn sample_rate(&self) -> usize {
43 self.sample_rate
44 }
45
46 #[must_use] pub fn num_classes(&self) -> usize {
48 self.num_classes
49 }
50}
51
52impl Dataset for AudioClassificationDataset {
53 type Item = (Tensor<f32>, Tensor<f32>);
54
55 fn len(&self) -> usize {
56 self.waveforms.len()
57 }
58
59 fn get(&self, index: usize) -> Option<Self::Item> {
60 if index >= self.len() {
61 return None;
62 }
63
64 let waveform = self.waveforms[index].clone();
65
66 let mut label_vec = vec![0.0f32; self.num_classes];
68 label_vec[self.labels[index]] = 1.0;
69 let label = Tensor::from_vec(label_vec, &[self.num_classes]).unwrap();
70
71 Some((waveform, label))
72 }
73}
74
75pub struct SyntheticCommandDataset {
82 num_samples: usize,
83 sample_rate: usize,
84 duration: f32,
85 num_classes: usize,
86}
87
88impl SyntheticCommandDataset {
89 #[must_use] pub fn new(num_samples: usize, sample_rate: usize, duration: f32, num_classes: usize) -> Self {
91 Self {
92 num_samples,
93 sample_rate,
94 duration,
95 num_classes: num_classes.max(2),
96 }
97 }
98
99 #[must_use] pub fn small() -> Self {
101 Self::new(100, 16000, 0.5, 10)
102 }
103
104 #[must_use] pub fn medium() -> Self {
106 Self::new(1000, 16000, 0.5, 10)
107 }
108
109 #[must_use] pub fn large() -> Self {
111 Self::new(10000, 16000, 0.5, 35)
112 }
113
114 fn generate_waveform(&self, class: usize, seed: u64) -> Tensor<f32> {
116 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
117 let n_samples = (self.sample_rate as f32 * self.duration) as usize;
118
119 let base_freq = 200.0 + (class as f32 * 100.0);
121 let freq_variation = rng.gen_range(0.9..1.1);
122 let freq = base_freq * freq_variation;
123
124 let harmonic_weight = 0.3 + (class as f32 * 0.05);
126
127 let data: Vec<f32> = (0..n_samples)
128 .map(|i| {
129 let t = i as f32 / self.sample_rate as f32;
130 let fundamental = (2.0 * PI * freq * t).sin();
131 let harmonic1 = harmonic_weight * (2.0 * PI * freq * 2.0 * t).sin();
132 let harmonic2 = harmonic_weight * 0.5 * (2.0 * PI * freq * 3.0 * t).sin();
133
134 let envelope = if t < 0.05 {
136 t / 0.05
137 } else if t > self.duration - 0.1 {
138 (self.duration - t) / 0.1
139 } else {
140 1.0
141 };
142
143 let noise: f32 = rng.gen_range(-0.05..0.05);
145
146 (fundamental + harmonic1 + harmonic2 + noise) * envelope * 0.5
147 })
148 .collect();
149
150 Tensor::from_vec(data, &[n_samples]).unwrap()
151 }
152
153 #[must_use] pub fn sample_rate(&self) -> usize {
155 self.sample_rate
156 }
157
158 #[must_use] pub fn num_classes(&self) -> usize {
160 self.num_classes
161 }
162}
163
164impl Dataset for SyntheticCommandDataset {
165 type Item = (Tensor<f32>, Tensor<f32>);
166
167 fn len(&self) -> usize {
168 self.num_samples
169 }
170
171 fn get(&self, index: usize) -> Option<Self::Item> {
172 if index >= self.num_samples {
173 return None;
174 }
175
176 let class = index % self.num_classes;
177 let waveform = self.generate_waveform(class, index as u64);
178
179 let mut label_vec = vec![0.0f32; self.num_classes];
181 label_vec[class] = 1.0;
182 let label = Tensor::from_vec(label_vec, &[self.num_classes]).unwrap();
183
184 Some((waveform, label))
185 }
186}
187
188pub struct SyntheticMusicDataset {
195 num_samples: usize,
196 sample_rate: usize,
197 duration: f32,
198 num_genres: usize,
199}
200
201impl SyntheticMusicDataset {
202 #[must_use] pub fn new(num_samples: usize, sample_rate: usize, duration: f32, num_genres: usize) -> Self {
204 Self {
205 num_samples,
206 sample_rate,
207 duration,
208 num_genres: num_genres.max(2),
209 }
210 }
211
212 #[must_use] pub fn small() -> Self {
214 Self::new(100, 22050, 1.0, 5)
215 }
216
217 #[must_use] pub fn medium() -> Self {
219 Self::new(500, 22050, 2.0, 10)
220 }
221
222 fn generate_waveform(&self, genre: usize, seed: u64) -> Tensor<f32> {
224 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
225 let n_samples = (self.sample_rate as f32 * self.duration) as usize;
226
227 let bpm = match genre % 5 {
229 0 => 60.0 + rng.gen_range(-5.0..5.0), 1 => 90.0 + rng.gen_range(-10.0..10.0), 2 => 120.0 + rng.gen_range(-10.0..10.0), 3 => 140.0 + rng.gen_range(-15.0..15.0), _ => 180.0 + rng.gen_range(-20.0..20.0), };
235
236 let beat_duration = 60.0 / bpm;
237 let base_freq = 220.0 + (genre as f32 * 50.0);
238
239 let data: Vec<f32> = (0..n_samples)
240 .map(|i| {
241 let t = i as f32 / self.sample_rate as f32;
242 let beat_phase = (t / beat_duration).fract();
243
244 let rhythm = if beat_phase < 0.1 {
246 1.0 - beat_phase / 0.1
247 } else {
248 0.0
249 };
250
251 let melody_freq = base_freq * (1.0 + 0.2 * (t * 2.0 * PI / beat_duration).sin());
253 let melody = (2.0 * PI * melody_freq * t).sin();
254
255 let bass = 0.5 * (2.0 * PI * base_freq * 0.5 * t).sin();
257
258 let mix = match genre % 5 {
260 0 => melody * 0.8 + bass * 0.2,
261 1 => melody * 0.6 + bass * 0.3 + rhythm * 0.1,
262 2 => melody * 0.5 + bass * 0.3 + rhythm * 0.2,
263 3 => melody * 0.3 + bass * 0.4 + rhythm * 0.3,
264 _ => melody * 0.4 + bass * 0.5 + rhythm * 0.3,
265 };
266
267 let noise: f32 = rng.gen_range(-0.02..0.02);
269
270 (mix + noise) * 0.5
271 })
272 .collect();
273
274 Tensor::from_vec(data, &[n_samples]).unwrap()
275 }
276
277 #[must_use] pub fn sample_rate(&self) -> usize {
279 self.sample_rate
280 }
281
282 #[must_use] pub fn num_genres(&self) -> usize {
284 self.num_genres
285 }
286}
287
288impl Dataset for SyntheticMusicDataset {
289 type Item = (Tensor<f32>, Tensor<f32>);
290
291 fn len(&self) -> usize {
292 self.num_samples
293 }
294
295 fn get(&self, index: usize) -> Option<Self::Item> {
296 if index >= self.num_samples {
297 return None;
298 }
299
300 let genre = index % self.num_genres;
301 let waveform = self.generate_waveform(genre, index as u64);
302
303 let mut label_vec = vec![0.0f32; self.num_genres];
305 label_vec[genre] = 1.0;
306 let label = Tensor::from_vec(label_vec, &[self.num_genres]).unwrap();
307
308 Some((waveform, label))
309 }
310}
311
312pub struct SyntheticSpeakerDataset {
319 num_samples: usize,
320 sample_rate: usize,
321 duration: f32,
322 num_speakers: usize,
323}
324
325impl SyntheticSpeakerDataset {
326 #[must_use] pub fn new(num_samples: usize, sample_rate: usize, duration: f32, num_speakers: usize) -> Self {
328 Self {
329 num_samples,
330 sample_rate,
331 duration,
332 num_speakers: num_speakers.max(2),
333 }
334 }
335
336 #[must_use] pub fn small() -> Self {
338 Self::new(100, 16000, 0.5, 5)
339 }
340
341 #[must_use] pub fn medium() -> Self {
343 Self::new(500, 16000, 1.0, 20)
344 }
345
346 fn generate_waveform(&self, speaker: usize, seed: u64) -> Tensor<f32> {
348 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
349 let n_samples = (self.sample_rate as f32 * self.duration) as usize;
350
351 let f0 = 80.0 + (speaker as f32 * 15.0) + rng.gen_range(-10.0..10.0);
353
354 let formants = [
356 f0 * 5.0 + (speaker as f32 * 20.0),
357 f0 * 10.0 + (speaker as f32 * 30.0),
358 f0 * 25.0 + (speaker as f32 * 10.0),
359 ];
360
361 let data: Vec<f32> = (0..n_samples)
362 .map(|i| {
363 let t = i as f32 / self.sample_rate as f32;
364
365 let pulse_phase = (t * f0).fract();
367 let glottal = if pulse_phase < 0.3 {
368 (pulse_phase * PI / 0.3).sin()
369 } else {
370 0.0
371 };
372
373 let mut signal = glottal;
375 for &formant in &formants {
376 signal += 0.2 * glottal * (2.0 * PI * formant * t).sin();
377 }
378
379 let variation = 1.0 + 0.1 * (t * 5.0 * PI).sin();
381
382 let noise: f32 = rng.gen_range(-0.03..0.03);
384
385 signal * variation * 0.3 + noise
386 })
387 .collect();
388
389 Tensor::from_vec(data, &[n_samples]).unwrap()
390 }
391
392 #[must_use] pub fn sample_rate(&self) -> usize {
394 self.sample_rate
395 }
396
397 #[must_use] pub fn num_speakers(&self) -> usize {
399 self.num_speakers
400 }
401}
402
403impl Dataset for SyntheticSpeakerDataset {
404 type Item = (Tensor<f32>, Tensor<f32>);
405
406 fn len(&self) -> usize {
407 self.num_samples
408 }
409
410 fn get(&self, index: usize) -> Option<Self::Item> {
411 if index >= self.num_samples {
412 return None;
413 }
414
415 let speaker = index % self.num_speakers;
416 let waveform = self.generate_waveform(speaker, index as u64);
417
418 let mut label_vec = vec![0.0f32; self.num_speakers];
420 label_vec[speaker] = 1.0;
421 let label = Tensor::from_vec(label_vec, &[self.num_speakers]).unwrap();
422
423 Some((waveform, label))
424 }
425}
426
427pub struct AudioSeq2SeqDataset {
433 sources: Vec<Tensor<f32>>,
434 targets: Vec<Tensor<f32>>,
435}
436
437impl AudioSeq2SeqDataset {
438 #[must_use] pub fn new(sources: Vec<Tensor<f32>>, targets: Vec<Tensor<f32>>) -> Self {
440 Self { sources, targets }
441 }
442
443 #[must_use] pub fn noise_reduction_task(num_samples: usize, sample_rate: usize, duration: f32) -> Self {
445 let n_samples_per = (sample_rate as f32 * duration) as usize;
446 let mut sources = Vec::with_capacity(num_samples);
447 let mut targets = Vec::with_capacity(num_samples);
448
449 for i in 0..num_samples {
450 let mut rng = rand::rngs::StdRng::seed_from_u64(i as u64);
451 let freq = 200.0 + (i as f32 * 50.0) % 800.0;
452
453 let clean: Vec<f32> = (0..n_samples_per)
455 .map(|j| {
456 let t = j as f32 / sample_rate as f32;
457 (2.0 * PI * freq * t).sin() * 0.5
458 })
459 .collect();
460
461 let noisy: Vec<f32> = clean
463 .iter()
464 .map(|&x| x + rng.gen_range(-0.2..0.2))
465 .collect();
466
467 sources.push(Tensor::from_vec(noisy, &[n_samples_per]).unwrap());
468 targets.push(Tensor::from_vec(clean, &[n_samples_per]).unwrap());
469 }
470
471 Self { sources, targets }
472 }
473}
474
475impl Dataset for AudioSeq2SeqDataset {
476 type Item = (Tensor<f32>, Tensor<f32>);
477
478 fn len(&self) -> usize {
479 self.sources.len()
480 }
481
482 fn get(&self, index: usize) -> Option<Self::Item> {
483 if index >= self.len() {
484 return None;
485 }
486
487 Some((self.sources[index].clone(), self.targets[index].clone()))
488 }
489}
490
491#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_audio_classification_dataset() {
501 let waveforms = vec![
502 Tensor::from_vec(vec![0.0; 16000], &[16000]).unwrap(),
503 Tensor::from_vec(vec![0.0; 16000], &[16000]).unwrap(),
504 ];
505 let labels = vec![0, 1];
506
507 let dataset = AudioClassificationDataset::new(waveforms, labels, 16000, 2);
508
509 assert_eq!(dataset.len(), 2);
510 assert_eq!(dataset.sample_rate(), 16000);
511 assert_eq!(dataset.num_classes(), 2);
512
513 let (wave, label) = dataset.get(0).unwrap();
514 assert_eq!(wave.shape(), &[16000]);
515 assert_eq!(label.shape(), &[2]);
516 }
517
518 #[test]
519 fn test_synthetic_command_dataset() {
520 let dataset = SyntheticCommandDataset::small();
521
522 assert_eq!(dataset.len(), 100);
523 assert_eq!(dataset.num_classes(), 10);
524 assert_eq!(dataset.sample_rate(), 16000);
525
526 let (wave, label) = dataset.get(0).unwrap();
527 assert_eq!(wave.shape()[0], 8000); assert_eq!(label.shape(), &[10]);
529
530 let label_sum: f32 = label.to_vec().iter().sum();
532 assert!((label_sum - 1.0).abs() < 0.001);
533 }
534
535 #[test]
536 fn test_synthetic_command_dataset_different_classes() {
537 let dataset = SyntheticCommandDataset::small();
538
539 let (_, label0) = dataset.get(0).unwrap();
541 let (_, label1) = dataset.get(1).unwrap();
542
543 let label0_vec = label0.to_vec();
544 let label1_vec = label1.to_vec();
545
546 let class0 = label0_vec.iter().position(|&x| x > 0.5).unwrap();
547 let class1 = label1_vec.iter().position(|&x| x > 0.5).unwrap();
548
549 assert_eq!(class0, 0);
550 assert_eq!(class1, 1);
551 }
552
553 #[test]
554 fn test_synthetic_music_dataset() {
555 let dataset = SyntheticMusicDataset::small();
556
557 assert_eq!(dataset.len(), 100);
558 assert_eq!(dataset.num_genres(), 5);
559 assert_eq!(dataset.sample_rate(), 22050);
560
561 let (wave, label) = dataset.get(0).unwrap();
562 assert_eq!(wave.shape()[0], 22050); assert_eq!(label.shape(), &[5]);
564 }
565
566 #[test]
567 fn test_synthetic_speaker_dataset() {
568 let dataset = SyntheticSpeakerDataset::small();
569
570 assert_eq!(dataset.len(), 100);
571 assert_eq!(dataset.num_speakers(), 5);
572 assert_eq!(dataset.sample_rate(), 16000);
573
574 let (wave, label) = dataset.get(0).unwrap();
575 assert_eq!(wave.shape()[0], 8000); assert_eq!(label.shape(), &[5]);
577 }
578
579 #[test]
580 fn test_audio_seq2seq_dataset() {
581 let dataset = AudioSeq2SeqDataset::noise_reduction_task(10, 16000, 0.1);
582
583 assert_eq!(dataset.len(), 10);
584
585 let (source, target) = dataset.get(0).unwrap();
586 assert_eq!(source.shape(), target.shape());
587 }
588
589 #[test]
590 fn test_dataset_bounds() {
591 let dataset = SyntheticCommandDataset::small();
592
593 assert!(dataset.get(99).is_some());
594 assert!(dataset.get(100).is_none());
595 }
596
597 #[test]
598 fn test_waveform_values_in_range() {
599 let dataset = SyntheticCommandDataset::small();
600
601 let (wave, _) = dataset.get(0).unwrap();
602 let data = wave.to_vec();
603
604 for &val in &data {
606 assert!(val.abs() <= 1.0, "Waveform value {val} out of range");
607 }
608 }
609
610 #[test]
611 fn test_music_dataset_different_genres() {
612 let dataset = SyntheticMusicDataset::small();
613
614 let (wave0, _) = dataset.get(0).unwrap();
616 let (wave1, _) = dataset.get(1).unwrap();
617
618 assert_ne!(wave0.to_vec(), wave1.to_vec());
620 }
621}