use super::audio::AudioData;
use crate::error::{AumateError, Result};
use crate::ml::{Device, DeviceConfig, get_device};
use byteorder::{ByteOrder, LittleEndian};
use candle_core::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::whisper::{self as m, Config};
use std::path::{Path, PathBuf};
use std::time::Instant;
use tokenizers::Tokenizer;
const MEL_FILTERS_80: &[u8] = include_bytes!("melfilters.bytes");
const MEL_FILTERS_128: &[u8] = include_bytes!("melfilters128.bytes");
pub enum WhisperModel {
Normal(m::model::Whisper),
}
impl WhisperModel {
pub fn encoder_forward(&mut self, mel: &Tensor, flush: bool) -> Result<Tensor> {
match self {
Self::Normal(model) => model
.encoder
.forward(mel, flush)
.map_err(|e| AumateError::Other(format!("Encoder forward failed: {}", e))),
}
}
pub fn decoder_forward(
&mut self,
tokens: &Tensor,
audio_features: &Tensor,
flush: bool,
) -> Result<Tensor> {
match self {
Self::Normal(model) => {
let decoder_output = model
.decoder
.forward(tokens, audio_features, flush)
.map_err(|e| AumateError::Other(format!("Decoder forward failed: {}", e)))?;
model
.decoder
.final_linear(&decoder_output)
.map_err(|e| AumateError::Other(format!("Final linear failed: {}", e)))
}
}
}
#[allow(dead_code)]
pub fn reset_kv_cache(&mut self) {
match self {
Self::Normal(model) => model.reset_kv_cache(),
}
}
}
#[derive(Debug, Clone)]
pub struct TranscriptionResult {
pub text: String,
pub language: Option<String>,
pub duration_ms: u64,
}
pub struct WhisperEngine {
model: Option<WhisperModel>,
config: Option<Config>,
tokenizer: Option<Tokenizer>,
mel_filters: Vec<f32>,
device: Device,
model_path: Option<PathBuf>,
language: Option<String>,
}
impl WhisperEngine {
pub fn new() -> Self {
let device = get_device(&DeviceConfig::with_gpu()).unwrap_or(Device::Cpu);
Self {
model: None,
config: None,
tokenizer: None,
mel_filters: Vec::new(),
device,
model_path: None,
language: None,
}
}
pub fn with_device(config: DeviceConfig) -> Result<Self> {
let device = get_device(&config)?;
Ok(Self {
model: None,
config: None,
tokenizer: None,
mel_filters: Vec::new(),
device,
model_path: None,
language: None,
})
}
pub fn load_model(&mut self, model_dir: &Path) -> Result<()> {
log::info!("Loading Whisper model from: {:?}", model_dir);
if !model_dir.exists() {
return Err(AumateError::Other(format!("Model directory not found: {:?}", model_dir)));
}
let config_path = model_dir.join("config.json");
let config: Config = if config_path.exists() {
let config_str = std::fs::read_to_string(&config_path)?;
serde_json::from_str(&config_str)
.map_err(|e| AumateError::Other(format!("Failed to parse config: {}", e)))?
} else {
return Err(AumateError::Other("config.json not found".to_string()));
};
let tokenizer_path = model_dir.join("tokenizer.json");
let tokenizer = if tokenizer_path.exists() {
Tokenizer::from_file(&tokenizer_path)
.map_err(|e| AumateError::Other(format!("Failed to load tokenizer: {}", e)))?
} else {
return Err(AumateError::Other("tokenizer.json not found".to_string()));
};
self.mel_filters = Self::get_mel_filters(config.num_mel_bins)?;
let weights_path = model_dir.join("model.safetensors");
if !weights_path.exists() {
return Err(AumateError::Other("model.safetensors not found".to_string()));
}
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[&weights_path], DType::F32, &self.device)
.map_err(|e| AumateError::Other(format!("Failed to load weights: {}", e)))?
};
let model = m::model::Whisper::load(&vb, config.clone())
.map_err(|e| AumateError::Other(format!("Failed to create model: {}", e)))?;
self.model = Some(WhisperModel::Normal(model));
self.config = Some(config);
self.tokenizer = Some(tokenizer);
self.model_path = Some(model_dir.to_path_buf());
log::info!("Whisper model loaded successfully on {:?}", self.device);
Ok(())
}
fn get_mel_filters(num_mel_bins: usize) -> Result<Vec<f32>> {
let mel_bytes = match num_mel_bins {
80 => MEL_FILTERS_80,
128 => MEL_FILTERS_128,
_ => {
return Err(AumateError::Other(format!(
"Unsupported num_mel_bins: {}. Expected 80 or 128.",
num_mel_bins
)));
}
};
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
LittleEndian::read_f32_into(mel_bytes, &mut mel_filters);
log::debug!("Loaded {} mel filters ({} floats)", num_mel_bins, mel_filters.len());
Ok(mel_filters)
}
pub fn unload_model(&mut self) {
self.model = None;
self.config = None;
self.tokenizer = None;
self.mel_filters.clear();
self.model_path = None;
log::info!("Whisper model unloaded");
}
pub fn is_loaded(&self) -> bool {
self.model.is_some()
}
pub fn model_path(&self) -> Option<&Path> {
self.model_path.as_deref()
}
pub fn set_language(&mut self, language: Option<String>) {
self.language = language;
}
pub fn language(&self) -> Option<&str> {
self.language.as_deref()
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn transcribe(&mut self, audio: &AudioData) -> Result<TranscriptionResult> {
let config = self
.config
.clone()
.ok_or_else(|| AumateError::Other("No config loaded".to_string()))?;
let tokenizer = self
.tokenizer
.clone()
.ok_or_else(|| AumateError::Other("No tokenizer loaded".to_string()))?;
let prepared = audio.prepare_for_whisper();
let start_time = Instant::now();
let mel = self.pcm_to_mel(&prepared.samples, &config)?;
let text = self.decode_audio(&mel, &tokenizer)?;
let duration_ms = start_time.elapsed().as_millis() as u64;
log::info!(
"Transcription completed in {}ms: \"{}\"",
duration_ms,
if text.len() > 50 { format!("{}...", &text[..50]) } else { text.clone() }
);
Ok(TranscriptionResult { text, language: self.language.clone(), duration_ms })
}
fn decode_audio(&mut self, mel: &Tensor, tokenizer: &Tokenizer) -> Result<String> {
let model =
self.model.as_mut().ok_or_else(|| AumateError::Other("No model loaded".to_string()))?;
let sot_token = tokenizer.token_to_id("<|startoftranscript|>").unwrap_or(50258);
let eot_token = tokenizer.token_to_id("<|endoftext|>").unwrap_or(50257);
let transcribe_token = tokenizer.token_to_id("<|transcribe|>").unwrap_or(50359);
let no_timestamps_token = tokenizer.token_to_id("<|notimestamps|>").unwrap_or(50363);
let language_token =
self.language.as_ref().and_then(|lang| tokenizer.token_to_id(&format!("<|{}|>", lang)));
let audio_features = model.encoder_forward(mel, true)?;
let mut tokens = vec![sot_token];
if let Some(lang_token) = language_token {
tokens.push(lang_token);
}
tokens.push(transcribe_token);
tokens.push(no_timestamps_token);
let initial_len = tokens.len();
let max_tokens = 224;
for _ in 0..max_tokens {
let tokens_tensor = Tensor::new(tokens.as_slice(), &self.device)
.map_err(|e| AumateError::Other(format!("Failed to create tokens tensor: {}", e)))?
.unsqueeze(0)
.map_err(|e| AumateError::Other(format!("Failed to unsqueeze: {}", e)))?;
let logits = model.decoder_forward(
&tokens_tensor,
&audio_features,
tokens.len() == initial_len,
)?;
let seq_len = logits
.dim(1)
.map_err(|e| AumateError::Other(format!("Failed to get dim: {}", e)))?;
let last_logits = logits
.narrow(1, seq_len - 1, 1)
.map_err(|e| AumateError::Other(format!("Failed to narrow: {}", e)))?
.squeeze(1)
.map_err(|e| AumateError::Other(format!("Failed to squeeze: {}", e)))?;
let next_token_tensor = last_logits
.argmax(1)
.map_err(|e| AumateError::Other(format!("Failed to argmax: {}", e)))?;
let next_token = next_token_tensor
.squeeze(0)
.map_err(|e| AumateError::Other(format!("Failed to squeeze argmax: {}", e)))?
.to_scalar::<u32>()
.map_err(|e| AumateError::Other(format!("Failed to get scalar: {}", e)))?;
if next_token == eot_token {
break;
}
tokens.push(next_token);
}
let text_tokens: Vec<u32> =
tokens.into_iter().skip(initial_len).filter(|&t| t != eot_token).collect();
let text = tokenizer
.decode(&text_tokens, true)
.map_err(|e| AumateError::Other(format!("Failed to decode tokens: {}", e)))?;
Ok(text.trim().to_string())
}
fn pcm_to_mel(&self, samples: &[f32], config: &Config) -> Result<Tensor> {
let mel = m::audio::pcm_to_mel(config, samples, &self.mel_filters);
let mel_len = mel.len();
Tensor::from_vec(mel, (1, config.num_mel_bins, mel_len / config.num_mel_bins), &self.device)
.map_err(|e| AumateError::Other(format!("Failed to create mel tensor: {}", e)))
}
}
impl Default for WhisperEngine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_engine_creation() {
let engine = WhisperEngine::new();
assert!(!engine.is_loaded());
assert!(engine.model_path().is_none());
}
#[test]
fn test_language_setting() {
let mut engine = WhisperEngine::new();
assert!(engine.language().is_none());
engine.set_language(Some("en".to_string()));
assert_eq!(engine.language(), Some("en"));
engine.set_language(None);
assert!(engine.language().is_none());
}
#[test]
fn test_transcribe_without_model() {
let mut engine = WhisperEngine::new();
let audio = AudioData { samples: vec![0.0; 16000], sample_rate: 16000, channels: 1 };
let result = engine.transcribe(&audio);
assert!(result.is_err());
}
}