use std::fmt::{Debug, Formatter};
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use anyhow::{anyhow, Result};
use mel_spec::mel::{log_mel_spectrogram, mel, norm_mel};
use mel_spec::stft::Spectrogram;
use ndarray::{s, stack, Array2, Axis};
use serde::Deserialize;
pub use super::sys::WhisperOptions;
use super::tokenizers::hf;
use super::{sys, Config, Tokenizer};
const PREPROCESSOR_CONFIG_FILE: &str = "preprocessor_config.json";
pub struct Whisper {
whisper: sys::Whisper,
tokenizer: hf::Tokenizer,
config: PreprocessorConfig,
}
impl Whisper {
pub fn new<T: AsRef<Path>>(model_path: T, config: Config) -> Result<Self> {
Ok(Self {
whisper: sys::Whisper::new(&model_path, config)?,
tokenizer: hf::Tokenizer::new(&model_path)?,
config: PreprocessorConfig::read(model_path.as_ref().join(PREPROCESSOR_CONFIG_FILE))?,
})
}
pub fn generate(
&self,
samples: &[f32],
language: Option<&str>,
timestamp: bool,
options: &WhisperOptions,
) -> Result<Vec<String>> {
let mut stft = Spectrogram::new(self.config.n_fft, self.config.hop_length);
let mut mel_spectrogram_vec = vec![];
for chunk in samples.chunks(self.config.n_samples) {
let mut mel_spectrogram_per_chunk =
Array2::zeros((self.config.feature_size, self.config.nb_max_frames));
for (i, flame) in chunk.chunks(self.config.hop_length).enumerate() {
if let Some(fft_frame) = stft.add(flame) {
let mel = norm_mel(&log_mel_spectrogram(&fft_frame, &self.config.mel_filters))
.mapv(|v| v as f32);
mel_spectrogram_per_chunk
.slice_mut(s![.., i])
.assign(&mel.slice(s![.., 0]));
}
}
mel_spectrogram_vec.push(mel_spectrogram_per_chunk);
}
let mut mel_spectrogram = stack(
Axis(0),
&mel_spectrogram_vec
.iter()
.map(|a| a.view())
.collect::<Vec<_>>(),
)?;
if !mel_spectrogram.is_standard_layout() {
mel_spectrogram = mel_spectrogram.as_standard_layout().into_owned()
}
let shape = mel_spectrogram.shape().to_vec();
let storage_view = sys::StorageView::new(
&shape,
mel_spectrogram.as_slice_mut().unwrap(),
Default::default(),
)?;
let lang_token = match language {
Some(lang) => {
format!("<|{}|>", lang)
}
None => {
let detection_result = self.whisper.detect_language(&storage_view)?;
detection_result
.into_iter()
.next()
.ok_or_else(|| anyhow!("failed to detect language"))?
.into_iter()
.next()
.ok_or_else(|| anyhow!("failed to detect language"))?
.language
}
};
let mut prompt = vec!["<|startoftranscript|>", &lang_token, "<|transcribe|>"];
if !timestamp {
prompt.push("<|notimestamps|>");
}
self.whisper
.generate(
&storage_view,
&vec![prompt; mel_spectrogram_vec.len()],
options,
)?
.into_iter()
.map(|res| {
let r = res
.sequences
.into_iter()
.next()
.ok_or_else(|| anyhow!("failed to transcribe samples"))?;
self.tokenizer.decode(r)
})
.collect()
}
pub fn sampling_rate(&self) -> usize {
self.config.sampling_rate
}
pub fn n_samples(&self) -> usize {
self.config.n_samples
}
#[inline]
pub fn is_multilingual(&self) -> bool {
self.whisper.is_multilingual()
}
#[inline]
pub fn num_languages(&self) -> usize {
self.whisper.num_languages()
}
#[inline]
pub fn num_queued_batches(&self) -> usize {
self.whisper.num_queued_batches()
}
#[inline]
pub fn num_active_batches(&self) -> usize {
self.whisper.num_active_batches()
}
#[inline]
pub fn num_replicas(&self) -> usize {
self.whisper.num_replicas()
}
}
impl Debug for Whisper {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.whisper)
}
}
#[derive(Debug)]
#[allow(dead_code)]
struct PreprocessorConfig {
chunk_length: usize,
feature_extractor_type: String,
feature_size: usize,
hop_length: usize,
n_fft: usize,
n_samples: usize,
nb_max_frames: usize,
padding_side: String,
padding_value: f32,
processor_class: String,
return_attention_mask: bool,
sampling_rate: usize,
mel_filters: Array2<f64>,
}
impl PreprocessorConfig {
fn read<T: AsRef<Path>>(path: T) -> Result<Self> {
let file = File::open(path)?;
let reader = BufReader::new(file);
#[derive(Deserialize)]
struct PreprocessorConfigAux {
chunk_length: usize,
feature_extractor_type: String,
feature_size: usize,
hop_length: usize,
n_fft: usize,
n_samples: usize,
nb_max_frames: usize,
padding_side: String,
padding_value: f32,
processor_class: String,
return_attention_mask: bool,
sampling_rate: usize,
mel_filters: Option<Vec<Vec<f64>>>,
}
let aux: PreprocessorConfigAux = serde_json::from_reader(reader)?;
let mel_filters = if let Some(mel_filters) = aux.mel_filters {
let rows = mel_filters.len();
let cols = mel_filters.first().map(|row| row.len()).unwrap_or_default();
Array2::from_shape_vec((rows, cols), mel_filters.into_iter().flatten().collect())?
} else {
mel(
aux.sampling_rate as f64,
aux.n_fft,
aux.feature_size,
None,
None,
false,
true,
)
};
Ok(Self {
chunk_length: aux.chunk_length,
feature_extractor_type: aux.feature_extractor_type,
feature_size: aux.feature_size,
hop_length: aux.hop_length,
n_fft: aux.n_fft,
n_samples: aux.n_samples,
nb_max_frames: aux.nb_max_frames,
padding_side: aux.padding_side,
padding_value: aux.padding_value,
processor_class: aux.processor_class,
return_attention_mask: aux.return_attention_mask,
sampling_rate: aux.sampling_rate,
mel_filters,
})
}
}
#[cfg(test)]
#[cfg(feature = "hub")]
mod tests {
use crate::{download_model, Config, Device, Whisper};
const MODEL_ID: &str = "jkawamoto/whisper-tiny-ct2";
#[test]
#[ignore]
fn test_whisper_debug() {
let model_path = download_model(MODEL_ID).unwrap();
let w = Whisper::new(
&model_path,
Config {
device: if cfg!(feature = "cuda") {
Device::CUDA
} else {
Device::CPU
},
..Default::default()
},
)
.unwrap();
assert!(format!("{:?}", w).contains(model_path.file_name().unwrap().to_str().unwrap()));
}
}