use super::{FeatureExtractor, FeatureExtractorConfig};
use crate::auto::types::{FeatureInput, FeatureOutput};
use crate::error::{Result, TrustformersError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct AudioFeatureExtractor {
config: AudioFeatureConfig,
}
impl AudioFeatureExtractor {
pub fn new(config: AudioFeatureConfig) -> Self {
Self { config }
}
pub fn get_config(&self) -> &AudioFeatureConfig {
&self.config
}
fn preprocess_audio(&self, samples: &[f32], sample_rate: u32) -> Result<Vec<f32>> {
let mut processed = samples.to_vec();
if sample_rate != self.config.sampling_rate {
processed = self.resample(samples, sample_rate, self.config.sampling_rate)?;
}
if self.config.normalize {
let max_val = processed.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
if max_val > 0.0 {
for sample in &mut processed {
*sample /= max_val;
}
}
}
Ok(processed)
}
fn resample(&self, samples: &[f32], from_rate: u32, to_rate: u32) -> Result<Vec<f32>> {
let ratio = to_rate as f32 / from_rate as f32;
let new_length = (samples.len() as f32 * ratio) as usize;
let mut resampled = Vec::with_capacity(new_length);
for i in 0..new_length {
let src_index = i as f32 / ratio;
let index = src_index as usize;
if index + 1 < samples.len() {
let fraction = src_index - index as f32;
let interpolated =
samples[index] * (1.0 - fraction) + samples[index + 1] * fraction;
resampled.push(interpolated);
} else {
resampled.push(samples[samples.len() - 1]);
}
}
Ok(resampled)
}
fn extract_audio_features(&self, audio: &[f32]) -> Result<Vec<f32>> {
let n_frames = audio.len() / self.config.hop_length;
let n_features = self.config.feature_size;
let mut features = Vec::with_capacity(n_frames * n_features);
for frame_idx in 0..n_frames {
let start = frame_idx * self.config.hop_length;
let end = std::cmp::min(start + self.config.n_fft, audio.len());
for feat_idx in 0..n_features {
let mut feature_val = 0.0f32;
for sample_idx in start..end {
feature_val += audio[sample_idx] * (feat_idx as f32 / n_features as f32).cos();
}
features.push(feature_val / (end - start) as f32);
}
}
Ok(features)
}
}
impl FeatureExtractor for AudioFeatureExtractor {
fn extract_features(&self, input: &FeatureInput) -> Result<FeatureOutput> {
match input {
FeatureInput::Audio {
samples,
sample_rate,
metadata,
} => {
let processed_audio = self.preprocess_audio(samples, *sample_rate)?;
let features = self.extract_audio_features(&processed_audio)?;
let mut output_metadata = HashMap::new();
output_metadata.insert(
"sample_rate".to_string(),
serde_json::Value::Number((*sample_rate).into()),
);
output_metadata.insert(
"original_length".to_string(),
serde_json::Value::Number(samples.len().into()),
);
output_metadata.insert(
"processed_length".to_string(),
serde_json::Value::Number(processed_audio.len().into()),
);
if let Some(meta) = metadata {
output_metadata.insert(
"duration".to_string(),
serde_json::Number::from_f64(meta.duration)
.map(serde_json::Value::Number)
.unwrap_or_else(|| {
serde_json::Value::String(format!("{}", meta.duration))
}),
);
output_metadata.insert(
"channels".to_string(),
serde_json::Value::Number(meta.channels.into()),
);
if let Some(bit_depth) = meta.bit_depth {
output_metadata.insert(
"bit_depth".to_string(),
serde_json::Value::Number(bit_depth.into()),
);
}
}
let n_frames = processed_audio.len() / self.config.hop_length;
Ok(FeatureOutput {
features,
shape: vec![n_frames, self.config.feature_size],
metadata: output_metadata,
attention_mask: None,
special_tokens: vec![],
})
},
_ => Err(TrustformersError::invalid_input_simple(
"Audio feature extractor requires audio input".to_string(),
)),
}
}
fn config(&self) -> &dyn FeatureExtractorConfig {
&self.config
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioFeatureConfig {
pub sampling_rate: u32,
pub feature_size: usize,
pub n_fft: usize,
pub hop_length: usize,
pub normalize: bool,
pub max_batch_size: Option<usize>,
}
impl AudioFeatureConfig {
pub fn from_config(config: &serde_json::Value) -> Result<Self> {
Ok(Self {
sampling_rate: config.get("sampling_rate").and_then(|v| v.as_u64()).unwrap_or(16000)
as u32,
feature_size: config.get("feature_size").and_then(|v| v.as_u64()).unwrap_or(80)
as usize,
n_fft: config.get("n_fft").and_then(|v| v.as_u64()).unwrap_or(512) as usize,
hop_length: config.get("hop_length").and_then(|v| v.as_u64()).unwrap_or(160) as usize,
normalize: config.get("normalize").and_then(|v| v.as_bool()).unwrap_or(true),
max_batch_size: config
.get("max_batch_size")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
})
}
pub fn validate_config(&self) -> Result<()> {
if self.sampling_rate == 0 {
return Err(TrustformersError::invalid_input(
"sampling_rate must be greater than 0".to_string(),
Some("sampling_rate".to_string()),
Some("positive integer > 0".to_string()),
Some("0".to_string()),
));
}
if self.feature_size == 0 {
return Err(TrustformersError::invalid_input(
"feature_size must be greater than 0".to_string(),
Some("feature_size".to_string()),
Some("positive integer > 0".to_string()),
Some("0".to_string()),
));
}
if self.n_fft == 0 {
return Err(TrustformersError::invalid_input(
"n_fft must be greater than 0".to_string(),
Some("n_fft".to_string()),
Some("positive integer > 0".to_string()),
Some("0".to_string()),
));
}
if self.hop_length == 0 {
return Err(TrustformersError::invalid_input(
"hop_length must be greater than 0".to_string(),
Some("hop_length".to_string()),
Some("positive integer > 0".to_string()),
Some("0".to_string()),
));
}
if self.hop_length > self.n_fft {
return Err(TrustformersError::invalid_input(
"hop_length should not exceed n_fft".to_string(),
Some("hop_length".to_string()),
Some("value <= n_fft".to_string()),
Some("value > n_fft".to_string()),
));
}
if let Some(batch_size) = self.max_batch_size {
if batch_size == 0 {
return Err(TrustformersError::invalid_input(
"max_batch_size must be greater than 0 if specified".to_string(),
Some("max_batch_size".to_string()),
Some("positive integer > 0".to_string()),
Some("0".to_string()),
));
}
}
if !self.n_fft.is_power_of_two() {
}
Ok(())
}
pub fn get_output_shape(&self, input_length: usize) -> (usize, usize) {
let n_frames = input_length / self.hop_length;
(n_frames, self.feature_size)
}
pub fn estimate_memory_usage(&self, input_length: usize) -> usize {
let (n_frames, feature_size) = self.get_output_shape(input_length);
let input_memory = input_length * 4;
let processed_memory = input_length * 4;
let output_memory = n_frames * feature_size * 4;
let temp_memory = self.n_fft * 4;
input_memory + processed_memory + output_memory + temp_memory
}
}
impl FeatureExtractorConfig for AudioFeatureConfig {
fn feature_size(&self) -> usize {
self.feature_size
}
fn supports_batching(&self) -> bool {
true
}
fn max_batch_size(&self) -> Option<usize> {
self.max_batch_size
}
fn additional_params(&self) -> HashMap<String, serde_json::Value> {
let mut params = HashMap::new();
params.insert(
"sampling_rate".to_string(),
serde_json::Value::Number(self.sampling_rate.into()),
);
params.insert(
"n_fft".to_string(),
serde_json::Value::Number(self.n_fft.into()),
);
params.insert(
"hop_length".to_string(),
serde_json::Value::Number(self.hop_length.into()),
);
params.insert(
"normalize".to_string(),
serde_json::Value::Bool(self.normalize),
);
params
}
fn validate(&self) -> Result<()> {
self.validate_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auto::types::AudioMetadata;
#[test]
fn test_audio_feature_extractor_creation() {
let config = AudioFeatureConfig {
sampling_rate: 16000,
feature_size: 80,
n_fft: 512,
hop_length: 160,
normalize: true,
max_batch_size: Some(16),
};
let extractor = AudioFeatureExtractor::new(config);
assert_eq!(extractor.config().feature_size(), 80);
assert!(extractor.config().supports_batching());
assert_eq!(extractor.config().max_batch_size(), Some(16));
}
#[test]
fn test_audio_feature_extraction() {
let config = AudioFeatureConfig {
sampling_rate: 16000,
feature_size: 80,
n_fft: 512,
hop_length: 160,
normalize: true,
max_batch_size: Some(16),
};
let extractor = AudioFeatureExtractor::new(config);
let samples: Vec<f32> = (0..16000).map(|i| (i as f32 * 0.001).sin()).collect();
let input = FeatureInput::Audio {
samples,
sample_rate: 16000,
metadata: Some(AudioMetadata {
duration: 1.0,
channels: 1,
bit_depth: Some(16),
}),
};
let result = extractor.extract_features(&input);
assert!(result.is_ok());
let output = result.expect("operation failed in test");
let expected_frames = 16000 / 160; assert_eq!(output.shape, vec![expected_frames, 80]);
assert_eq!(output.features.len(), expected_frames * 80);
}
#[test]
fn test_resampling() {
let config = AudioFeatureConfig {
sampling_rate: 16000,
feature_size: 80,
n_fft: 512,
hop_length: 160,
normalize: true,
max_batch_size: Some(16),
};
let extractor = AudioFeatureExtractor::new(config);
let samples_8k: Vec<f32> = (0..8000).map(|i| (i as f32 * 0.001).sin()).collect();
let resampled =
extractor.resample(&samples_8k, 8000, 16000).expect("operation failed in test");
assert_eq!(resampled.len(), 16000);
}
#[test]
fn test_audio_config_from_json() {
let config_json = serde_json::json!({
"sampling_rate": 22050,
"feature_size": 128,
"n_fft": 1024,
"hop_length": 256,
"normalize": false
});
let config =
AudioFeatureConfig::from_config(&config_json).expect("operation failed in test");
assert_eq!(config.sampling_rate, 22050);
assert_eq!(config.feature_size, 128);
assert_eq!(config.n_fft, 1024);
assert_eq!(config.hop_length, 256);
assert!(!config.normalize);
}
#[test]
fn test_config_validation() {
let valid_config = AudioFeatureConfig {
sampling_rate: 16000,
feature_size: 80,
n_fft: 512,
hop_length: 160,
normalize: true,
max_batch_size: Some(16),
};
assert!(valid_config.validate_config().is_ok());
let invalid_config = AudioFeatureConfig {
sampling_rate: 16000,
feature_size: 80,
n_fft: 512,
hop_length: 1024, normalize: true,
max_batch_size: Some(16),
};
assert!(invalid_config.validate_config().is_err());
let zero_rate_config = AudioFeatureConfig {
sampling_rate: 0,
feature_size: 80,
n_fft: 512,
hop_length: 160,
normalize: true,
max_batch_size: Some(16),
};
assert!(zero_rate_config.validate_config().is_err());
}
#[test]
fn test_output_shape_calculation() {
let config = AudioFeatureConfig {
sampling_rate: 16000,
feature_size: 80,
n_fft: 512,
hop_length: 160,
normalize: true,
max_batch_size: Some(16),
};
let (n_frames, feature_size) = config.get_output_shape(16000); assert_eq!(n_frames, 100); assert_eq!(feature_size, 80);
}
#[test]
fn test_memory_estimation() {
let config = AudioFeatureConfig {
sampling_rate: 16000,
feature_size: 80,
n_fft: 512,
hop_length: 160,
normalize: true,
max_batch_size: Some(16),
};
let memory_usage = config.estimate_memory_usage(16000);
assert!(memory_usage > 0);
let expected_min = 16000 * 4 * 2 + 100 * 80 * 4 + 512 * 4;
assert!(memory_usage >= expected_min);
}
#[test]
fn test_invalid_input_type() {
let config = AudioFeatureConfig {
sampling_rate: 16000,
feature_size: 80,
n_fft: 512,
hop_length: 160,
normalize: true,
max_batch_size: Some(16),
};
let extractor = AudioFeatureExtractor::new(config);
let input = FeatureInput::Text {
content: "This is not audio".to_string(),
metadata: None,
};
let result = extractor.extract_features(&input);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
TrustformersError::InvalidInput { .. }
));
}
}