use crate::error::{Result, TrustformersError};
use crate::pipeline::{BasePipeline, Pipeline};
use crate::{AutoModel, AutoTokenizer};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum AudioInput {
FilePath(String),
RawAudio { samples: Vec<f32>, sample_rate: u32 },
Base64(String),
Bytes {
data: Vec<u8>,
format: AudioFormat,
sample_rate: u32,
},
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum AudioFormat {
Wav,
Flac,
Mp3,
M4a,
Ogg,
WebM,
}
impl AudioFormat {
pub fn from_extension(ext: &str) -> Option<Self> {
match ext.to_lowercase().as_str() {
"wav" => Some(Self::Wav),
"flac" => Some(Self::Flac),
"mp3" => Some(Self::Mp3),
"m4a" => Some(Self::M4a),
"ogg" => Some(Self::Ogg),
"webm" => Some(Self::WebM),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpeechToTextOutput {
pub text: String,
pub confidence: Option<f32>,
pub word_timestamps: Option<Vec<WordTimestamp>>,
pub language: Option<String>,
pub processing_time_ms: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WordTimestamp {
pub word: String,
pub start_time: f64, pub end_time: f64, pub confidence: f32, }
#[derive(Clone, Debug)]
pub struct SpeechToTextConfig {
pub sample_rate: u32,
pub max_duration: Option<f64>,
pub return_timestamps: bool,
pub language: Option<String>,
pub task: SpeechTask,
pub num_beams: usize,
pub temperature: f32,
pub length_penalty: f32,
pub repetition_penalty: f32,
pub no_repeat_ngram_size: usize,
pub chunk_length_s: Option<f64>,
pub stride_length_s: Option<f64>,
}
impl Default for SpeechToTextConfig {
fn default() -> Self {
Self {
sample_rate: 16000, max_duration: Some(30.0), return_timestamps: false,
language: None, task: SpeechTask::Transcribe,
num_beams: 1, temperature: 0.0, length_penalty: 1.0,
repetition_penalty: 1.0,
no_repeat_ngram_size: 0,
chunk_length_s: Some(30.0), stride_length_s: Some(5.0), }
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum SpeechTask {
Transcribe,
Translate,
}
#[derive(Clone)]
pub struct SpeechToTextPipeline {
base: BasePipeline<AutoModel, AutoTokenizer>,
config: SpeechToTextConfig,
feature_extractor: Arc<AudioFeatureExtractor>,
}
impl SpeechToTextPipeline {
pub fn new(model: AutoModel, tokenizer: AutoTokenizer) -> Result<Self> {
let base = BasePipeline::new(model, tokenizer);
let config = SpeechToTextConfig::default();
let feature_extractor = Arc::new(AudioFeatureExtractor::new(config.sample_rate)?);
Ok(Self {
base,
config,
feature_extractor,
})
}
pub fn with_config(mut self, config: SpeechToTextConfig) -> Self {
self.config = config;
self
}
pub fn with_language(mut self, language: String) -> Self {
self.config.language = Some(language);
self
}
pub fn with_timestamps(mut self, enable: bool) -> Self {
self.config.return_timestamps = enable;
self
}
pub fn with_task(mut self, task: SpeechTask) -> Self {
self.config.task = task;
self
}
pub fn with_chunk_length(mut self, chunk_length_s: f64) -> Self {
self.config.chunk_length_s = Some(chunk_length_s);
self
}
pub fn transcribe_file<P: AsRef<Path>>(&self, audio_path: P) -> Result<SpeechToTextOutput> {
let input = AudioInput::FilePath(audio_path.as_ref().to_string_lossy().to_string());
self.__call__(input)
}
pub fn transcribe_samples(
&self,
samples: Vec<f32>,
sample_rate: u32,
) -> Result<SpeechToTextOutput> {
let input = AudioInput::RawAudio {
samples,
sample_rate,
};
self.__call__(input)
}
pub fn transcribe_streaming(&self, audio_chunk: &[f32]) -> Result<SpeechToTextOutput> {
let input = AudioInput::RawAudio {
samples: audio_chunk.to_vec(),
sample_rate: self.config.sample_rate,
};
self.__call__(input)
}
fn preprocess_audio(&self, input: &AudioInput) -> Result<AudioFeatures> {
match input {
AudioInput::FilePath(path) => {
self.feature_extractor.load_and_extract(path)
},
AudioInput::RawAudio {
samples,
sample_rate,
} => {
let resampled = if *sample_rate != self.config.sample_rate {
self.feature_extractor.resample(
samples,
*sample_rate,
self.config.sample_rate,
)?
} else {
samples.clone()
};
self.feature_extractor.extract_features(&resampled)
},
AudioInput::Base64(encoded) => {
let decoded = base64::decode(encoded).map_err(|e| {
TrustformersError::invalid_input_simple(format!(
"Failed to decode base64 audio: {}",
e
))
})?;
self.feature_extractor.decode_and_extract(&decoded, AudioFormat::Wav)
},
AudioInput::Bytes {
data,
format,
sample_rate,
} => {
self.feature_extractor
.decode_and_extract(data, *format)?
.resample_to(self.config.sample_rate)
},
}
}
fn postprocess_output(
&self,
model_output: &crate::core::tensor::Tensor,
audio_duration: f64,
) -> Result<SpeechToTextOutput> {
let text = "Transcribed text placeholder".to_string(); let confidence = Some(0.95);
let word_timestamps = if self.config.return_timestamps {
Some(vec![
WordTimestamp {
word: "Transcribed".to_string(),
start_time: 0.0,
end_time: 0.5,
confidence: 0.95,
},
WordTimestamp {
word: "text".to_string(),
start_time: 0.5,
end_time: 1.0,
confidence: 0.90,
},
])
} else {
None
};
Ok(SpeechToTextOutput {
text,
confidence,
word_timestamps,
language: self.config.language.clone(),
processing_time_ms: Some(100), })
}
}
impl Pipeline for SpeechToTextPipeline {
type Input = AudioInput;
type Output = SpeechToTextOutput;
fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
let start_time = std::time::Instant::now();
let audio_features = self.preprocess_audio(&input)?;
let audio_duration = audio_features.duration();
if let Some(max_duration) = self.config.max_duration {
if audio_duration > max_duration {
return Err(TrustformersError::invalid_input_simple(format!(
"Audio duration ({:.2}s) exceeds maximum allowed ({:.2}s)",
audio_duration, max_duration
)));
}
}
let input_tensor = audio_features.to_tensor()?;
let model_output = input_tensor;
let mut result = self.postprocess_output(&model_output, audio_duration)?;
result.processing_time_ms = Some(start_time.elapsed().as_millis() as u64);
Ok(result)
}
}
pub struct AudioFeatureExtractor {
sample_rate: u32,
n_fft: usize,
hop_length: usize,
n_mels: usize,
}
impl AudioFeatureExtractor {
pub fn new(sample_rate: u32) -> Result<Self> {
Ok(Self {
sample_rate,
n_fft: 400, hop_length: 160, n_mels: 80, })
}
pub fn load_and_extract(&self, path: &str) -> Result<AudioFeatures> {
Ok(AudioFeatures {
features: vec![vec![0.0; self.n_mels]; 100], sample_rate: self.sample_rate,
duration_s: 5.0, })
}
pub fn extract_features(&self, samples: &[f32]) -> Result<AudioFeatures> {
let duration_s = samples.len() as f64 / self.sample_rate as f64;
let n_frames = (samples.len() / self.hop_length) + 1;
Ok(AudioFeatures {
features: vec![vec![0.0; self.n_mels]; n_frames], sample_rate: self.sample_rate,
duration_s,
})
}
pub fn resample(&self, samples: &[f32], from_rate: u32, to_rate: u32) -> Result<Vec<f32>> {
if from_rate == to_rate {
return Ok(samples.to_vec());
}
let ratio = to_rate as f64 / from_rate as f64;
let new_len = (samples.len() as f64 * ratio) as usize;
let mut resampled = Vec::with_capacity(new_len);
for i in 0..new_len {
let original_idx = (i as f64 / ratio) as usize;
if original_idx < samples.len() {
resampled.push(samples[original_idx]);
} else {
resampled.push(0.0);
}
}
Ok(resampled)
}
pub fn decode_and_extract(&self, data: &[u8], format: AudioFormat) -> Result<AudioFeatures> {
match format {
AudioFormat::Wav => {
self.extract_features(&[0.0; 16000]) },
_ => {
self.extract_features(&[0.0; 16000]) },
}
}
}
#[derive(Debug)]
pub struct AudioFeatures {
pub features: Vec<Vec<f32>>, pub sample_rate: u32,
pub duration_s: f64,
}
impl AudioFeatures {
pub fn duration(&self) -> f64 {
self.duration_s
}
pub fn to_tensor(&self) -> Result<crate::core::tensor::Tensor> {
use crate::core::tensor::Tensor;
let flat_features: Vec<f32> = self.features.iter().flatten().cloned().collect();
let shape = vec![1, self.features.len(), self.features[0].len()];
Tensor::from_vec(flat_features, &shape).map_err(Into::into)
}
pub fn resample_to(self, target_rate: u32) -> Result<Self> {
if self.sample_rate == target_rate {
return Ok(self);
}
Ok(self)
}
}
mod base64 {
pub fn decode(_input: &str) -> Result<Vec<u8>, String> {
Ok(vec![])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_audio_format_from_extension_wav() {
let fmt = AudioFormat::from_extension("wav");
assert!(matches!(fmt, Some(AudioFormat::Wav)));
}
#[test]
fn test_audio_format_from_extension_flac() {
let fmt = AudioFormat::from_extension("flac");
assert!(matches!(fmt, Some(AudioFormat::Flac)));
}
#[test]
fn test_audio_format_from_extension_mp3() {
let fmt = AudioFormat::from_extension("mp3");
assert!(matches!(fmt, Some(AudioFormat::Mp3)));
}
#[test]
fn test_audio_format_from_extension_case_insensitive() {
let fmt = AudioFormat::from_extension("WAV");
assert!(matches!(fmt, Some(AudioFormat::Wav)));
}
#[test]
fn test_audio_format_from_extension_unknown() {
let fmt = AudioFormat::from_extension("xyz");
assert!(fmt.is_none());
}
#[test]
fn test_audio_format_all_variants() {
let exts = ["wav", "flac", "mp3", "m4a", "ogg", "webm"];
for ext in &exts {
assert!(
AudioFormat::from_extension(ext).is_some(),
"missing: {}",
ext
);
}
}
#[test]
fn test_config_default_values() {
let cfg = SpeechToTextConfig::default();
assert_eq!(cfg.sample_rate, 16000);
assert_eq!(cfg.max_duration, Some(30.0));
assert!(!cfg.return_timestamps);
assert!(cfg.language.is_none());
assert!(matches!(cfg.task, SpeechTask::Transcribe));
assert_eq!(cfg.num_beams, 1);
assert!((cfg.temperature - 0.0).abs() < 1e-6);
}
#[test]
fn test_config_chunk_length() {
let cfg = SpeechToTextConfig::default();
assert_eq!(cfg.chunk_length_s, Some(30.0));
assert_eq!(cfg.stride_length_s, Some(5.0));
}
#[test]
fn test_extractor_creates_successfully() {
let extractor = AudioFeatureExtractor::new(16000).expect("extractor creation succeeded");
assert_eq!(extractor.sample_rate, 16000);
}
#[test]
fn test_extract_features_duration_calculation() {
let extractor = AudioFeatureExtractor::new(16000).expect("ok");
let samples = vec![0.0f32; 16000]; let features = extractor.extract_features(&samples).expect("ok");
assert!((features.duration_s - 1.0).abs() < 0.01);
}
#[test]
fn test_extract_features_frame_count() {
let extractor = AudioFeatureExtractor::new(16000).expect("ok");
let samples = vec![0.0f32; 1600]; let features = extractor.extract_features(&samples).expect("ok");
assert_eq!(features.features.len(), 11);
}
#[test]
fn test_extract_features_mel_dims() {
let extractor = AudioFeatureExtractor::new(16000).expect("ok");
let samples = vec![0.0f32; 3200];
let features = extractor.extract_features(&samples).expect("ok");
for frame in &features.features {
assert_eq!(frame.len(), 80);
}
}
#[test]
fn test_resample_same_rate_no_op() {
let extractor = AudioFeatureExtractor::new(16000).expect("ok");
let samples = vec![0.1_f32, 0.2, 0.3];
let resampled = extractor.resample(&samples, 16000, 16000).expect("ok");
assert_eq!(resampled.len(), samples.len());
for (a, b) in resampled.iter().zip(samples.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn test_resample_upsample_increases_length() {
let extractor = AudioFeatureExtractor::new(16000).expect("ok");
let samples = vec![0.0f32; 100];
let resampled = extractor.resample(&samples, 8000, 16000).expect("ok");
assert!(
resampled.len() > samples.len(),
"upsampled should be longer: {}",
resampled.len()
);
}
#[test]
fn test_resample_downsample_decreases_length() {
let extractor = AudioFeatureExtractor::new(16000).expect("ok");
let samples = vec![0.0f32; 200];
let resampled = extractor.resample(&samples, 16000, 8000).expect("ok");
assert!(
resampled.len() < samples.len(),
"downsampled should be shorter: {}",
resampled.len()
);
}
#[test]
fn test_audio_features_duration() {
let af = AudioFeatures {
features: vec![vec![0.0; 80]; 50],
sample_rate: 16000,
duration_s: 3.5,
};
assert!((af.duration() - 3.5).abs() < 1e-6);
}
#[test]
fn test_audio_features_to_tensor_shape() {
let n_frames = 10;
let n_mels = 80;
let af = AudioFeatures {
features: vec![vec![0.1; n_mels]; n_frames],
sample_rate: 16000,
duration_s: 1.0,
};
let tensor = af.to_tensor().expect("tensor creation succeeded");
let shape = tensor.shape();
assert_eq!(shape[0], 1);
assert_eq!(shape[1], n_frames);
assert_eq!(shape[2], n_mels);
}
#[test]
fn test_audio_features_resample_to_same_rate() {
let af = AudioFeatures {
features: vec![vec![0.0; 80]; 5],
sample_rate: 16000,
duration_s: 1.0,
};
let result = af.resample_to(16000).expect("ok");
assert_eq!(result.sample_rate, 16000);
}
#[test]
fn test_word_timestamp_time_ordering() {
let ts = WordTimestamp {
word: "hello".to_string(),
start_time: 0.0,
end_time: 0.5,
confidence: 0.95,
};
assert!(ts.start_time < ts.end_time);
assert!(ts.confidence >= 0.0 && ts.confidence <= 1.0);
}
#[test]
fn test_word_timestamp_confidence_range() {
let ts = WordTimestamp {
word: "world".to_string(),
start_time: 0.5,
end_time: 1.0,
confidence: 0.87,
};
assert!(ts.confidence >= 0.0 && ts.confidence <= 1.0);
}
#[test]
fn test_speech_task_transcribe_variant() {
let task = SpeechTask::Transcribe;
assert!(matches!(task, SpeechTask::Transcribe));
}
#[test]
fn test_speech_task_translate_variant() {
let task = SpeechTask::Translate;
assert!(matches!(task, SpeechTask::Translate));
}
#[test]
fn test_frame_level_processing_non_empty() {
let extractor = AudioFeatureExtractor::new(16000).expect("ok");
let mut seed = 12345u64;
let samples: Vec<f32> = (0..4800)
.map(|_| {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
((seed >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0
})
.collect();
let features = extractor.extract_features(&samples).expect("ok");
assert!(!features.features.is_empty());
assert_eq!(features.features[0].len(), 80);
}
#[test]
fn test_load_and_extract_returns_features() {
let extractor = AudioFeatureExtractor::new(16000).expect("ok");
let af = extractor.load_and_extract("dummy_path.wav").expect("ok");
assert!(!af.features.is_empty());
assert!(af.duration_s > 0.0);
}
#[test]
fn test_decode_and_extract_wav() {
let extractor = AudioFeatureExtractor::new(16000).expect("ok");
let dummy_data = vec![0u8; 512];
let af = extractor.decode_and_extract(&dummy_data, AudioFormat::Wav).expect("ok");
assert!(!af.features.is_empty());
}
}