use crate::error::{Error, ErrorKind, Result};
use crate::{ComputeUnits, Model};
use std::path::PathBuf;
pub enum ModelHandle {
Loaded {
model: Model,
compute_units: ComputeUnits,
},
Unloaded {
path: PathBuf,
compute_units: ComputeUnits,
},
}
impl ModelHandle {
pub fn load(
path: impl AsRef<std::path::Path>,
compute_units: ComputeUnits,
) -> Result<Self> {
let model = Model::load(&path, compute_units)?;
Ok(Self::Loaded {
model,
compute_units,
})
}
pub fn from_model(model: Model, compute_units: ComputeUnits) -> Self {
Self::Loaded {
model,
compute_units,
}
}
pub fn is_loaded(&self) -> bool {
matches!(self, Self::Loaded { .. })
}
pub fn path(&self) -> &std::path::Path {
match self {
Self::Loaded { model, .. } => model.path(),
Self::Unloaded { path, .. } => path,
}
}
pub fn compute_units(&self) -> ComputeUnits {
match self {
Self::Loaded { compute_units, .. } | Self::Unloaded { compute_units, .. } => {
*compute_units
}
}
}
pub fn model(&self) -> Result<&Model> {
match self {
Self::Loaded { model, .. } => Ok(model),
Self::Unloaded { .. } => Err(Error::new(
ErrorKind::ModelLoad,
"model is unloaded; call reload() first",
)),
}
}
pub fn unload(self) -> Result<Self> {
match self {
Self::Loaded {
model,
compute_units,
} => {
let path = model.path().to_path_buf();
drop(model);
Ok(Self::Unloaded {
path,
compute_units,
})
}
Self::Unloaded { .. } => Err(Error::new(
ErrorKind::ModelLoad,
"model is already unloaded",
)),
}
}
pub fn reload(self) -> Result<Self> {
match self {
Self::Unloaded {
path,
compute_units,
} => {
let model = Model::load(&path, compute_units)?;
Ok(Self::Loaded {
model,
compute_units,
})
}
Self::Loaded { .. } => Err(Error::new(
ErrorKind::ModelLoad,
"model is already loaded",
)),
}
}
pub fn predict(
&self,
inputs: &[(&str, &dyn crate::tensor::AsMultiArray)],
) -> Result<crate::Prediction> {
self.model()?.predict(inputs)
}
pub fn inputs(&self) -> Result<Vec<crate::FeatureDescription>> {
Ok(self.model()?.inputs())
}
pub fn outputs(&self) -> Result<Vec<crate::FeatureDescription>> {
Ok(self.model()?.outputs())
}
pub fn metadata(&self) -> Result<crate::ModelMetadata> {
Ok(self.model()?.metadata())
}
}
impl std::fmt::Debug for ModelHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Loaded {
model,
compute_units,
} => f
.debug_struct("ModelHandle")
.field("state", &"Loaded")
.field("path", &model.path())
.field("compute_units", compute_units)
.finish(),
Self::Unloaded {
path,
compute_units,
} => f
.debug_struct("ModelHandle")
.field("state", &"Unloaded")
.field("path", path)
.field("compute_units", compute_units)
.finish(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unloaded_handle_is_not_loaded() {
let handle = ModelHandle::Unloaded {
path: PathBuf::from("/test.mlmodelc"),
compute_units: ComputeUnits::All,
};
assert!(!handle.is_loaded());
}
#[test]
fn unloaded_handle_preserves_path() {
let handle = ModelHandle::Unloaded {
path: PathBuf::from("/models/my_model.mlmodelc"),
compute_units: ComputeUnits::CpuAndGpu,
};
assert_eq!(
handle.path(),
std::path::Path::new("/models/my_model.mlmodelc")
);
}
#[test]
fn unloaded_handle_preserves_compute_units() {
let handle = ModelHandle::Unloaded {
path: PathBuf::from("/test.mlmodelc"),
compute_units: ComputeUnits::CpuAndNeuralEngine,
};
assert_eq!(handle.compute_units(), ComputeUnits::CpuAndNeuralEngine);
}
#[test]
fn unloaded_handle_rejects_model_access() {
let handle = ModelHandle::Unloaded {
path: PathBuf::from("/test.mlmodelc"),
compute_units: ComputeUnits::All,
};
let err = handle.model().unwrap_err();
assert_eq!(err.kind(), &ErrorKind::ModelLoad);
assert!(err.message().contains("unloaded"));
}
#[test]
fn unloaded_handle_rejects_double_unload() {
let handle = ModelHandle::Unloaded {
path: PathBuf::from("/test.mlmodelc"),
compute_units: ComputeUnits::All,
};
let err = handle.unload().unwrap_err();
assert_eq!(err.kind(), &ErrorKind::ModelLoad);
assert!(err.message().contains("already unloaded"));
}
#[test]
fn load_nonexistent_model_fails() {
let result = ModelHandle::load("/nonexistent.mlmodelc", ComputeUnits::All);
assert!(result.is_err());
}
#[test]
fn debug_format_unloaded() {
let handle = ModelHandle::Unloaded {
path: PathBuf::from("/test.mlmodelc"),
compute_units: ComputeUnits::All,
};
let debug = format!("{:?}", handle);
assert!(debug.contains("Unloaded"));
assert!(debug.contains("/test.mlmodelc"));
}
#[test]
fn unloaded_handle_rejects_inputs() {
let handle = ModelHandle::Unloaded {
path: PathBuf::from("/test.mlmodelc"),
compute_units: ComputeUnits::All,
};
assert!(handle.inputs().is_err());
}
#[test]
fn unloaded_handle_rejects_outputs() {
let handle = ModelHandle::Unloaded {
path: PathBuf::from("/test.mlmodelc"),
compute_units: ComputeUnits::All,
};
assert!(handle.outputs().is_err());
}
#[test]
fn unloaded_handle_rejects_metadata() {
let handle = ModelHandle::Unloaded {
path: PathBuf::from("/test.mlmodelc"),
compute_units: ComputeUnits::All,
};
assert!(handle.metadata().is_err());
}
#[cfg(target_vendor = "apple")]
mod apple_tests {
use super::*;
#[test]
fn load_unload_reload_cycle() {
let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/test_linear.mlmodelc");
if !model_path.exists() {
return;
}
let handle = ModelHandle::load(&model_path, ComputeUnits::All).unwrap();
assert!(handle.is_loaded());
assert!(handle.model().is_ok());
let handle = handle.unload().unwrap();
assert!(!handle.is_loaded());
assert!(handle.model().is_err());
assert_eq!(handle.path(), model_path);
let handle = handle.reload().unwrap();
assert!(handle.is_loaded());
assert!(handle.model().is_ok());
}
#[test]
fn loaded_handle_rejects_double_reload() {
let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/test_linear.mlmodelc");
if !model_path.exists() {
return;
}
let handle = ModelHandle::load(&model_path, ComputeUnits::All).unwrap();
let err = handle.reload().unwrap_err();
assert_eq!(err.kind(), &ErrorKind::ModelLoad);
assert!(err.message().contains("already loaded"));
}
#[test]
fn from_model_wraps_existing() {
let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/test_linear.mlmodelc");
if !model_path.exists() {
return;
}
let model = Model::load(&model_path, ComputeUnits::All).unwrap();
let handle = ModelHandle::from_model(model, ComputeUnits::All);
assert!(handle.is_loaded());
assert_eq!(handle.compute_units(), ComputeUnits::All);
}
#[test]
fn debug_format_loaded() {
let model_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/test_linear.mlmodelc");
if !model_path.exists() {
return;
}
let handle = ModelHandle::load(&model_path, ComputeUnits::All).unwrap();
let debug = format!("{:?}", handle);
assert!(debug.contains("Loaded"));
}
}
}