use std::path::Path;
use objc2::AnyThread;
use objc2::rc::Retained;
use objc2::runtime::{AnyObject, ProtocolObject};
use objc2_core_ml::{
MLComputeUnits, MLDictionaryFeatureProvider, MLFeatureProvider, MLFeatureValue, MLModel,
MLModelConfiguration, MLMultiArray,
};
use objc2_foundation::{NSCopying, NSMutableDictionary, NSString, NSURL};
use super::{CoreMlError, GpuPrecision};
pub(super) fn load_model(
path: &Path,
compute_units: MLComputeUnits,
gpu_precision: GpuPrecision,
) -> Result<Retained<MLModel>, CoreMlError> {
let path_str = NSString::from_str(&path.to_string_lossy());
let url = NSURL::fileURLWithPath_isDirectory(&path_str, true);
let low_precision = matches!(gpu_precision, GpuPrecision::Low);
unsafe {
let config = MLModelConfiguration::new();
config.setComputeUnits(compute_units);
config.setAllowLowPrecisionAccumulationOnGPU(low_precision);
MLModel::modelWithContentsOfURL_configuration_error(&url, &config)
}
.map_err(|e| CoreMlError::LoadFailed(format!("{e}")))
}
pub(super) fn insert_input_feature(
input_dict: &NSMutableDictionary<NSString, AnyObject>,
key_copy: &ProtocolObject<dyn NSCopying>,
multi_array: &MLMultiArray,
) {
unsafe {
let feature_value = MLFeatureValue::featureValueWithMultiArray(multi_array);
input_dict.setObject_forKey(feature_value_as_any_object(&feature_value), key_copy);
}
}
pub(super) fn build_feature_provider(
input_dict: &NSMutableDictionary<NSString, AnyObject>,
) -> Result<Retained<MLDictionaryFeatureProvider>, CoreMlError> {
unsafe {
MLDictionaryFeatureProvider::initWithDictionary_error(
MLDictionaryFeatureProvider::alloc(),
input_dict,
)
}
.map_err(|e| CoreMlError::PredictionFailed(format!("feature provider: {e}")))
}
pub(super) fn predict_output(
model: &MLModel,
input_ref: &ProtocolObject<dyn MLFeatureProvider>,
output_key: &NSString,
output_name: &str,
) -> Result<Retained<MLMultiArray>, CoreMlError> {
let output = unsafe { model.predictionFromFeatures_error(input_ref) }
.map_err(|e| CoreMlError::PredictionFailed(format!("{e}")))?;
output_multi_array(&output, output_key, output_name)
}
pub(super) fn output_multi_array(
output: &ProtocolObject<dyn MLFeatureProvider>,
output_key: &NSString,
output_name: &str,
) -> Result<Retained<MLMultiArray>, CoreMlError> {
let output_value = unsafe { output.featureValueForName(output_key) }
.ok_or_else(|| CoreMlError::OutputNotFound(output_name.to_owned()))?;
unsafe { output_value.multiArrayValue() }
.ok_or_else(|| CoreMlError::OutputNotFound(output_name.to_owned()))
}
fn feature_value_as_any_object(feature_value: &MLFeatureValue) -> &AnyObject {
unsafe { &*(feature_value as *const MLFeatureValue).cast::<AnyObject>() }
}