1#![warn(missing_docs)]
26#![warn(clippy::all)]
27#![warn(clippy::pedantic)]
28#![allow(clippy::cast_possible_truncation)]
30#![allow(clippy::cast_sign_loss)]
31#![allow(clippy::cast_precision_loss)]
32#![allow(clippy::cast_possible_wrap)]
33#![allow(clippy::missing_errors_doc)]
34#![allow(clippy::missing_panics_doc)]
35#![allow(clippy::must_use_candidate)]
36#![allow(clippy::module_name_repetitions)]
37#![allow(clippy::similar_names)]
38#![allow(clippy::many_single_char_names)]
39#![allow(clippy::too_many_arguments)]
40#![allow(clippy::doc_markdown)]
41#![allow(clippy::cast_lossless)]
42#![allow(clippy::needless_pass_by_value)]
43#![allow(clippy::redundant_closure_for_method_calls)]
44#![allow(clippy::uninlined_format_args)]
45#![allow(clippy::ptr_arg)]
46#![allow(clippy::return_self_not_must_use)]
47#![allow(clippy::not_unsafe_ptr_arg_deref)]
48#![allow(clippy::items_after_statements)]
49#![allow(clippy::unreadable_literal)]
50#![allow(clippy::if_same_then_else)]
51#![allow(clippy::needless_range_loop)]
52#![allow(clippy::trivially_copy_pass_by_ref)]
53#![allow(clippy::unnecessary_wraps)]
54#![allow(clippy::match_same_arms)]
55#![allow(clippy::unused_self)]
56#![allow(clippy::too_many_lines)]
57#![allow(clippy::single_match_else)]
58#![allow(clippy::fn_params_excessive_bools)]
59#![allow(clippy::struct_excessive_bools)]
60#![allow(clippy::format_push_string)]
61#![allow(clippy::erasing_op)]
62#![allow(clippy::type_repetition_in_bounds)]
63#![allow(clippy::iter_without_into_iter)]
64#![allow(clippy::should_implement_trait)]
65#![allow(clippy::use_debug)]
66#![allow(clippy::case_sensitive_file_extension_comparisons)]
67#![allow(clippy::large_enum_variant)]
68#![allow(clippy::panic)]
69#![allow(clippy::struct_field_names)]
70#![allow(clippy::missing_fields_in_debug)]
71#![allow(clippy::upper_case_acronyms)]
72#![allow(clippy::assigning_clones)]
73#![allow(clippy::option_if_let_else)]
74#![allow(clippy::manual_let_else)]
75#![allow(clippy::explicit_iter_loop)]
76#![allow(clippy::default_trait_access)]
77#![allow(clippy::only_used_in_recursion)]
78#![allow(clippy::manual_clamp)]
79#![allow(clippy::ref_option)]
80#![allow(clippy::multiple_bound_locations)]
81#![allow(clippy::comparison_chain)]
82#![allow(clippy::manual_assert)]
83#![allow(clippy::unnecessary_debug_formatting)]
84
85pub mod datasets;
86pub mod transforms;
87
88pub use transforms::{
93 AddNoise, MelSpectrogram, NormalizeAudio, PitchShift, Resample, TimeStretch, TrimSilence, MFCC,
94};
95
96pub use datasets::{
97 AudioClassificationDataset, AudioSeq2SeqDataset, SyntheticCommandDataset,
98 SyntheticMusicDataset, SyntheticSpeakerDataset,
99};
100
101pub mod prelude {
107 pub use crate::{
108 AddNoise,
109 AudioClassificationDataset,
111 AudioSeq2SeqDataset,
112 MelSpectrogram,
113 NormalizeAudio,
114 PitchShift,
115 Resample,
117 SyntheticCommandDataset,
118 SyntheticMusicDataset,
119 SyntheticSpeakerDataset,
120 TimeStretch,
121 TrimSilence,
122 MFCC,
123 };
124
125 pub use axonml_data::{DataLoader, Dataset, Transform};
126 pub use axonml_tensor::Tensor;
127}
128
129#[cfg(test)]
134mod tests {
135 use super::*;
136 use axonml_data::{Dataset, Transform};
137
138 #[test]
139 fn test_transform_on_dataset() {
140 let dataset = SyntheticCommandDataset::small();
141 let mel = MelSpectrogram::with_params(16000, 512, 256, 40);
142
143 let (waveform, _label) = dataset.get(0).unwrap();
144 let spectrogram = mel.apply(&waveform);
145
146 assert_eq!(spectrogram.shape()[0], 40);
147 assert!(spectrogram.shape()[1] > 0);
148 }
149
150 #[test]
151 fn test_mfcc_on_dataset() {
152 let dataset = SyntheticCommandDataset::small();
153 let mfcc = MFCC::new(16000, 13);
154
155 let (waveform, _) = dataset.get(0).unwrap();
156 let coeffs = mfcc.apply(&waveform);
157
158 assert_eq!(coeffs.shape()[0], 13);
159 }
160
161 #[test]
162 fn test_resample_on_dataset() {
163 let dataset = SyntheticCommandDataset::new(10, 22050, 0.5, 5);
164 let resample = Resample::new(22050, 16000);
165
166 let (waveform, _) = dataset.get(0).unwrap();
167 let resampled = resample.apply(&waveform);
168
169 assert_eq!(waveform.shape()[0], 11025);
172 assert_eq!(resampled.shape()[0], 8000);
173 }
174
175 #[test]
176 fn test_noise_augmentation() {
177 let dataset = SyntheticMusicDataset::small();
178 let add_noise = AddNoise::new(30.0); let (waveform, _) = dataset.get(0).unwrap();
181 let noisy = add_noise.apply(&waveform);
182
183 assert_eq!(noisy.shape(), waveform.shape());
184 }
185
186 #[test]
187 fn test_normalize_audio() {
188 let dataset = SyntheticSpeakerDataset::small();
189 let normalize = NormalizeAudio::new();
190
191 let (waveform, _) = dataset.get(0).unwrap();
192 let normalized = normalize.apply(&waveform);
193
194 let max_val = normalized
195 .to_vec()
196 .iter()
197 .map(|x| x.abs())
198 .fold(0.0f32, f32::max);
199 assert!((max_val - 1.0).abs() < 0.01);
200 }
201
202 #[test]
203 fn test_pipeline() {
204 let dataset = SyntheticCommandDataset::small();
205
206 let resample = Resample::new(16000, 8000);
208 let normalize = NormalizeAudio::new();
209 let mel = MelSpectrogram::with_params(8000, 256, 128, 40);
210
211 let (waveform, _) = dataset.get(0).unwrap();
212
213 let resampled = resample.apply(&waveform);
215 let normalized = normalize.apply(&resampled);
216 let spectrogram = mel.apply(&normalized);
217
218 assert_eq!(spectrogram.shape()[0], 40);
219 }
220
221 #[test]
222 fn test_time_stretch_preserves_audio_characteristics() {
223 let dataset = SyntheticMusicDataset::small();
224 let stretch = TimeStretch::new(1.0); let (waveform, _) = dataset.get(0).unwrap();
227 let stretched = stretch.apply(&waveform);
228
229 assert_eq!(stretched.shape()[0], waveform.shape()[0]);
231 }
232
233 #[test]
234 fn test_pitch_shift() {
235 let dataset = SyntheticCommandDataset::small();
236 let shift = PitchShift::new(0.0); let (waveform, _) = dataset.get(0).unwrap();
239 let shifted = shift.apply(&waveform);
240
241 assert_eq!(shifted.shape()[0], waveform.shape()[0]);
242 }
243
244 #[test]
245 fn test_dataset_with_dataloader() {
246 use axonml_data::DataLoader;
247
248 let dataset = SyntheticCommandDataset::small();
249 let loader = DataLoader::new(dataset, 16);
250
251 let mut batch_count = 0;
252 for batch in loader.iter().take(3) {
253 assert_eq!(batch.data.shape()[0], 16);
254 batch_count += 1;
255 }
256 assert_eq!(batch_count, 3);
257 }
258}