Skip to main content

coreml_native/
batch.rs

1//! Batch prediction support via MLArrayBatchProvider.
2//!
3//! More efficient than calling `predict()` in a loop when you have
4//! multiple inputs to process. CoreML can optimize execution across
5//! the entire batch.
6
7use crate::error::{Error, ErrorKind, Result};
8use crate::tensor::AsMultiArray;
9
10// ─── BatchProvider ──────────────────────────────────────────────────────────
11
12/// A batch of input feature sets for bulk prediction.
13///
14/// Wraps `MLArrayBatchProvider` to collect multiple input sets that can
15/// be submitted to `Model::predict_batch` in a single call.
16///
17/// # Example
18///
19/// ```ignore
20/// let inputs_0: &[(&str, &dyn AsMultiArray)] = &[("x", &tensor_a)];
21/// let inputs_1: &[(&str, &dyn AsMultiArray)] = &[("x", &tensor_b)];
22/// let batch = BatchProvider::new(&[&inputs_0[..], &inputs_1[..]])?;
23/// assert_eq!(batch.count(), 2);
24/// ```
25#[cfg(target_vendor = "apple")]
26pub struct BatchProvider {
27    pub(crate) inner: objc2::rc::Retained<objc2_core_ml::MLArrayBatchProvider>,
28}
29
30#[cfg(target_vendor = "apple")]
31impl BatchProvider {
32    /// Create a batch provider from a slice of named input sets.
33    ///
34    /// Each element in `inputs` is a complete set of named tensors for one prediction.
35    pub fn new(inputs: &[&[(&str, &dyn AsMultiArray)]]) -> Result<Self> {
36        use objc2::AnyThread;
37        use objc2::runtime::ProtocolObject;
38        use objc2_core_ml::{
39            MLArrayBatchProvider, MLDictionaryFeatureProvider, MLFeatureProvider, MLFeatureValue,
40        };
41        use objc2_foundation::{NSDictionary, NSString};
42
43        let mut providers: Vec<
44            objc2::rc::Retained<ProtocolObject<dyn MLFeatureProvider>>,
45        > = Vec::with_capacity(inputs.len());
46
47        for input_set in inputs {
48            let mut keys: Vec<objc2::rc::Retained<NSString>> =
49                Vec::with_capacity(input_set.len());
50            let mut vals: Vec<objc2::rc::Retained<MLFeatureValue>> =
51                Vec::with_capacity(input_set.len());
52
53            for &(name, tensor) in *input_set {
54                keys.push(crate::ffi::str_to_nsstring(name));
55                vals.push(unsafe {
56                    MLFeatureValue::featureValueWithMultiArray(tensor.as_ml_multi_array())
57                });
58            }
59
60            let key_refs: Vec<&NSString> = keys.iter().map(|k| &**k).collect();
61            let val_refs: Vec<&MLFeatureValue> = vals.iter().map(|v| &**v).collect();
62
63            let dict: objc2::rc::Retained<NSDictionary<NSString, MLFeatureValue>> =
64                NSDictionary::from_slices(&key_refs, &val_refs);
65
66            // Safety: MLFeatureValue is an NSObject subclass, so the pointer cast is valid.
67            let dict_any: &NSDictionary<NSString, objc2::runtime::AnyObject> = unsafe {
68                &*((&*dict) as *const NSDictionary<NSString, MLFeatureValue>
69                    as *const NSDictionary<NSString, objc2::runtime::AnyObject>)
70            };
71
72            let provider = unsafe {
73                MLDictionaryFeatureProvider::initWithDictionary_error(
74                    MLDictionaryFeatureProvider::alloc(),
75                    dict_any,
76                )
77            }
78            .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
79
80            let proto = ProtocolObject::from_retained(provider);
81            providers.push(proto);
82        }
83
84        let provider_refs: Vec<&ProtocolObject<dyn MLFeatureProvider>> =
85            providers.iter().map(|p| &**p).collect();
86        let array = objc2_foundation::NSArray::from_slice(&provider_refs);
87
88        let batch = unsafe {
89            MLArrayBatchProvider::initWithFeatureProviderArray(
90                MLArrayBatchProvider::alloc(),
91                &array,
92            )
93        };
94
95        Ok(Self { inner: batch })
96    }
97
98    /// Returns the number of input sets in this batch.
99    pub fn count(&self) -> usize {
100        use objc2_core_ml::MLBatchProvider;
101        (unsafe { self.inner.count() }) as usize
102    }
103}
104
105#[cfg(target_vendor = "apple")]
106impl std::fmt::Debug for BatchProvider {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        f.debug_struct("BatchProvider")
109            .field("count", &self.count())
110            .finish()
111    }
112}
113
114// Safety: MLArrayBatchProvider holds immutable data after construction.
115#[cfg(target_vendor = "apple")]
116unsafe impl Send for BatchProvider {}
117
118// ─── Non-Apple stub ─────────────────────────────────────────────────────────
119
120#[cfg(not(target_vendor = "apple"))]
121#[derive(Debug)]
122pub struct BatchProvider {
123    _private: (),
124}
125
126#[cfg(not(target_vendor = "apple"))]
127impl BatchProvider {
128    pub fn new(_inputs: &[&[(&str, &dyn AsMultiArray)]]) -> Result<Self> {
129        Err(Error::new(
130            ErrorKind::UnsupportedPlatform,
131            "CoreML requires Apple platform",
132        ))
133    }
134
135    pub fn count(&self) -> usize {
136        0
137    }
138}
139
140// ─── BatchPrediction ────────────────────────────────────────────────────────
141
142/// Result of a batch prediction, wrapping an `MLBatchProvider`.
143///
144/// Provides indexed access to individual prediction results.
145#[cfg(target_vendor = "apple")]
146pub struct BatchPrediction {
147    pub(crate) inner:
148        objc2::rc::Retained<objc2::runtime::ProtocolObject<dyn objc2_core_ml::MLBatchProvider>>,
149}
150
151#[cfg(target_vendor = "apple")]
152impl BatchPrediction {
153    /// Number of results in the batch.
154    pub fn count(&self) -> usize {
155        use objc2_core_ml::MLBatchProvider;
156        (unsafe { self.inner.count() }) as usize
157    }
158
159    /// Get an output tensor as `(Vec<f32>, shape)` from the result at `index`.
160    ///
161    /// This is the batch equivalent of `Prediction::get_f32`.
162    #[allow(deprecated)]
163    #[allow(clippy::needless_range_loop)]
164    pub fn get_f32(&self, index: usize, output_name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
165        use objc2_core_ml::{MLBatchProvider, MLFeatureProvider};
166
167        if index >= self.count() {
168            return Err(Error::new(
169                ErrorKind::Prediction,
170                format!(
171                    "batch index {index} out of range (count: {})",
172                    self.count()
173                ),
174            ));
175        }
176
177        objc2::rc::autoreleasepool(|_pool| {
178            let provider = unsafe { self.inner.featuresAtIndex(index as isize) };
179
180            let ns_name = crate::ffi::str_to_nsstring(output_name);
181            let feature_val =
182                unsafe { provider.featureValueForName(&ns_name) }.ok_or_else(|| {
183                    Error::new(
184                        ErrorKind::Prediction,
185                        format!("output '{output_name}' not found at batch index {index}"),
186                    )
187                })?;
188
189            let array = unsafe { feature_val.multiArrayValue() }.ok_or_else(|| {
190                Error::new(
191                    ErrorKind::Prediction,
192                    format!(
193                        "output '{output_name}' is not a multi-array at batch index {index}"
194                    ),
195                )
196            })?;
197
198            let shape = crate::ffi::nsarray_to_shape(unsafe { &array.shape() });
199            let count = crate::tensor::element_count(&shape);
200            let dt_raw = unsafe { array.dataType() };
201            let data_type = crate::ffi::ml_to_datatype(dt_raw.0);
202
203            let mut buf = vec![0.0f32; count];
204            unsafe {
205                let ptr = array.dataPointer();
206                match data_type {
207                    Some(crate::tensor::DataType::Float32) => {
208                        let src = ptr.as_ptr() as *const f32;
209                        std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
210                    }
211                    Some(crate::tensor::DataType::Float16) => {
212                        let src = ptr.as_ptr() as *const u16;
213                        for i in 0..count {
214                            buf[i] = crate::f16_to_f32(*src.add(i));
215                        }
216                    }
217                    Some(crate::tensor::DataType::Float64) => {
218                        let src = ptr.as_ptr() as *const f64;
219                        for i in 0..count {
220                            buf[i] = *src.add(i) as f32;
221                        }
222                    }
223                    Some(crate::tensor::DataType::Int32) => {
224                        let src = ptr.as_ptr() as *const i32;
225                        for i in 0..count {
226                            buf[i] = *src.add(i) as f32;
227                        }
228                    }
229                    Some(crate::tensor::DataType::Int16) => {
230                        let src = ptr.as_ptr() as *const i16;
231                        for i in 0..count {
232                            buf[i] = *src.add(i) as f32;
233                        }
234                    }
235                    Some(crate::tensor::DataType::Int8) => {
236                        let src = ptr.as_ptr() as *const i8;
237                        for i in 0..count {
238                            buf[i] = *src.add(i) as f32;
239                        }
240                    }
241                    Some(crate::tensor::DataType::UInt32) => {
242                        let src = ptr.as_ptr() as *const u32;
243                        for i in 0..count {
244                            buf[i] = *src.add(i) as f32;
245                        }
246                    }
247                    Some(crate::tensor::DataType::UInt16) => {
248                        let src = ptr.as_ptr() as *const u16;
249                        for i in 0..count {
250                            buf[i] = *src.add(i) as f32;
251                        }
252                    }
253                    Some(crate::tensor::DataType::UInt8) => {
254                        let src = ptr.as_ptr() as *const u8;
255                        for i in 0..count {
256                            buf[i] = *src.add(i) as f32;
257                        }
258                    }
259                    None => {
260                        return Err(Error::new(
261                            ErrorKind::Prediction,
262                            "unsupported output data type",
263                        ));
264                    }
265                }
266            }
267
268            Ok((buf, shape))
269        })
270    }
271
272    /// Get the feature provider at the given index (for advanced use).
273    ///
274    /// Returns a retained protocol object that can be queried for any output features.
275    pub fn feature_provider(
276        &self,
277        index: usize,
278    ) -> Result<
279        objc2::rc::Retained<
280            objc2::runtime::ProtocolObject<dyn objc2_core_ml::MLFeatureProvider>,
281        >,
282    > {
283        use objc2_core_ml::MLBatchProvider;
284
285        if index >= self.count() {
286            return Err(Error::new(
287                ErrorKind::Prediction,
288                format!(
289                    "batch index {index} out of range (count: {})",
290                    self.count()
291                ),
292            ));
293        }
294
295        Ok(unsafe { self.inner.featuresAtIndex(index as isize) })
296    }
297}
298
299#[cfg(target_vendor = "apple")]
300impl std::fmt::Debug for BatchPrediction {
301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        f.debug_struct("BatchPrediction")
303            .field("count", &self.count())
304            .finish()
305    }
306}
307
308// Safety: the retained MLBatchProvider is reference-counted and read-only after creation.
309#[cfg(target_vendor = "apple")]
310unsafe impl Send for BatchPrediction {}
311
312// ─── Non-Apple stub ─────────────────────────────────────────────────────────
313
314#[cfg(not(target_vendor = "apple"))]
315#[derive(Debug)]
316pub struct BatchPrediction {
317    _private: (),
318}
319
320#[cfg(not(target_vendor = "apple"))]
321impl BatchPrediction {
322    pub fn count(&self) -> usize {
323        0
324    }
325
326    pub fn get_f32(
327        &self,
328        _index: usize,
329        _output_name: &str,
330    ) -> Result<(Vec<f32>, Vec<usize>)> {
331        Err(Error::new(
332            ErrorKind::UnsupportedPlatform,
333            "CoreML requires Apple platform",
334        ))
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    #[cfg(not(target_vendor = "apple"))]
341    use super::*;
342
343    #[cfg(not(target_vendor = "apple"))]
344    #[test]
345    fn batch_provider_fails_on_non_apple() {
346        let inputs: &[&[(&str, &dyn AsMultiArray)]] = &[];
347        let err = BatchProvider::new(inputs).unwrap_err();
348        assert_eq!(err.kind(), &ErrorKind::UnsupportedPlatform);
349    }
350
351    #[cfg(not(target_vendor = "apple"))]
352    #[test]
353    fn batch_prediction_fails_on_non_apple() {
354        let pred = BatchPrediction { _private: () };
355        assert_eq!(pred.count(), 0);
356        let err = pred.get_f32(0, "output").unwrap_err();
357        assert_eq!(err.kind(), &ErrorKind::UnsupportedPlatform);
358    }
359}