use anyhow::{Context, Result};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{OnceLock, RwLock};
use whisper_rs::{WhisperContext, WhisperContextParameters, WhisperState};
static MODEL_CACHE: OnceLock<RwLock<Option<CachedModel>>> = OnceLock::new();
static KEEP_LOADED: AtomicBool = AtomicBool::new(false);
struct CachedModel {
context: WhisperContext,
path: String,
}
fn get_cache() -> &'static RwLock<Option<CachedModel>> {
MODEL_CACHE.get_or_init(|| RwLock::new(None))
}
pub fn get_model(path: &str) -> Result<ModelGuard> {
{
let cache = get_cache().read().unwrap();
if let Some(ref cached) = *cache
&& cached.path == path
{
let state = cached
.context
.create_state()
.context("Failed to create whisper state")?;
return Ok(ModelGuard { state });
}
}
let state = {
let mut cache = get_cache().write().unwrap();
if let Some(ref cached) = *cache
&& cached.path == path
{
let state = cached
.context
.create_state()
.context("Failed to create whisper state")?;
return Ok(ModelGuard { state });
}
if path.is_empty() {
anyhow::bail!(
"Whisper model path not configured. Set LOCAL_WHISPER_MODEL_PATH or use: whis config --whisper-model-path <path>"
);
}
if !std::path::Path::new(path).exists() {
anyhow::bail!(
"Whisper model not found at: {}\n\
Download a model from: https://huggingface.co/ggerganov/whisper.cpp/tree/main",
path
);
}
whisper_rs::install_logging_hooks();
crate::verbose!("Loading whisper model from: {}", path);
let context = WhisperContext::new_with_params(path, WhisperContextParameters::default())
.context("Failed to load whisper model")?;
crate::verbose!("Whisper model loaded successfully");
let state = context
.create_state()
.context("Failed to create whisper state")?;
*cache = Some(CachedModel {
context,
path: path.to_string(),
});
state
};
Ok(ModelGuard { state })
}
pub fn unload_model() {
let mut cache = get_cache().write().unwrap();
if cache.is_some() {
crate::verbose!("Unloading whisper model from cache");
*cache = None;
}
}
pub fn set_keep_loaded(keep: bool) {
KEEP_LOADED.store(keep, Ordering::SeqCst);
crate::verbose!("Model cache keep_loaded set to: {}", keep);
}
pub fn should_keep_loaded() -> bool {
KEEP_LOADED.load(Ordering::SeqCst)
}
pub fn maybe_unload() {
if !should_keep_loaded() {
unload_model();
}
}
pub fn preload_model(path: &str) {
{
let cache = get_cache().read().unwrap();
if let Some(ref cached) = *cache
&& cached.path == path
{
crate::verbose!("Model already cached, skipping preload");
return;
}
}
let path = path.to_string();
std::thread::spawn(move || {
crate::verbose!("Preloading whisper model in background...");
if let Err(e) = get_model(&path) {
crate::verbose!("Preload failed: {}", e);
}
});
}
pub struct ModelGuard {
state: WhisperState,
}
impl ModelGuard {
pub fn state_mut(&mut self) -> &mut WhisperState {
&mut self.state
}
pub fn into_state(self) -> WhisperState {
self.state
}
}