use byteorder::{ByteOrder, LittleEndian};
use candle_core::{Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::whisper::{self as m, audio, Config};
use std::path::Path;
use thiserror::Error;
use tokenizers::Tokenizer;
#[derive(Error, Debug)]
pub enum WhisperError {
#[error("Config error: {0}")]
Config(String),
#[error("Tokenizer error: {0}")]
Tokenizer(String),
#[error("Mel filters error: {0}")]
MelFilters(String),
#[error("Model weights error: {0}")]
Weights(String),
#[error("Token '{0}' not found in vocabulary")]
TokenNotFound(String),
#[error("Candle error: {0}")]
Candle(#[from] candle_core::Error),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
}
pub type WhisperResult<T> = Result<T, WhisperError>;
#[derive(Debug, Clone)]
pub struct WhisperConfig {
pub model_size: WhisperSize,
pub language: Option<String>,
pub task: Task,
pub timestamps: bool,
}
impl Default for WhisperConfig {
fn default() -> Self {
Self {
model_size: WhisperSize::Tiny,
language: Some("en".to_string()),
task: Task::Transcribe,
timestamps: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WhisperSize {
Tiny,
Base,
Small,
Medium,
Large,
LargeV2,
LargeV3,
LargeV3Turbo,
}
impl WhisperSize {
pub fn as_str(&self) -> &'static str {
match self {
WhisperSize::Tiny => "tiny",
WhisperSize::Base => "base",
WhisperSize::Small => "small",
WhisperSize::Medium => "medium",
WhisperSize::Large => "large",
WhisperSize::LargeV2 => "large-v2",
WhisperSize::LargeV3 => "large-v3",
WhisperSize::LargeV3Turbo => "large-v3-turbo",
}
}
pub fn hf_model_id(&self) -> &'static str {
match self {
WhisperSize::Tiny => "openai/whisper-tiny",
WhisperSize::Base => "openai/whisper-base",
WhisperSize::Small => "openai/whisper-small",
WhisperSize::Medium => "openai/whisper-medium",
WhisperSize::Large => "openai/whisper-large",
WhisperSize::LargeV2 => "openai/whisper-large-v2",
WhisperSize::LargeV3 => "openai/whisper-large-v3",
WhisperSize::LargeV3Turbo => "openai/whisper-large-v3-turbo",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Task {
#[default]
Transcribe,
Translate,
}
pub struct WhisperModel {
model: m::model::Whisper,
tokenizer: Tokenizer,
config: Config,
device: Device,
mel_filters: Vec<f32>,
sot_token: u32,
eot_token: u32,
transcribe_token: u32,
translate_token: u32,
no_timestamps_token: u32,
language_token: Option<u32>,
user_config: WhisperConfig,
}
impl WhisperModel {
pub fn load(model_dir: &Path, device: &Device) -> WhisperResult<Self> {
Self::load_with_config(model_dir, device, WhisperConfig::default())
}
pub fn load_with_config(
model_dir: &Path,
device: &Device,
user_config: WhisperConfig,
) -> WhisperResult<Self> {
let config_path = model_dir.join("config.json");
let config: Config = serde_json::from_str(&std::fs::read_to_string(&config_path)?)?;
let tokenizer_path = model_dir.join("tokenizer.json");
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| WhisperError::Tokenizer(format!("Failed to load tokenizer: {}", e)))?;
let mel_filters_path = model_dir.join("melfilters.bytes");
let mel_filters = if mel_filters_path.exists() {
let mel_bytes = std::fs::read(&mel_filters_path)?;
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
LittleEndian::read_f32_into(&mel_bytes, &mut mel_filters);
mel_filters
} else {
match config.num_mel_bins {
80 => {
return Err(WhisperError::MelFilters(format!(
"melfilters.bytes not found at {:?}. Please download from Candle examples.",
mel_filters_path
)));
}
128 => {
return Err(WhisperError::MelFilters(format!(
"melfilters128.bytes not found at {:?}. Please download from Candle examples.",
mel_filters_path
)));
}
n => {
return Err(WhisperError::MelFilters(format!(
"Unsupported num_mel_bins: {}",
n
)));
}
}
};
let weights_path = model_dir.join("model.safetensors");
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, device)?
};
let model = m::model::Whisper::load(&vb, config.clone())?;
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
let language_token = if let Some(ref lang) = user_config.language {
let lang_token = format!("<|{}|>", lang);
token_id(&tokenizer, &lang_token).ok()
} else {
None
};
Ok(Self {
model,
tokenizer,
config,
device: device.clone(),
mel_filters,
sot_token,
eot_token,
transcribe_token,
translate_token,
no_timestamps_token,
language_token,
user_config,
})
}
#[cfg(feature = "candle-hub")]
pub fn from_hf(size: WhisperSize, device: &Device) -> WhisperResult<Self> {
use hf_hub::{api::sync::Api, Repo, RepoType};
let api = Api::new().map_err(|e| WhisperError::Config(format!("HF API error: {}", e)))?;
let repo = api.repo(Repo::new(size.hf_model_id().to_string(), RepoType::Model));
let _config_path = repo
.get("config.json")
.map_err(|e| WhisperError::Config(format!("Failed to download config: {}", e)))?;
let _tokenizer_path = repo
.get("tokenizer.json")
.map_err(|e| WhisperError::Tokenizer(format!("Failed to download tokenizer: {}", e)))?;
let weights_path = repo
.get("model.safetensors")
.map_err(|e| WhisperError::Weights(format!("Failed to download weights: {}", e)))?;
let model_dir = weights_path.parent().unwrap();
Self::load_with_config(
model_dir,
device,
WhisperConfig {
model_size: size,
..Default::default()
},
)
}
pub fn encode(&mut self, mel: &Tensor) -> candle_core::Result<Tensor> {
self.model.encoder.forward(mel, true)
}
pub fn transcribe(&mut self, mel: &Tensor) -> WhisperResult<String> {
let audio_features = self.model.encoder.forward(mel, true)?;
let mut tokens = vec![self.sot_token];
if let Some(lang_token) = self.language_token {
tokens.push(lang_token);
}
match self.user_config.task {
Task::Transcribe => tokens.push(self.transcribe_token),
Task::Translate => tokens.push(self.translate_token),
}
if !self.user_config.timestamps {
tokens.push(self.no_timestamps_token);
}
let sample_len = self.config.max_target_positions / 2;
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), &self.device)?;
let tokens_t = tokens_t.unsqueeze(0)?;
let ys = self
.model
.decoder
.forward(&tokens_t, &audio_features, i == 0)?;
let (_, seq_len, _) = ys.dims3()?;
let logits = self
.model
.decoder
.final_linear(&ys.i((.., seq_len - 1.., ..))?)?
.i(0)?
.i(0)?;
let next_token = logits.argmax(candle_core::D::Minus1)?.to_scalar::<u32>()?;
if next_token == self.eot_token || tokens.len() > self.config.max_target_positions {
break;
}
tokens.push(next_token);
}
let text = self
.tokenizer
.decode(&tokens, true)
.map_err(|e| WhisperError::Tokenizer(format!("Tokenizer decode error: {}", e)))?;
Ok(text)
}
pub fn config(&self) -> &Config {
&self.config
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn mel_filters(&self) -> &[f32] {
&self.mel_filters
}
pub fn pcm_to_mel_tensor(&self, pcm_data: &[f32]) -> WhisperResult<Tensor> {
const MAX_SAMPLES_30S: usize = 16000 * 30; let pcm_data = if pcm_data.len() > MAX_SAMPLES_30S {
&pcm_data[..MAX_SAMPLES_30S]
} else {
pcm_data
};
let mel = audio::pcm_to_mel(&self.config, pcm_data, &self.mel_filters);
let mel_len = mel.len();
let n_mels = self.config.num_mel_bins;
Tensor::from_vec(mel, (1, n_mels, mel_len / n_mels), &self.device)
.map_err(WhisperError::from)
}
pub fn transcribe_pcm(&mut self, pcm_data: &[f32]) -> WhisperResult<String> {
let mel = self.pcm_to_mel_tensor(pcm_data)?;
self.transcribe(&mel)
}
}
fn token_id(tokenizer: &Tokenizer, token: &str) -> WhisperResult<u32> {
tokenizer
.token_to_id(token)
.ok_or_else(|| WhisperError::TokenNotFound(token.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_whisper_size_as_str() {
assert_eq!(WhisperSize::Tiny.as_str(), "tiny");
assert_eq!(WhisperSize::LargeV3.as_str(), "large-v3");
}
#[test]
fn test_whisper_config_default() {
let config = WhisperConfig::default();
assert_eq!(config.model_size, WhisperSize::Tiny);
assert_eq!(config.language, Some("en".to_string()));
assert_eq!(config.task, Task::Transcribe);
assert!(!config.timestamps);
}
}