use std::path::Path;
use anyhow::Result;
use burn::tensor::{Int, Tensor, TensorData, backend::Backend};
use crate::{
audio::{
constants::{N_FRAMES, N_SAMPLES},
mel_spectrogram::log_mel_spectrogram,
preprocessing::load_audio,
},
model::{config::WhisperConfig, whisper::WhisperModel},
tokenizer::whisper_tokenizer::WhisperTokenizer,
};
pub struct Transcriber<B: Backend> {
model: WhisperModel<B>,
tokenizer: WhisperTokenizer,
config: WhisperConfig,
device: B::Device,
}
impl<B: Backend> Transcriber<B> {
pub fn new(
model: WhisperModel<B>,
tokenizer: WhisperTokenizer,
config: WhisperConfig,
device: B::Device,
) -> Self {
Self {
model,
tokenizer,
config,
device,
}
}
pub fn transcribe<P: AsRef<Path>>(&self, path: P, language: Option<&str>) -> Result<String> {
let samples = load_audio(path)?;
let mut parts = Vec::new();
for chunk in samples.chunks(N_SAMPLES) {
let text = self.transcribe_chunk(chunk, language)?;
let trimmed = text.trim().to_string();
if !trimmed.is_empty() {
parts.push(trimmed);
}
}
Ok(parts.join(" "))
}
fn transcribe_chunk(&self, samples: &[f32], language: Option<&str>) -> Result<String> {
let mel = log_mel_spectrogram(samples, self.config.num_mel_bins)?;
assert_eq!(mel.len(), self.config.num_mel_bins);
assert_eq!(mel.first().unwrap().len(), N_FRAMES);
let n_mels = mel.len();
let flat: Vec<f32> = mel.into_iter().flatten().collect();
let mel_tensor =
Tensor::<B, 3>::from_data(TensorData::new(flat, [1, n_mels, N_FRAMES]), &self.device);
let encoder_out = self.model.encode(mel_tensor);
let tokens = self.greedy_decode(encoder_out, language)?;
self.tokenizer.decode(&tokens)
}
fn greedy_decode(&self, encoder_out: Tensor<B, 3>, language: Option<&str>) -> Result<Vec<u32>> {
let initial = self.tokenizer.initial_tokens(language);
let prompt_len = initial.len();
let mut tokens: Vec<i64> = initial.iter().map(|&t| t as i64).collect();
for _ in 0..(self.config.max_target_positions - prompt_len) {
let seq_len = tokens.len();
let token_tensor = Tensor::<B, 2, Int>::from_data(
TensorData::new(tokens.clone(), [1, seq_len]),
&self.device,
);
let logits = self.model.decode(token_tensor, encoder_out.clone());
let vocab = logits.shape().dims[2];
let last_logits = logits
.slice([0..1, (seq_len - 1)..seq_len, 0..vocab])
.into_data();
let mut vals: Vec<f32> = last_logits.to_vec::<f32>().unwrap();
vals[50363] = 0.0;
let next_tok = vals
.iter()
.copied()
.enumerate()
.fold((0usize, f32::NEG_INFINITY), |(bi, bv), (i, v)| {
if v > bv { (i, v) } else { (bi, bv) }
})
.0 as u32;
if next_tok == self.tokenizer.eot {
break;
}
tokens.push(next_tok as i64);
}
Ok(tokens[prompt_len..].iter().map(|&t| t as u32).collect())
}
}