use std::{fmt::Display, time::Instant};
use log::{info, trace};
use strum::EnumIter;
use transcript::{Transcript, Utterance};
use whisper_rs::{
FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, WhisperError,
};
mod tests;
mod transcode;
pub mod transcript;
pub struct Model {
context: WhisperContext,
}
impl Model {
pub fn new(path: &str) -> Result<Self, WhisperError> {
trace!("Loading model {}", path);
let path_converted = std::path::Path::new(path);
if !path_converted.exists() {
return Err(WhisperError::InitError);
}
let params: WhisperContextParameters = WhisperContextParameters::default();
Ok({
Self {
context: WhisperContext::new_with_params(path, params)?,
}
})
}
pub fn download(model: &ModelType) -> Result<Self, ModelError> {
trace!("Downloading model {}", model);
let resp = ureq::get(&model.to_string())
.call()
.map_err(|e| ModelError::DownloadError(Box::new(e)))?;
assert!(resp.has("Content-Length"));
let len: usize = resp
.header("Content-Length")
.unwrap()
.parse()
.unwrap_or_default();
trace!("Model length: {}", len);
let mut bytes: Vec<u8> = Vec::with_capacity(len);
resp.into_reader()
.read_to_end(&mut bytes)
.map_err(ModelError::IoError)?;
assert_eq!(bytes.len(), len);
info!("Downloaded model: {}", model);
let params: WhisperContextParameters = WhisperContextParameters::default();
Ok({
Self {
context: WhisperContext::new_from_buffer_with_params(&bytes, params)
.map_err(ModelError::WhisperError)?,
}
})
}
pub fn transcribe_audio(
&self,
audio: impl AsRef<[u8]>,
translate: bool,
word_timestamps: bool,
initial_prompt: Option<&str>,
language: Option<&str>,
threads: Option<u16>,
) -> Result<Transcript, ModelError> {
trace!("Decoding audio.");
let samples = transcode::decode(audio.as_ref().to_vec())?;
trace!("Transcribing audio.");
self.transcribe_pcm_s16le(
&samples,
translate,
word_timestamps,
initial_prompt,
language,
threads,
)
}
pub fn transcribe_pcm_s16le(
&self,
audio: &[f32],
translate: bool,
word_timestamps: bool,
initial_prompt: Option<&str>,
language: Option<&str>,
threads: Option<u16>,
) -> Result<Transcript, ModelError> {
trace!(
"Transcribing audio: {} with translate: {translate} and timestamps: {word_timestamps}",
audio.len()
);
let mut params = FullParams::new(SamplingStrategy::BeamSearch {
beam_size: 5,
patience: 1.0,
});
if let Some(prompt) = initial_prompt {
params.set_initial_prompt(prompt);
}
params.set_language(language);
params.set_translate(translate);
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
params.set_token_timestamps(word_timestamps);
params.set_split_on_word(true);
#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
let threads = threads.map_or_else(|| num_cpus::get() as i32, i32::from);
trace!("Using {} threads", threads);
params.set_n_threads(threads);
let st = Instant::now();
let mut state = self.context.create_state().expect("failed to create state");
trace!("Transcribing audio with WhisperState");
state.full(params, audio).expect("failed to transcribe");
let num_segments = state.full_n_segments().expect("failed to get segments");
trace!("Number of segments: {}", num_segments);
let mut words = Vec::new();
let mut utterances = Vec::new();
for segment_idx in 0..num_segments {
let text = state
.full_get_segment_text(segment_idx)
.map_err(ModelError::WhisperError)?;
let start = state
.full_get_segment_t0(segment_idx)
.map_err(ModelError::WhisperError)?;
let stop = state
.full_get_segment_t1(segment_idx)
.map_err(ModelError::WhisperError)?;
utterances.push(Utterance { start, stop, text });
if !word_timestamps {
trace!("Skipping word timestamps");
continue;
}
trace!("Getting word timestamps for segment {}", segment_idx);
let num_tokens = state
.full_n_tokens(segment_idx)
.map_err(ModelError::WhisperError)?;
for t in 0..num_tokens {
let text = state
.full_get_token_text(segment_idx, t)
.map_err(ModelError::WhisperError)?;
let token_data = state
.full_get_token_data(segment_idx, t)
.map_err(ModelError::WhisperError)?;
if text.starts_with("[_") {
continue;
}
words.push(Utterance {
text,
start: token_data.t0,
stop: token_data.t1,
});
}
}
Ok(Transcript {
utterances,
processing_time: Instant::now().duration_since(st),
word_utterances: if word_timestamps { Some(words) } else { None },
})
}
}
#[derive(Debug)]
pub enum ModelError {
WhisperError(WhisperError),
DownloadError(Box<ureq::Error>),
IoError(std::io::Error),
AudioDecodeError,
}
#[derive(Debug, EnumIter)]
pub enum ModelType {
TinyEn,
Tiny,
BaseEn,
Base,
SmallEn,
Small,
MediumEn,
Medium,
LargeV1,
LargeV2,
LargeV3,
}
impl Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-"
)?;
match self {
Self::TinyEn => write!(f, "tiny.en.bin"),
Self::Tiny => write!(f, "tiny.bin"),
Self::BaseEn => write!(f, "base.en.bin"),
Self::Base => write!(f, "base.bin"),
Self::SmallEn => write!(f, "small.en.bin"),
Self::Small => write!(f, "small.bin"),
Self::MediumEn => write!(f, "medium.en.bin"),
Self::Medium => write!(f, "medium.bin"),
Self::LargeV1 => write!(f, "large-v1.bin"),
Self::LargeV2 => write!(f, "large-v2.bin"),
Self::LargeV3 => write!(f, "large-v3.bin"),
}
}
}