use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::Mutex;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
use super::{Result, SttError, TranscribeConfig};
static MODEL_CACHE: tokio::sync::OnceCell<Mutex<HashMap<PathBuf, Arc<WhisperContext>>>> =
tokio::sync::OnceCell::const_new();
async fn load_model(path: &Path) -> Result<Arc<WhisperContext>> {
if !path.is_file() {
return Err(SttError::ModelMissing(path.display().to_string()));
}
let cache = MODEL_CACHE
.get_or_init(|| async { Mutex::new(HashMap::new()) })
.await;
{
let guard = cache.lock().await;
if let Some(ctx) = guard.get(path) {
return Ok(Arc::clone(ctx));
}
}
tracing::info!(model = %path.display(), "stt: loading whisper model");
let path_str = path.to_string_lossy().into_owned();
let ctx = tokio::task::spawn_blocking(move || {
WhisperContext::new_with_params(&path_str, WhisperContextParameters::default())
.map_err(|e| SttError::Whisper(format!("context init: {e}")))
})
.await
.map_err(|e| SttError::Whisper(format!("context init join: {e}")))??;
let arc = Arc::new(ctx);
let mut guard = cache.lock().await;
Ok(guard
.entry(path.to_path_buf())
.or_insert_with(|| Arc::clone(&arc))
.clone())
}
pub async fn transcribe_file(path: &Path, cfg: &TranscribeConfig) -> Result<String> {
let started = std::time::Instant::now();
let pcm = super::audio::decode_to_pcm_mono(path, cfg).await?;
if pcm.is_empty() {
return Err(SttError::EmptyAudio);
}
let samples = super::audio::pcm_s16_to_f32(&pcm);
let lang = cfg.lang_hint.clone();
let context = load_model(&cfg.model_path).await?;
let transcript = tokio::task::spawn_blocking(move || -> Result<String> {
let mut state = context
.create_state()
.map_err(|e| SttError::Whisper(format!("create_state: {e}")))?;
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
if let Some(l) = lang.filter(|l| l != "auto") {
params.set_language(Some(Box::leak(l.into_boxed_str())));
}
state
.full(params, &samples)
.map_err(|e| SttError::Whisper(format!("full: {e}")))?;
let n = state
.full_n_segments()
.map_err(|e| SttError::Whisper(format!("segments: {e}")))?;
let mut out = String::new();
for i in 0..n {
let seg = state
.full_get_segment_text(i)
.map_err(|e| SttError::Whisper(format!("segment text {i}: {e}")))?;
out.push_str(&seg);
}
Ok(out.trim().to_string())
})
.await
.map_err(|e| SttError::Whisper(format!("transcribe join: {e}")))??;
let elapsed_ms = started.elapsed().as_millis() as u64;
tracing::info!(
path = %path.display(),
transcript_len = transcript.len(),
elapsed_ms,
"stt: transcription ok",
);
if transcript.is_empty() {
return Err(SttError::EmptyTranscript);
}
Ok(transcript)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pcm_round_trip_handles_extreme_values() {
let pcm: Vec<u8> = vec![0x00, 0x80, 0x00, 0x00, 0xFF, 0x7F];
let f = super::super::audio::pcm_s16_to_f32(&pcm);
assert_eq!(f.len(), 3);
assert!((f[0] - -1.0).abs() < 0.001);
assert_eq!(f[1], 0.0);
assert!((f[2] - 1.0).abs() < 0.001);
}
#[tokio::test]
async fn load_model_surfaces_missing_file_as_typed_error() {
let p = PathBuf::from("/nonexistent/whisper-model-for-tests.bin");
let r = load_model(&p).await;
assert!(matches!(r, Err(SttError::ModelMissing(_))));
}
#[tokio::test]
async fn transcribe_file_surfaces_missing_audio_as_io_error() {
let cfg = TranscribeConfig::default();
let r = transcribe_file(Path::new("/nonexistent/voice.ogg"), &cfg).await;
assert!(matches!(r, Err(SttError::Io(_))));
}
}