use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::ptr::NonNull;
use block2::RcBlock;
use objc2::AnyThread;
use objc2::rc::{Retained, autoreleasepool};
use objc2::runtime::{AnyObject, ProtocolObject};
use objc2_core_ml::{
MLComputeUnits, MLDictionaryFeatureProvider, MLFeatureProvider, MLFeatureValue, MLModel,
MLModelConfiguration, MLMultiArray,
};
use objc2_foundation::{NSCopying, NSMutableDictionary, NSString, NSURL};
use crate::coreml::array::{
extract_output, multi_array_f32, multi_array_f32_cached, multi_array_i32,
multi_array_i32_cached,
};
use crate::coreml::{CachedCoreMlInput, CoreMlInput, CoreMlTensor};
use crate::error::TranscriptionError;
pub(super) fn load_model(
path: &Path,
compute_units: MLComputeUnits,
) -> Result<Retained<MLModel>, TranscriptionError> {
let path_str = NSString::from_str(&path.to_string_lossy());
let url = NSURL::fileURLWithPath_isDirectory(&path_str, true);
unsafe {
let config = MLModelConfiguration::new();
config.setComputeUnits(compute_units);
MLModel::modelWithContentsOfURL_configuration_error(&url, &config)
}
.map_err(|error| TranscriptionError::CoreMl(format!("failed to load model: {error}")))
}
pub(super) fn predict(
model: &MLModel,
input_dict: &NSMutableDictionary<NSString, AnyObject>,
deallocator: &RcBlock<dyn Fn(NonNull<c_void>)>,
inputs: &[CoreMlInput<'_>],
output_names: &[&str],
) -> Result<HashMap<String, CoreMlTensor>, TranscriptionError> {
autoreleasepool(|_| {
let mut arrays = Vec::with_capacity(inputs.len());
for input in inputs {
let (name, array) = match input {
CoreMlInput::F32 {
name,
values,
shape,
} => (*name, multi_array_f32(values, shape, deallocator)?),
CoreMlInput::I32 {
name,
values,
shape,
} => (*name, multi_array_i32(values, shape, deallocator)?),
};
let key = NSString::from_str(name);
let key_copy: &ProtocolObject<dyn NSCopying> = ProtocolObject::from_ref(&*key);
insert_input_feature(input_dict, key_copy, &array);
arrays.push(array);
}
let output_provider = predict_features(
model,
ProtocolObject::from_ref(&*build_feature_provider(input_dict)?),
)?;
let mut outputs = HashMap::with_capacity(output_names.len());
for output_name in output_names {
let key = NSString::from_str(output_name);
let array = output_multi_array(&output_provider, &key, output_name)?;
let (data, shape) = extract_output(&array)?;
outputs.insert((*output_name).to_owned(), CoreMlTensor { data, shape });
}
Ok(outputs)
})
}
pub(super) fn insert_cached_input(
input_dict: &NSMutableDictionary<NSString, AnyObject>,
deallocator: &RcBlock<dyn Fn(NonNull<c_void>)>,
input: CachedCoreMlInput<'_>,
) -> Result<Retained<MLMultiArray>, TranscriptionError> {
let (cached, array) = match input {
CachedCoreMlInput::F32 { cached, values } => {
debug_assert_eq!(values.len(), cached.total_elements);
(cached, multi_array_f32_cached(values, cached, deallocator)?)
}
CachedCoreMlInput::I32 { cached, values } => {
debug_assert_eq!(values.len(), cached.total_elements);
(cached, multi_array_i32_cached(values, cached, deallocator)?)
}
};
let key_copy: &ProtocolObject<dyn NSCopying> = ProtocolObject::from_ref(&*cached.name);
insert_input_feature(input_dict, key_copy, &array);
Ok(array)
}
pub(super) fn insert_input_feature(
input_dict: &NSMutableDictionary<NSString, AnyObject>,
key_copy: &ProtocolObject<dyn NSCopying>,
multi_array: &MLMultiArray,
) {
unsafe {
let feature_value = MLFeatureValue::featureValueWithMultiArray(multi_array);
input_dict.setObject_forKey(feature_value_as_any_object(&feature_value), key_copy);
}
}
pub(super) fn build_feature_provider(
input_dict: &NSMutableDictionary<NSString, AnyObject>,
) -> Result<Retained<MLDictionaryFeatureProvider>, TranscriptionError> {
unsafe {
MLDictionaryFeatureProvider::initWithDictionary_error(
MLDictionaryFeatureProvider::alloc(),
input_dict,
)
}
.map_err(|error| TranscriptionError::CoreMl(format!("feature provider failed: {error}")))
}
pub(super) fn predict_features(
model: &MLModel,
input_ref: &ProtocolObject<dyn MLFeatureProvider>,
) -> Result<Retained<ProtocolObject<dyn MLFeatureProvider>>, TranscriptionError> {
unsafe { model.predictionFromFeatures_error(input_ref) }
.map_err(|error| TranscriptionError::CoreMl(format!("prediction failed: {error}")))
}
pub(super) fn output_multi_array(
output: &ProtocolObject<dyn MLFeatureProvider>,
output_key: &NSString,
output_name: &str,
) -> Result<Retained<MLMultiArray>, TranscriptionError> {
let feature = unsafe { output.featureValueForName(output_key) }.ok_or_else(|| {
TranscriptionError::CoreMl(format!("missing CoreML output `{output_name}`"))
})?;
unsafe { feature.multiArrayValue() }.ok_or_else(|| {
TranscriptionError::CoreMl(format!("CoreML output `{output_name}` was not an array"))
})
}
fn feature_value_as_any_object(feature_value: &MLFeatureValue) -> &AnyObject {
unsafe { &*(feature_value as *const MLFeatureValue).cast::<AnyObject>() }
}