1use crate::error::{Error, ErrorKind, Result};
8use crate::tensor::AsMultiArray;
9
10#[cfg(target_vendor = "apple")]
26pub struct BatchProvider {
27 pub(crate) inner: objc2::rc::Retained<objc2_core_ml::MLArrayBatchProvider>,
28}
29
30#[cfg(target_vendor = "apple")]
31impl BatchProvider {
32 pub fn new(inputs: &[&[(&str, &dyn AsMultiArray)]]) -> Result<Self> {
36 use objc2::AnyThread;
37 use objc2::runtime::ProtocolObject;
38 use objc2_core_ml::{
39 MLArrayBatchProvider, MLDictionaryFeatureProvider, MLFeatureProvider, MLFeatureValue,
40 };
41 use objc2_foundation::{NSDictionary, NSString};
42
43 let mut providers: Vec<
44 objc2::rc::Retained<ProtocolObject<dyn MLFeatureProvider>>,
45 > = Vec::with_capacity(inputs.len());
46
47 for input_set in inputs {
48 let mut keys: Vec<objc2::rc::Retained<NSString>> =
49 Vec::with_capacity(input_set.len());
50 let mut vals: Vec<objc2::rc::Retained<MLFeatureValue>> =
51 Vec::with_capacity(input_set.len());
52
53 for &(name, tensor) in *input_set {
54 keys.push(crate::ffi::str_to_nsstring(name));
55 vals.push(unsafe {
56 MLFeatureValue::featureValueWithMultiArray(tensor.as_ml_multi_array())
57 });
58 }
59
60 let key_refs: Vec<&NSString> = keys.iter().map(|k| &**k).collect();
61 let val_refs: Vec<&MLFeatureValue> = vals.iter().map(|v| &**v).collect();
62
63 let dict: objc2::rc::Retained<NSDictionary<NSString, MLFeatureValue>> =
64 NSDictionary::from_slices(&key_refs, &val_refs);
65
66 let dict_any: &NSDictionary<NSString, objc2::runtime::AnyObject> = unsafe {
68 &*((&*dict) as *const NSDictionary<NSString, MLFeatureValue>
69 as *const NSDictionary<NSString, objc2::runtime::AnyObject>)
70 };
71
72 let provider = unsafe {
73 MLDictionaryFeatureProvider::initWithDictionary_error(
74 MLDictionaryFeatureProvider::alloc(),
75 dict_any,
76 )
77 }
78 .map_err(|e| Error::from_nserror(ErrorKind::Prediction, &e))?;
79
80 let proto = ProtocolObject::from_retained(provider);
81 providers.push(proto);
82 }
83
84 let provider_refs: Vec<&ProtocolObject<dyn MLFeatureProvider>> =
85 providers.iter().map(|p| &**p).collect();
86 let array = objc2_foundation::NSArray::from_slice(&provider_refs);
87
88 let batch = unsafe {
89 MLArrayBatchProvider::initWithFeatureProviderArray(
90 MLArrayBatchProvider::alloc(),
91 &array,
92 )
93 };
94
95 Ok(Self { inner: batch })
96 }
97
98 pub fn count(&self) -> usize {
100 use objc2_core_ml::MLBatchProvider;
101 (unsafe { self.inner.count() }) as usize
102 }
103}
104
105#[cfg(target_vendor = "apple")]
106impl std::fmt::Debug for BatchProvider {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("BatchProvider")
109 .field("count", &self.count())
110 .finish()
111 }
112}
113
114#[cfg(target_vendor = "apple")]
116unsafe impl Send for BatchProvider {}
117
118#[cfg(not(target_vendor = "apple"))]
121#[derive(Debug)]
122pub struct BatchProvider {
123 _private: (),
124}
125
126#[cfg(not(target_vendor = "apple"))]
127impl BatchProvider {
128 pub fn new(_inputs: &[&[(&str, &dyn AsMultiArray)]]) -> Result<Self> {
129 Err(Error::new(
130 ErrorKind::UnsupportedPlatform,
131 "CoreML requires Apple platform",
132 ))
133 }
134
135 pub fn count(&self) -> usize {
136 0
137 }
138}
139
140#[cfg(target_vendor = "apple")]
146pub struct BatchPrediction {
147 pub(crate) inner:
148 objc2::rc::Retained<objc2::runtime::ProtocolObject<dyn objc2_core_ml::MLBatchProvider>>,
149}
150
151#[cfg(target_vendor = "apple")]
152impl BatchPrediction {
153 pub fn count(&self) -> usize {
155 use objc2_core_ml::MLBatchProvider;
156 (unsafe { self.inner.count() }) as usize
157 }
158
159 #[allow(deprecated)]
163 #[allow(clippy::needless_range_loop)]
164 pub fn get_f32(&self, index: usize, output_name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
165 use objc2_core_ml::{MLBatchProvider, MLFeatureProvider};
166
167 if index >= self.count() {
168 return Err(Error::new(
169 ErrorKind::Prediction,
170 format!(
171 "batch index {index} out of range (count: {})",
172 self.count()
173 ),
174 ));
175 }
176
177 objc2::rc::autoreleasepool(|_pool| {
178 let provider = unsafe { self.inner.featuresAtIndex(index as isize) };
179
180 let ns_name = crate::ffi::str_to_nsstring(output_name);
181 let feature_val =
182 unsafe { provider.featureValueForName(&ns_name) }.ok_or_else(|| {
183 Error::new(
184 ErrorKind::Prediction,
185 format!("output '{output_name}' not found at batch index {index}"),
186 )
187 })?;
188
189 let array = unsafe { feature_val.multiArrayValue() }.ok_or_else(|| {
190 Error::new(
191 ErrorKind::Prediction,
192 format!(
193 "output '{output_name}' is not a multi-array at batch index {index}"
194 ),
195 )
196 })?;
197
198 let shape = crate::ffi::nsarray_to_shape(unsafe { &array.shape() });
199 let count = crate::tensor::element_count(&shape);
200 let dt_raw = unsafe { array.dataType() };
201 let data_type = crate::ffi::ml_to_datatype(dt_raw.0);
202
203 let mut buf = vec![0.0f32; count];
204 unsafe {
205 let ptr = array.dataPointer();
206 match data_type {
207 Some(crate::tensor::DataType::Float32) => {
208 let src = ptr.as_ptr() as *const f32;
209 std::ptr::copy_nonoverlapping(src, buf.as_mut_ptr(), count);
210 }
211 Some(crate::tensor::DataType::Float16) => {
212 let src = ptr.as_ptr() as *const u16;
213 for i in 0..count {
214 buf[i] = crate::f16_to_f32(*src.add(i));
215 }
216 }
217 Some(crate::tensor::DataType::Float64) => {
218 let src = ptr.as_ptr() as *const f64;
219 for i in 0..count {
220 buf[i] = *src.add(i) as f32;
221 }
222 }
223 Some(crate::tensor::DataType::Int32) => {
224 let src = ptr.as_ptr() as *const i32;
225 for i in 0..count {
226 buf[i] = *src.add(i) as f32;
227 }
228 }
229 Some(crate::tensor::DataType::Int16) => {
230 let src = ptr.as_ptr() as *const i16;
231 for i in 0..count {
232 buf[i] = *src.add(i) as f32;
233 }
234 }
235 Some(crate::tensor::DataType::Int8) => {
236 let src = ptr.as_ptr() as *const i8;
237 for i in 0..count {
238 buf[i] = *src.add(i) as f32;
239 }
240 }
241 Some(crate::tensor::DataType::UInt32) => {
242 let src = ptr.as_ptr() as *const u32;
243 for i in 0..count {
244 buf[i] = *src.add(i) as f32;
245 }
246 }
247 Some(crate::tensor::DataType::UInt16) => {
248 let src = ptr.as_ptr() as *const u16;
249 for i in 0..count {
250 buf[i] = *src.add(i) as f32;
251 }
252 }
253 Some(crate::tensor::DataType::UInt8) => {
254 let src = ptr.as_ptr() as *const u8;
255 for i in 0..count {
256 buf[i] = *src.add(i) as f32;
257 }
258 }
259 None => {
260 return Err(Error::new(
261 ErrorKind::Prediction,
262 "unsupported output data type",
263 ));
264 }
265 }
266 }
267
268 Ok((buf, shape))
269 })
270 }
271
272 pub fn feature_provider(
276 &self,
277 index: usize,
278 ) -> Result<
279 objc2::rc::Retained<
280 objc2::runtime::ProtocolObject<dyn objc2_core_ml::MLFeatureProvider>,
281 >,
282 > {
283 use objc2_core_ml::MLBatchProvider;
284
285 if index >= self.count() {
286 return Err(Error::new(
287 ErrorKind::Prediction,
288 format!(
289 "batch index {index} out of range (count: {})",
290 self.count()
291 ),
292 ));
293 }
294
295 Ok(unsafe { self.inner.featuresAtIndex(index as isize) })
296 }
297}
298
299#[cfg(target_vendor = "apple")]
300impl std::fmt::Debug for BatchPrediction {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 f.debug_struct("BatchPrediction")
303 .field("count", &self.count())
304 .finish()
305 }
306}
307
308#[cfg(target_vendor = "apple")]
310unsafe impl Send for BatchPrediction {}
311
312#[cfg(not(target_vendor = "apple"))]
315#[derive(Debug)]
316pub struct BatchPrediction {
317 _private: (),
318}
319
320#[cfg(not(target_vendor = "apple"))]
321impl BatchPrediction {
322 pub fn count(&self) -> usize {
323 0
324 }
325
326 pub fn get_f32(
327 &self,
328 _index: usize,
329 _output_name: &str,
330 ) -> Result<(Vec<f32>, Vec<usize>)> {
331 Err(Error::new(
332 ErrorKind::UnsupportedPlatform,
333 "CoreML requires Apple platform",
334 ))
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 #[cfg(not(target_vendor = "apple"))]
341 use super::*;
342
343 #[cfg(not(target_vendor = "apple"))]
344 #[test]
345 fn batch_provider_fails_on_non_apple() {
346 let inputs: &[&[(&str, &dyn AsMultiArray)]] = &[];
347 let err = BatchProvider::new(inputs).unwrap_err();
348 assert_eq!(err.kind(), &ErrorKind::UnsupportedPlatform);
349 }
350
351 #[cfg(not(target_vendor = "apple"))]
352 #[test]
353 fn batch_prediction_fails_on_non_apple() {
354 let pred = BatchPrediction { _private: () };
355 assert_eq!(pred.count(), 0);
356 let err = pred.get_f32(0, "output").unwrap_err();
357 assert_eq!(err.kind(), &ErrorKind::UnsupportedPlatform);
358 }
359}