Skip to main content

coreml_native/
model_async.rs

1//! Async model loading and prediction.
2//!
3//! Extends [`Model`](crate::Model) with async variants of load and predict
4//! that use Apple's completion-handler-based APIs under the hood.
5
6use crate::async_bridge::{self, CompletionFuture};
7use crate::error::{Error, ErrorKind, Result};
8use crate::{ComputeUnits, Prediction};
9
10#[cfg(target_vendor = "apple")]
11impl crate::Model {
12    /// Load a compiled model asynchronously.
13    ///
14    /// Returns a `CompletionFuture` that resolves when loading completes.
15    /// Use `.await` in async contexts or `.block_on()` for synchronous callers.
16    ///
17    /// Requires macOS 12+ / iOS 15+.
18    pub fn load_async(
19        path: impl AsRef<std::path::Path>,
20        compute_units: ComputeUnits,
21    ) -> Result<CompletionFuture<Self>> {
22        use objc2_core_ml::{MLComputeUnits, MLModel, MLModelConfiguration};
23
24        let path = path.as_ref();
25        let path_str = path.to_str().ok_or_else(|| {
26            Error::new(ErrorKind::ModelLoad, "path contains non-UTF8 characters")
27        })?;
28
29        let url =
30            objc2_foundation::NSURL::fileURLWithPath(&crate::ffi::str_to_nsstring(path_str));
31        let config = unsafe { MLModelConfiguration::new() };
32        let ml_units = match compute_units {
33            ComputeUnits::CpuOnly => MLComputeUnits(1),
34            ComputeUnits::CpuAndGpu => MLComputeUnits::CPUAndGPU,
35            ComputeUnits::CpuAndNeuralEngine => MLComputeUnits(2),
36            ComputeUnits::All => MLComputeUnits::All,
37        };
38        unsafe { config.setComputeUnits(ml_units) };
39
40        let (sender, future) = async_bridge::completion_channel();
41        let sender_cell = std::cell::Cell::new(Some(sender));
42        let owned_path = path.to_path_buf();
43
44        let block = block2::RcBlock::new(
45            move |model_ptr: *mut MLModel, error_ptr: *mut objc2_foundation::NSError| {
46                let sender = sender_cell
47                    .take()
48                    .expect("completion handler called more than once");
49                if model_ptr.is_null() {
50                    if error_ptr.is_null() {
51                        sender.send(Err(Error::new(
52                            ErrorKind::ModelLoad,
53                            "model load returned null with no error",
54                        )));
55                    } else {
56                        let err = unsafe { &*error_ptr };
57                        sender.send(Err(Error::from_nserror(ErrorKind::ModelLoad, err)));
58                    }
59                } else {
60                    // Safety: model_ptr is non-null. Use retain to get +1 refcount.
61                    let retained = unsafe { objc2::rc::Retained::retain(model_ptr) };
62                    match retained {
63                        Some(inner) => {
64                            sender.send(Ok(crate::Model {
65                                inner,
66                                path: owned_path.clone(),
67                            }));
68                        }
69                        None => {
70                            sender.send(Err(Error::new(
71                                ErrorKind::ModelLoad,
72                                "failed to retain MLModel pointer",
73                            )));
74                        }
75                    }
76                }
77            },
78        );
79
80        unsafe {
81            MLModel::loadContentsOfURL_configuration_completionHandler(&url, &config, &block);
82        }
83
84        Ok(future)
85    }
86
87    /// Load a model from in-memory specification bytes asynchronously.
88    ///
89    /// Creates an `MLModelAsset` from the specification data (synchronously),
90    /// then loads the model asynchronously via the completion handler API.
91    ///
92    /// The `data` parameter should contain the contents of a `.mlmodel` file
93    /// (the protobuf specification, not a compiled `.mlmodelc`).
94    ///
95    /// Requires macOS 14.4+ / iOS 17.4+.
96    pub fn load_from_bytes(
97        data: &[u8],
98        compute_units: ComputeUnits,
99    ) -> Result<CompletionFuture<Self>> {
100        use objc2_core_ml::{MLComputeUnits, MLModel, MLModelAsset, MLModelConfiguration};
101        use objc2_foundation::NSData;
102
103        // Step 1: Create NSData from the byte slice (copies data).
104        let ns_data = NSData::with_bytes(data);
105
106        // Step 2: Create MLModelAsset synchronously.
107        let asset =
108            unsafe { MLModelAsset::modelAssetWithSpecificationData_error(&ns_data) }
109                .map_err(|e| Error::from_nserror(ErrorKind::ModelLoad, &e))?;
110
111        // Step 3: Configure compute units.
112        let config = unsafe { MLModelConfiguration::new() };
113        let ml_units = match compute_units {
114            ComputeUnits::CpuOnly => MLComputeUnits(1),
115            ComputeUnits::CpuAndGpu => MLComputeUnits::CPUAndGPU,
116            ComputeUnits::CpuAndNeuralEngine => MLComputeUnits(2),
117            ComputeUnits::All => MLComputeUnits::All,
118        };
119        unsafe { config.setComputeUnits(ml_units) };
120
121        // Step 4: Load asynchronously via completion handler.
122        let (sender, future) = async_bridge::completion_channel();
123        let sender_cell = std::cell::Cell::new(Some(sender));
124
125        let block = block2::RcBlock::new(
126            move |model_ptr: *mut MLModel, error_ptr: *mut objc2_foundation::NSError| {
127                let sender = sender_cell
128                    .take()
129                    .expect("completion handler called more than once");
130                if model_ptr.is_null() {
131                    if error_ptr.is_null() {
132                        sender.send(Err(Error::new(
133                            ErrorKind::ModelLoad,
134                            "model load from bytes returned null with no error",
135                        )));
136                    } else {
137                        let err = unsafe { &*error_ptr };
138                        sender.send(Err(Error::from_nserror(ErrorKind::ModelLoad, err)));
139                    }
140                } else {
141                    let retained = unsafe { objc2::rc::Retained::retain(model_ptr) };
142                    match retained {
143                        Some(inner) => {
144                            sender.send(Ok(crate::Model {
145                                inner,
146                                path: std::path::PathBuf::from("<in-memory>"),
147                            }));
148                        }
149                        None => {
150                            sender.send(Err(Error::new(
151                                ErrorKind::ModelLoad,
152                                "failed to retain MLModel pointer",
153                            )));
154                        }
155                    }
156                }
157            },
158        );
159
160        unsafe {
161            MLModel::loadModelAsset_configuration_completionHandler(&asset, &config, &block);
162        }
163
164        Ok(future)
165    }
166
167    /// Run a prediction asynchronously.
168    ///
169    /// Builds the feature provider from the input tensors, then calls the
170    /// async prediction API with a completion handler.
171    ///
172    /// Requires macOS 14+ / iOS 17+.
173    pub fn predict_async(
174        &self,
175        inputs: &[(&str, &dyn crate::tensor::AsMultiArray)],
176    ) -> Result<CompletionFuture<Prediction>> {
177        use objc2::AnyThread;
178        use objc2_core_ml::{MLDictionaryFeatureProvider, MLFeatureProvider, MLFeatureValue};
179        use objc2_foundation::{NSDictionary, NSString};
180
181        // Build the feature provider (same as sync predict).
182        let provider = objc2::rc::autoreleasepool(|_pool| {
183            let mut keys: Vec<objc2::rc::Retained<NSString>> =
184                Vec::with_capacity(inputs.len());
185            let mut vals: Vec<objc2::rc::Retained<MLFeatureValue>> =
186                Vec::with_capacity(inputs.len());
187
188            for &(name, tensor) in inputs {
189                keys.push(crate::ffi::str_to_nsstring(name));
190                vals.push(unsafe {
191                    MLFeatureValue::featureValueWithMultiArray(tensor.as_ml_multi_array())
192                });
193            }
194
195            let key_refs: Vec<&NSString> = keys.iter().map(|k| &**k).collect();
196            let val_refs: Vec<&MLFeatureValue> = vals.iter().map(|v| &**v).collect();
197
198            let dict: objc2::rc::Retained<NSDictionary<NSString, MLFeatureValue>> =
199                NSDictionary::from_slices(&key_refs, &val_refs);
200
201            let dict_any: &NSDictionary<NSString, objc2::runtime::AnyObject> = unsafe {
202                &*((&*dict) as *const NSDictionary<NSString, MLFeatureValue>
203                    as *const NSDictionary<NSString, objc2::runtime::AnyObject>)
204            };
205
206            let provider = unsafe {
207                MLDictionaryFeatureProvider::initWithDictionary_error(
208                    MLDictionaryFeatureProvider::alloc(),
209                    dict_any,
210                )
211            }
212            .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
213
214            Ok(provider)
215        })?;
216
217        let provider_ref: &objc2::runtime::ProtocolObject<dyn MLFeatureProvider> =
218            objc2::runtime::ProtocolObject::from_ref(&*provider);
219
220        let (sender, future) = async_bridge::completion_channel();
221        let sender_cell = std::cell::Cell::new(Some(sender));
222
223        let block = block2::RcBlock::new(
224            move |result_ptr: *mut objc2::runtime::ProtocolObject<dyn MLFeatureProvider>,
225                  error_ptr: *mut objc2_foundation::NSError| {
226                let sender = sender_cell
227                    .take()
228                    .expect("completion handler called more than once");
229                if result_ptr.is_null() {
230                    if error_ptr.is_null() {
231                        sender.send(Err(Error::new(
232                            ErrorKind::Prediction,
233                            "async prediction returned null with no error",
234                        )));
235                    } else {
236                        let err = unsafe { &*error_ptr };
237                        sender.send(Err(Error::from_nserror(ErrorKind::Prediction, err)));
238                    }
239                } else {
240                    let retained = unsafe { objc2::rc::Retained::retain(result_ptr) };
241                    match retained {
242                        Some(inner) => {
243                            sender.send(Ok(Prediction { inner }));
244                        }
245                        None => {
246                            sender.send(Err(Error::new(
247                                ErrorKind::Prediction,
248                                "failed to retain prediction result pointer",
249                            )));
250                        }
251                    }
252                }
253            },
254        );
255
256        unsafe {
257            self.inner
258                .predictionFromFeatures_completionHandler(provider_ref, &block);
259        }
260
261        Ok(future)
262    }
263}
264
265// Non-Apple stubs
266#[cfg(not(target_vendor = "apple"))]
267impl crate::Model {
268    /// Load a compiled model asynchronously (stub for non-Apple platforms).
269    pub fn load_async(
270        _path: impl AsRef<std::path::Path>,
271        _compute_units: ComputeUnits,
272    ) -> Result<CompletionFuture<Self>> {
273        Err(Error::new(
274            ErrorKind::UnsupportedPlatform,
275            "CoreML requires Apple platform",
276        ))
277    }
278
279    /// Load a model from in-memory bytes (stub for non-Apple platforms).
280    pub fn load_from_bytes(
281        _data: &[u8],
282        _compute_units: ComputeUnits,
283    ) -> Result<CompletionFuture<Self>> {
284        Err(Error::new(
285            ErrorKind::UnsupportedPlatform,
286            "CoreML requires Apple platform",
287        ))
288    }
289
290    /// Run a prediction asynchronously (stub for non-Apple platforms).
291    pub fn predict_async(
292        &self,
293        _inputs: &[(&str, &dyn crate::tensor::AsMultiArray)],
294    ) -> Result<CompletionFuture<Prediction>> {
295        Err(Error::new(
296            ErrorKind::UnsupportedPlatform,
297            "CoreML requires Apple platform",
298        ))
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    #[cfg(not(target_vendor = "apple"))]
305    use crate::{ComputeUnits, ErrorKind, Model};
306
307    #[cfg(not(target_vendor = "apple"))]
308    #[test]
309    fn load_async_fails_on_non_apple() {
310        let err = Model::load_async("/tmp/fake.mlmodelc", ComputeUnits::All).unwrap_err();
311        assert_eq!(err.kind(), &ErrorKind::UnsupportedPlatform);
312    }
313
314    #[cfg(not(target_vendor = "apple"))]
315    #[test]
316    fn load_from_bytes_fails_on_non_apple() {
317        let err = Model::load_from_bytes(&[0u8; 10], ComputeUnits::All).unwrap_err();
318        assert_eq!(err.kind(), &ErrorKind::UnsupportedPlatform);
319    }
320}