use crate::{
audio::{
load::{LoadedAudioModel, base_load_model},
stt::model::Model as SttModel,
},
error::Result,
};
pub const STT_SAMPLE_RATE: u32 = 16_000;
pub const MODEL_REMAPPING: &[(&str, &str)] = &[
("cohere_asr", "cohere_asr"),
("fireredasr2", "fireredasr2"),
("glm", "glmasr"),
("sensevoice", "sensevoice"),
("voxtral", "voxtral"),
("voxtral_realtime", "voxtral_realtime"),
("vibevoice", "vibevoice_asr"),
("qwen3_asr", "qwen3_asr"),
("canary", "canary"),
("moonshine", "moonshine"),
("mms", "mms"),
("granite_speech", "granite_speech"),
("qwen2_audio", "qwen2_audio"),
];
pub fn load_model<F>(path: &str, constructor: F) -> Result<Box<dyn SttModel>>
where
F: FnOnce(LoadedAudioModel) -> Result<Box<dyn SttModel>>,
{
let bundle = base_load_model(path)?;
constructor(bundle)
}
pub fn load<F>(path: &str, constructor: F) -> Result<Box<dyn SttModel>>
where
F: FnOnce(LoadedAudioModel) -> Result<Box<dyn SttModel>>,
{
load_model(path, constructor)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
array::Array,
audio::stt::model::MelConfig,
error::{Error, InvariantViolationPayload},
lm::{cache::KvCache, model::Model as LmModel},
};
use std::{fs, path::PathBuf};
struct FakeStt;
impl LmModel for FakeStt {
fn forward(&self, _tokens: &Array, _cache: &mut [Box<dyn KvCache>]) -> Result<Array> {
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"FakeStt::forward",
"test stub — unreachable in this test",
)))
}
}
impl SttModel for FakeStt {
fn encode_audio(&self, _mel: &Array) -> Result<Array> {
Array::from_slice::<f32>(&[0.0_f32; 4], &(1, 4))
}
fn mel_config(&self) -> MelConfig {
MelConfig::whisper_default()
}
fn bos_token(&self) -> u32 {
50258 }
fn eos_token(&self) -> u32 {
50257 }
}
fn temp_dir(name: &str) -> PathBuf {
let dir = std::env::temp_dir().join(format!(
"mlxrs_audio_stt_load_{}_{}",
std::process::id(),
name
));
let _ = fs::remove_dir_all(&dir);
fs::create_dir_all(&dir).unwrap();
dir
}
#[test]
fn load_stt_constructs_via_factory() {
let dir = temp_dir("constructs_via_factory");
let body = r#"{ "model_type": "whisper", "n_mels": 80 }"#;
fs::write(dir.join("config.json"), body).unwrap();
let captured: std::cell::RefCell<Option<PathBuf>> = std::cell::RefCell::new(None);
let model = load(&dir.to_string_lossy(), |bundle| {
*captured.borrow_mut() = Some(bundle.model_path().to_path_buf());
Ok(Box::new(FakeStt))
})
.expect("load constructs via the supplied factory");
assert_eq!(captured.into_inner().unwrap(), dir);
assert_eq!(model.bos_token(), 50258);
assert_eq!(model.eos_token(), 50257);
assert_eq!(model.mel_config().sample_rate(), 16_000);
}
}