1use crate::async_bridge::{self, CompletionFuture};
7use crate::error::{Error, ErrorKind, Result};
8use crate::{ComputeUnits, Prediction};
9
10#[cfg(target_vendor = "apple")]
11impl crate::Model {
12 pub fn load_async(
19 path: impl AsRef<std::path::Path>,
20 compute_units: ComputeUnits,
21 ) -> Result<CompletionFuture<Self>> {
22 use objc2_core_ml::{MLComputeUnits, MLModel, MLModelConfiguration};
23
24 let path = path.as_ref();
25 let path_str = path.to_str().ok_or_else(|| {
26 Error::new(ErrorKind::ModelLoad, "path contains non-UTF8 characters")
27 })?;
28
29 let url =
30 objc2_foundation::NSURL::fileURLWithPath(&crate::ffi::str_to_nsstring(path_str));
31 let config = unsafe { MLModelConfiguration::new() };
32 let ml_units = match compute_units {
33 ComputeUnits::CpuOnly => MLComputeUnits(1),
34 ComputeUnits::CpuAndGpu => MLComputeUnits::CPUAndGPU,
35 ComputeUnits::CpuAndNeuralEngine => MLComputeUnits(2),
36 ComputeUnits::All => MLComputeUnits::All,
37 };
38 unsafe { config.setComputeUnits(ml_units) };
39
40 let (sender, future) = async_bridge::completion_channel();
41 let sender_cell = std::cell::Cell::new(Some(sender));
42 let owned_path = path.to_path_buf();
43
44 let block = block2::RcBlock::new(
45 move |model_ptr: *mut MLModel, error_ptr: *mut objc2_foundation::NSError| {
46 let sender = sender_cell
47 .take()
48 .expect("completion handler called more than once");
49 if model_ptr.is_null() {
50 if error_ptr.is_null() {
51 sender.send(Err(Error::new(
52 ErrorKind::ModelLoad,
53 "model load returned null with no error",
54 )));
55 } else {
56 let err = unsafe { &*error_ptr };
57 sender.send(Err(Error::from_nserror(ErrorKind::ModelLoad, err)));
58 }
59 } else {
60 let retained = unsafe { objc2::rc::Retained::retain(model_ptr) };
62 match retained {
63 Some(inner) => {
64 sender.send(Ok(crate::Model {
65 inner,
66 path: owned_path.clone(),
67 }));
68 }
69 None => {
70 sender.send(Err(Error::new(
71 ErrorKind::ModelLoad,
72 "failed to retain MLModel pointer",
73 )));
74 }
75 }
76 }
77 },
78 );
79
80 unsafe {
81 MLModel::loadContentsOfURL_configuration_completionHandler(&url, &config, &block);
82 }
83
84 Ok(future)
85 }
86
87 pub fn load_from_bytes(
97 data: &[u8],
98 compute_units: ComputeUnits,
99 ) -> Result<CompletionFuture<Self>> {
100 use objc2_core_ml::{MLComputeUnits, MLModel, MLModelAsset, MLModelConfiguration};
101 use objc2_foundation::NSData;
102
103 let ns_data = NSData::with_bytes(data);
105
106 let asset =
108 unsafe { MLModelAsset::modelAssetWithSpecificationData_error(&ns_data) }
109 .map_err(|e| Error::from_nserror(ErrorKind::ModelLoad, &e))?;
110
111 let config = unsafe { MLModelConfiguration::new() };
113 let ml_units = match compute_units {
114 ComputeUnits::CpuOnly => MLComputeUnits(1),
115 ComputeUnits::CpuAndGpu => MLComputeUnits::CPUAndGPU,
116 ComputeUnits::CpuAndNeuralEngine => MLComputeUnits(2),
117 ComputeUnits::All => MLComputeUnits::All,
118 };
119 unsafe { config.setComputeUnits(ml_units) };
120
121 let (sender, future) = async_bridge::completion_channel();
123 let sender_cell = std::cell::Cell::new(Some(sender));
124
125 let block = block2::RcBlock::new(
126 move |model_ptr: *mut MLModel, error_ptr: *mut objc2_foundation::NSError| {
127 let sender = sender_cell
128 .take()
129 .expect("completion handler called more than once");
130 if model_ptr.is_null() {
131 if error_ptr.is_null() {
132 sender.send(Err(Error::new(
133 ErrorKind::ModelLoad,
134 "model load from bytes returned null with no error",
135 )));
136 } else {
137 let err = unsafe { &*error_ptr };
138 sender.send(Err(Error::from_nserror(ErrorKind::ModelLoad, err)));
139 }
140 } else {
141 let retained = unsafe { objc2::rc::Retained::retain(model_ptr) };
142 match retained {
143 Some(inner) => {
144 sender.send(Ok(crate::Model {
145 inner,
146 path: std::path::PathBuf::from("<in-memory>"),
147 }));
148 }
149 None => {
150 sender.send(Err(Error::new(
151 ErrorKind::ModelLoad,
152 "failed to retain MLModel pointer",
153 )));
154 }
155 }
156 }
157 },
158 );
159
160 unsafe {
161 MLModel::loadModelAsset_configuration_completionHandler(&asset, &config, &block);
162 }
163
164 Ok(future)
165 }
166
167 pub fn predict_async(
174 &self,
175 inputs: &[(&str, &dyn crate::tensor::AsMultiArray)],
176 ) -> Result<CompletionFuture<Prediction>> {
177 use objc2::AnyThread;
178 use objc2_core_ml::{MLDictionaryFeatureProvider, MLFeatureProvider, MLFeatureValue};
179 use objc2_foundation::{NSDictionary, NSString};
180
181 let provider = objc2::rc::autoreleasepool(|_pool| {
183 let mut keys: Vec<objc2::rc::Retained<NSString>> =
184 Vec::with_capacity(inputs.len());
185 let mut vals: Vec<objc2::rc::Retained<MLFeatureValue>> =
186 Vec::with_capacity(inputs.len());
187
188 for &(name, tensor) in inputs {
189 keys.push(crate::ffi::str_to_nsstring(name));
190 vals.push(unsafe {
191 MLFeatureValue::featureValueWithMultiArray(tensor.as_ml_multi_array())
192 });
193 }
194
195 let key_refs: Vec<&NSString> = keys.iter().map(|k| &**k).collect();
196 let val_refs: Vec<&MLFeatureValue> = vals.iter().map(|v| &**v).collect();
197
198 let dict: objc2::rc::Retained<NSDictionary<NSString, MLFeatureValue>> =
199 NSDictionary::from_slices(&key_refs, &val_refs);
200
201 let dict_any: &NSDictionary<NSString, objc2::runtime::AnyObject> = unsafe {
202 &*((&*dict) as *const NSDictionary<NSString, MLFeatureValue>
203 as *const NSDictionary<NSString, objc2::runtime::AnyObject>)
204 };
205
206 let provider = unsafe {
207 MLDictionaryFeatureProvider::initWithDictionary_error(
208 MLDictionaryFeatureProvider::alloc(),
209 dict_any,
210 )
211 }
212 .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
213
214 Ok(provider)
215 })?;
216
217 let provider_ref: &objc2::runtime::ProtocolObject<dyn MLFeatureProvider> =
218 objc2::runtime::ProtocolObject::from_ref(&*provider);
219
220 let (sender, future) = async_bridge::completion_channel();
221 let sender_cell = std::cell::Cell::new(Some(sender));
222
223 let block = block2::RcBlock::new(
224 move |result_ptr: *mut objc2::runtime::ProtocolObject<dyn MLFeatureProvider>,
225 error_ptr: *mut objc2_foundation::NSError| {
226 let sender = sender_cell
227 .take()
228 .expect("completion handler called more than once");
229 if result_ptr.is_null() {
230 if error_ptr.is_null() {
231 sender.send(Err(Error::new(
232 ErrorKind::Prediction,
233 "async prediction returned null with no error",
234 )));
235 } else {
236 let err = unsafe { &*error_ptr };
237 sender.send(Err(Error::from_nserror(ErrorKind::Prediction, err)));
238 }
239 } else {
240 let retained = unsafe { objc2::rc::Retained::retain(result_ptr) };
241 match retained {
242 Some(inner) => {
243 sender.send(Ok(Prediction { inner }));
244 }
245 None => {
246 sender.send(Err(Error::new(
247 ErrorKind::Prediction,
248 "failed to retain prediction result pointer",
249 )));
250 }
251 }
252 }
253 },
254 );
255
256 unsafe {
257 self.inner
258 .predictionFromFeatures_completionHandler(provider_ref, &block);
259 }
260
261 Ok(future)
262 }
263}
264
265#[cfg(not(target_vendor = "apple"))]
267impl crate::Model {
268 pub fn load_async(
270 _path: impl AsRef<std::path::Path>,
271 _compute_units: ComputeUnits,
272 ) -> Result<CompletionFuture<Self>> {
273 Err(Error::new(
274 ErrorKind::UnsupportedPlatform,
275 "CoreML requires Apple platform",
276 ))
277 }
278
279 pub fn load_from_bytes(
281 _data: &[u8],
282 _compute_units: ComputeUnits,
283 ) -> Result<CompletionFuture<Self>> {
284 Err(Error::new(
285 ErrorKind::UnsupportedPlatform,
286 "CoreML requires Apple platform",
287 ))
288 }
289
290 pub fn predict_async(
292 &self,
293 _inputs: &[(&str, &dyn crate::tensor::AsMultiArray)],
294 ) -> Result<CompletionFuture<Prediction>> {
295 Err(Error::new(
296 ErrorKind::UnsupportedPlatform,
297 "CoreML requires Apple platform",
298 ))
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 #[cfg(not(target_vendor = "apple"))]
305 use crate::{ComputeUnits, ErrorKind, Model};
306
307 #[cfg(not(target_vendor = "apple"))]
308 #[test]
309 fn load_async_fails_on_non_apple() {
310 let err = Model::load_async("/tmp/fake.mlmodelc", ComputeUnits::All).unwrap_err();
311 assert_eq!(err.kind(), &ErrorKind::UnsupportedPlatform);
312 }
313
314 #[cfg(not(target_vendor = "apple"))]
315 #[test]
316 fn load_from_bytes_fails_on_non_apple() {
317 let err = Model::load_from_bytes(&[0u8; 10], ComputeUnits::All).unwrap_err();
318 assert_eq!(err.kind(), &ErrorKind::UnsupportedPlatform);
319 }
320}