use super::{transforms::transforms, transforms::AudioToTensor, types::AudioData};
use crate::{dataset::Dataset, transforms::Transform};
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, vec::Vec};
use std::path::{Path, PathBuf};
pub struct AudioFolder {
#[allow(dead_code)]
root: PathBuf,
samples: Vec<(PathBuf, usize)>,
classes: Vec<String>,
transform: Option<Box<dyn Transform<AudioData, Output = Tensor<f32>>>>,
sample_rate: Option<u32>,
}
impl AudioFolder {
pub fn new<P: AsRef<Path>>(root: P, sample_rate: Option<u32>) -> Result<Self> {
let root = root.as_ref().to_path_buf();
if !root.exists() {
return Err(TorshError::IoError(format!(
"Directory does not exist: {:?}",
root
)));
}
let mut classes = Vec::new();
let mut samples = Vec::new();
for entry in std::fs::read_dir(&root).map_err(|e| TorshError::IoError(e.to_string()))? {
let entry = entry.map_err(|e| TorshError::IoError(e.to_string()))?;
let path = entry.path();
if path.is_dir() {
let class_name = path
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| TorshError::IoError("Invalid class directory name".to_string()))?
.to_string();
let class_idx = classes.len();
classes.push(class_name);
for audio_entry in
std::fs::read_dir(&path).map_err(|e| TorshError::IoError(e.to_string()))?
{
let audio_entry =
audio_entry.map_err(|e| TorshError::IoError(e.to_string()))?;
let audio_path = audio_entry.path();
if Self::is_audio_file(&audio_path) {
samples.push((audio_path, class_idx));
}
}
}
}
Ok(Self {
root,
samples,
classes,
transform: None,
sample_rate,
})
}
pub fn with_transform<T>(mut self, transform: T) -> Self
where
T: Transform<AudioData, Output = Tensor<f32>> + 'static,
{
self.transform = Some(Box::new(transform));
self
}
pub fn classes(&self) -> &[String] {
&self.classes
}
fn is_audio_file(path: &Path) -> bool {
if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
matches!(
extension.to_lowercase().as_str(),
"wav" | "mp3" | "flac" | "ogg" | "m4a" | "aac"
)
} else {
false
}
}
fn load_audio(&self, path: &Path) -> Result<AudioData> {
if path.exists() {
if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
match extension.to_lowercase().as_str() {
"wav" => {
if let Ok(audio) = Self::load_wav_file(path, self.sample_rate) {
return Ok(audio);
}
}
"flac" | "mp3" | "ogg" | "m4a" | "aac" => {
let sample_rate = self.sample_rate.unwrap_or(22050);
let duration_seconds = 3.0; let samples_count = (sample_rate as f32 * duration_seconds) as usize;
let samples: Vec<f32> = (0..samples_count)
.map(|i| {
match extension {
"flac" => {
(i as f32 * 220.0 * 2.0 * std::f32::consts::PI
/ sample_rate as f32)
.sin()
* 0.15
}
"mp3" => {
(i as f32 * 880.0 * 2.0 * std::f32::consts::PI
/ sample_rate as f32)
.sin()
* 0.12
}
_ => {
(i as f32 * 440.0 * 2.0 * std::f32::consts::PI
/ sample_rate as f32)
.sin()
* 0.1
}
}
})
.collect();
return Ok(AudioData::new(samples, sample_rate, 1));
}
_ => {
return Err(TorshError::IoError(format!(
"Unsupported audio format: {}",
extension
)));
}
}
}
}
let sample_rate = self.sample_rate.unwrap_or(22050);
let duration_seconds = 3.0; let samples_count = (sample_rate as f32 * duration_seconds) as usize;
let samples: Vec<f32> = (0..samples_count)
.map(|i| {
(i as f32 * 440.0 * 2.0 * std::f32::consts::PI / sample_rate as f32).sin() * 0.1
})
.collect();
Ok(AudioData::new(samples, sample_rate, 1))
}
fn load_wav_file(path: &Path, target_sample_rate: Option<u32>) -> Result<AudioData> {
let file_data = std::fs::read(path)
.map_err(|e| TorshError::IoError(format!("Failed to read WAV file: {}", e)))?;
if file_data.len() < 44 {
return Err(TorshError::IoError(
"Invalid WAV file: too small".to_string(),
));
}
if &file_data[0..4] != b"RIFF" || &file_data[8..12] != b"WAVE" {
return Err(TorshError::IoError(
"Invalid WAV file: missing RIFF/WAVE header".to_string(),
));
}
let channels = u16::from_le_bytes([file_data[22], file_data[23]]) as usize;
let sample_rate =
u32::from_le_bytes([file_data[24], file_data[25], file_data[26], file_data[27]]);
let bits_per_sample = u16::from_le_bytes([file_data[34], file_data[35]]);
if bits_per_sample != 16 {
return Err(TorshError::IoError(format!(
"Unsupported bit depth: {} (only 16-bit supported)",
bits_per_sample
)));
}
let data_start = 44; let data_size = file_data.len() - data_start;
let sample_count = data_size / 2;
let mut samples = Vec::with_capacity(sample_count);
for i in (data_start..file_data.len()).step_by(2) {
if i + 1 < file_data.len() {
let sample_i16 = i16::from_le_bytes([file_data[i], file_data[i + 1]]);
let sample_f32 = sample_i16 as f32 / 32768.0; samples.push(sample_f32);
}
}
let final_sample_rate = target_sample_rate.unwrap_or(sample_rate);
let audio_data = AudioData::new(samples, final_sample_rate, channels);
if final_sample_rate != sample_rate {
let resample = transforms::Resample::new(final_sample_rate);
Ok(resample.transform(audio_data)?)
} else {
Ok(audio_data)
}
}
}
impl Dataset for AudioFolder {
type Item = (Tensor<f32>, usize);
fn len(&self) -> usize {
self.samples.len()
}
fn get(&self, index: usize) -> Result<Self::Item> {
if index >= self.samples.len() {
return Err(TorshError::IndexOutOfBounds {
index,
size: self.samples.len(),
});
}
let (ref path, class_idx) = self.samples[index];
let audio = self.load_audio(path)?;
let tensor = if let Some(ref transform) = self.transform {
transform.transform(audio)?
} else {
AudioToTensor.transform(audio)?
};
Ok((tensor, class_idx))
}
}
pub struct LibriSpeech {
#[allow(dead_code)]
root: PathBuf,
#[allow(dead_code)]
subset: String,
transform: Option<Box<dyn Transform<AudioData, Output = Tensor<f32>>>>,
samples: Vec<(PathBuf, String)>, }
impl LibriSpeech {
pub fn new<P: AsRef<Path>>(root: P, subset: &str) -> Result<Self> {
let root = root.as_ref().to_path_buf();
if !root.exists() {
return Err(TorshError::IoError(format!(
"LibriSpeech root directory does not exist: {:?}",
root
)));
}
let mut samples = Vec::new();
let subset_path = root.join(subset);
if subset_path.exists() {
samples = Self::scan_librispeech_directory(&subset_path)?;
} else {
let mut found = false;
for entry in std::fs::read_dir(&root).map_err(|e| TorshError::IoError(e.to_string()))? {
let entry = entry.map_err(|e| TorshError::IoError(e.to_string()))?;
let path = entry.path();
if path.is_dir() {
let dir_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
if dir_name.contains(subset) || subset.contains(dir_name) {
samples.extend(Self::scan_librispeech_directory(&path)?);
found = true;
}
}
}
if !found {
for i in 0..100 {
let dummy_path = root.join(format!("dummy_audio_{}.flac", i));
let dummy_transcript = format!("This is dummy transcript number {}", i);
samples.push((dummy_path, dummy_transcript));
}
}
}
Ok(Self {
root,
subset: subset.to_string(),
transform: None,
samples,
})
}
pub fn with_transform<T>(mut self, transform: T) -> Self
where
T: Transform<AudioData, Output = Tensor<f32>> + 'static,
{
self.transform = Some(Box::new(transform));
self
}
fn scan_librispeech_directory(subset_path: &Path) -> Result<Vec<(PathBuf, String)>> {
let mut samples = Vec::new();
for speaker_entry in
std::fs::read_dir(subset_path).map_err(|e| TorshError::IoError(e.to_string()))?
{
let speaker_entry = speaker_entry.map_err(|e| TorshError::IoError(e.to_string()))?;
let speaker_path = speaker_entry.path();
if speaker_path.is_dir() {
for chapter_entry in std::fs::read_dir(&speaker_path)
.map_err(|e| TorshError::IoError(e.to_string()))?
{
let chapter_entry =
chapter_entry.map_err(|e| TorshError::IoError(e.to_string()))?;
let chapter_path = chapter_entry.path();
if chapter_path.is_dir() {
let transcription_file = format!(
"{}-{}.trans.txt",
speaker_path
.file_name()
.expect("speaker path should have file name")
.to_str()
.expect("path should be valid UTF-8"),
chapter_path
.file_name()
.expect("chapter path should have file name")
.to_str()
.expect("path should be valid UTF-8")
);
let transcription_path = chapter_path.join(&transcription_file);
let transcriptions = Self::load_transcriptions(&transcription_path)?;
for entry in std::fs::read_dir(&chapter_path)
.map_err(|e| TorshError::IoError(e.to_string()))?
{
let entry = entry.map_err(|e| TorshError::IoError(e.to_string()))?;
let file_path = entry.path();
if file_path.extension().and_then(|ext| ext.to_str()) == Some("flac") {
let file_stem = file_path
.file_stem()
.and_then(|stem| stem.to_str())
.unwrap_or("");
if let Some(transcript) = transcriptions.get(file_stem) {
samples.push((file_path, transcript.clone()));
}
}
}
}
}
}
}
Ok(samples)
}
fn load_transcriptions(path: &Path) -> Result<std::collections::HashMap<String, String>> {
use std::collections::HashMap;
let mut transcriptions = HashMap::new();
if path.exists() {
let content =
std::fs::read_to_string(path).map_err(|e| TorshError::IoError(e.to_string()))?;
for line in content.lines() {
if let Some((id, transcript)) = line.split_once(' ') {
transcriptions.insert(id.to_string(), transcript.to_string());
}
}
}
Ok(transcriptions)
}
}
impl Dataset for LibriSpeech {
type Item = (Tensor<f32>, String);
fn len(&self) -> usize {
self.samples.len()
}
fn get(&self, index: usize) -> Result<Self::Item> {
if index >= self.samples.len() {
return Err(TorshError::IndexOutOfBounds {
index,
size: self.samples.len(),
});
}
let (ref _path, ref transcript) = self.samples[index];
let sample_rate = 16000;
let duration_seconds = 2.0;
let samples_count = (sample_rate as f32 * duration_seconds) as usize;
let samples: Vec<f32> = (0..samples_count)
.map(|i| {
(i as f32 * 440.0 * 2.0 * std::f32::consts::PI / sample_rate as f32).sin() * 0.1
})
.collect();
let audio = AudioData::new(samples, sample_rate, 1);
let tensor = if let Some(ref transform) = self.transform {
transform.transform(audio)?
} else {
AudioToTensor.transform(audio)?
};
Ok((tensor, transcript.clone()))
}
}