use std::path::{Path, PathBuf};
use std::sync::Arc;
use async_trait::async_trait;
use rubato::{FftFixedIn, Resampler};
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
use crate::stt::SttProvider;
use crate::{Result, VoiceConfig, VoiceError};
const WHISPER_SAMPLE_RATE: u32 = 16_000;
const WHISPER_MODEL_BASE_URL: &str = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main";
pub struct WhisperCppSttProvider {
ctx: Arc<WhisperContext>,
language: String,
}
impl std::fmt::Debug for WhisperCppSttProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WhisperCppSttProvider")
.field("language", &self.language)
.finish()
}
}
impl WhisperCppSttProvider {
pub fn from_config(config: &VoiceConfig) -> Result<Self> {
let model_id = &config.whisper_cpp_model;
let model_path = resolve_model_path(model_id)?;
if !model_path.exists() {
tracing::info!(
"[whisper-cpp] model file missing at {}, downloading from {}",
model_path.display(),
WHISPER_MODEL_BASE_URL,
);
download_model(model_id, &model_path)?;
}
tracing::info!(
"[whisper-cpp] loading {} ({} MB on disk)",
model_path.display(),
std::fs::metadata(&model_path)
.map(|m| m.len() / (1024 * 1024))
.unwrap_or(0),
);
let ctx = WhisperContext::new_with_params(
model_path
.to_str()
.ok_or_else(|| VoiceError::Stt("model path is not utf-8".into()))?,
WhisperContextParameters::default(),
)
.map_err(|e| VoiceError::Stt(format!("whisper ctx init: {e}")))?;
Ok(Self {
ctx: Arc::new(ctx),
language: config.language.clone(),
})
}
}
#[async_trait]
impl SttProvider for WhisperCppSttProvider {
async fn transcribe(&self, samples: &[f32], sample_rate: u32) -> Result<String> {
let samples_16k = if sample_rate == WHISPER_SAMPLE_RATE {
samples.to_vec()
} else {
resample_to_16k(samples, sample_rate)?
};
let ctx = Arc::clone(&self.ctx);
let lang = self.language.clone();
tokio::task::spawn_blocking(move || run_transcription(&ctx, &samples_16k, &lang))
.await
.map_err(|e| VoiceError::Stt(format!("whisper join: {e}")))?
}
}
fn run_transcription(ctx: &WhisperContext, samples: &[f32], language: &str) -> Result<String> {
let mut state = ctx
.create_state()
.map_err(|e| VoiceError::Stt(format!("whisper state: {e}")))?;
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
params.set_language(Some(language));
params.set_no_context(true);
params.set_suppress_nst(true);
params.set_no_speech_thold(0.6);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_special(false);
params.set_print_timestamps(false);
let n_threads = std::env::var("WHISPER_N_THREADS")
.ok()
.and_then(|s| s.trim().parse::<i32>().ok())
.filter(|&n| n >= 1)
.unwrap_or(1);
params.set_n_threads(n_threads);
state
.full(params, samples)
.map_err(|e| VoiceError::Stt(format!("whisper full: {e}")))?;
let n = state
.full_n_segments()
.map_err(|e| VoiceError::Stt(format!("whisper n_segments: {e}")))?;
let mut text = String::new();
for i in 0..n {
let seg = state
.full_get_segment_text(i)
.map_err(|e| VoiceError::Stt(format!("whisper get_segment_text: {e}")))?;
text.push_str(&seg);
}
Ok(text.trim().to_string())
}
fn resolve_model_path(model_id: &str) -> Result<PathBuf> {
let cache_root = dirs::home_dir()
.ok_or_else(|| VoiceError::Stt("no home dir".into()))?
.join(".tokhn")
.join("whisper");
std::fs::create_dir_all(&cache_root)
.map_err(|e| VoiceError::Stt(format!("create cache dir: {e}")))?;
Ok(cache_root.join(format!("ggml-{model_id}.bin")))
}
fn download_model(model_id: &str, dest: &Path) -> Result<()> {
let file_name = format!("ggml-{model_id}.bin");
let url = format!("{WHISPER_MODEL_BASE_URL}/{file_name}");
tracing::info!("[whisper-cpp] downloading {url}");
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(60 * 10))
.build()
.map_err(|e| VoiceError::Stt(format!("http client: {e}")))?;
let mut resp = client
.get(&url)
.send()
.map_err(|e| VoiceError::Stt(format!("download {url}: {e}")))?;
if !resp.status().is_success() {
return Err(VoiceError::Stt(format!(
"download {url} returned {}",
resp.status()
)));
}
let tmp = dest.with_extension("bin.partial");
{
let mut f =
std::fs::File::create(&tmp).map_err(|e| VoiceError::Stt(format!("create tmp: {e}")))?;
resp.copy_to(&mut f)
.map_err(|e| VoiceError::Stt(format!("write: {e}")))?;
}
std::fs::rename(&tmp, dest).map_err(|e| VoiceError::Stt(format!("rename: {e}")))?;
tracing::info!(
"[whisper-cpp] saved to {} ({} MB)",
dest.display(),
std::fs::metadata(dest)
.map(|m| m.len() / (1024 * 1024))
.unwrap_or(0),
);
Ok(())
}
fn resample_to_16k(samples: &[f32], source_rate: u32) -> Result<Vec<f32>> {
if samples.is_empty() {
return Ok(Vec::new());
}
let chunk_size = 1024;
let mut resampler = FftFixedIn::<f32>::new(
source_rate as usize,
WHISPER_SAMPLE_RATE as usize,
chunk_size,
1,
1,
)
.map_err(|e| VoiceError::Stt(format!("resampler init: {e}")))?;
let mut out = Vec::with_capacity(
samples.len() * WHISPER_SAMPLE_RATE as usize / source_rate as usize + chunk_size,
);
let mut cursor = 0usize;
while cursor + chunk_size <= samples.len() {
let input = vec![samples[cursor..cursor + chunk_size].to_vec()];
let output = resampler
.process(&input, None)
.map_err(|e| VoiceError::Stt(format!("resampler: {e}")))?;
out.extend_from_slice(&output[0]);
cursor += chunk_size;
}
if cursor < samples.len() {
let tail = samples[cursor..].to_vec();
let mut padded = tail;
padded.resize(chunk_size, 0.0);
let input = vec![padded];
let output = resampler
.process(&input, None)
.map_err(|e| VoiceError::Stt(format!("resampler tail: {e}")))?;
out.extend_from_slice(&output[0]);
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_model_path_uses_home_cache() {
let p = resolve_model_path("tiny").unwrap();
assert!(p.ends_with("ggml-tiny.bin"));
assert!(p.to_string_lossy().contains(".tokhn/whisper"));
}
#[test]
fn resample_empty_is_empty() {
let out = resample_to_16k(&[], 44_100).unwrap();
assert!(out.is_empty());
}
#[test]
fn resample_keeps_rate_when_equal() {
let input: Vec<f32> = (0..2048).map(|i| (i as f32 / 100.0).sin()).collect();
let out = resample_to_16k(&input, 16_000).unwrap();
let diff = (out.len() as i64 - input.len() as i64).abs();
assert!(
diff < 1200,
"out.len={} input.len={}",
out.len(),
input.len()
);
}
}