candle_coreml/
model.rs

1//! CoreML model wrapper for Candle
2//!
3//! This module provides a high-level interface for CoreML models that integrates
4//! with Candle's tensor system.
5
6use candle_core::{Device, Tensor, Error as CandleError};
7use std::path::{Path, PathBuf};
8use serde::{Deserialize, Serialize};
9
10/// Configuration for CoreML models
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Config {
13    /// Input tensor names in order (e.g., ["input_ids", "token_type_ids", "attention_mask"])
14    pub input_names: Vec<String>,
15    /// Output tensor name (e.g., "logits") 
16    pub output_name: String,
17    /// Maximum sequence length
18    pub max_sequence_length: usize,
19    /// Vocabulary size
20    pub vocab_size: usize,
21    /// Model architecture name
22    pub model_type: String,
23}
24
25impl Default for Config {
26    fn default() -> Self {
27        Self {
28            input_names: vec!["input_ids".to_string()],
29            output_name: "logits".to_string(), 
30            max_sequence_length: 128,
31            vocab_size: 32000,
32            model_type: "coreml".to_string(),
33        }
34    }
35}
36
37impl Config {
38    /// Create BERT-style config with input_ids, token_type_ids, and attention_mask
39    pub fn bert_config(output_name: &str, max_seq_len: usize, vocab_size: usize) -> Self {
40        Self {
41            input_names: vec![
42                "input_ids".to_string(),
43                "token_type_ids".to_string(), 
44                "attention_mask".to_string(),
45            ],
46            output_name: output_name.to_string(),
47            max_sequence_length: max_seq_len,
48            vocab_size,
49            model_type: "bert".to_string(),
50        }
51    }
52}
53
54#[cfg(target_os = "macos")]
55use objc2::rc::{autoreleasepool, Retained};
56#[cfg(target_os = "macos")]
57use objc2_core_ml::{MLModel, MLMultiArray, MLDictionaryFeatureProvider, MLFeatureProvider};
58#[cfg(target_os = "macos")]
59use objc2_foundation::{NSString, NSURL};
60#[cfg(target_os = "macos")]
61use objc2::runtime::ProtocolObject;
62#[cfg(target_os = "macos")]
63use objc2::AnyThread;
64#[cfg(target_os = "macos")]
65use block2::StackBlock;
66
67/// CoreML model wrapper that provides Candle tensor integration
68pub struct CoreMLModel {
69    #[cfg(target_os = "macos")]
70    inner: Retained<MLModel>,
71    #[cfg(not(target_os = "macos"))]
72    _phantom: std::marker::PhantomData<()>,
73    config: Config,
74}
75
76impl std::fmt::Debug for CoreMLModel {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.debug_struct("CoreMLModel")
79            .field("config", &self.config)
80            .finish_non_exhaustive()
81    }
82}
83
84impl CoreMLModel {
85    /// Load a CoreML model from a .mlmodelc directory with default configuration
86    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, CandleError> {
87        let config = Config::default();
88        Self::load_from_file(path, &config)
89    }
90
91    /// Load a CoreML model from a .mlmodelc directory following standard Candle patterns
92    /// 
93    /// Note: Unlike other Candle models, CoreML models are pre-compiled and don't use VarBuilder.
94    /// This method provides a Candle-compatible interface while loading from CoreML files.
95    pub fn load_from_file<P: AsRef<Path>>(path: P, config: &Config) -> Result<Self, CandleError> {
96        #[cfg(target_os = "macos")]
97        {
98            let path = path.as_ref();
99            if !path.exists() {
100                return Err(CandleError::Msg(format!(
101                    "Model file not found: {}",
102                    path.display()
103                )));
104            }
105
106            autoreleasepool(|_| {
107                let url = unsafe {
108                    NSURL::fileURLWithPath(&NSString::from_str(&path.to_string_lossy()))
109                };
110                
111                match unsafe { MLModel::modelWithContentsOfURL_error(&url) } {
112                    Ok(model) => Ok(CoreMLModel { 
113                        inner: model,
114                        config: config.clone(),
115                    }),
116                    Err(err) => Err(CandleError::Msg(format!(
117                        "Failed to load CoreML model: {:?}",
118                        err
119                    ))),
120                }
121            })
122        }
123        
124        #[cfg(not(target_os = "macos"))]
125        {
126            let _ = (path, config);
127            Err(CandleError::Msg(
128                "CoreML is only available on macOS".to_string(),
129            ))
130        }
131    }
132
133    /// Run forward pass through the model with multiple inputs
134    /// 
135    /// Accepts tensors from CPU or Metal devices, rejects CUDA tensors.
136    /// Returns output tensor on the same device as the input tensors.
137    /// 
138    /// # Arguments
139    /// * `inputs` - Slice of tensors corresponding to the input_names in config order
140    /// 
141    /// Convenience method for single-input models (backward compatibility)
142    pub fn forward_single(&self, input: &Tensor) -> Result<Tensor, CandleError> {
143        self.forward(&[input])
144    }
145
146    pub fn forward(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError> {
147        // Validate we have the expected number of inputs
148        if inputs.len() != self.config.input_names.len() {
149            return Err(CandleError::Msg(format!(
150                "Expected {} inputs, got {}. Input names: {:?}",
151                self.config.input_names.len(),
152                inputs.len(),
153                self.config.input_names
154            )));
155        }
156
157        // Validate all input devices are compatible - accept CPU/Metal, reject CUDA
158        for (i, input) in inputs.iter().enumerate() {
159            match input.device() {
160                Device::Cpu | Device::Metal(_) => {
161                    // Valid devices for CoreML
162                }
163                Device::Cuda(_) => {
164                    return Err(CandleError::Msg(format!(
165                        "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
166                        i, self.config.input_names[i]
167                    )));
168                }
169            }
170        }
171
172        #[cfg(target_os = "macos")]
173        {
174            self.forward_impl(inputs)
175        }
176        
177        #[cfg(not(target_os = "macos"))]
178        {
179            let _ = inputs;
180            Err(CandleError::Msg(
181                "CoreML is only available on macOS".to_string(),
182            ))
183        }
184    }
185
186    /// Get the model configuration
187    pub fn config(&self) -> &Config {
188        &self.config
189    }
190
191    #[cfg(target_os = "macos")]
192    fn forward_impl(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError> {
193        autoreleasepool(|_| {
194            // Convert all Candle tensors to MLMultiArrays
195            let mut ml_arrays = Vec::with_capacity(inputs.len());
196            for input in inputs {
197                let ml_array = self.tensor_to_mlmultiarray(input)?;
198                ml_arrays.push(ml_array);
199            }
200            
201            // Create feature provider with all named inputs
202            let provider = self.create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
203            
204            // Run prediction
205            let prediction = self.run_prediction(&provider)?;
206            
207            // Extract output with configured output name (use first input device for output)
208            let output_tensor = self.extract_output(&prediction, &self.config.output_name, inputs[0].device())?;
209            
210            Ok(output_tensor)
211        })
212    }
213
214    #[cfg(target_os = "macos")]
215    pub fn tensor_to_mlmultiarray(&self, tensor: &Tensor) -> Result<Retained<MLMultiArray>, CandleError> {
216        use objc2_core_ml::MLMultiArrayDataType;
217        use objc2_foundation::{NSArray, NSNumber};
218        use candle_core::DType;
219        
220        let contiguous_tensor = if tensor.is_contiguous() {
221            tensor.clone()
222        } else {
223            tensor.contiguous()?
224        };
225
226        let element_count = tensor.elem_count();
227        let dims = tensor.dims();
228        let mut shape = Vec::with_capacity(dims.len());
229        for &dim in dims {
230            shape.push(NSNumber::new_usize(dim));
231        }
232        let shape_nsarray = NSArray::from_retained_slice(&shape);
233
234        // Choose MLMultiArrayDataType based on tensor dtype
235        let (ml_data_type, element_size) = match tensor.dtype() {
236            DType::F32 => (MLMultiArrayDataType::Float32, std::mem::size_of::<f32>()),
237            DType::I64 => (MLMultiArrayDataType::Int32, std::mem::size_of::<i32>()), // Convert I64 to Int32
238            _ => return Err(CandleError::Msg(format!(
239                "Unsupported tensor dtype {:?} for CoreML conversion. Only F32 and I64 tensors are supported.",
240                tensor.dtype()
241            ))),
242        };
243
244        let multi_array_result = unsafe {
245            MLMultiArray::initWithShape_dataType_error(
246                MLMultiArray::alloc(),
247                &shape_nsarray,
248                ml_data_type,
249            )
250        };
251
252        match multi_array_result {
253            Ok(ml_array) => {
254                use std::sync::atomic::{AtomicBool, Ordering};
255                let copied = AtomicBool::new(false);
256
257                let flattened_tensor = contiguous_tensor.flatten_all()?;
258
259                // Handle different data types
260                match tensor.dtype() {
261                    DType::F32 => {
262                        let data_vec = flattened_tensor.to_vec1::<f32>()?;
263                        unsafe {
264                            ml_array.getMutableBytesWithHandler(&StackBlock::new(
265                                |ptr: std::ptr::NonNull<std::ffi::c_void>, len, _| {
266                                    let dst = ptr.as_ptr() as *mut f32;
267                                    let src = data_vec.as_ptr();
268                                    let copy_elements = element_count.min(len as usize / element_size);
269
270                                    if copy_elements > 0 && len as usize >= copy_elements * element_size {
271                                        std::ptr::copy_nonoverlapping(src, dst, copy_elements);
272                                        copied.store(true, Ordering::Relaxed);
273                                    }
274                                },
275                            ));
276                        }
277                    }
278                    DType::I64 => {
279                        // Convert I64 to I32 for CoreML
280                        let data_vec = flattened_tensor.to_vec1::<i64>()?;
281                        let i32_data: Vec<i32> = data_vec.into_iter()
282                            .map(|x| x as i32)
283                            .collect();
284                        
285                        unsafe {
286                            ml_array.getMutableBytesWithHandler(&StackBlock::new(
287                                |ptr: std::ptr::NonNull<std::ffi::c_void>, len, _| {
288                                    let dst = ptr.as_ptr() as *mut i32;
289                                    let src = i32_data.as_ptr();
290                                    let copy_elements = element_count.min(len as usize / element_size);
291
292                                    if copy_elements > 0 && len as usize >= copy_elements * element_size {
293                                        std::ptr::copy_nonoverlapping(src, dst, copy_elements);
294                                        copied.store(true, Ordering::Relaxed);
295                                    }
296                                },
297                            ));
298                        }
299                    }
300                    _ => unreachable!(), // Already handled above
301                }
302
303                if copied.load(Ordering::Relaxed) {
304                    Ok(ml_array)
305                } else {
306                    Err(CandleError::Msg("Failed to copy data to MLMultiArray".to_string()))
307                }
308            }
309            Err(err) => Err(CandleError::Msg(format!(
310                "Failed to create MLMultiArray: {:?}",
311                err
312            ))),
313        }
314    }
315
316
317    #[cfg(target_os = "macos")]
318    fn create_multi_feature_provider(
319        &self,
320        input_names: &[String],
321        input_arrays: &[Retained<MLMultiArray>],
322    ) -> Result<Retained<MLDictionaryFeatureProvider>, CandleError> {
323        use objc2_core_ml::MLFeatureValue;
324        use objc2_foundation::{NSDictionary, NSString};
325        use objc2::runtime::AnyObject;
326
327        autoreleasepool(|_| {
328            let mut keys = Vec::with_capacity(input_names.len());
329            let mut values: Vec<Retained<MLFeatureValue>> = Vec::with_capacity(input_arrays.len());
330
331            for (name, array) in input_names.iter().zip(input_arrays.iter()) {
332                let key = NSString::from_str(name);
333                let value = unsafe { MLFeatureValue::featureValueWithMultiArray(array) };
334                keys.push(key);
335                values.push(value);
336            }
337
338            let key_refs: Vec<&NSString> = keys.iter().map(|k| &**k).collect();
339            let value_refs: Vec<&AnyObject> = values.iter().map(|v| v.as_ref() as &AnyObject).collect();
340            let dict: Retained<NSDictionary<NSString, AnyObject>> =
341                NSDictionary::from_slices::<NSString>(&key_refs, &value_refs);
342
343            unsafe {
344                MLDictionaryFeatureProvider::initWithDictionary_error(
345                    MLDictionaryFeatureProvider::alloc(),
346                    dict.as_ref(),
347                )
348            }
349            .map_err(|e| CandleError::Msg(format!("CoreML initWithDictionary_error: {:?}", e)))
350        })
351    }
352
353    #[cfg(target_os = "macos")]
354    fn run_prediction(
355        &self,
356        provider: &MLDictionaryFeatureProvider,
357    ) -> Result<Retained<ProtocolObject<dyn MLFeatureProvider>>, CandleError> {
358        autoreleasepool(|_| unsafe {
359            let protocol_provider = ProtocolObject::from_ref(provider);
360
361            self.inner
362                .predictionFromFeatures_error(protocol_provider)
363                .map_err(|e| CandleError::Msg(format!("CoreML prediction error: {:?}", e)))
364        })
365    }
366
367    #[cfg(target_os = "macos")]
368    pub fn extract_output(
369        &self,
370        prediction: &ProtocolObject<dyn MLFeatureProvider>,
371        output_name: &str,
372        input_device: &Device,
373    ) -> Result<Tensor, CandleError> {
374        autoreleasepool(|_| unsafe {
375            let name = NSString::from_str(output_name);
376            let value = prediction
377                .featureValueForName(&name)
378                .ok_or_else(|| CandleError::Msg(format!("Output '{}' not found", output_name)))?;
379
380            let marray = value.multiArrayValue().ok_or_else(|| {
381                CandleError::Msg(format!("Output '{}' is not MLMultiArray", output_name))
382            })?;
383
384            let count = marray.count() as usize;
385            let mut buf = vec![0.0f32; count];
386
387            use std::cell::RefCell;
388            let buf_cell = RefCell::new(&mut buf);
389
390            marray.getBytesWithHandler(&StackBlock::new(
391                |ptr: std::ptr::NonNull<std::ffi::c_void>, len: isize| {
392                    let src = ptr.as_ptr() as *const f32;
393                    let copy_elements = count.min(len as usize / std::mem::size_of::<f32>());
394                    if copy_elements > 0 && len as usize >= copy_elements * std::mem::size_of::<f32>() {
395                        if let Ok(mut buf_ref) = buf_cell.try_borrow_mut() {
396                            std::ptr::copy_nonoverlapping(src, buf_ref.as_mut_ptr(), copy_elements);
397                        }
398                    }
399                },
400            ));
401
402            // Get shape from MLMultiArray
403            let shape_nsarray = marray.shape();
404            let shape_count = shape_nsarray.count();
405            let mut shape = Vec::with_capacity(shape_count);
406            
407            for i in 0..shape_count {
408                let dim_number = shape_nsarray.objectAtIndex(i);
409                let dim_value = dim_number.integerValue() as usize;
410                shape.push(dim_value);
411            }
412
413            // Create tensor with the same device as input
414            Tensor::from_vec(buf, shape, input_device)
415                .map_err(|e| CandleError::Msg(format!("Failed to create output tensor: {}", e)))
416        })
417    }
418}
419
420
421/// Builder for CoreML models
422/// 
423/// This provides an interface for loading CoreML models with configuration
424/// management and device selection.
425pub struct CoreMLModelBuilder {
426    config: Config,
427    model_filename: PathBuf,
428}
429
430impl CoreMLModelBuilder {
431    /// Create a new builder with the specified model path and config
432    pub fn new<P: AsRef<Path>>(model_path: P, config: Config) -> Self {
433        Self {
434            config,
435            model_filename: model_path.as_ref().to_path_buf(),
436        }
437    }
438    
439    /// Load a CoreML model from HuggingFace or local files
440    pub fn load_from_hub(
441        model_id: &str,
442        model_filename: Option<&str>,
443        config_filename: Option<&str>,
444    ) -> Result<Self, CandleError> {
445        use crate::get_local_or_remote_file;
446        use hf_hub::{api::sync::Api, Repo, RepoType};
447
448        let api = Api::new().map_err(|e| CandleError::Msg(format!("Failed to create HF API: {}", e)))?;
449        let repo = api.repo(Repo::with_revision(model_id.to_string(), RepoType::Model, "main".to_string()));
450
451        // Load config
452        let config_path = match config_filename {
453            Some(filename) => get_local_or_remote_file(filename, &repo)
454                .map_err(|e| CandleError::Msg(format!("Failed to get config file: {}", e)))?,
455            None => get_local_or_remote_file("config.json", &repo)
456                .map_err(|e| CandleError::Msg(format!("Failed to get config.json: {}", e)))?,
457        };
458
459        let config_str = std::fs::read_to_string(config_path)
460            .map_err(|e| CandleError::Msg(format!("Failed to read config file: {}", e)))?;
461        let config: Config = serde_json::from_str(&config_str)
462            .map_err(|e| CandleError::Msg(format!("Failed to parse config: {}", e)))?;
463
464        // Get model file
465        let model_path = match model_filename {
466            Some(filename) => get_local_or_remote_file(filename, &repo)
467                .map_err(|e| CandleError::Msg(format!("Failed to get model file: {}", e)))?,
468            None => {
469                // Try common CoreML model filenames
470                for filename in &["model.mlmodelc", "model.mlpackage"] {
471                    if let Ok(path) = get_local_or_remote_file(filename, &repo) {
472                        return Ok(Self::new(path, config));
473                    }
474                }
475                return Err(CandleError::Msg("No CoreML model file found".to_string()));
476            }
477        };
478
479        Ok(Self::new(model_path, config))
480    }
481
482    /// Build the CoreML model 
483    pub fn build_model(&self) -> Result<CoreMLModel, CandleError> {
484        CoreMLModel::load_from_file(&self.model_filename, &self.config)
485    }
486
487    /// Get the config
488    pub fn config(&self) -> &Config {
489        &self.config
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use candle_core::{Device, Tensor};
497
498    #[test]
499    #[cfg(target_os = "macos")]
500    fn test_model_creation() {
501        // This test requires an actual .mlmodelc file
502        // Skip if file doesn't exist
503        let model_path = "models/test.mlmodelc";
504        if !std::path::Path::new(model_path).exists() {
505            return;
506        }
507
508        let config = Config::default();
509        let device = Device::Cpu;
510        
511        let model = CoreMLModel::load_from_file(model_path, &config)
512            .expect("Failed to load model");
513        
514        // Test config access
515        assert_eq!(model.config().input_names[0], "input_ids");
516        
517        // Test with dummy input tensor on CPU device
518        let input = Tensor::ones((1, 10), candle_core::DType::F32, &device)
519            .expect("Failed to create input tensor");
520        
521        // This will fail without a real model but tests the interface
522        let _result = model.forward_single(&input);
523    }
524}
525