1use candle_core::{Device, Tensor, Error as CandleError};
7use std::path::{Path, PathBuf};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Config {
13 pub input_names: Vec<String>,
15 pub output_name: String,
17 pub max_sequence_length: usize,
19 pub vocab_size: usize,
21 pub model_type: String,
23}
24
25impl Default for Config {
26 fn default() -> Self {
27 Self {
28 input_names: vec!["input_ids".to_string()],
29 output_name: "logits".to_string(),
30 max_sequence_length: 128,
31 vocab_size: 32000,
32 model_type: "coreml".to_string(),
33 }
34 }
35}
36
37impl Config {
38 pub fn bert_config(output_name: &str, max_seq_len: usize, vocab_size: usize) -> Self {
40 Self {
41 input_names: vec![
42 "input_ids".to_string(),
43 "token_type_ids".to_string(),
44 "attention_mask".to_string(),
45 ],
46 output_name: output_name.to_string(),
47 max_sequence_length: max_seq_len,
48 vocab_size,
49 model_type: "bert".to_string(),
50 }
51 }
52}
53
54#[cfg(target_os = "macos")]
55use objc2::rc::{autoreleasepool, Retained};
56#[cfg(target_os = "macos")]
57use objc2_core_ml::{MLModel, MLMultiArray, MLDictionaryFeatureProvider, MLFeatureProvider};
58#[cfg(target_os = "macos")]
59use objc2_foundation::{NSString, NSURL};
60#[cfg(target_os = "macos")]
61use objc2::runtime::ProtocolObject;
62#[cfg(target_os = "macos")]
63use objc2::AnyThread;
64#[cfg(target_os = "macos")]
65use block2::StackBlock;
66
67pub struct CoreMLModel {
69 #[cfg(target_os = "macos")]
70 inner: Retained<MLModel>,
71 #[cfg(not(target_os = "macos"))]
72 _phantom: std::marker::PhantomData<()>,
73 config: Config,
74}
75
76impl std::fmt::Debug for CoreMLModel {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 f.debug_struct("CoreMLModel")
79 .field("config", &self.config)
80 .finish_non_exhaustive()
81 }
82}
83
84impl CoreMLModel {
85 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, CandleError> {
87 let config = Config::default();
88 Self::load_from_file(path, &config)
89 }
90
91 pub fn load_from_file<P: AsRef<Path>>(path: P, config: &Config) -> Result<Self, CandleError> {
96 #[cfg(target_os = "macos")]
97 {
98 let path = path.as_ref();
99 if !path.exists() {
100 return Err(CandleError::Msg(format!(
101 "Model file not found: {}",
102 path.display()
103 )));
104 }
105
106 autoreleasepool(|_| {
107 let url = unsafe {
108 NSURL::fileURLWithPath(&NSString::from_str(&path.to_string_lossy()))
109 };
110
111 match unsafe { MLModel::modelWithContentsOfURL_error(&url) } {
112 Ok(model) => Ok(CoreMLModel {
113 inner: model,
114 config: config.clone(),
115 }),
116 Err(err) => Err(CandleError::Msg(format!(
117 "Failed to load CoreML model: {:?}",
118 err
119 ))),
120 }
121 })
122 }
123
124 #[cfg(not(target_os = "macos"))]
125 {
126 let _ = (path, config);
127 Err(CandleError::Msg(
128 "CoreML is only available on macOS".to_string(),
129 ))
130 }
131 }
132
133 pub fn forward_single(&self, input: &Tensor) -> Result<Tensor, CandleError> {
143 self.forward(&[input])
144 }
145
146 pub fn forward(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError> {
147 if inputs.len() != self.config.input_names.len() {
149 return Err(CandleError::Msg(format!(
150 "Expected {} inputs, got {}. Input names: {:?}",
151 self.config.input_names.len(),
152 inputs.len(),
153 self.config.input_names
154 )));
155 }
156
157 for (i, input) in inputs.iter().enumerate() {
159 match input.device() {
160 Device::Cpu | Device::Metal(_) => {
161 }
163 Device::Cuda(_) => {
164 return Err(CandleError::Msg(format!(
165 "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
166 i, self.config.input_names[i]
167 )));
168 }
169 }
170 }
171
172 #[cfg(target_os = "macos")]
173 {
174 self.forward_impl(inputs)
175 }
176
177 #[cfg(not(target_os = "macos"))]
178 {
179 let _ = inputs;
180 Err(CandleError::Msg(
181 "CoreML is only available on macOS".to_string(),
182 ))
183 }
184 }
185
186 pub fn config(&self) -> &Config {
188 &self.config
189 }
190
191 #[cfg(target_os = "macos")]
192 fn forward_impl(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError> {
193 autoreleasepool(|_| {
194 let mut ml_arrays = Vec::with_capacity(inputs.len());
196 for input in inputs {
197 let ml_array = self.tensor_to_mlmultiarray(input)?;
198 ml_arrays.push(ml_array);
199 }
200
201 let provider = self.create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
203
204 let prediction = self.run_prediction(&provider)?;
206
207 let output_tensor = self.extract_output(&prediction, &self.config.output_name, inputs[0].device())?;
209
210 Ok(output_tensor)
211 })
212 }
213
214 #[cfg(target_os = "macos")]
215 pub fn tensor_to_mlmultiarray(&self, tensor: &Tensor) -> Result<Retained<MLMultiArray>, CandleError> {
216 use objc2_core_ml::MLMultiArrayDataType;
217 use objc2_foundation::{NSArray, NSNumber};
218 use candle_core::DType;
219
220 let contiguous_tensor = if tensor.is_contiguous() {
221 tensor.clone()
222 } else {
223 tensor.contiguous()?
224 };
225
226 let element_count = tensor.elem_count();
227 let dims = tensor.dims();
228 let mut shape = Vec::with_capacity(dims.len());
229 for &dim in dims {
230 shape.push(NSNumber::new_usize(dim));
231 }
232 let shape_nsarray = NSArray::from_retained_slice(&shape);
233
234 let (ml_data_type, element_size) = match tensor.dtype() {
236 DType::F32 => (MLMultiArrayDataType::Float32, std::mem::size_of::<f32>()),
237 DType::I64 => (MLMultiArrayDataType::Int32, std::mem::size_of::<i32>()), _ => return Err(CandleError::Msg(format!(
239 "Unsupported tensor dtype {:?} for CoreML conversion. Only F32 and I64 tensors are supported.",
240 tensor.dtype()
241 ))),
242 };
243
244 let multi_array_result = unsafe {
245 MLMultiArray::initWithShape_dataType_error(
246 MLMultiArray::alloc(),
247 &shape_nsarray,
248 ml_data_type,
249 )
250 };
251
252 match multi_array_result {
253 Ok(ml_array) => {
254 use std::sync::atomic::{AtomicBool, Ordering};
255 let copied = AtomicBool::new(false);
256
257 let flattened_tensor = contiguous_tensor.flatten_all()?;
258
259 match tensor.dtype() {
261 DType::F32 => {
262 let data_vec = flattened_tensor.to_vec1::<f32>()?;
263 unsafe {
264 ml_array.getMutableBytesWithHandler(&StackBlock::new(
265 |ptr: std::ptr::NonNull<std::ffi::c_void>, len, _| {
266 let dst = ptr.as_ptr() as *mut f32;
267 let src = data_vec.as_ptr();
268 let copy_elements = element_count.min(len as usize / element_size);
269
270 if copy_elements > 0 && len as usize >= copy_elements * element_size {
271 std::ptr::copy_nonoverlapping(src, dst, copy_elements);
272 copied.store(true, Ordering::Relaxed);
273 }
274 },
275 ));
276 }
277 }
278 DType::I64 => {
279 let data_vec = flattened_tensor.to_vec1::<i64>()?;
281 let i32_data: Vec<i32> = data_vec.into_iter()
282 .map(|x| x as i32)
283 .collect();
284
285 unsafe {
286 ml_array.getMutableBytesWithHandler(&StackBlock::new(
287 |ptr: std::ptr::NonNull<std::ffi::c_void>, len, _| {
288 let dst = ptr.as_ptr() as *mut i32;
289 let src = i32_data.as_ptr();
290 let copy_elements = element_count.min(len as usize / element_size);
291
292 if copy_elements > 0 && len as usize >= copy_elements * element_size {
293 std::ptr::copy_nonoverlapping(src, dst, copy_elements);
294 copied.store(true, Ordering::Relaxed);
295 }
296 },
297 ));
298 }
299 }
300 _ => unreachable!(), }
302
303 if copied.load(Ordering::Relaxed) {
304 Ok(ml_array)
305 } else {
306 Err(CandleError::Msg("Failed to copy data to MLMultiArray".to_string()))
307 }
308 }
309 Err(err) => Err(CandleError::Msg(format!(
310 "Failed to create MLMultiArray: {:?}",
311 err
312 ))),
313 }
314 }
315
316
317 #[cfg(target_os = "macos")]
318 fn create_multi_feature_provider(
319 &self,
320 input_names: &[String],
321 input_arrays: &[Retained<MLMultiArray>],
322 ) -> Result<Retained<MLDictionaryFeatureProvider>, CandleError> {
323 use objc2_core_ml::MLFeatureValue;
324 use objc2_foundation::{NSDictionary, NSString};
325 use objc2::runtime::AnyObject;
326
327 autoreleasepool(|_| {
328 let mut keys = Vec::with_capacity(input_names.len());
329 let mut values: Vec<Retained<MLFeatureValue>> = Vec::with_capacity(input_arrays.len());
330
331 for (name, array) in input_names.iter().zip(input_arrays.iter()) {
332 let key = NSString::from_str(name);
333 let value = unsafe { MLFeatureValue::featureValueWithMultiArray(array) };
334 keys.push(key);
335 values.push(value);
336 }
337
338 let key_refs: Vec<&NSString> = keys.iter().map(|k| &**k).collect();
339 let value_refs: Vec<&AnyObject> = values.iter().map(|v| v.as_ref() as &AnyObject).collect();
340 let dict: Retained<NSDictionary<NSString, AnyObject>> =
341 NSDictionary::from_slices::<NSString>(&key_refs, &value_refs);
342
343 unsafe {
344 MLDictionaryFeatureProvider::initWithDictionary_error(
345 MLDictionaryFeatureProvider::alloc(),
346 dict.as_ref(),
347 )
348 }
349 .map_err(|e| CandleError::Msg(format!("CoreML initWithDictionary_error: {:?}", e)))
350 })
351 }
352
353 #[cfg(target_os = "macos")]
354 fn run_prediction(
355 &self,
356 provider: &MLDictionaryFeatureProvider,
357 ) -> Result<Retained<ProtocolObject<dyn MLFeatureProvider>>, CandleError> {
358 autoreleasepool(|_| unsafe {
359 let protocol_provider = ProtocolObject::from_ref(provider);
360
361 self.inner
362 .predictionFromFeatures_error(protocol_provider)
363 .map_err(|e| CandleError::Msg(format!("CoreML prediction error: {:?}", e)))
364 })
365 }
366
367 #[cfg(target_os = "macos")]
368 pub fn extract_output(
369 &self,
370 prediction: &ProtocolObject<dyn MLFeatureProvider>,
371 output_name: &str,
372 input_device: &Device,
373 ) -> Result<Tensor, CandleError> {
374 autoreleasepool(|_| unsafe {
375 let name = NSString::from_str(output_name);
376 let value = prediction
377 .featureValueForName(&name)
378 .ok_or_else(|| CandleError::Msg(format!("Output '{}' not found", output_name)))?;
379
380 let marray = value.multiArrayValue().ok_or_else(|| {
381 CandleError::Msg(format!("Output '{}' is not MLMultiArray", output_name))
382 })?;
383
384 let count = marray.count() as usize;
385 let mut buf = vec![0.0f32; count];
386
387 use std::cell::RefCell;
388 let buf_cell = RefCell::new(&mut buf);
389
390 marray.getBytesWithHandler(&StackBlock::new(
391 |ptr: std::ptr::NonNull<std::ffi::c_void>, len: isize| {
392 let src = ptr.as_ptr() as *const f32;
393 let copy_elements = count.min(len as usize / std::mem::size_of::<f32>());
394 if copy_elements > 0 && len as usize >= copy_elements * std::mem::size_of::<f32>() {
395 if let Ok(mut buf_ref) = buf_cell.try_borrow_mut() {
396 std::ptr::copy_nonoverlapping(src, buf_ref.as_mut_ptr(), copy_elements);
397 }
398 }
399 },
400 ));
401
402 let shape_nsarray = marray.shape();
404 let shape_count = shape_nsarray.count();
405 let mut shape = Vec::with_capacity(shape_count);
406
407 for i in 0..shape_count {
408 let dim_number = shape_nsarray.objectAtIndex(i);
409 let dim_value = dim_number.integerValue() as usize;
410 shape.push(dim_value);
411 }
412
413 Tensor::from_vec(buf, shape, input_device)
415 .map_err(|e| CandleError::Msg(format!("Failed to create output tensor: {}", e)))
416 })
417 }
418}
419
420
421pub struct CoreMLModelBuilder {
426 config: Config,
427 model_filename: PathBuf,
428}
429
430impl CoreMLModelBuilder {
431 pub fn new<P: AsRef<Path>>(model_path: P, config: Config) -> Self {
433 Self {
434 config,
435 model_filename: model_path.as_ref().to_path_buf(),
436 }
437 }
438
439 pub fn load_from_hub(
441 model_id: &str,
442 model_filename: Option<&str>,
443 config_filename: Option<&str>,
444 ) -> Result<Self, CandleError> {
445 use crate::get_local_or_remote_file;
446 use hf_hub::{api::sync::Api, Repo, RepoType};
447
448 let api = Api::new().map_err(|e| CandleError::Msg(format!("Failed to create HF API: {}", e)))?;
449 let repo = api.repo(Repo::with_revision(model_id.to_string(), RepoType::Model, "main".to_string()));
450
451 let config_path = match config_filename {
453 Some(filename) => get_local_or_remote_file(filename, &repo)
454 .map_err(|e| CandleError::Msg(format!("Failed to get config file: {}", e)))?,
455 None => get_local_or_remote_file("config.json", &repo)
456 .map_err(|e| CandleError::Msg(format!("Failed to get config.json: {}", e)))?,
457 };
458
459 let config_str = std::fs::read_to_string(config_path)
460 .map_err(|e| CandleError::Msg(format!("Failed to read config file: {}", e)))?;
461 let config: Config = serde_json::from_str(&config_str)
462 .map_err(|e| CandleError::Msg(format!("Failed to parse config: {}", e)))?;
463
464 let model_path = match model_filename {
466 Some(filename) => get_local_or_remote_file(filename, &repo)
467 .map_err(|e| CandleError::Msg(format!("Failed to get model file: {}", e)))?,
468 None => {
469 for filename in &["model.mlmodelc", "model.mlpackage"] {
471 if let Ok(path) = get_local_or_remote_file(filename, &repo) {
472 return Ok(Self::new(path, config));
473 }
474 }
475 return Err(CandleError::Msg("No CoreML model file found".to_string()));
476 }
477 };
478
479 Ok(Self::new(model_path, config))
480 }
481
482 pub fn build_model(&self) -> Result<CoreMLModel, CandleError> {
484 CoreMLModel::load_from_file(&self.model_filename, &self.config)
485 }
486
487 pub fn config(&self) -> &Config {
489 &self.config
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use candle_core::{Device, Tensor};
497
498 #[test]
499 #[cfg(target_os = "macos")]
500 fn test_model_creation() {
501 let model_path = "models/test.mlmodelc";
504 if !std::path::Path::new(model_path).exists() {
505 return;
506 }
507
508 let config = Config::default();
509 let device = Device::Cpu;
510
511 let model = CoreMLModel::load_from_file(model_path, &config)
512 .expect("Failed to load model");
513
514 assert_eq!(model.config().input_names[0], "input_ids");
516
517 let input = Tensor::ones((1, 10), candle_core::DType::F32, &device)
519 .expect("Failed to create input tensor");
520
521 let _result = model.forward_single(&input);
523 }
524}
525