use crate::whisper::{config::WhisperConfig, model::WhisperForConditionalGeneration};
use std::fmt;
use trustformers_core::{errors::Result, tensor::Tensor, traits::Layer};
#[derive(Debug)]
pub enum WhisperError {
EmptyInput,
InvalidBeamSize,
ForwardError(String),
LanguageDetectionFailed,
DecodingFailed(String),
}
impl fmt::Display for WhisperError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WhisperError::EmptyInput => write!(f, "Whisper: empty mel-spectrogram input"),
WhisperError::InvalidBeamSize => {
write!(f, "Whisper: beam_size must be at least 1")
},
WhisperError::ForwardError(msg) => {
write!(f, "Whisper: forward pass error: {msg}")
},
WhisperError::LanguageDetectionFailed => {
write!(f, "Whisper: language detection returned no probabilities")
},
WhisperError::DecodingFailed(msg) => {
write!(f, "Whisper: decoding failed: {msg}")
},
}
}
}
impl std::error::Error for WhisperError {}
#[derive(Debug, Clone, PartialEq)]
pub struct WhisperTimestamp {
pub start_ms: f32,
pub end_ms: f32,
pub text: String,
}
impl WhisperTimestamp {
pub fn new(start_ms: f32, end_ms: f32, text: impl Into<String>) -> Self {
Self {
start_ms,
end_ms,
text: text.into(),
}
}
pub fn duration_ms(&self) -> f32 {
self.end_ms - self.start_ms
}
}
impl fmt::Display for WhisperTimestamp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"[{:.0}ms – {:.0}ms] {}",
self.start_ms, self.end_ms, self.text
)
}
}
pub struct SpeechRecognitionTask {
model: WhisperForConditionalGeneration,
}
impl SpeechRecognitionTask {
pub fn new(config: WhisperConfig) -> Result<Self> {
let model = WhisperForConditionalGeneration::new(config)?;
Ok(Self { model })
}
pub fn forward(&self, mel: &Tensor, decoder_input_ids: &[u32]) -> Result<Tensor> {
self.model.forward(mel, decoder_input_ids)
}
pub fn model(&self) -> &WhisperForConditionalGeneration {
&self.model
}
pub fn config(&self) -> &WhisperConfig {
&self.model.model.config
}
pub fn transcribe_greedy(
&self,
mel: &Tensor,
start_token: u32,
max_new_tokens: usize,
) -> std::result::Result<String, WhisperError> {
let shape = mel.shape().to_vec();
if shape.len() < 3 || shape[2] == 0 {
return Err(WhisperError::EmptyInput);
}
let vocab_size = self.model.model.config.vocab_size;
let eos_token = (vocab_size.saturating_sub(1)) as u32;
let mut decoder_ids: Vec<u32> = vec![start_token];
let mut generated: Vec<u32> = Vec::new();
for _ in 0..max_new_tokens {
let logits = self
.model
.forward(mel, &decoder_ids)
.map_err(|e| WhisperError::ForwardError(e.to_string()))?;
let logits_shape = logits.shape().to_vec();
if logits_shape.len() < 3 {
return Err(WhisperError::DecodingFailed(
"unexpected logits rank".to_string(),
));
}
let seq_pos = logits_shape[1] - 1;
let v = logits_shape[2];
let next_token = extract_argmax_at_position(&logits, seq_pos, v)
.map_err(|e| WhisperError::DecodingFailed(e.to_string()))?;
if next_token == eos_token {
break;
}
generated.push(next_token);
decoder_ids.push(next_token);
}
let text = if generated.is_empty() {
String::new()
} else {
generated.iter().map(|t| t.to_string()).collect::<Vec<_>>().join(" ")
};
Ok(text)
}
pub fn transcribe_beam(
&self,
mel: &Tensor,
start_token: u32,
beam_size: usize,
max_new_tokens: usize,
) -> std::result::Result<Vec<String>, WhisperError> {
if beam_size == 0 {
return Err(WhisperError::InvalidBeamSize);
}
let shape = mel.shape().to_vec();
if shape.len() < 3 || shape[2] == 0 {
return Err(WhisperError::EmptyInput);
}
let vocab_size = self.model.model.config.vocab_size;
let eos_token = (vocab_size.saturating_sub(1)) as u32;
let mut beams: Vec<(Vec<u32>, f32)> = vec![(vec![start_token], 0.0)];
let mut completed: Vec<(Vec<u32>, f32)> = Vec::new();
for _ in 0..max_new_tokens {
if beams.is_empty() {
break;
}
let mut next_beams: Vec<(Vec<u32>, f32)> = Vec::new();
for (seq, log_prob) in &beams {
let logits = self
.model
.forward(mel, seq)
.map_err(|e| WhisperError::ForwardError(e.to_string()))?;
let logits_shape = logits.shape().to_vec();
if logits_shape.len() < 3 {
return Err(WhisperError::DecodingFailed(
"unexpected logits rank".to_string(),
));
}
let seq_pos = logits_shape[1] - 1;
let v = logits_shape[2];
let top_tokens = extract_top_k_at_position(&logits, seq_pos, v, beam_size)
.map_err(|e| WhisperError::DecodingFailed(e.to_string()))?;
for (token, logit) in top_tokens {
let new_log_prob = log_prob + logit;
let mut new_seq = seq.clone();
new_seq.push(token);
if token == eos_token {
completed.push((new_seq, new_log_prob));
} else {
next_beams.push((new_seq, new_log_prob));
}
}
}
next_beams.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
next_beams.truncate(beam_size);
beams = next_beams;
}
completed.extend(beams);
completed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
completed.truncate(beam_size);
if completed.is_empty() {
return Err(WhisperError::DecodingFailed(
"no hypotheses generated".to_string(),
));
}
let results: Vec<String> = completed
.into_iter()
.map(|(seq, _)| {
let tokens = &seq[1..];
if tokens.is_empty() {
String::new()
} else {
tokens.iter().map(|t| t.to_string()).collect::<Vec<_>>().join(" ")
}
})
.collect();
Ok(results)
}
pub fn detect_language(
&self,
mel: &Tensor,
) -> std::result::Result<Vec<(String, f32)>, WhisperError> {
let shape = mel.shape().to_vec();
if shape.len() < 3 || shape[2] == 0 {
return Err(WhisperError::EmptyInput);
}
let sot_token = 0u32; let logits = self
.model
.forward(mel, &[sot_token])
.map_err(|e| WhisperError::ForwardError(e.to_string()))?;
let logits_shape = logits.shape().to_vec();
if logits_shape.len() < 3 {
return Err(WhisperError::LanguageDetectionFailed);
}
let v = logits_shape[2];
let raw_logits = extract_slice_at_position(&logits, 0, v)
.map_err(|_| WhisperError::LanguageDetectionFailed)?;
let probs = softmax_f32(&raw_logits);
let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(5);
let top5: Vec<(String, f32)> =
indexed.into_iter().map(|(idx, prob)| (format!("lang_{idx}"), prob)).collect();
if top5.is_empty() {
return Err(WhisperError::LanguageDetectionFailed);
}
Ok(top5)
}
pub fn transcribe_with_timestamps(
&self,
mel: &Tensor,
start_token: u32,
chunk_frames: usize,
max_new_tokens_per_chunk: usize,
) -> std::result::Result<Vec<WhisperTimestamp>, WhisperError> {
let shape = mel.shape().to_vec();
if shape.len() < 3 || shape[2] == 0 {
return Err(WhisperError::EmptyInput);
}
let total_frames = shape[2];
let ms_per_frame = 20.0f32;
let effective_chunk = if chunk_frames == 0 { total_frames } else { chunk_frames };
let num_chunks = total_frames.div_ceil(effective_chunk);
let mut timestamps: Vec<WhisperTimestamp> = Vec::with_capacity(num_chunks);
for chunk_idx in 0..num_chunks {
let start_frame = chunk_idx * effective_chunk;
let end_frame = (start_frame + effective_chunk).min(total_frames);
let start_ms = start_frame as f32 * ms_per_frame;
let end_ms = end_frame as f32 * ms_per_frame;
let chunk_mel = slice_mel_time(mel, start_frame, end_frame)
.map_err(|e| WhisperError::ForwardError(e.to_string()))?;
let text = self.transcribe_greedy(&chunk_mel, start_token, max_new_tokens_per_chunk)?;
timestamps.push(WhisperTimestamp::new(start_ms, end_ms, text));
}
Ok(timestamps)
}
}
pub struct WhisperForAudioClassification {
model: WhisperForConditionalGeneration,
classifier_weight: Vec<f32>, classifier_bias: Vec<f32>, num_labels: usize,
d_model: usize,
}
impl WhisperForAudioClassification {
pub fn new(config: WhisperConfig, num_labels: usize) -> Result<Self> {
let d_model = config.d_model;
let model = WhisperForConditionalGeneration::new(config)?;
let classifier_weight = vec![0.0f32; num_labels * d_model];
let classifier_bias = vec![0.0f32; num_labels];
Ok(Self {
model,
classifier_weight,
classifier_bias,
num_labels,
d_model,
})
}
pub fn num_labels(&self) -> usize {
self.num_labels
}
pub fn forward(&self, mel: &Tensor) -> std::result::Result<Vec<f32>, WhisperError> {
let shape = mel.shape().to_vec();
if shape.len() < 3 || shape[2] == 0 {
return Err(WhisperError::EmptyInput);
}
let batch = shape[0];
let encoder_out = self
.model
.model
.encoder
.forward(mel)
.map_err(|e| WhisperError::ForwardError(e.to_string()))?;
let enc_shape = encoder_out.shape().to_vec();
if enc_shape.len() < 3 {
return Err(WhisperError::ForwardError(
"encoder output has unexpected rank".to_string(),
));
}
let seq = enc_shape[1];
let d = enc_shape[2];
let enc_data = match &encoder_out {
Tensor::F32(arr) => arr.iter().copied().collect::<Vec<f32>>(),
_ => {
return Err(WhisperError::ForwardError(
"encoder output must be F32".to_string(),
))
},
};
let mut pooled = vec![0.0f32; batch * d];
for b in 0..batch {
for t in 0..seq {
for c in 0..d {
pooled[b * d + c] += enc_data[b * seq * d + t * d + c];
}
}
for c in 0..d {
pooled[b * d + c] /= seq as f32;
}
}
let mut logits = vec![0.0f32; batch * self.num_labels];
for b in 0..batch {
for label in 0..self.num_labels {
let w_offset = label * self.d_model;
let dot: f32 = (0..self.d_model)
.map(|i| pooled[b * d + i] * self.classifier_weight[w_offset + i])
.sum();
logits[b * self.num_labels + label] = dot + self.classifier_bias[label];
}
}
Ok(logits)
}
}
pub struct WhisperDecoderWrapper {
inner: WhisperForConditionalGeneration,
}
impl WhisperDecoderWrapper {
pub fn new(config: WhisperConfig) -> Result<Self> {
let inner = WhisperForConditionalGeneration::new(config)?;
Ok(Self { inner })
}
pub fn decode(
&self,
encoder_hidden_states: &Tensor,
decoder_input_ids: &[u32],
) -> Result<Tensor> {
let decoder_hidden =
self.inner.model.decoder.forward(decoder_input_ids, encoder_hidden_states)?;
self.inner.proj_out.forward(decoder_hidden)
}
pub fn config(&self) -> &WhisperConfig {
&self.inner.model.config
}
}
fn extract_argmax_at_position(logits: &Tensor, pos: usize, vocab_size: usize) -> Result<u32> {
use trustformers_core::errors::TrustformersError;
match logits {
Tensor::F32(arr) => {
let flat: Vec<f32> = arr.iter().copied().collect();
let offset = pos * vocab_size;
let slice = flat
.get(offset..offset + vocab_size)
.ok_or_else(|| TrustformersError::shape_error("logit slice OOB".to_string()))?;
let best = slice
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx as u32)
.ok_or_else(|| TrustformersError::shape_error("empty logit slice".to_string()))?;
Ok(best)
},
_ => Err(trustformers_core::errors::TrustformersError::shape_error(
"logits tensor must be F32".to_string(),
)),
}
}
fn extract_top_k_at_position(
logits: &Tensor,
pos: usize,
vocab_size: usize,
k: usize,
) -> Result<Vec<(u32, f32)>> {
use trustformers_core::errors::TrustformersError;
match logits {
Tensor::F32(arr) => {
let flat: Vec<f32> = arr.iter().copied().collect();
let offset = pos * vocab_size;
let slice = flat
.get(offset..offset + vocab_size)
.ok_or_else(|| TrustformersError::shape_error("logit slice OOB".to_string()))?;
let mut indexed: Vec<(u32, f32)> =
slice.iter().copied().enumerate().map(|(i, v)| (i as u32, v)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(k);
Ok(indexed)
},
_ => Err(trustformers_core::errors::TrustformersError::shape_error(
"logits tensor must be F32".to_string(),
)),
}
}
fn extract_slice_at_position(logits: &Tensor, pos: usize, vocab_size: usize) -> Result<Vec<f32>> {
use trustformers_core::errors::TrustformersError;
match logits {
Tensor::F32(arr) => {
let flat: Vec<f32> = arr.iter().copied().collect();
let offset = pos * vocab_size;
let slice = flat
.get(offset..offset + vocab_size)
.ok_or_else(|| TrustformersError::shape_error("logit slice OOB".to_string()))?;
Ok(slice.to_vec())
},
_ => Err(trustformers_core::errors::TrustformersError::shape_error(
"logits tensor must be F32".to_string(),
)),
}
}
fn softmax_f32(x: &[f32]) -> Vec<f32> {
if x.is_empty() {
return Vec::new();
}
let max_val = x.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = x.iter().map(|v| (v - max_val).exp()).collect();
let sum: f32 = exps.iter().sum();
if sum == 0.0 {
return vec![1.0 / x.len() as f32; x.len()];
}
exps.iter().map(|e| e / sum).collect()
}
fn slice_mel_time(mel: &Tensor, start: usize, end: usize) -> Result<Tensor> {
use trustformers_core::errors::TrustformersError;
match mel {
Tensor::F32(arr) => {
let shape = arr.shape();
let batch = shape[0];
let mel_bins = shape[1];
let _total_time = shape[2];
let chunk_time = end - start;
let flat: Vec<f32> = arr.iter().copied().collect();
let mut chunk_data = vec![0.0f32; batch * mel_bins * chunk_time];
for b in 0..batch {
for m in 0..mel_bins {
for t in 0..chunk_time {
let src_idx = b * mel_bins * shape[2] + m * shape[2] + (start + t);
let dst_idx = b * mel_bins * chunk_time + m * chunk_time + t;
if src_idx < flat.len() {
chunk_data[dst_idx] = flat[src_idx];
}
}
}
}
Tensor::from_vec(chunk_data, &[batch, mel_bins, chunk_time])
},
_ => Err(TrustformersError::shape_error(
"mel tensor must be F32".to_string(),
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::whisper::config::WhisperConfig;
use trustformers_core::tensor::Tensor;
fn tiny_config() -> WhisperConfig {
WhisperConfig {
num_mel_bins: 16,
max_source_positions: 4,
encoder_layers: 1,
encoder_attention_heads: 2,
d_model: 16,
encoder_ffn_dim: 32,
vocab_size: 16,
max_target_positions: 8,
decoder_layers: 1,
decoder_attention_heads: 2,
decoder_ffn_dim: 32,
..WhisperConfig::default()
}
}
fn make_mel(cfg: &WhisperConfig, time_frames: usize) -> Tensor {
let data = vec![0.1_f32; cfg.num_mel_bins * time_frames];
Tensor::from_vec(data, &[1, cfg.num_mel_bins, time_frames])
.expect("mel tensor creation should succeed")
}
#[test]
fn test_timestamp_new() {
let ts = WhisperTimestamp::new(100.0, 300.0, "hello");
assert!((ts.start_ms - 100.0).abs() < 1e-4, "start_ms should be 100");
assert!((ts.end_ms - 300.0).abs() < 1e-4, "end_ms should be 300");
assert_eq!(ts.text, "hello");
}
#[test]
fn test_timestamp_duration() {
let ts = WhisperTimestamp::new(200.0, 800.0, "test");
assert!(
(ts.duration_ms() - 600.0).abs() < 1e-3,
"duration should be 600ms"
);
}
#[test]
fn test_timestamp_display() {
let ts = WhisperTimestamp::new(0.0, 1000.0, "word");
let s = format!("{ts}");
assert!(s.contains("word"), "display should include the text");
}
#[test]
fn test_speech_task_creation() {
let cfg = tiny_config();
SpeechRecognitionTask::new(cfg).expect("SpeechRecognitionTask creation should succeed");
}
#[test]
fn test_config_d_model_divisible_by_heads() {
let cfg = tiny_config();
assert_eq!(
cfg.d_model % cfg.encoder_attention_heads,
0,
"d_model must be divisible by encoder_attention_heads"
);
}
#[test]
fn test_default_config_num_mel_bins() {
let cfg = WhisperConfig::default();
assert_eq!(cfg.num_mel_bins, 80, "default num_mel_bins should be 80");
}
#[test]
fn test_default_config_vocab_size() {
let cfg = WhisperConfig::default();
assert_eq!(
cfg.vocab_size, 51865,
"default Whisper vocab_size should be 51865"
);
}
#[test]
fn test_transcribe_greedy_empty_mel_fails() {
let cfg = tiny_config();
let task = SpeechRecognitionTask::new(cfg.clone()).expect("task creation should succeed");
let empty_mel = Tensor::from_vec(vec![0.0_f32; 0], &[1, cfg.num_mel_bins, 0])
.expect("tensor creation should succeed");
let result = task.transcribe_greedy(&empty_mel, 0, 10);
assert!(
matches!(result, Err(WhisperError::EmptyInput)),
"empty mel should return EmptyInput error"
);
}
#[test]
fn test_transcribe_greedy_returns_string() {
let cfg = tiny_config();
let task = SpeechRecognitionTask::new(cfg.clone()).expect("task creation should succeed");
let mel = make_mel(&cfg, 4);
match task.transcribe_greedy(&mel, 0, 5) {
Ok(_) => {
},
Err(_) => {
},
}
}
#[test]
fn test_transcribe_beam_zero_size_fails() {
let cfg = tiny_config();
let task = SpeechRecognitionTask::new(cfg.clone()).expect("task creation should succeed");
let mel = make_mel(&cfg, 4);
let result = task.transcribe_beam(&mel, 0, 0, 5);
assert!(
matches!(result, Err(WhisperError::InvalidBeamSize)),
"beam_size=0 should return InvalidBeamSize error"
);
}
#[test]
fn test_transcribe_beam_returns_hypotheses() {
let cfg = tiny_config();
let task = SpeechRecognitionTask::new(cfg.clone()).expect("task creation should succeed");
let mel = make_mel(&cfg, 4);
match task.transcribe_beam(&mel, 0, 2, 5) {
Ok(hypotheses) => {
assert!(
!hypotheses.is_empty(),
"beam search should produce at least one hypothesis"
);
},
Err(_) => {
},
}
}
#[test]
fn test_beam_hypotheses_count_at_most_beam_size() {
let cfg = tiny_config();
let task = SpeechRecognitionTask::new(cfg.clone()).expect("task creation should succeed");
let mel = make_mel(&cfg, 4);
let beam_size = 3;
match task.transcribe_beam(&mel, 0, beam_size, 5) {
Ok(hypotheses) => {
assert!(
hypotheses.len() <= beam_size,
"number of hypotheses must not exceed beam_size"
);
},
Err(_) => {
},
}
}
#[test]
fn test_detect_language_empty_fails() {
let cfg = tiny_config();
let task = SpeechRecognitionTask::new(cfg.clone()).expect("task creation should succeed");
let empty = Tensor::from_vec(vec![], &[1, cfg.num_mel_bins, 0])
.expect("tensor creation should succeed");
let result = task.detect_language(&empty);
assert!(
matches!(result, Err(WhisperError::EmptyInput)),
"empty input should fail with EmptyInput"
);
}
#[test]
fn test_detect_language_returns_top5() {
let cfg = tiny_config();
let task = SpeechRecognitionTask::new(cfg.clone()).expect("task creation should succeed");
let mel = make_mel(&cfg, 4);
match task.detect_language(&mel) {
Ok(detections) => {
assert!(
!detections.is_empty(),
"language detection should return results"
);
assert!(detections.len() <= 5, "should return at most 5 detections");
},
Err(_) => {
},
}
}
#[test]
fn test_detect_language_probs_sum_to_one() {
let cfg = tiny_config();
let task = SpeechRecognitionTask::new(cfg.clone()).expect("task creation should succeed");
let mel = make_mel(&cfg, 4);
match task.detect_language(&mel) {
Ok(detections) => {
let total: f32 = detections.iter().map(|(_, p)| p).sum();
assert!(
total <= 1.0 + 1e-4,
"top-5 probs must sum to <= 1.0, got {total}"
);
},
Err(_) => {
},
}
}
#[test]
fn test_transcribe_with_timestamps_empty_fails() {
let cfg = tiny_config();
let task = SpeechRecognitionTask::new(cfg.clone()).expect("task creation should succeed");
let empty = Tensor::from_vec(vec![], &[1, cfg.num_mel_bins, 0])
.expect("tensor creation should succeed");
let result = task.transcribe_with_timestamps(&empty, 0, 30, 5);
assert!(
matches!(result, Err(WhisperError::EmptyInput)),
"empty mel should return EmptyInput"
);
}
#[test]
fn test_transcribe_with_timestamps_returns_chunks() {
let cfg = tiny_config();
let task = SpeechRecognitionTask::new(cfg.clone()).expect("task creation should succeed");
let mel = make_mel(&cfg, 4);
match task.transcribe_with_timestamps(&mel, 0, 2, 5) {
Ok(timestamps) => {
assert_eq!(timestamps.len(), 2, "4 frames / chunk_size 2 -> 2 segments");
},
Err(_) => {
},
}
}
#[test]
fn test_audio_classification_creation() {
let cfg = tiny_config();
WhisperForAudioClassification::new(cfg, 5)
.expect("audio classification model creation should succeed");
}
#[test]
fn test_audio_classification_num_labels() {
let cfg = tiny_config();
let clf = WhisperForAudioClassification::new(cfg, 7).expect("creation should succeed");
assert_eq!(clf.num_labels(), 7, "num_labels should be 7");
}
#[test]
fn test_softmax_f32_sums_to_one() {
let logits = vec![1.0_f32, 2.0, 3.0];
let probs = softmax_f32(&logits);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"softmax must sum to 1.0, got {sum}"
);
}
#[test]
fn test_softmax_f32_empty_returns_empty() {
let probs = softmax_f32(&[]);
assert!(probs.is_empty(), "softmax of empty slice should be empty");
}
}