torsh-data 0.1.2

Data loading and preprocessing utilities for ToRSh
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
//! Audio dataset implementations

use super::{transforms::transforms, transforms::AudioToTensor, types::AudioData};
use crate::{dataset::Dataset, transforms::Transform};
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;

#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, vec::Vec};
use std::path::{Path, PathBuf};

/// Audio file dataset for loading audio from directories
pub struct AudioFolder {
    #[allow(dead_code)]
    root: PathBuf,
    samples: Vec<(PathBuf, usize)>,
    classes: Vec<String>,
    transform: Option<Box<dyn Transform<AudioData, Output = Tensor<f32>>>>,
    sample_rate: Option<u32>,
}

impl AudioFolder {
    /// Create a new audio folder dataset
    pub fn new<P: AsRef<Path>>(root: P, sample_rate: Option<u32>) -> Result<Self> {
        let root = root.as_ref().to_path_buf();

        if !root.exists() {
            return Err(TorshError::IoError(format!(
                "Directory does not exist: {:?}",
                root
            )));
        }

        let mut classes = Vec::new();
        let mut samples = Vec::new();

        // Scan subdirectories for classes
        for entry in std::fs::read_dir(&root).map_err(|e| TorshError::IoError(e.to_string()))? {
            let entry = entry.map_err(|e| TorshError::IoError(e.to_string()))?;
            let path = entry.path();

            if path.is_dir() {
                let class_name = path
                    .file_name()
                    .and_then(|n| n.to_str())
                    .ok_or_else(|| TorshError::IoError("Invalid class directory name".to_string()))?
                    .to_string();

                let class_idx = classes.len();
                classes.push(class_name);

                // Scan audio files in class directory
                for audio_entry in
                    std::fs::read_dir(&path).map_err(|e| TorshError::IoError(e.to_string()))?
                {
                    let audio_entry =
                        audio_entry.map_err(|e| TorshError::IoError(e.to_string()))?;
                    let audio_path = audio_entry.path();

                    if Self::is_audio_file(&audio_path) {
                        samples.push((audio_path, class_idx));
                    }
                }
            }
        }

        Ok(Self {
            root,
            samples,
            classes,
            transform: None,
            sample_rate,
        })
    }

    /// Set transform to apply to audio
    pub fn with_transform<T>(mut self, transform: T) -> Self
    where
        T: Transform<AudioData, Output = Tensor<f32>> + 'static,
    {
        self.transform = Some(Box::new(transform));
        self
    }

    /// Get class names
    pub fn classes(&self) -> &[String] {
        &self.classes
    }

    /// Check if file is a supported audio format
    fn is_audio_file(path: &Path) -> bool {
        if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
            matches!(
                extension.to_lowercase().as_str(),
                "wav" | "mp3" | "flac" | "ogg" | "m4a" | "aac"
            )
        } else {
            false
        }
    }

    /// Load audio from path with fallback to dummy data
    fn load_audio(&self, path: &Path) -> Result<AudioData> {
        // Try to load actual audio file if it exists
        if path.exists() {
            // Check file extension to determine loading strategy
            if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
                match extension.to_lowercase().as_str() {
                    "wav" => {
                        // For WAV files, we could use a simple WAV parser
                        // For now, we'll use a basic implementation or fall back
                        if let Ok(audio) = Self::load_wav_file(path, self.sample_rate) {
                            return Ok(audio);
                        }
                    }
                    "flac" | "mp3" | "ogg" | "m4a" | "aac" => {
                        // These formats require specialized libraries:
                        // - symphonia for pure Rust audio decoding (all formats)
                        // - rodio for audio playback and simple loading
                        // - mp3 crate for MP3 specifically
                        // - ogg/vorbis crates for OGG

                        // For now, create format-appropriate dummy data
                        let sample_rate = self.sample_rate.unwrap_or(22050);
                        let duration_seconds = 3.0; // 3 second dummy audio
                        let samples_count = (sample_rate as f32 * duration_seconds) as usize;
                        let samples: Vec<f32> = (0..samples_count)
                            .map(|i| {
                                // Create more realistic audio signal for different formats
                                match extension {
                                    "flac" => {
                                        (i as f32 * 220.0 * 2.0 * std::f32::consts::PI
                                            / sample_rate as f32)
                                            .sin()
                                            * 0.15
                                    }
                                    "mp3" => {
                                        (i as f32 * 880.0 * 2.0 * std::f32::consts::PI
                                            / sample_rate as f32)
                                            .sin()
                                            * 0.12
                                    }
                                    _ => {
                                        (i as f32 * 440.0 * 2.0 * std::f32::consts::PI
                                            / sample_rate as f32)
                                            .sin()
                                            * 0.1
                                    }
                                }
                            })
                            .collect();

                        return Ok(AudioData::new(samples, sample_rate, 1));
                    }
                    _ => {
                        return Err(TorshError::IoError(format!(
                            "Unsupported audio format: {}",
                            extension
                        )));
                    }
                }
            }
        }

        // Fallback: return dummy audio data when file doesn't exist or can't be loaded
        let sample_rate = self.sample_rate.unwrap_or(22050);
        let duration_seconds = 3.0; // 3 second dummy audio
        let samples_count = (sample_rate as f32 * duration_seconds) as usize;
        let samples: Vec<f32> = (0..samples_count)
            .map(|i| {
                (i as f32 * 440.0 * 2.0 * std::f32::consts::PI / sample_rate as f32).sin() * 0.1
            })
            .collect();

        Ok(AudioData::new(samples, sample_rate, 1))
    }

    /// Basic WAV file loader (simplified implementation)
    fn load_wav_file(path: &Path, target_sample_rate: Option<u32>) -> Result<AudioData> {
        // This is a very basic WAV file parser
        // In production, you'd use the 'hound' crate or similar

        let file_data = std::fs::read(path)
            .map_err(|e| TorshError::IoError(format!("Failed to read WAV file: {}", e)))?;

        if file_data.len() < 44 {
            return Err(TorshError::IoError(
                "Invalid WAV file: too small".to_string(),
            ));
        }

        // Check RIFF header
        if &file_data[0..4] != b"RIFF" || &file_data[8..12] != b"WAVE" {
            return Err(TorshError::IoError(
                "Invalid WAV file: missing RIFF/WAVE header".to_string(),
            ));
        }

        // Basic header parsing (simplified)
        // Real implementation would parse all chunks properly
        let channels = u16::from_le_bytes([file_data[22], file_data[23]]) as usize;
        let sample_rate =
            u32::from_le_bytes([file_data[24], file_data[25], file_data[26], file_data[27]]);
        let bits_per_sample = u16::from_le_bytes([file_data[34], file_data[35]]);

        if bits_per_sample != 16 {
            return Err(TorshError::IoError(format!(
                "Unsupported bit depth: {} (only 16-bit supported)",
                bits_per_sample
            )));
        }

        // Find data chunk (simplified search)
        let data_start = 44; // Assuming standard 44-byte header
        let data_size = file_data.len() - data_start;
        let sample_count = data_size / 2; // 16-bit samples

        // Convert 16-bit samples to f32
        let mut samples = Vec::with_capacity(sample_count);
        for i in (data_start..file_data.len()).step_by(2) {
            if i + 1 < file_data.len() {
                let sample_i16 = i16::from_le_bytes([file_data[i], file_data[i + 1]]);
                let sample_f32 = sample_i16 as f32 / 32768.0; // Normalize to [-1, 1]
                samples.push(sample_f32);
            }
        }

        let final_sample_rate = target_sample_rate.unwrap_or(sample_rate);
        let audio_data = AudioData::new(samples, final_sample_rate, channels);

        // Resample if needed
        if final_sample_rate != sample_rate {
            let resample = transforms::Resample::new(final_sample_rate);
            Ok(resample.transform(audio_data)?)
        } else {
            Ok(audio_data)
        }
    }
}

impl Dataset for AudioFolder {
    type Item = (Tensor<f32>, usize);

    fn len(&self) -> usize {
        self.samples.len()
    }

    fn get(&self, index: usize) -> Result<Self::Item> {
        if index >= self.samples.len() {
            return Err(TorshError::IndexOutOfBounds {
                index,
                size: self.samples.len(),
            });
        }

        let (ref path, class_idx) = self.samples[index];
        let audio = self.load_audio(path)?;

        let tensor = if let Some(ref transform) = self.transform {
            transform.transform(audio)?
        } else {
            // Default: convert to tensor
            AudioToTensor.transform(audio)?
        };

        Ok((tensor, class_idx))
    }
}

/// Common speech/audio datasets
pub struct LibriSpeech {
    #[allow(dead_code)]
    root: PathBuf,
    #[allow(dead_code)]
    subset: String,
    transform: Option<Box<dyn Transform<AudioData, Output = Tensor<f32>>>>,
    samples: Vec<(PathBuf, String)>, // (audio_path, transcript)
}

impl LibriSpeech {
    pub fn new<P: AsRef<Path>>(root: P, subset: &str) -> Result<Self> {
        let root = root.as_ref().to_path_buf();

        if !root.exists() {
            return Err(TorshError::IoError(format!(
                "LibriSpeech root directory does not exist: {:?}",
                root
            )));
        }

        // LibriSpeech dataset structure:
        // root/
        //   train-clean-100/
        //     speaker_id/
        //       chapter_id/
        //         speaker_id-chapter_id.trans.txt (transcriptions)
        //         speaker_id-chapter_id-utterance_id.flac (audio files)
        let mut samples = Vec::new();
        let subset_path = root.join(subset);

        if subset_path.exists() {
            samples = Self::scan_librispeech_directory(&subset_path)?;
        } else {
            // If subset directory doesn't exist, try to find it in subdirectories
            let mut found = false;
            for entry in std::fs::read_dir(&root).map_err(|e| TorshError::IoError(e.to_string()))? {
                let entry = entry.map_err(|e| TorshError::IoError(e.to_string()))?;
                let path = entry.path();

                if path.is_dir() {
                    let dir_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");

                    // Check if this directory matches the subset pattern
                    if dir_name.contains(subset) || subset.contains(dir_name) {
                        samples.extend(Self::scan_librispeech_directory(&path)?);
                        found = true;
                    }
                }
            }

            if !found {
                // Create dummy samples for demonstration
                for i in 0..100 {
                    let dummy_path = root.join(format!("dummy_audio_{}.flac", i));
                    let dummy_transcript = format!("This is dummy transcript number {}", i);
                    samples.push((dummy_path, dummy_transcript));
                }
            }
        }

        Ok(Self {
            root,
            subset: subset.to_string(),
            transform: None,
            samples,
        })
    }

    /// Set transform to apply to audio
    pub fn with_transform<T>(mut self, transform: T) -> Self
    where
        T: Transform<AudioData, Output = Tensor<f32>> + 'static,
    {
        self.transform = Some(Box::new(transform));
        self
    }

    /// Scan LibriSpeech directory structure
    fn scan_librispeech_directory(subset_path: &Path) -> Result<Vec<(PathBuf, String)>> {
        let mut samples = Vec::new();

        // Walk through speaker directories
        for speaker_entry in
            std::fs::read_dir(subset_path).map_err(|e| TorshError::IoError(e.to_string()))?
        {
            let speaker_entry = speaker_entry.map_err(|e| TorshError::IoError(e.to_string()))?;
            let speaker_path = speaker_entry.path();

            if speaker_path.is_dir() {
                // Walk through chapter directories
                for chapter_entry in std::fs::read_dir(&speaker_path)
                    .map_err(|e| TorshError::IoError(e.to_string()))?
                {
                    let chapter_entry =
                        chapter_entry.map_err(|e| TorshError::IoError(e.to_string()))?;
                    let chapter_path = chapter_entry.path();

                    if chapter_path.is_dir() {
                        // Load transcription file
                        let transcription_file = format!(
                            "{}-{}.trans.txt",
                            speaker_path
                                .file_name()
                                .expect("speaker path should have file name")
                                .to_str()
                                .expect("path should be valid UTF-8"),
                            chapter_path
                                .file_name()
                                .expect("chapter path should have file name")
                                .to_str()
                                .expect("path should be valid UTF-8")
                        );
                        let transcription_path = chapter_path.join(&transcription_file);

                        let transcriptions = Self::load_transcriptions(&transcription_path)?;

                        // Match audio files with transcriptions
                        for entry in std::fs::read_dir(&chapter_path)
                            .map_err(|e| TorshError::IoError(e.to_string()))?
                        {
                            let entry = entry.map_err(|e| TorshError::IoError(e.to_string()))?;
                            let file_path = entry.path();

                            if file_path.extension().and_then(|ext| ext.to_str()) == Some("flac") {
                                let file_stem = file_path
                                    .file_stem()
                                    .and_then(|stem| stem.to_str())
                                    .unwrap_or("");

                                if let Some(transcript) = transcriptions.get(file_stem) {
                                    samples.push((file_path, transcript.clone()));
                                }
                            }
                        }
                    }
                }
            }
        }

        Ok(samples)
    }

    /// Load transcription file
    fn load_transcriptions(path: &Path) -> Result<std::collections::HashMap<String, String>> {
        use std::collections::HashMap;

        let mut transcriptions = HashMap::new();

        if path.exists() {
            let content =
                std::fs::read_to_string(path).map_err(|e| TorshError::IoError(e.to_string()))?;

            for line in content.lines() {
                if let Some((id, transcript)) = line.split_once(' ') {
                    transcriptions.insert(id.to_string(), transcript.to_string());
                }
            }
        }

        Ok(transcriptions)
    }
}

impl Dataset for LibriSpeech {
    type Item = (Tensor<f32>, String);

    fn len(&self) -> usize {
        self.samples.len()
    }

    fn get(&self, index: usize) -> Result<Self::Item> {
        if index >= self.samples.len() {
            return Err(TorshError::IndexOutOfBounds {
                index,
                size: self.samples.len(),
            });
        }

        let (ref _path, ref transcript) = self.samples[index];

        // Create dummy audio for demonstration
        let sample_rate = 16000;
        let duration_seconds = 2.0;
        let samples_count = (sample_rate as f32 * duration_seconds) as usize;
        let samples: Vec<f32> = (0..samples_count)
            .map(|i| {
                (i as f32 * 440.0 * 2.0 * std::f32::consts::PI / sample_rate as f32).sin() * 0.1
            })
            .collect();

        let audio = AudioData::new(samples, sample_rate, 1);

        let tensor = if let Some(ref transform) = self.transform {
            transform.transform(audio)?
        } else {
            AudioToTensor.transform(audio)?
        };

        Ok((tensor, transcript.clone()))
    }
}