use std::ffi::OsString;
use std::fmt::{Debug, Formatter};
use std::path::Path;
use anyhow::{anyhow, Result};
use cxx::UniquePtr;
use super::{
config, storage_view, vec_ffi_vecstr, Config, StorageView, VecStr, VecString, VecUSize,
};
use self::ffi::VecDetectionResult;
pub use self::ffi::{
DetectionResult, WhisperAlignmentResult, WhisperOptions, WhisperTokenAlignment,
};
use crate::sys::{model_memory_reader, ModelMemoryReader};
#[cxx::bridge]
mod ffi {
#[derive(Clone, Debug)]
pub struct WhisperOptions {
pub beam_size: usize,
pub patience: f32,
pub length_penalty: f32,
pub repetition_penalty: f32,
pub no_repeat_ngram_size: usize,
pub max_length: usize,
pub sampling_topk: usize,
pub sampling_temperature: f32,
pub num_hypotheses: usize,
pub return_scores: bool,
pub return_logits_vocab: bool,
pub return_no_speech_prob: bool,
pub max_initial_timestamp_index: usize,
pub suppress_blank: bool,
pub suppress_tokens: Vec<i32>,
}
struct WhisperGenerationResult {
sequences: Vec<VecString>,
sequences_ids: Vec<VecUSize>,
scores: Vec<f32>,
no_speech_prob: f32,
}
#[derive(PartialEq, Clone, Debug)]
pub struct DetectionResult {
language: String,
probability: f32,
}
#[derive(PartialEq, Clone)]
struct VecDetectionResult {
v: Vec<DetectionResult>,
}
#[derive(Clone, Debug)]
pub struct WhisperTokenAlignment {
pub token_x: i64,
pub frame_x: i64,
}
#[derive(Debug)]
pub struct WhisperAlignmentResult {
pub alignments: Vec<WhisperTokenAlignment>,
pub text_token_probs: Vec<f32>,
}
unsafe extern "C++" {
include!("ct2rs/include/whisper.h");
include!("ct2rs/src/sys/types.rs.h");
type VecStr<'a> = super::VecStr<'a>;
type VecString = super::VecString;
type VecUSize = super::VecUSize;
type Config = super::config::ffi::Config;
type ModelMemoryReader = super::model_memory_reader::ffi::ModelMemoryReader;
type StorageView = super::storage_view::ffi::StorageView;
type Whisper;
fn whisper(model_path: &str, config: UniquePtr<Config>) -> Result<UniquePtr<Whisper>>;
fn encode(
self: &Whisper,
features: &StorageView,
to_cpu: bool,
) -> Result<UniquePtr<StorageView>>;
fn whisper_from_memory(
model_memory_reader: Pin<&mut ModelMemoryReader>,
config: UniquePtr<Config>,
) -> Result<UniquePtr<Whisper>>;
fn generate(
self: &Whisper,
features: &StorageView,
prompts: &[VecStr],
options: &WhisperOptions,
) -> Result<Vec<WhisperGenerationResult>>;
fn detect_language(
self: &Whisper,
features: &StorageView,
) -> Result<Vec<VecDetectionResult>>;
fn align(
self: &Whisper,
features: &StorageView,
start_sequence: &[usize],
text_tokens: &[Vec<usize>],
num_frames: &[usize],
median_filter_width: i64,
) -> Result<Vec<WhisperAlignmentResult>>;
fn is_multilingual(self: &Whisper) -> bool;
fn n_mels(self: &Whisper) -> usize;
fn num_languages(self: &Whisper) -> usize;
fn num_queued_batches(self: &Whisper) -> usize;
fn num_active_batches(self: &Whisper) -> usize;
fn num_replicas(self: &Whisper) -> usize;
}
}
impl Default for WhisperOptions {
fn default() -> Self {
Self {
beam_size: 5,
patience: 1.,
length_penalty: 1.,
repetition_penalty: 1.,
no_repeat_ngram_size: 0,
max_length: 448,
sampling_topk: 1,
sampling_temperature: 1.,
num_hypotheses: 1,
return_scores: false,
return_logits_vocab: false,
return_no_speech_prob: false,
max_initial_timestamp_index: 50,
suppress_blank: true,
suppress_tokens: vec![-1],
}
}
}
#[derive(Clone, Debug)]
pub struct WhisperGenerationResult {
pub sequences: Vec<Vec<String>>,
pub sequences_ids: Vec<Vec<usize>>,
pub scores: Vec<f32>,
pub no_speech_prob: f32,
}
impl WhisperGenerationResult {
#[inline]
pub fn num_sequences(&self) -> usize {
self.sequences.len()
}
#[inline]
pub fn has_scores(&self) -> bool {
!self.scores.is_empty()
}
}
impl From<ffi::WhisperGenerationResult> for WhisperGenerationResult {
fn from(r: ffi::WhisperGenerationResult) -> Self {
Self {
sequences: r.sequences.into_iter().map(Vec::<String>::from).collect(),
sequences_ids: r
.sequences_ids
.into_iter()
.map(Vec::<usize>::from)
.collect(),
scores: r.scores,
no_speech_prob: r.no_speech_prob,
}
}
}
impl From<VecDetectionResult> for Vec<DetectionResult> {
fn from(value: VecDetectionResult) -> Self {
value.v
}
}
pub struct Whisper {
model: OsString,
ptr: UniquePtr<ffi::Whisper>,
}
impl Whisper {
pub fn new<T: AsRef<Path>>(model_path: T, config: Config) -> Result<Self> {
let model_path = model_path.as_ref();
Ok(Self {
model: model_path
.file_name()
.map(|s| s.to_os_string())
.unwrap_or_default(),
ptr: ffi::whisper(
model_path
.to_str()
.ok_or_else(|| anyhow!("invalid path: {}", model_path.display()))?,
config.to_ffi(),
)?,
})
}
pub fn encode(&self, features: &StorageView, to_cpu: bool) -> Result<StorageView<'static>> {
Ok(StorageView::from_cxx(self.ptr.encode(features, to_cpu)?))
}
pub fn new_from_memory(
model_memory_reader: &mut ModelMemoryReader,
config: Config,
) -> Result<Self> {
Ok(Self {
model: OsString::from(model_memory_reader.get_model_id()),
ptr: ffi::whisper_from_memory(model_memory_reader.pin_mut_impl(), config.to_ffi())?,
})
}
pub fn generate<T: AsRef<str>>(
&self,
features: &StorageView,
prompts: &[Vec<T>],
options: &WhisperOptions,
) -> Result<Vec<WhisperGenerationResult>> {
self.ptr
.generate(features, &vec_ffi_vecstr(prompts), options)
.map(|res| res.into_iter().map(WhisperGenerationResult::from).collect())
.map_err(|e| anyhow!("failed to generate: {e}"))
}
pub fn detect_language(&self, features: &StorageView) -> Result<Vec<Vec<DetectionResult>>> {
self.ptr
.detect_language(features)
.map(|res| res.into_iter().map(VecDetectionResult::into).collect())
.map_err(|e| anyhow!("failed to detect language: {e}"))
}
pub fn align(
&self,
encoder_output: &StorageView,
start_sequence: &[usize],
text_tokens: &[Vec<usize>],
num_frames: &[usize],
median_filter_width: i64,
) -> Result<Vec<ffi::WhisperAlignmentResult>> {
Ok(self.ptr.align(
encoder_output,
start_sequence,
text_tokens,
num_frames,
median_filter_width,
)?)
}
#[inline]
pub fn is_multilingual(&self) -> bool {
self.ptr.is_multilingual()
}
#[inline]
pub fn n_mels(&self) -> usize {
self.ptr.n_mels()
}
#[inline]
pub fn num_languages(&self) -> usize {
self.ptr.num_languages()
}
#[inline]
pub fn num_queued_batches(&self) -> usize {
self.ptr.num_queued_batches()
}
#[inline]
pub fn num_active_batches(&self) -> usize {
self.ptr.num_active_batches()
}
#[inline]
pub fn num_replicas(&self) -> usize {
self.ptr.num_replicas()
}
}
impl Debug for Whisper {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Whisper")
.field("model", &self.model)
.field("multilingual", &self.is_multilingual())
.field("mels", &self.n_mels())
.field("languages", &self.num_languages())
.field("queued_batches", &self.num_queued_batches())
.field("active_batches", &self.num_active_batches())
.field("replicas", &self.num_replicas())
.finish()
}
}
#[cfg(target_os = "windows")]
impl Drop for Whisper {
fn drop(&mut self) {
let ptr = std::mem::replace(&mut self.ptr, UniquePtr::null());
unsafe {
std::ptr::drop_in_place(ptr.into_raw());
}
}
}
unsafe impl Send for ffi::Whisper {}
unsafe impl Sync for ffi::Whisper {}
#[cfg(test)]
mod tests {
use super::{ffi, WhisperGenerationResult, WhisperOptions};
#[test]
fn test_default_options() {
let opts = WhisperOptions::default();
assert_eq!(opts.beam_size, 5);
assert_eq!(opts.patience, 1.);
assert_eq!(opts.length_penalty, 1.);
assert_eq!(opts.repetition_penalty, 1.);
assert_eq!(opts.no_repeat_ngram_size, 0);
assert_eq!(opts.max_length, 448);
assert_eq!(opts.sampling_topk, 1);
assert_eq!(opts.sampling_temperature, 1.);
assert_eq!(opts.num_hypotheses, 1);
assert!(!opts.return_scores);
assert!(!opts.return_logits_vocab);
assert!(!opts.return_no_speech_prob);
assert_eq!(opts.max_initial_timestamp_index, 50);
assert!(opts.suppress_blank);
assert_eq!(opts.suppress_tokens, vec![-1]);
}
#[test]
fn test_generation_result() {
let sequences = vec![
vec!["a".to_string(), "b".to_string()],
vec!["x".to_string(), "y".to_string(), "z".to_string()],
];
let sequences_ids = vec![vec![1, 2], vec![5, 6, 7]];
let scores = vec![9., 8., 7.];
let no_speech_prob = 10.;
let res: WhisperGenerationResult = ffi::WhisperGenerationResult {
sequences: sequences
.iter()
.map(|v| ffi::VecString::from(v.clone()))
.collect(),
sequences_ids: sequences_ids
.iter()
.map(|v| ffi::VecUSize::from(v.clone()))
.collect(),
scores: scores.clone(),
no_speech_prob,
}
.into();
assert_eq!(res.sequences, sequences);
assert_eq!(res.sequences_ids, sequences_ids);
assert_eq!(res.scores, scores);
assert_eq!(res.no_speech_prob, no_speech_prob);
assert_eq!(res.num_sequences(), sequences.len());
assert!(res.has_scores());
}
#[test]
fn test_empty_result() {
let res: WhisperGenerationResult = ffi::WhisperGenerationResult {
sequences: vec![],
sequences_ids: vec![],
scores: vec![],
no_speech_prob: 0.,
}
.into();
assert!(res.sequences.is_empty());
assert!(res.sequences_ids.is_empty());
assert!(res.scores.is_empty());
assert_eq!(res.no_speech_prob, 0.);
assert_eq!(res.num_sequences(), 0);
assert!(!res.has_scores());
}
#[cfg(feature = "hub")]
mod hub {
use crate::download_model;
use crate::sys::Whisper;
const MODEL_ID: &str = "jkawamoto/whisper-tiny-ct2";
#[test]
#[ignore]
fn test_whisper_debug() {
let model_path = download_model(MODEL_ID).unwrap();
let whisper = Whisper::new(&model_path, Default::default()).unwrap();
assert!(format!("{:?}", whisper)
.contains(model_path.file_name().unwrap().to_str().unwrap()));
}
}
}