#![cfg(feature = "coreml")]
use std::path::Path;
use std::sync::Arc;
use objc2_core_ml::MLComputeUnits;
use crate::inference::ExecutionMode;
use crate::inference::coreml::{CachedInputShape, CoreMlModel, GpuPrecision, SharedCoreMlModel};
use super::super::{
ChunkEmbeddingSession, ChunkSessionSpec, EmbeddingModel, FBANK_FEATURES, MASK_FRAMES,
fp32_coreml_path, split_fbank_batched_model_path, split_fbank_model_path,
split_tail_model_path,
};
fn load_shared_or_warn(
path: &Path,
compute_units: MLComputeUnits,
error_context: &str,
) -> Option<SharedCoreMlModel> {
match SharedCoreMlModel::load(path, compute_units, "output", GpuPrecision::Low) {
Ok(model) => Some(model),
Err(e) => {
tracing::warn!("{error_context}: {e}");
None
}
}
}
impl EmbeddingModel {
pub(in crate::inference::embedding) fn load_native_tail(
model_path: &Path,
mode: ExecutionMode,
batch_size: usize,
) -> Option<CoreMlModel> {
let compute_units = match mode {
ExecutionMode::CoreMl | ExecutionMode::CoreMlFast => {
CoreMlModel::default_compute_units()
}
_ => return None,
};
let tail_onnx = split_tail_model_path(model_path, batch_size);
let coreml_path = fp32_coreml_path(&tail_onnx);
if !coreml_path.exists() {
if batch_size == 1 {
tracing::warn!(
path = %coreml_path.display(),
"Native CoreML tail model not found, falling back to ORT CPU",
);
}
return None;
}
match CoreMlModel::load(&coreml_path, compute_units, "output", GpuPrecision::Low) {
Ok(model) => Some(model),
Err(e) => {
tracing::warn!(batch_size, "Failed to load native CoreML tail: {e}");
None
}
}
}
pub(in crate::inference::embedding) fn has_native_tail_model(
model_path: &Path,
mode: ExecutionMode,
batch_size: usize,
) -> bool {
match mode {
ExecutionMode::CoreMl | ExecutionMode::CoreMlFast => {}
_ => return false,
}
let tail_onnx = split_tail_model_path(model_path, batch_size);
fp32_coreml_path(&tail_onnx).exists()
}
pub(in crate::inference::embedding) fn load_native_fbank(
model_path: &Path,
mode: ExecutionMode,
batch_size: usize,
) -> Option<SharedCoreMlModel> {
if !matches!(mode, ExecutionMode::CoreMl | ExecutionMode::CoreMlFast) {
return None;
}
let fbank_onnx = if batch_size == 1 {
split_fbank_model_path(model_path)
} else {
split_fbank_batched_model_path(model_path)
};
let coreml_path = fp32_coreml_path(&fbank_onnx);
if !coreml_path.exists() {
return None;
}
load_shared_or_warn(
&coreml_path,
CoreMlModel::default_compute_units(),
&format!("Failed to load native CoreML fbank (batch_size={batch_size})"),
)
}
pub(in crate::inference::embedding) fn has_native_fbank_model(
model_path: &Path,
mode: ExecutionMode,
batch_size: usize,
) -> bool {
if !matches!(mode, ExecutionMode::CoreMl | ExecutionMode::CoreMlFast) {
return false;
}
let fbank_onnx = if batch_size == 1 {
split_fbank_model_path(model_path)
} else {
split_fbank_batched_model_path(model_path)
};
fp32_coreml_path(&fbank_onnx).exists()
}
pub(in crate::inference::embedding) fn load_native_fbank_30s(
model_path: &Path,
mode: ExecutionMode,
) -> Option<SharedCoreMlModel> {
if !matches!(mode, ExecutionMode::CoreMl | ExecutionMode::CoreMlFast) {
return None;
}
let coreml_path = model_path.with_file_name("wespeaker-fbank-30s.mlmodelc");
if !coreml_path.exists() {
return None;
}
let model = load_shared_or_warn(
&coreml_path,
MLComputeUnits::CPUAndNeuralEngine,
"Failed to load 30s fbank model",
)?;
tracing::info!("Loaded 30s fbank model (CPUAndNeuralEngine)");
Some(model)
}
pub(in crate::inference::embedding) fn load_native_multi_mask(
model_path: &Path,
mode: ExecutionMode,
) -> Option<SharedCoreMlModel> {
if !matches!(mode, ExecutionMode::CoreMl | ExecutionMode::CoreMlFast) {
return None;
}
let onnx_path = model_path.with_file_name("wespeaker-multimask-tail-b32.onnx");
let coreml_path = fp32_coreml_path(&onnx_path);
if !coreml_path.exists() {
return None;
}
load_shared_or_warn(
&coreml_path,
CoreMlModel::default_compute_units(),
"Failed to load native CoreML multi-mask",
)
}
pub(in crate::inference::embedding) fn has_native_multi_mask_model(
model_path: &Path,
mode: ExecutionMode,
) -> bool {
if !matches!(mode, ExecutionMode::CoreMl | ExecutionMode::CoreMlFast) {
return false;
}
let onnx_path = model_path.with_file_name("wespeaker-multimask-tail-b32.onnx");
fp32_coreml_path(&onnx_path).exists()
}
fn chunk_session_config(mode: ExecutionMode) -> &'static [(usize, usize, usize, usize)] {
match mode {
ExecutionMode::CoreMlFast => &[
(25, 11, 3000, 33),
(25, 16, 4000, 48),
(25, 21, 5000, 63),
(25, 26, 6000, 78),
(25, 36, 8000, 108),
(25, 46, 10000, 138),
(25, 56, 12000, 168),
],
_ => &[
(12, 22, 3016, 66),
(12, 37, 4456, 111),
(12, 53, 5992, 159),
(12, 84, 8968, 252),
(12, 116, 12040, 348),
],
}
}
pub(in crate::inference::embedding) fn chunk_session_specs(
model_path: &Path,
mode: ExecutionMode,
) -> Vec<ChunkSessionSpec> {
if !matches!(mode, ExecutionMode::CoreMl | ExecutionMode::CoreMlFast) {
return Vec::new();
}
Self::chunk_session_config(mode)
.iter()
.filter_map(|&(step_resnet, num_windows, fbank_frames, num_masks)| {
let stem = format!("wespeaker-chunk-emb-s{step_resnet}-w{num_windows}");
let w8a16_path = model_path.with_file_name(format!("{stem}-w8a16.mlmodelc"));
let fp32_path = model_path.with_file_name(format!("{stem}.mlmodelc"));
let coreml_path = if fp32_path.exists() {
fp32_path
} else if w8a16_path.exists() {
w8a16_path
} else {
return None;
};
Some(ChunkSessionSpec {
coreml_path,
num_windows,
fbank_frames,
num_masks,
})
})
.collect()
}
pub(in crate::inference::embedding) fn load_chunk_session(
spec: &ChunkSessionSpec,
compute_units: MLComputeUnits,
) -> Result<ChunkEmbeddingSession, crate::inference::coreml::CoreMlError> {
let model = SharedCoreMlModel::load(
&spec.coreml_path,
compute_units,
"output",
GpuPrecision::Low,
)?;
Ok(ChunkEmbeddingSession {
model: Arc::new(model),
num_windows: spec.num_windows,
fbank_frames: spec.fbank_frames,
num_masks: spec.num_masks,
cached_fbank_shape: Arc::new(CachedInputShape::new(
"fbank",
&[1, spec.fbank_frames, FBANK_FEATURES],
)),
cached_masks_shape: Arc::new(CachedInputShape::new(
"masks",
&[spec.num_masks, MASK_FRAMES],
)),
})
}
}