use std::path::{Path, PathBuf};
#[cfg(feature = "online")]
use crate::inference::ExecutionMode;
const SEGMENTATION_ONNX: &str = "segmentation-3.0.onnx";
const EMBEDDING_ONNX: &str = "wespeaker-voxceleb-resnet34.onnx";
#[derive(Debug, Clone)]
pub struct ModelBundle {
segmentation_onnx: PathBuf,
embedding_onnx: PathBuf,
plda_dir: PathBuf,
}
impl ModelBundle {
pub fn from_dir(models_dir: impl Into<PathBuf>) -> Self {
let dir = models_dir.into();
Self {
segmentation_onnx: dir.join(SEGMENTATION_ONNX),
embedding_onnx: dir.join(EMBEDDING_ONNX),
plda_dir: dir,
}
}
#[cfg(feature = "online")]
pub fn from_pretrained(mode: ExecutionMode) -> Result<Self, hf_hub::api::sync::ApiError> {
let manager = ModelManager::new()?;
let dir = manager.ensure(mode)?;
Ok(Self::from_dir(dir))
}
pub fn segmentation_path(&self) -> &Path {
&self.segmentation_onnx
}
pub fn embedding_path(&self) -> &Path {
&self.embedding_onnx
}
pub fn plda_dir(&self) -> &Path {
&self.plda_dir
}
}
#[cfg(feature = "online")]
const HF_REPO: &str = "avencera/speakrs-models";
#[cfg(feature = "online")]
pub struct ModelManager {
repo: hf_hub::api::sync::ApiRepo,
}
#[cfg(feature = "online")]
impl ModelManager {
pub fn new() -> Result<Self, hf_hub::api::sync::ApiError> {
let api = hf_hub::api::sync::Api::new()?;
let repo = api.model(HF_REPO.to_string());
Ok(Self { repo })
}
pub fn with_cache_dir(cache_dir: PathBuf) -> Result<Self, hf_hub::api::sync::ApiError> {
let api =
hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(cache_dir)).build()?;
let repo = api.model(HF_REPO.to_string());
Ok(Self { repo })
}
pub fn get(&self, filename: impl AsRef<str>) -> Result<PathBuf, hf_hub::api::sync::ApiError> {
self.repo.get(filename.as_ref())
}
pub fn ensure(&self, mode: ExecutionMode) -> Result<PathBuf, hf_hub::api::sync::ApiError> {
let files = required_files(mode);
for file in &files {
self.repo.get(file)?;
}
let first = self.repo.get(&files[0])?;
let Some(parent) = first.parent() else {
return Ok(first);
};
Ok(parent.to_path_buf())
}
}
#[cfg(feature = "online")]
const PLDA_FILES: &[&str] = &[
"plda_lda.npy",
"plda_tr.npy",
"plda_mu.npy",
"plda_psi.npy",
"plda_mean1.npy",
"plda_mean2.npy",
"wespeaker-voxceleb-resnet34.min_num_samples.txt",
];
#[cfg(feature = "online")]
const ONNX_FILES: &[&str] = &[
"segmentation-3.0.onnx",
"wespeaker-voxceleb-resnet34.onnx",
"wespeaker-voxceleb-resnet34.onnx.data",
];
#[cfg(feature = "online")]
fn mlmodelc_files(name: &str) -> Vec<String> {
vec![
format!("{name}/model.mil"),
format!("{name}/coremldata.bin"),
format!("{name}/weights/weight.bin"),
format!("{name}/analytics/coremldata.bin"),
]
}
#[cfg(feature = "online")]
fn extend_mlmodelc_files(files: &mut Vec<String>, names: &[&str]) {
for name in names {
files.extend(mlmodelc_files(name));
}
}
#[cfg(feature = "online")]
const COREML_COMMON_MODEL_STEMS: &[&str] = &[
"segmentation-3.0.mlmodelc",
"segmentation-3.0-b32.mlmodelc",
"segmentation-3.0-b64.mlmodelc",
"wespeaker-fbank.mlmodelc",
"wespeaker-fbank-b32.mlmodelc",
"wespeaker-fbank-30s.mlmodelc",
"wespeaker-multimask-tail-b32.mlmodelc",
"wespeaker-voxceleb-resnet34-tail.mlmodelc",
"wespeaker-voxceleb-resnet34-tail-b3.mlmodelc",
"wespeaker-voxceleb-resnet34-tail-b32.mlmodelc",
];
#[cfg(feature = "online")]
const COREML_CHUNK_MODEL_STEMS: &[&str] = &[
"wespeaker-chunk-emb-s12-w22.mlmodelc",
"wespeaker-chunk-emb-s12-w37.mlmodelc",
"wespeaker-chunk-emb-s12-w53.mlmodelc",
"wespeaker-chunk-emb-s12-w84.mlmodelc",
"wespeaker-chunk-emb-s12-w116.mlmodelc",
];
#[cfg(feature = "online")]
const COREML_FAST_SEGMENTATION_MODEL_STEMS: &[&str] = &[
"segmentation-3.0-w8a16.mlmodelc",
"segmentation-3.0-b32-w8a16.mlmodelc",
"segmentation-3.0-b64-w8a16.mlmodelc",
];
#[cfg(feature = "online")]
const COREML_FAST_CHUNK_MODEL_STEMS: &[&str] = &[
"wespeaker-chunk-emb-s25-w11.mlmodelc",
"wespeaker-chunk-emb-s25-w16.mlmodelc",
"wespeaker-chunk-emb-s25-w21.mlmodelc",
"wespeaker-chunk-emb-s25-w26.mlmodelc",
"wespeaker-chunk-emb-s25-w36.mlmodelc",
"wespeaker-chunk-emb-s25-w46.mlmodelc",
"wespeaker-chunk-emb-s25-w56.mlmodelc",
];
#[cfg(feature = "online")]
fn required_files(mode: ExecutionMode) -> Vec<String> {
let mut files: Vec<String> = PLDA_FILES.iter().map(|s| s.to_string()).collect();
match mode {
ExecutionMode::Cpu => {
files.extend(ONNX_FILES.iter().map(|s| s.to_string()));
}
ExecutionMode::Cuda | ExecutionMode::CudaFast => {
files.extend(ONNX_FILES.iter().map(|s| s.to_string()));
files.push("wespeaker-fbank.onnx".to_string());
files.push("wespeaker-fbank-b32.onnx".to_string());
files.push("wespeaker-multimask-tail.onnx".to_string());
files.push("wespeaker-multimask-tail-b32.onnx".to_string());
files.push("segmentation-3.0-b32.onnx".to_string());
files.push("wespeaker-voxceleb-resnet34-b64.onnx".to_string());
}
ExecutionMode::CoreMl => {
files.push("segmentation-3.0.onnx".to_string());
files.push("wespeaker-voxceleb-resnet34.onnx".to_string());
files.push("wespeaker-voxceleb-resnet34.onnx.data".to_string());
files.push("segmentation-3.0-b32.onnx".to_string());
files.push("wespeaker-fbank.onnx".to_string());
files.push("wespeaker-fbank-b32.onnx".to_string());
files.push("wespeaker-voxceleb-resnet34-tail.onnx".to_string());
files.push("wespeaker-voxceleb-resnet34-tail-b3.onnx".to_string());
files.push("wespeaker-voxceleb-resnet34-tail-b32.onnx".to_string());
extend_mlmodelc_files(&mut files, COREML_COMMON_MODEL_STEMS);
extend_mlmodelc_files(&mut files, COREML_CHUNK_MODEL_STEMS);
}
ExecutionMode::CoreMlFast => {
files.push("segmentation-3.0.onnx".to_string());
files.push("wespeaker-voxceleb-resnet34.onnx".to_string());
files.push("wespeaker-voxceleb-resnet34.onnx.data".to_string());
files.push("segmentation-3.0-b32.onnx".to_string());
files.push("wespeaker-fbank.onnx".to_string());
files.push("wespeaker-fbank-b32.onnx".to_string());
files.push("wespeaker-voxceleb-resnet34-tail.onnx".to_string());
files.push("wespeaker-voxceleb-resnet34-tail-b3.onnx".to_string());
files.push("wespeaker-voxceleb-resnet34-tail-b32.onnx".to_string());
extend_mlmodelc_files(&mut files, COREML_COMMON_MODEL_STEMS);
extend_mlmodelc_files(&mut files, COREML_FAST_SEGMENTATION_MODEL_STEMS);
extend_mlmodelc_files(&mut files, COREML_FAST_CHUNK_MODEL_STEMS);
}
}
files
}
#[cfg(all(test, feature = "online"))]
mod tests {
use super::*;
#[test]
fn coreml_required_files_include_chunk_fast_path_assets() {
let files = required_files(ExecutionMode::CoreMl);
assert!(files.contains(&"segmentation-3.0-b64.mlmodelc/model.mil".to_string()));
assert!(files.contains(&"wespeaker-fbank-30s.mlmodelc/model.mil".to_string()));
assert!(files.contains(&"wespeaker-multimask-tail-b32.mlmodelc/model.mil".to_string()));
assert!(files.contains(&"wespeaker-chunk-emb-s12-w116.mlmodelc/model.mil".to_string()));
}
#[test]
fn coreml_fast_required_files_include_fast_assets() {
let files = required_files(ExecutionMode::CoreMlFast);
assert!(files.contains(&"segmentation-3.0-w8a16.mlmodelc/model.mil".to_string()));
assert!(files.contains(&"segmentation-3.0-b64-w8a16.mlmodelc/model.mil".to_string()));
assert!(files.contains(&"wespeaker-chunk-emb-s25-w56.mlmodelc/model.mil".to_string()));
}
}