#[cfg(feature = "audio")]
use std::collections::HashMap;
#[cfg(feature = "audio")]
use std::fs;
#[cfg(feature = "audio")]
use std::path::{Path, PathBuf};
#[cfg(feature = "audio")]
use tenflowers_core::{Result, Tensor, TensorError};
#[cfg(feature = "audio")]
use crate::Dataset;
#[cfg(feature = "audio")]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FeatureType {
Raw,
MFCC,
MelSpectrogram,
LogSpectrogram,
Chroma,
}
#[cfg(feature = "audio")]
#[derive(Debug, Clone)]
pub struct AudioConfig {
pub sample_rate: u32,
pub max_duration: Option<f32>,
pub min_duration: Option<f32>,
pub normalize: bool,
pub feature_type: FeatureType,
pub n_mfcc: usize,
pub n_mels: usize,
pub n_fft: usize,
pub hop_length: usize,
pub supported_extensions: Vec<String>,
pub cache_audio: bool,
pub label_strategy: AudioLabelStrategy,
pub label_mapping: Option<HashMap<String, String>>,
}
#[cfg(feature = "audio")]
#[derive(Debug, Clone)]
pub enum AudioLabelStrategy {
FromFilename,
FromDirectory,
FromMapping,
None,
}
#[cfg(feature = "audio")]
impl Default for AudioConfig {
fn default() -> Self {
Self {
sample_rate: 16000,
max_duration: None,
min_duration: None,
normalize: true,
feature_type: FeatureType::Raw,
n_mfcc: 13,
n_mels: 80,
n_fft: 1024,
hop_length: 512,
supported_extensions: vec![
"wav".to_string(),
"flac".to_string(),
"mp3".to_string(),
"ogg".to_string(),
"m4a".to_string(),
],
cache_audio: false,
label_strategy: AudioLabelStrategy::FromDirectory,
label_mapping: None,
}
}
}
#[cfg(feature = "audio")]
impl AudioConfig {
pub fn with_sample_rate(mut self, sample_rate: u32) -> Self {
self.sample_rate = sample_rate;
self
}
pub fn with_max_duration(mut self, duration: f32) -> Self {
self.max_duration = Some(duration);
self
}
pub fn with_min_duration(mut self, duration: f32) -> Self {
self.min_duration = Some(duration);
self
}
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
pub fn with_feature_extraction(mut self, feature_type: FeatureType) -> Self {
self.feature_type = feature_type;
self
}
pub fn with_n_mfcc(mut self, n_mfcc: usize) -> Self {
self.n_mfcc = n_mfcc;
self
}
pub fn with_n_mels(mut self, n_mels: usize) -> Self {
self.n_mels = n_mels;
self
}
pub fn with_label_strategy(mut self, strategy: AudioLabelStrategy) -> Self {
self.label_strategy = strategy;
self
}
pub fn with_label_mapping(mut self, mapping: HashMap<String, String>) -> Self {
self.label_mapping = Some(mapping);
self
}
pub fn with_cache_audio(mut self, cache: bool) -> Self {
self.cache_audio = cache;
self
}
}
#[cfg(feature = "audio")]
#[derive(Debug, Clone)]
pub struct AudioInfo {
pub path: PathBuf,
pub sample_rate: u32,
pub channels: usize,
pub duration: f32,
pub num_samples: usize,
pub file_size: u64,
pub format: String,
pub label: Option<String>,
}
#[cfg(feature = "audio")]
#[derive(Debug, Clone)]
pub struct AudioDatasetInfo {
pub directory: PathBuf,
pub num_files: usize,
pub total_duration: f32,
pub avg_duration: f32,
pub labels: Vec<String>,
pub label_counts: HashMap<String, usize>,
pub file_info: Vec<AudioInfo>,
}
#[cfg(feature = "audio")]
pub struct AudioDatasetBuilder {
directory: Option<PathBuf>,
config: AudioConfig,
}
#[cfg(feature = "audio")]
impl Default for AudioDatasetBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "audio")]
impl AudioDatasetBuilder {
pub fn new() -> Self {
Self {
directory: None,
config: AudioConfig::default(),
}
}
pub fn directory<P: AsRef<Path>>(mut self, path: P) -> Self {
self.directory = Some(path.as_ref().to_path_buf());
self
}
pub fn config(mut self, config: AudioConfig) -> Self {
self.config = config;
self
}
pub fn sample_rate(mut self, sample_rate: u32) -> Self {
self.config.sample_rate = sample_rate;
self
}
pub fn feature_type(mut self, feature_type: FeatureType) -> Self {
self.config.feature_type = feature_type;
self
}
pub fn build(self) -> Result<AudioDataset> {
let directory = self.directory.ok_or_else(|| {
TensorError::invalid_argument("Directory must be specified".to_string())
})?;
AudioDataset::from_directory_with_config(&directory, self.config)
}
}
#[cfg(feature = "audio")]
pub struct AudioDataset {
config: AudioConfig,
info: AudioDatasetInfo,
cached_audio: Option<Vec<Vec<f32>>>,
cached_labels: Option<Vec<String>>,
label_to_idx: HashMap<String, usize>,
}
#[cfg(feature = "audio")]
impl AudioDataset {
pub fn from_directory<P: AsRef<Path>>(directory: P) -> Result<Self> {
Self::from_directory_with_config(directory, AudioConfig::default())
}
pub fn from_directory_with_config<P: AsRef<Path>>(
directory: P,
config: AudioConfig,
) -> Result<Self> {
let dir_path = directory.as_ref().to_path_buf();
if !dir_path.exists() {
return Err(TensorError::invalid_argument(format!(
"Directory not found: {}",
dir_path.display()
)));
}
if !dir_path.is_dir() {
return Err(TensorError::invalid_argument(format!(
"Path is not a directory: {}",
dir_path.display()
)));
}
let file_info = discover_audio_files(&dir_path, &config)?;
if file_info.is_empty() {
return Err(TensorError::invalid_argument(
"No supported audio files found in directory".to_string(),
));
}
let num_files = file_info.len();
let total_duration: f32 = file_info.iter().map(|info| info.duration).sum();
let avg_duration = total_duration / num_files as f32;
let mut labels = Vec::new();
let mut label_counts = HashMap::new();
for info in &file_info {
if let Some(ref label) = info.label {
if !labels.contains(label) {
labels.push(label.clone());
}
*label_counts.entry(label.clone()).or_insert(0) += 1;
}
}
labels.sort();
let label_to_idx: HashMap<String, usize> = labels
.iter()
.enumerate()
.map(|(idx, label)| (label.clone(), idx))
.collect();
let dataset_info = AudioDatasetInfo {
directory: dir_path,
num_files,
total_duration,
avg_duration,
labels,
label_counts,
file_info,
};
let mut dataset = Self {
config,
info: dataset_info,
cached_audio: None,
cached_labels: None,
label_to_idx,
};
if dataset.config.cache_audio {
dataset.load_audio()?;
}
Ok(dataset)
}
pub fn info(&self) -> &AudioDatasetInfo {
&self.info
}
fn load_audio(&mut self) -> Result<()> {
let mut cached_audio = Vec::new();
let mut cached_labels = Vec::new();
for file_info in &self.info.file_info {
let audio_data = load_audio_file(&file_info.path, &self.config)?;
cached_audio.push(audio_data);
if let Some(ref label) = file_info.label {
cached_labels.push(label.clone());
} else {
cached_labels.push("unknown".to_string());
}
}
self.cached_audio = Some(cached_audio);
self.cached_labels = Some(cached_labels);
Ok(())
}
pub fn num_classes(&self) -> usize {
self.info.labels.len()
}
pub fn label_names(&self) -> &[String] {
&self.info.labels
}
}
#[cfg(feature = "audio")]
impl Dataset<f32> for AudioDataset {
fn len(&self) -> usize {
self.info.num_files
}
fn get(&self, index: usize) -> Result<(Tensor<f32>, Tensor<f32>)> {
if index >= self.len() {
return Err(TensorError::invalid_argument(format!(
"Index {} out of bounds for dataset of length {}",
index,
self.len()
)));
}
let (audio_data, label_str) = if let Some(ref cached_audio) = self.cached_audio {
let audio = cached_audio[index].clone();
let label = self
.cached_labels
.as_ref()
.and_then(|labels| labels.get(index))
.cloned()
.unwrap_or_else(|| "unknown".to_string());
(audio, label)
} else {
let file_info = &self.info.file_info[index];
let audio = load_audio_file(&file_info.path, &self.config)?;
let label = file_info
.label
.clone()
.unwrap_or_else(|| "unknown".to_string());
(audio, label)
};
let len = audio_data.len();
let feature_tensor = Tensor::from_vec(audio_data, &[len])?;
let label_idx = self.label_to_idx.get(&label_str).copied().unwrap_or(0);
let label_tensor = Tensor::from_vec(vec![label_idx as f32], &[])?;
Ok((feature_tensor, label_tensor))
}
}
#[cfg(feature = "audio")]
fn discover_audio_files(directory: &Path, config: &AudioConfig) -> Result<Vec<AudioInfo>> {
let mut file_info = Vec::new();
for entry in fs::read_dir(directory)
.map_err(|e| TensorError::invalid_argument(format!("Failed to read directory: {e}")))?
{
let entry = entry.map_err(|e| {
TensorError::invalid_argument(format!("Failed to read directory entry: {e}"))
})?;
let path = entry.path();
if path.is_file() {
if let Some(extension) = path.extension() {
let ext_str = extension.to_string_lossy().to_lowercase();
if config.supported_extensions.contains(&ext_str) {
match get_audio_info(&path, config) {
Ok(info) => file_info.push(info),
Err(_) => continue, }
}
}
}
}
Ok(file_info)
}
#[cfg(feature = "audio")]
fn get_audio_info(path: &Path, config: &AudioConfig) -> Result<AudioInfo> {
let file_size = fs::metadata(path)
.map_err(|e| TensorError::invalid_argument(format!("Failed to get file metadata: {e}")))?
.len();
let sample_rate = config.sample_rate;
let channels = 1; let duration = 1.0; let num_samples = (sample_rate as f32 * duration) as usize;
let format = path
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or("unknown")
.to_string();
let label = extract_label(path, &config.label_strategy, &config.label_mapping);
Ok(AudioInfo {
path: path.to_path_buf(),
sample_rate,
channels,
duration,
num_samples,
file_size,
format,
label,
})
}
#[cfg(feature = "audio")]
fn extract_label(
path: &Path,
strategy: &AudioLabelStrategy,
mapping: &Option<HashMap<String, String>>,
) -> Option<String> {
match strategy {
AudioLabelStrategy::FromFilename => path
.file_stem()
.and_then(|stem| stem.to_str())
.and_then(|stem| stem.split('_').next())
.map(|s| s.to_string()),
AudioLabelStrategy::FromDirectory => path
.parent()
.and_then(|parent| parent.file_name())
.and_then(|name| name.to_str())
.map(|s| s.to_string()),
AudioLabelStrategy::FromMapping => {
if let Some(ref map) = mapping {
path.file_name()
.and_then(|name| name.to_str())
.and_then(|name| map.get(name))
.cloned()
} else {
None
}
}
AudioLabelStrategy::None => None,
}
}
#[cfg(feature = "audio")]
fn load_audio_file(_path: &Path, config: &AudioConfig) -> Result<Vec<f32>> {
let duration = config.max_duration.unwrap_or(1.0);
let sample_rate = config.sample_rate as f32;
let num_samples = (duration * sample_rate) as usize;
let mut audio_data = Vec::with_capacity(num_samples);
let frequency = 440.0;
for i in 0..num_samples {
let t = i as f32 / sample_rate;
let sample = (2.0 * std::f32::consts::PI * frequency * t).sin();
audio_data.push(sample);
}
if config.normalize {
normalize_audio(&mut audio_data);
}
Ok(audio_data)
}
#[cfg(feature = "audio")]
fn normalize_audio(audio: &mut [f32]) {
let max_abs = audio.iter().map(|&x| x.abs()).fold(0.0f32, |a, b| a.max(b));
if max_abs > 0.0 {
for sample in audio.iter_mut() {
*sample /= max_abs;
}
}
}
#[cfg(not(feature = "audio"))]
pub struct AudioConfig;
#[cfg(not(feature = "audio"))]
pub struct AudioDatasetInfo;
#[cfg(not(feature = "audio"))]
pub struct AudioDatasetBuilder;
#[cfg(not(feature = "audio"))]
pub struct AudioDataset;
#[cfg(not(feature = "audio"))]
pub struct AudioInfo;
#[cfg(not(feature = "audio"))]
pub enum FeatureType {
Raw,
}
#[cfg(not(feature = "audio"))]
pub enum AudioLabelStrategy {
None,
}
#[cfg(test)]
#[cfg(feature = "audio")]
mod tests {
use super::*;
#[test]
fn test_audio_config_default() {
let config = AudioConfig::default();
assert_eq!(config.sample_rate, 16000);
assert_eq!(config.feature_type, FeatureType::Raw);
assert!(config.normalize);
assert_eq!(config.n_mfcc, 13);
assert_eq!(config.n_mels, 80);
}
#[test]
fn test_audio_config_builder() {
let config = AudioConfig::default()
.with_sample_rate(22050)
.with_max_duration(5.0)
.with_feature_extraction(FeatureType::MFCC)
.with_n_mfcc(20)
.with_normalize(false);
assert_eq!(config.sample_rate, 22050);
assert_eq!(config.max_duration, Some(5.0));
assert_eq!(config.feature_type, FeatureType::MFCC);
assert_eq!(config.n_mfcc, 20);
assert!(!config.normalize);
}
#[test]
fn test_audio_dataset_builder() {
let builder = AudioDatasetBuilder::new()
.sample_rate(16000)
.feature_type(FeatureType::MelSpectrogram);
assert_eq!(builder.config.sample_rate, 16000);
assert_eq!(builder.config.feature_type, FeatureType::MelSpectrogram);
}
#[test]
fn test_normalize_audio() {
let mut audio = vec![0.5, -1.0, 0.25, -0.5];
normalize_audio(&mut audio);
let max_abs = audio.iter().map(|&x| x.abs()).fold(0.0f32, |a, b| a.max(b));
assert!((max_abs - 1.0).abs() < 1e-6);
}
}