use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use anyhow::{Context, Result, anyhow};
use ct2rs::Config;
use ct2rs::sys::{Device, StorageView, Whisper as SysWhisper};
use mel_spec::mel::{log_mel_spectrogram, mel, norm_mel};
use mel_spec::stft::Spectrogram;
use ndarray::{Array2, Array3, Axis, s, stack};
use serde::Deserialize;
use tokenizers::{Decoder, Tokenizer};
const PREPROCESSOR_CONFIG_FILE: &str = "preprocessor_config.json";
const TOKENIZER_FILE: &str = "tokenizer.json";
pub use ct2rs::sys::WhisperOptions;
#[derive(Clone, Debug)]
pub struct WhisperChunkResult {
pub text: String,
pub no_speech_prob: f32,
}
pub struct Whisper {
inner: SysWhisper,
tokenizer: Tokenizer,
config: PreprocessorConfig,
}
impl Whisper {
pub fn new<T: AsRef<Path>>(model_path: T, config: Config) -> Result<Self> {
let model_path = model_path.as_ref();
let tokenizer_path = model_path.join(TOKENIZER_FILE);
let preprocessor_path = model_path.join(PREPROCESSOR_CONFIG_FILE);
Ok(Self {
inner: SysWhisper::new(model_path, config)?,
tokenizer: Tokenizer::from_file(&tokenizer_path).map_err(|error| {
anyhow!(
"failed to load tokenizer from `{}`: {error}",
tokenizer_path.display()
)
})?,
config: PreprocessorConfig::read(&preprocessor_path).with_context(|| {
format!(
"failed to load Whisper preprocessor config from `{}`",
preprocessor_path.display()
)
})?,
})
}
pub fn generate(
&self,
samples: &[f32],
language: Option<&str>,
timestamp: bool,
options: &WhisperOptions,
) -> Result<Vec<String>> {
self.generate_detailed(samples, language, timestamp, options)
.map(|results| results.into_iter().map(|result| result.text).collect())
}
pub fn generate_detailed(
&self,
samples: &[f32],
language: Option<&str>,
timestamp: bool,
options: &WhisperOptions,
) -> Result<Vec<WhisperChunkResult>> {
if samples.is_empty() {
return Ok(Vec::new());
}
let mut prepared = self.prepare_mel_spectrogram(samples)?;
let shape = prepared.mel_spectrogram.shape().to_vec();
let storage = StorageView::new(
&shape,
prepared
.mel_spectrogram
.as_slice_mut()
.context("failed to access mel spectrogram buffer")?,
Device::CPU,
)?;
let language_token = match language {
Some(language) => normalize_language_token(language),
None => self.detect_language(&storage)?,
};
let mut prompt = vec![
String::from("<|startoftranscript|>"),
language_token,
String::from("<|transcribe|>"),
];
if !timestamp {
prompt.push(String::from("<|notimestamps|>"));
}
let prompts = vec![prompt; prepared.batch_count];
self.inner
.generate(&storage, &prompts, options)?
.into_iter()
.map(|result| {
let no_speech_prob = result.no_speech_prob;
Ok(WhisperChunkResult {
text: self.decode_result(result)?,
no_speech_prob,
})
})
.collect()
}
pub fn sampling_rate(&self) -> usize {
self.config.sampling_rate
}
pub fn n_samples(&self) -> usize {
self.config.n_samples
}
fn detect_language(&self, storage: &StorageView<'_>) -> Result<String> {
self.inner
.detect_language(storage)?
.into_iter()
.next()
.and_then(|result| result.into_iter().next())
.map(|result| result.language)
.ok_or_else(|| anyhow!("failed to detect language"))
}
fn decode_result(&self, result: ct2rs::sys::WhisperGenerationResult) -> Result<String> {
let tokens = result
.sequences
.into_iter()
.next()
.ok_or_else(|| anyhow!("failed to decode empty Whisper sequence"))?;
let decoder = self
.tokenizer
.get_decoder()
.ok_or_else(|| anyhow!("tokenizer does not provide a decoder"))?;
decoder
.decode(tokens)
.map_err(|error| anyhow!("failed to decode Whisper tokens: {error}"))
}
fn prepare_mel_spectrogram(&self, samples: &[f32]) -> Result<PreparedMelSpectrogram> {
let mut stft = Spectrogram::new(self.config.n_fft, self.config.hop_length);
let mut mel_batches = Vec::new();
for chunk in samples.chunks(self.config.n_samples.max(1)) {
let mut mel_per_chunk =
Array2::zeros((self.config.feature_size, self.config.nb_max_frames));
for (frame_index, frame) in chunk.chunks(self.config.hop_length).enumerate() {
if frame_index >= self.config.nb_max_frames {
break;
}
if let Some(fft_frame) = stft.add(frame) {
let mel = norm_mel(&log_mel_spectrogram(&fft_frame, &self.config.mel_filters))
.mapv(|value| value as f32);
mel_per_chunk
.slice_mut(s![.., frame_index])
.assign(&mel.slice(s![.., 0]));
}
}
mel_batches.push(mel_per_chunk);
}
let mut mel_spectrogram = stack(
Axis(0),
&mel_batches
.iter()
.map(|batch| batch.view())
.collect::<Vec<_>>(),
)?;
if !mel_spectrogram.is_standard_layout() {
mel_spectrogram = mel_spectrogram.as_standard_layout().into_owned();
}
Ok(PreparedMelSpectrogram {
batch_count: mel_batches.len(),
mel_spectrogram,
})
}
}
fn normalize_language_token(language: &str) -> String {
let trimmed = language.trim();
if trimmed.starts_with("<|") && trimmed.ends_with("|>") {
trimmed.to_string()
} else {
format!("<|{trimmed}|>")
}
}
struct PreparedMelSpectrogram {
batch_count: usize,
mel_spectrogram: Array3<f32>,
}
#[derive(Debug)]
#[allow(dead_code)]
struct PreprocessorConfig {
chunk_length: usize,
feature_extractor_type: String,
feature_size: usize,
hop_length: usize,
n_fft: usize,
n_samples: usize,
nb_max_frames: usize,
padding_side: String,
padding_value: f32,
processor_class: String,
return_attention_mask: bool,
sampling_rate: usize,
mel_filters: Array2<f64>,
}
impl PreprocessorConfig {
fn read<T: AsRef<Path>>(path: T) -> Result<Self> {
let file = File::open(path)?;
let reader = BufReader::new(file);
#[derive(Deserialize)]
struct PreprocessorConfigAux {
chunk_length: usize,
feature_extractor_type: String,
feature_size: usize,
hop_length: usize,
n_fft: usize,
n_samples: usize,
nb_max_frames: usize,
padding_side: String,
padding_value: f32,
processor_class: String,
return_attention_mask: bool,
sampling_rate: usize,
mel_filters: Option<Vec<Vec<f64>>>,
}
let aux: PreprocessorConfigAux = serde_json::from_reader(reader)?;
let mel_filters = if let Some(mel_filters) = aux.mel_filters {
let rows = mel_filters.len();
let cols = mel_filters.first().map(|row| row.len()).unwrap_or_default();
Array2::from_shape_vec((rows, cols), mel_filters.into_iter().flatten().collect())?
} else {
mel(
aux.sampling_rate as f64,
aux.n_fft,
aux.feature_size,
None,
None,
false,
true,
)
};
Ok(Self {
chunk_length: aux.chunk_length,
feature_extractor_type: aux.feature_extractor_type,
feature_size: aux.feature_size,
hop_length: aux.hop_length,
n_fft: aux.n_fft,
n_samples: aux.n_samples,
nb_max_frames: aux.nb_max_frames,
padding_side: aux.padding_side,
padding_value: aux.padding_value,
processor_class: aux.processor_class,
return_attention_mask: aux.return_attention_mask,
sampling_rate: aux.sampling_rate,
mel_filters,
})
}
}
#[cfg(test)]
mod tests {
use super::normalize_language_token;
#[test]
fn preserves_detected_language_tokens() {
assert_eq!(normalize_language_token("<|en|>"), "<|en|>");
}
#[test]
fn wraps_plain_language_codes() {
assert_eq!(normalize_language_token("en"), "<|en|>");
}
}