candle_coreml/
model.rs

1//! Core CoreML model implementation
2
3use crate::config::Config;
4use crate::state::CoreMLState;
5
6#[cfg(target_os = "macos")]
7use crate::conversion::{
8    create_multi_feature_provider, extract_all_outputs, extract_output, tensor_to_mlmultiarray,
9};
10use candle_core::{Device, Error as CandleError, Tensor};
11use std::path::Path;
12
13#[cfg(target_os = "macos")]
14use tracing::{debug, info};
15
16#[cfg(target_os = "macos")]
17use objc2::rc::{autoreleasepool, Retained};
18#[cfg(target_os = "macos")]
19use objc2::runtime::ProtocolObject;
20#[cfg(target_os = "macos")]
21use objc2_core_ml::{
22    MLDictionaryFeatureProvider, MLFeatureProvider, MLModel, MLModelConfiguration,
23};
24#[cfg(target_os = "macos")]
25use objc2_foundation::{NSString, NSURL};
26
27/// CoreML model wrapper that provides Candle tensor integration
28pub struct CoreMLModel {
29    #[cfg(target_os = "macos")]
30    pub(crate) inner: Retained<MLModel>,
31    #[cfg(not(target_os = "macos"))]
32    _phantom: std::marker::PhantomData<()>,
33    pub(crate) config: Config,
34    pub(crate) function_name: Option<String>,
35}
36
37impl std::fmt::Debug for CoreMLModel {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("CoreMLModel")
40            .field("config", &self.config)
41            .field("function_name", &self.function_name)
42            .finish_non_exhaustive()
43    }
44}
45
46impl CoreMLModel {
47    /// Load a CoreML model from a .mlmodelc directory with default configuration
48    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, CandleError> {
49        let config = Config::default();
50        Self::load_from_file(path, &config)
51    }
52
53    /// Load a CoreML model with a specific function name
54    pub fn load_with_function<P: AsRef<Path>>(
55        path: P,
56        config: &Config,
57        function_name: &str,
58    ) -> Result<Self, CandleError> {
59        Self::load_from_file_with_function(path, config, Some(function_name))
60    }
61
62    /// Load a CoreML model from a .mlmodelc directory following standard Candle patterns
63    ///
64    /// Note: Unlike other Candle models, CoreML models are pre-compiled and don't use VarBuilder.
65    /// This method provides a Candle-compatible interface while loading from CoreML files.
66    pub fn load_from_file<P: AsRef<Path>>(path: P, config: &Config) -> Result<Self, CandleError> {
67        Self::load_from_file_with_function(path, config, None)
68    }
69
70    /// Load a CoreML model with optional function name specification
71    pub fn load_from_file_with_function<P: AsRef<Path>>(
72        path: P,
73        config: &Config,
74        function_name: Option<&str>,
75    ) -> Result<Self, CandleError> {
76        #[cfg(target_os = "macos")]
77        {
78            let path = path.as_ref();
79            if !path.exists() {
80                return Err(CandleError::Msg(format!(
81                    "Model file not found: {}",
82                    path.display()
83                )));
84            }
85
86            autoreleasepool(|_| {
87                let url =
88                    unsafe { NSURL::fileURLWithPath(&NSString::from_str(&path.to_string_lossy())) };
89
90                // Show loading progress for large models
91                info!("Loading and compiling CoreML model at {}", path.display());
92                let load_start = std::time::Instant::now();
93
94                // Create configuration with function name if provided
95                let model_result = if let Some(func_name) = function_name {
96                    let config = unsafe { MLModelConfiguration::new() };
97                    let ns_func_name = NSString::from_str(func_name);
98                    unsafe { config.setFunctionName(Some(&ns_func_name)) };
99                    unsafe { MLModel::modelWithContentsOfURL_configuration_error(&url, &config) }
100                } else {
101                    unsafe { MLModel::modelWithContentsOfURL_error(&url) }
102                };
103
104                let load_time = load_start.elapsed();
105
106                // Try to load the model with function name support
107                match model_result {
108                    Ok(model) => {
109                        info!(
110                            "Model loaded and compiled in {:.1}s",
111                            load_time.as_secs_f32()
112                        );
113                        Ok(CoreMLModel {
114                            inner: model,
115                            config: config.clone(),
116                            function_name: function_name.map(|s| s.to_string()),
117                        })
118                    }
119                    Err(err) => {
120                        // If direct loading fails, try compiling first
121                        let err_msg = format!("{err:?}");
122                        if err_msg.contains("Compile the model") {
123                            debug!("Model requires compilation, compiling now");
124                            #[allow(deprecated)]
125                            match unsafe { MLModel::compileModelAtURL_error(&url) } {
126                                Ok(compiled_url) => {
127                                    debug!("Compilation completed, loading compiled model");
128                                    // Try loading the compiled model
129                                    match unsafe {
130                                        MLModel::modelWithContentsOfURL_error(&compiled_url)
131                                    } {
132                                        Ok(model) => {
133                                            info!(
134                                                "Compiled model loaded in {:.1}s total",
135                                                load_time.as_secs_f32()
136                                            );
137                                            Ok(CoreMLModel {
138                                                inner: model,
139                                                config: config.clone(),
140                                                function_name: function_name.map(|s| s.to_string()),
141                                            })
142                                        }
143                                        Err(compile_err) => Err(CandleError::Msg(format!(
144                                            "Failed to load compiled CoreML model: {compile_err:?}"
145                                        ))),
146                                    }
147                                }
148                                Err(compile_err) => Err(CandleError::Msg(format!(
149                                    "Failed to compile CoreML model: {compile_err:?}. Original error: {err:?}"
150                                ))),
151                            }
152                        } else {
153                            // Check for common CoreML version compatibility issues
154                            let err_msg = format!("{err:?}");
155                            if err_msg.contains("compiler major version")
156                                && err_msg.contains("more recent than this framework")
157                            {
158                                Err(CandleError::Msg(format!(
159                                    "CoreML version compatibility issue: {err_msg}\n\
160                                    This model was compiled with a newer CoreML compiler than this system supports.\n\
161                                    Solutions:\n\
162                                    • Update to a newer macOS version\n\
163                                    • Use models compiled for your CoreML framework version\n\
164                                    • Set RUST_LOG=debug for more details"
165                                )))
166                            } else {
167                                Err(CandleError::Msg(format!(
168                                    "Failed to load CoreML model: {err:?}"
169                                )))
170                            }
171                        }
172                    }
173                }
174            })
175        }
176
177        #[cfg(not(target_os = "macos"))]
178        {
179            let _ = (path, config, function_name);
180            Err(CandleError::Msg(
181                "CoreML is only available on macOS".to_string(),
182            ))
183        }
184    }
185
186    /// Run forward pass through the model with multiple inputs
187    ///
188    /// Accepts tensors from CPU or Metal devices, rejects CUDA tensors.
189    /// Returns output tensor on the same device as the input tensors.
190    ///
191    /// # Arguments
192    /// * `inputs` - Slice of tensors corresponding to the input_names in config order
193    ///
194    /// Convenience method for single-input models (backward compatibility)
195    pub fn forward_single(&self, input: &Tensor) -> Result<Tensor, CandleError> {
196        self.forward(&[input])
197    }
198
199    pub fn forward(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError> {
200        // Validate we have the expected number of inputs
201        if inputs.len() != self.config.input_names.len() {
202            return Err(CandleError::Msg(format!(
203                "Expected {} inputs, got {}. Input names: {:?}",
204                self.config.input_names.len(),
205                inputs.len(),
206                self.config.input_names
207            )));
208        }
209
210        // Validate all input devices are compatible - accept CPU/Metal, reject CUDA
211        for (i, input) in inputs.iter().enumerate() {
212            match input.device() {
213                Device::Cpu | Device::Metal(_) => {
214                    // Valid devices for CoreML
215                }
216                Device::Cuda(_) => {
217                    return Err(CandleError::Msg(format!(
218                            "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
219                            i, self.config.input_names[i]
220                        )));
221                }
222            }
223        }
224
225        #[cfg(target_os = "macos")]
226        {
227            self.forward_impl(inputs)
228        }
229
230        #[cfg(not(target_os = "macos"))]
231        {
232            let _ = inputs;
233            Err(CandleError::Msg(
234                "CoreML is only available on macOS".to_string(),
235            ))
236        }
237    }
238
239    /// Forward pass returning all outputs as a HashMap
240    ///
241    /// This is useful for models that have multiple outputs, such as the Qwen LM head
242    /// which produces 16 different logits chunks that need to be concatenated.
243    pub fn forward_all(
244        &self,
245        inputs: &[&Tensor],
246    ) -> Result<std::collections::HashMap<String, Tensor>, CandleError> {
247        // Validate we have the expected number of inputs
248        if inputs.len() != self.config.input_names.len() {
249            return Err(CandleError::Msg(format!(
250                "Expected {} inputs, got {}. Input names: {:?}",
251                self.config.input_names.len(),
252                inputs.len(),
253                self.config.input_names
254            )));
255        }
256
257        // Validate all input devices are compatible - accept CPU/Metal, reject CUDA
258        for (i, input) in inputs.iter().enumerate() {
259            match input.device() {
260                Device::Cpu | Device::Metal(_) => {
261                    // Valid devices for CoreML
262                }
263                Device::Cuda(_) => {
264                    return Err(CandleError::Msg(format!(
265                            "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
266                            i, self.config.input_names[i]
267                        )));
268                }
269            }
270        }
271
272        #[cfg(target_os = "macos")]
273        {
274            self.forward_all_impl(inputs)
275        }
276
277        #[cfg(not(target_os = "macos"))]
278        {
279            let _ = inputs;
280            Err(CandleError::Msg(
281                "CoreML is only available on macOS".to_string(),
282            ))
283        }
284    }
285
286    /// Get the model configuration
287    pub fn config(&self) -> &Config {
288        &self.config
289    }
290
291    /// Get access to the inner MLModel for advanced usage (testing only)
292    #[cfg(target_os = "macos")]
293    pub fn inner_model(&self) -> &Retained<MLModel> {
294        &self.inner
295    }
296
297    /// Create a CoreMLModel from an existing MLModel (for testing)
298    #[cfg(target_os = "macos")]
299    pub fn from_mlmodel(inner: Retained<MLModel>, config: Config) -> Self {
300        CoreMLModel {
301            inner,
302            config,
303            function_name: None,
304        }
305    }
306
307    /// Create a fresh state object for this model.
308    ///
309    /// This enables efficient autoregressive generation by maintaining
310    /// persistent KV-cache across multiple prediction calls.
311    ///
312    /// # Returns
313    ///
314    /// A new `CoreMLState` instance that can be used with `predict_with_state()`.
315    /// For stateless models, this returns an empty state object that can still
316    /// be used with stateful prediction methods (resulting in stateless behavior).
317    ///
318    /// # Example
319    ///
320    /// ```rust,no_run
321    /// use candle_core::{Device, Tensor};
322    /// use candle_coreml::{CoreMLModel, Config};
323    ///
324    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
325    /// let model = CoreMLModel::load("model.mlmodelc")?;
326    ///
327    /// // Create state for efficient token generation
328    /// let mut state = model.make_state()?;
329    ///
330    /// // Use state with predict_with_state() for streaming inference
331    /// # Ok(())
332    /// # }
333    /// ```
334    pub fn make_state(&self) -> Result<CoreMLState, CandleError> {
335        #[cfg(target_os = "macos")]
336        {
337            CoreMLState::new(&self.inner)
338        }
339
340        #[cfg(not(target_os = "macos"))]
341        {
342            CoreMLState::new(&())
343        }
344    }
345
346    /// Run forward pass through the model with persistent state.
347    ///
348    /// This method enables efficient autoregressive generation by maintaining
349    /// KV-cache state across multiple prediction calls. Unlike the stateless
350    /// `forward()` method, this preserves computation state between calls.
351    ///
352    /// # Arguments
353    ///
354    /// * `inputs` - Slice of tensors corresponding to input_names in config order
355    /// * `state` - Mutable reference to the model state (will be updated)
356    ///
357    /// # Returns
358    ///
359    /// Output tensor on the same device as the input tensors.
360    ///
361    /// # Device Compatibility
362    ///
363    /// Accepts tensors from CPU or Metal devices, rejects CUDA tensors.
364    ///
365    /// # Example
366    ///
367    /// ```rust,no_run
368    /// use candle_core::{Device, Tensor};
369    /// use candle_coreml::{CoreMLModel, Config};
370    ///
371    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
372    /// let model = CoreMLModel::load("model.mlmodelc")?;
373    /// let device = Device::Cpu;
374    ///
375    /// let mut state = model.make_state()?;
376    ///
377    /// // Generate tokens with persistent KV-cache
378    /// for i in 0..10 {
379    ///     let input = Tensor::ones((1, 1), candle_core::DType::I64, &device)?;
380    ///     let output = model.predict_with_state(&[&input], &mut state)?;
381    ///     println!("Token {}: {:?}", i, output);
382    /// }
383    /// # Ok(())
384    /// # }
385    /// ```
386    pub fn predict_with_state(
387        &self,
388        inputs: &[&Tensor],
389        state: &mut CoreMLState,
390    ) -> Result<Tensor, CandleError> {
391        // Validate we have the expected number of inputs
392        if inputs.len() != self.config.input_names.len() {
393            return Err(CandleError::Msg(format!(
394                "Expected {} inputs, got {}. Input names: {:?}",
395                self.config.input_names.len(),
396                inputs.len(),
397                self.config.input_names
398            )));
399        }
400
401        // Validate all input devices are compatible - accept CPU/Metal, reject CUDA
402        for (i, input) in inputs.iter().enumerate() {
403            match input.device() {
404                Device::Cpu | Device::Metal(_) => {
405                    // Valid devices for CoreML
406                }
407                Device::Cuda(_) => {
408                    return Err(CandleError::Msg(format!(
409                            "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
410                            i, self.config.input_names[i]
411                        )));
412                }
413            }
414        }
415
416        #[cfg(target_os = "macos")]
417        {
418            // Debug print input shapes and names
419            tracing::debug!("predict_with_state function={:?}", self.function_name);
420            for (i, t) in inputs.iter().enumerate() {
421                tracing::debug!(
422                    "predict_with_state input {} '{}' shape={:?}",
423                    i,
424                    self.config.input_names[i],
425                    t.dims()
426                );
427            }
428            self.predict_with_state_impl(inputs, state)
429        }
430
431        #[cfg(not(target_os = "macos"))]
432        {
433            let _ = (inputs, state);
434            Err(CandleError::Msg(
435                "CoreML is only available on macOS".to_string(),
436            ))
437        }
438    }
439
440    #[cfg(target_os = "macos")]
441    fn forward_impl(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError> {
442        autoreleasepool(|_| {
443            // Convert all Candle tensors to MLMultiArrays
444            let mut ml_arrays = Vec::with_capacity(inputs.len());
445            for input in inputs {
446                let ml_array = tensor_to_mlmultiarray(input)?;
447                ml_arrays.push(ml_array);
448            }
449
450            // Create feature provider with all named inputs
451            let provider = create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
452
453            // Run prediction
454            let prediction = self.run_prediction(&provider)?;
455
456            // Extract output with configured output name (use first input device for output)
457            let output_tensor =
458                extract_output(&prediction, &self.config.output_name, inputs[0].device())?;
459
460            Ok(output_tensor)
461        })
462    }
463
464    #[cfg(target_os = "macos")]
465    fn forward_all_impl(
466        &self,
467        inputs: &[&Tensor],
468    ) -> Result<std::collections::HashMap<String, Tensor>, CandleError> {
469        autoreleasepool(|_| {
470            // Convert all Candle tensors to MLMultiArrays
471            let mut ml_arrays = Vec::with_capacity(inputs.len());
472            for input in inputs {
473                let ml_array = tensor_to_mlmultiarray(input)?;
474                ml_arrays.push(ml_array);
475            }
476
477            // Create feature provider with all named inputs
478            let provider = create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
479
480            // Run prediction
481            let prediction = self.run_prediction(&provider)?;
482
483            // Extract all outputs
484            extract_all_outputs(&prediction, inputs[0].device())
485        })
486    }
487
488    #[cfg(target_os = "macos")]
489    fn run_prediction(
490        &self,
491        provider: &MLDictionaryFeatureProvider,
492    ) -> Result<Retained<ProtocolObject<dyn MLFeatureProvider>>, CandleError> {
493        autoreleasepool(|_| unsafe {
494            let protocol_provider = ProtocolObject::from_ref(provider);
495
496            // Function name is now handled during model loading via MLModelConfiguration
497            self.inner
498                .predictionFromFeatures_error(protocol_provider)
499                .map_err(|e| CandleError::Msg(format!("CoreML prediction error: {e:?}")))
500        })
501    }
502
503    #[cfg(target_os = "macos")]
504    fn predict_with_state_impl(
505        &self,
506        inputs: &[&Tensor],
507        state: &mut CoreMLState,
508    ) -> Result<Tensor, CandleError> {
509        autoreleasepool(|_| {
510            // Convert all Candle tensors to MLMultiArrays (reuse existing logic)
511            let mut ml_arrays = Vec::with_capacity(inputs.len());
512            for input in inputs {
513                let ml_array = tensor_to_mlmultiarray(input)?;
514                ml_arrays.push(ml_array);
515            }
516
517            // Create feature provider with all named inputs (reuse existing logic)
518            let provider = create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
519
520            // Run stateful prediction
521            let prediction = self.run_prediction_with_state(&provider, state)?;
522
523            // Extract output with configured output name (use first input device for output)
524            let output_tensor =
525                extract_output(&prediction, &self.config.output_name, inputs[0].device())?;
526
527            Ok(output_tensor)
528        })
529    }
530
531    #[cfg(target_os = "macos")]
532    fn run_prediction_with_state(
533        &self,
534        provider: &MLDictionaryFeatureProvider,
535        state: &mut CoreMLState,
536    ) -> Result<Retained<ProtocolObject<dyn MLFeatureProvider>>, CandleError> {
537        autoreleasepool(|_| unsafe {
538            let protocol_provider = ProtocolObject::from_ref(provider);
539
540            self.inner
541                .predictionFromFeatures_usingState_error(protocol_provider, state.inner())
542                .map_err(|e| CandleError::Msg(format!("CoreML stateful prediction error: {e:?}")))
543        })
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    #[cfg(target_os = "macos")]
550    use super::*;
551
552    #[test]
553    #[cfg(target_os = "macos")]
554    fn test_model_creation() {
555        // This test requires an actual .mlmodelc file
556        // Skip if file doesn't exist
557        let model_path = "models/test.mlmodelc";
558        if !std::path::Path::new(model_path).exists() {
559            return;
560        }
561
562        let config = Config::default();
563        let device = Device::Cpu;
564
565        let model = CoreMLModel::load_from_file(model_path, &config).expect("Failed to load model");
566
567        // Test config access
568        assert_eq!(model.config().input_names[0], "input_ids");
569
570        // Test with dummy input tensor on CPU device
571        let input = Tensor::ones((1, 10), candle_core::DType::F32, &device)
572            .expect("Failed to create input tensor");
573
574        // This will fail without a real model but tests the interface
575        let _result = model.forward_single(&input);
576    }
577}