use std::path::{Path, PathBuf};
use voirs_dataset::{
loaders::LjSpeechLoader,
processing::features::{extract_mel_spectrogram, MelSpectrogramConfig},
traits::Dataset,
DatasetSample,
};
use voirs_sdk::Result;
pub struct VocoderDataLoader {
samples: Vec<DatasetSample>,
mel_config: MelSpectrogramConfig,
current_index: usize,
}
impl VocoderDataLoader {
pub async fn load<P: AsRef<Path>>(data_dir: P) -> Result<Self> {
let is_valid = LjSpeechLoader::is_valid_dataset(data_dir.as_ref());
let dataset = if is_valid {
LjSpeechLoader::load(data_dir).await.map_err(|e| {
voirs_sdk::VoirsError::config_error(format!("Failed to load dataset: {}", e))
})?
} else {
return Err(voirs_sdk::VoirsError::config_error(format!(
"Unsupported dataset format at {:?}. Currently only LJSpeech is supported.",
data_dir.as_ref()
)));
};
let num_samples = dataset.len();
let mut samples = Vec::with_capacity(num_samples);
for i in 0..num_samples {
match dataset.get(i).await {
Ok(sample) => samples.push(sample),
Err(e) => {
eprintln!("Warning: Failed to load sample {}: {}", i, e);
}
}
}
if samples.is_empty() {
return Err(voirs_sdk::VoirsError::config_error(
"No valid samples found in dataset".to_string(),
));
}
Ok(Self {
samples,
mel_config: MelSpectrogramConfig::default(),
current_index: 0,
})
}
pub fn len(&self) -> usize {
self.samples.len()
}
pub fn is_empty(&self) -> bool {
self.samples.is_empty()
}
pub fn get_batch(&mut self, batch_size: usize) -> Result<VocoderBatch> {
let mut batch_audio = Vec::new();
let mut batch_mels = Vec::new();
for _ in 0..batch_size {
if self.current_index >= self.samples.len() {
self.current_index = 0;
}
let sample = &self.samples[self.current_index];
self.current_index += 1;
let mel_result = extract_mel_spectrogram(
&sample.audio,
self.mel_config.n_mels,
self.mel_config.n_fft,
self.mel_config.hop_length,
)
.map_err(|e| {
voirs_sdk::VoirsError::config_error(format!(
"Failed to extract mel spectrogram: {}",
e
))
})?;
let mel_matrix = mel_result.as_matrix();
batch_audio.push(sample.audio.samples().to_vec());
batch_mels.push(mel_matrix);
}
Ok(VocoderBatch {
audio: batch_audio,
mels: batch_mels,
})
}
pub fn reset(&mut self) {
self.current_index = 0;
}
pub fn current_index(&self) -> usize {
self.current_index
}
pub fn set_index(&mut self, index: usize) {
self.current_index = index.min(self.samples.len());
}
}
pub struct VocoderBatch {
pub audio: Vec<Vec<f32>>,
pub mels: Vec<Vec<Vec<f32>>>,
}
impl VocoderBatch {
pub fn len(&self) -> usize {
self.audio.len()
}
pub fn is_empty(&self) -> bool {
self.audio.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
fn resolve_ljspeech_path() -> PathBuf {
env::var("LJSPEECH_PATH")
.map(PathBuf::from)
.unwrap_or_else(|_| {
std::env::temp_dir()
.join("voirs")
.join("datasets")
.join("LJSpeech-1.1")
})
}
#[tokio::test]
async fn test_vocoder_data_loader_basic() {
let ljspeech_path = resolve_ljspeech_path();
if !ljspeech_path.exists() {
eprintln!(
"Skipping test: LJSpeech dataset not found at {}",
ljspeech_path.display()
);
return;
}
let loader = VocoderDataLoader::load(&ljspeech_path).await;
assert!(loader.is_ok(), "Failed to load dataset");
let loader = loader.unwrap();
assert!(loader.len() > 0, "Dataset should not be empty");
assert!(!loader.is_empty(), "Dataset should not be empty");
}
#[tokio::test]
async fn test_batch_generation() {
let ljspeech_path = resolve_ljspeech_path();
if !ljspeech_path.exists() {
eprintln!("Skipping test: LJSpeech dataset not found");
return;
}
let mut loader = VocoderDataLoader::load(&ljspeech_path).await.unwrap();
let batch_size = 4;
let batch = loader.get_batch(batch_size).unwrap();
assert_eq!(batch.len(), batch_size, "Batch size should match");
assert_eq!(
batch.audio.len(),
batch_size,
"Audio batch size should match"
);
assert_eq!(batch.mels.len(), batch_size, "Mel batch size should match");
for mel in &batch.mels {
assert!(!mel.is_empty(), "Mel spectrogram should not be empty");
assert!(mel[0].len() > 0, "Mel spectrogram should have features");
}
}
#[tokio::test]
async fn test_batch_wraparound() {
let ljspeech_path = resolve_ljspeech_path();
if !ljspeech_path.exists() {
eprintln!("Skipping test: LJSpeech dataset not found");
return;
}
let mut loader = VocoderDataLoader::load(&ljspeech_path).await.unwrap();
let total_samples = loader.len();
let batch_size = 4;
let num_batches = (total_samples / batch_size) + 2;
for i in 0..num_batches {
let batch = loader.get_batch(batch_size);
assert!(batch.is_ok(), "Batch generation failed at iteration {}", i);
assert_eq!(batch.unwrap().len(), batch_size);
}
}
#[test]
fn test_vocoder_batch_properties() {
let batch = VocoderBatch {
audio: vec![vec![0.0; 100]; 4],
mels: vec![vec![vec![0.0; 80]; 10]; 4],
};
assert_eq!(batch.len(), 4);
assert!(!batch.is_empty());
let empty_batch = VocoderBatch {
audio: vec![],
mels: vec![],
};
assert_eq!(empty_batch.len(), 0);
assert!(empty_batch.is_empty());
}
#[tokio::test]
async fn test_invalid_dataset_path() {
let invalid_path = "/nonexistent/path/to/dataset";
let result = VocoderDataLoader::load(invalid_path).await;
assert!(result.is_err(), "Should fail with invalid path");
}
#[tokio::test]
async fn test_mel_spectrogram_shape() {
let ljspeech_path = resolve_ljspeech_path();
if !ljspeech_path.exists() {
eprintln!("Skipping test: LJSpeech dataset not found");
return;
}
let mut loader = VocoderDataLoader::load(&ljspeech_path).await.unwrap();
let batch = loader.get_batch(1).unwrap();
assert_eq!(batch.mels.len(), 1);
let mel = &batch.mels[0];
assert!(!mel.is_empty(), "Mel spectrogram should have frames");
for frame in mel {
assert_eq!(frame.len(), 80, "Each frame should have 80 mel bins");
}
}
}