use std::path::{Path, PathBuf};
use ndarray::Array2;
use ort::session::Session;
#[cfg(feature = "coreml")]
use crate::inference::coreml::{CachedInputShape, SharedCoreMlModel};
use crate::inference::{ExecutionMode, ModelLoadError, ensure_ort_ready, with_execution_mode};
#[cfg(feature = "coreml")]
mod native;
#[cfg(feature = "coreml")]
mod parallel;
mod run;
mod tensor;
#[derive(Debug, thiserror::Error)]
pub enum SegmentationError {
#[error(transparent)]
Ort(#[from] ort::Error),
#[error("receiver disconnected")]
Disconnected(#[from] crossbeam_channel::SendError<Array2<f32>>),
#[error("{context}: {message}")]
Invariant {
context: &'static str,
message: String,
},
#[error("{worker} thread panicked")]
WorkerPanic {
worker: String,
},
}
const PRIMARY_BATCH_SIZE: usize = 32;
#[cfg(feature = "coreml")]
const LARGE_BATCH_SIZE: usize = 64;
pub struct SegmentationModel {
model_path: PathBuf,
mode: ExecutionMode,
session: Session,
primary_batched_session: Option<Session>,
#[cfg(feature = "coreml")]
native_session: Option<SharedCoreMlModel>,
#[cfg(feature = "coreml")]
native_batched_session: Option<SharedCoreMlModel>,
#[cfg(feature = "coreml")]
native_large_batched_session: Option<SharedCoreMlModel>,
#[cfg(feature = "coreml")]
cached_single_input_shape: CachedInputShape,
#[cfg(feature = "coreml")]
cached_batch_input_shape: CachedInputShape,
input_buffer: ndarray::Array3<f32>,
primary_batch_input_buffer: ndarray::Array3<f32>,
window_samples: usize,
step_samples: usize,
sample_rate: usize,
}
#[cfg(feature = "coreml")]
unsafe impl Send for SegmentationModel {}
impl SegmentationModel {
pub fn new(model_path: impl AsRef<Path>, step_duration: f32) -> Result<Self, ModelLoadError> {
Self::with_mode(model_path, step_duration, ExecutionMode::Cpu)
}
pub fn with_mode(
model_path: impl AsRef<Path>,
step_duration: f32,
mode: ExecutionMode,
) -> Result<Self, ModelLoadError> {
mode.validate()?;
ensure_ort_ready()?;
let model_path = model_path.as_ref();
let sample_rate = 16000;
let window_duration = 10.0;
let window_samples = (window_duration * sample_rate as f32) as usize;
let step_samples = (step_duration * sample_rate as f32) as usize;
#[cfg(feature = "coreml")]
if matches!(mode, ExecutionMode::CoreMl | ExecutionMode::CoreMlFast) {
Self::validate_native_coreml_assets(model_path, mode)?;
}
macro_rules! timed {
($expr:expr) => {{
let start = std::time::Instant::now();
let value = $expr;
(value, start.elapsed())
}};
}
let (session, session_elapsed) = timed!(Self::build_session(model_path, mode)?);
let (primary_batched_session, primary_batched_elapsed) = timed!(
batched_model_path(model_path, PRIMARY_BATCH_SIZE)
.filter(|path| path.exists())
.map(|path| Self::build_session(&path, mode))
.transpose()?
);
#[cfg(feature = "coreml")]
let (native_session, native_session_elapsed) =
timed!(Self::load_native_coreml(model_path, mode)?);
#[cfg(feature = "coreml")]
let (native_batched_session, native_batched_elapsed) =
timed!(Self::load_native_coreml_batched(model_path, mode)?);
#[cfg(feature = "coreml")]
let (native_large_batched_session, native_large_batched_elapsed) =
timed!(Self::load_native_coreml_large_batched(model_path, mode)?);
#[cfg(feature = "coreml")]
if matches!(mode, ExecutionMode::CoreMl | ExecutionMode::CoreMlFast) {
if native_session.is_none() {
return Err(ModelLoadError::MissingNativeAsset {
mode,
path: Self::resolve_coreml_path(model_path, mode)
.unwrap_or_else(|| model_path.to_path_buf()),
});
}
if native_batched_session.is_none() {
return Err(ModelLoadError::MissingNativeAsset {
mode,
path: Self::resolve_batched_coreml_path(model_path, mode, PRIMARY_BATCH_SIZE)
.unwrap_or_else(|| model_path.to_path_buf()),
});
}
if native_large_batched_session.is_none() {
return Err(ModelLoadError::MissingNativeAsset {
mode,
path: Self::resolve_batched_coreml_path(model_path, mode, LARGE_BATCH_SIZE)
.unwrap_or_else(|| model_path.to_path_buf()),
});
}
}
#[cfg(feature = "coreml")]
{
let total_ms = (session_elapsed
+ primary_batched_elapsed
+ native_session_elapsed
+ native_batched_elapsed
+ native_large_batched_elapsed)
.as_millis();
tracing::trace!(
ort_single_ms = session_elapsed.as_millis(),
ort_batched_ms = primary_batched_elapsed.as_millis(),
native_single_ms = native_session_elapsed.as_millis(),
native_b32_ms = native_batched_elapsed.as_millis(),
native_b64_ms = native_large_batched_elapsed.as_millis(),
total_ms,
"Segmentation model init",
);
}
#[cfg(not(feature = "coreml"))]
{
let total_ms = (session_elapsed + primary_batched_elapsed).as_millis();
tracing::trace!(
ort_single_ms = session_elapsed.as_millis(),
ort_batched_ms = primary_batched_elapsed.as_millis(),
total_ms,
"Segmentation model init",
);
}
Ok(Self {
model_path: model_path.to_path_buf(),
mode,
session,
primary_batched_session,
#[cfg(feature = "coreml")]
native_session,
#[cfg(feature = "coreml")]
native_batched_session,
#[cfg(feature = "coreml")]
native_large_batched_session,
#[cfg(feature = "coreml")]
cached_single_input_shape: CachedInputShape::new("input", &[1, 1, window_samples]),
#[cfg(feature = "coreml")]
cached_batch_input_shape: CachedInputShape::new(
"input",
&[PRIMARY_BATCH_SIZE, 1, window_samples],
),
input_buffer: ndarray::Array3::zeros((1, 1, window_samples)),
primary_batch_input_buffer: ndarray::Array3::zeros((
PRIMARY_BATCH_SIZE,
1,
window_samples,
)),
window_samples,
step_samples,
sample_rate,
})
}
fn build_session(model_path: &Path, mode: ExecutionMode) -> Result<Session, ort::Error> {
let builder = Session::builder()?
.with_independent_thread_pool()?
.with_intra_threads(Self::available_threads().min(6))?
.with_inter_threads(1)?
.with_memory_pattern(true)?;
let mut builder = with_execution_mode(builder, mode)?;
builder.commit_from_file(model_path)
}
fn available_threads() -> usize {
std::thread::available_parallelism()
.map(usize::from)
.unwrap_or(1)
}
pub fn sample_rate(&self) -> usize {
self.sample_rate
}
pub fn window_samples(&self) -> usize {
self.window_samples
}
pub fn step_samples(&self) -> usize {
self.step_samples
}
pub fn step_seconds(&self) -> f64 {
self.step_samples as f64 / self.sample_rate as f64
}
pub fn mode(&self) -> ExecutionMode {
self.mode
}
pub fn reset_session(&mut self) -> Result<(), ort::Error> {
self.session = Self::build_session(&self.model_path, self.mode)?;
self.primary_batched_session = batched_model_path(&self.model_path, PRIMARY_BATCH_SIZE)
.filter(|path| path.exists())
.map(|path| Self::build_session(&path, self.mode))
.transpose()?;
#[cfg(feature = "coreml")]
{
if matches!(self.mode, ExecutionMode::CoreMl | ExecutionMode::CoreMlFast) {
Self::validate_native_coreml_assets(&self.model_path, self.mode)
.map_err(|error| ort::Error::new(error.to_string()))?;
}
self.native_session = Self::load_native_coreml(&self.model_path, self.mode)
.map_err(|error| ort::Error::new(error.to_string()))?;
self.native_batched_session =
Self::load_native_coreml_batched(&self.model_path, self.mode)
.map_err(|error| ort::Error::new(error.to_string()))?;
self.native_large_batched_session =
Self::load_native_coreml_large_batched(&self.model_path, self.mode)
.map_err(|error| ort::Error::new(error.to_string()))?;
if self.native_session.is_none()
|| self.native_batched_session.is_none()
|| self.native_large_batched_session.is_none()
{
return Err(ort::Error::new(format!(
"{} native CoreML sessions failed to load",
self.mode
)));
}
}
Ok(())
}
}
fn batched_model_path(model_path: &Path, batch_size: usize) -> Option<PathBuf> {
let path = model_path;
let file_name = path.file_name()?.to_str()?;
let stem = file_name.strip_suffix(".onnx")?;
Some(path.with_file_name(format!("{stem}-b{batch_size}.onnx")))
}