use std::path::Path;
use std::sync::Arc;
use async_trait::async_trait;
use sherpa_onnx::{
OfflineRecognizer, OfflineRecognizerConfig, OfflineSenseVoiceModelConfig,
OfflineWhisperModelConfig,
};
use crate::config::{SherpaSttConfig, SherpaSttKind};
use crate::voice::providers::sherpa_download::{
self, BundleCategory, ensure_bundle, path_string, pick_file,
};
use crate::voice::{PIPELINE_SAMPLE_RATE, SttProvider};
pub(crate) struct SherpaOnnxStt {
name: String,
recognizer: Arc<OfflineRecognizer>,
language: Option<String>,
kind: SherpaSttKind,
}
unsafe impl Send for SherpaOnnxStt {}
unsafe impl Sync for SherpaOnnxStt {}
impl SherpaOnnxStt {
pub(crate) fn new(name: String, cfg: SherpaSttConfig) -> anyhow::Result<Self> {
let dir = ensure_bundle(
cfg.model_dir.as_deref(),
cfg.model.as_deref(),
BundleCategory::Asr,
)?;
let recognizer = build_recognizer(&dir, &cfg)
.map_err(|e| anyhow::anyhow!("stt_provider '{name}': {e:#}"))?;
Ok(Self {
name,
recognizer: Arc::new(recognizer),
language: cfg.language.clone(),
kind: cfg.kind,
})
}
}
fn build_recognizer(dir: &Path, cfg: &SherpaSttConfig) -> anyhow::Result<OfflineRecognizer> {
let mut rec_cfg = OfflineRecognizerConfig::default();
rec_cfg.model_config.num_threads = cfg.num_threads;
rec_cfg.model_config.provider = Some(cfg.provider.clone());
match cfg.kind {
SherpaSttKind::SenseVoice => {
let model = pick_file(dir, &["model.int8.onnx", "model.onnx"])?;
let tokens = pick_file(dir, &["tokens.txt"])?;
rec_cfg.model_config.sense_voice = OfflineSenseVoiceModelConfig {
model: Some(path_string(&model)),
language: cfg.language.clone().or_else(|| Some("auto".into())),
use_itn: true,
};
rec_cfg.model_config.tokens = Some(path_string(&tokens));
}
SherpaSttKind::Whisper => {
let encoder = find_by_suffix(dir, "-encoder.int8.onnx")
.or_else(|_| find_by_suffix(dir, "-encoder.onnx"))?;
let decoder = find_by_suffix(dir, "-decoder.int8.onnx")
.or_else(|_| find_by_suffix(dir, "-decoder.onnx"))?;
let tokens = find_by_suffix(dir, "-tokens.txt")?;
rec_cfg.model_config.whisper = OfflineWhisperModelConfig {
encoder: Some(path_string(&encoder)),
decoder: Some(path_string(&decoder)),
language: cfg.language.clone(),
task: Some("transcribe".into()),
tail_paddings: 0,
enable_token_timestamps: false,
enable_segment_timestamps: false,
};
rec_cfg.model_config.tokens = Some(path_string(&tokens));
}
}
OfflineRecognizer::create(&rec_cfg)
.ok_or_else(|| anyhow::anyhow!("OfflineRecognizer::create failed"))
}
fn find_by_suffix(dir: &Path, suffix: &str) -> anyhow::Result<std::path::PathBuf> {
for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let name = entry.file_name();
if name.to_string_lossy().ends_with(suffix) {
return Ok(entry.path());
}
}
anyhow::bail!("no file ending with '{suffix}' in {}", dir.display())
}
#[async_trait]
impl SttProvider for SherpaOnnxStt {
fn name(&self) -> &str {
&self.name
}
fn sample_rate(&self) -> u32 {
PIPELINE_SAMPLE_RATE
}
async fn transcribe(&self, pcm: &[i16], language: Option<&str>) -> anyhow::Result<String> {
let samples: Vec<f32> = pcm.iter().map(|s| *s as f32 / 32768.0).collect();
let recognizer = Arc::clone(&self.recognizer);
let configured_lang = self.language.clone();
let kind = self.kind;
let override_lang = language.map(String::from);
tokio::task::spawn_blocking(move || -> anyhow::Result<String> {
let stream = recognizer.create_stream();
if matches!(kind, SherpaSttKind::SenseVoice) {
let lang = override_lang
.as_deref()
.or(configured_lang.as_deref())
.unwrap_or("auto");
stream.set_option("language", lang);
}
stream.accept_waveform(PIPELINE_SAMPLE_RATE as i32, &samples);
recognizer.decode(&stream);
let result = stream
.get_result()
.ok_or_else(|| anyhow::anyhow!("recognizer returned no result"))?;
Ok(result.text.trim().to_string())
})
.await
.map_err(|e| anyhow::anyhow!("sherpa STT blocking task panicked: {e}"))?
}
}
#[allow(unused_imports)]
pub(crate) use sherpa_download::cache_dir;