candle_coreml/
model.rs

1//! Core CoreML model implementation
2
3use crate::config::basic::Config;
4use crate::state::CoreMLState;
5
6#[cfg(target_os = "macos")]
7use crate::conversion::{
8    create_multi_feature_provider, extract_all_outputs, extract_output, tensor_to_mlmultiarray,
9};
10use candle_core::{Device, Error as CandleError, Tensor};
11use std::path::Path;
12
13#[cfg(target_os = "macos")]
14use tracing::{debug, info};
15
16#[cfg(target_os = "macos")]
17use objc2::rc::{autoreleasepool, Retained};
18#[cfg(target_os = "macos")]
19use objc2::runtime::ProtocolObject;
20#[cfg(target_os = "macos")]
21use objc2_core_ml::{
22    MLDictionaryFeatureProvider, MLFeatureProvider, MLModel, MLModelConfiguration,
23};
24#[cfg(target_os = "macos")]
25use objc2_foundation::{NSString, NSURL};
26
27/// CoreML model wrapper that provides Candle tensor integration
28pub struct CoreMLModel {
29    #[cfg(target_os = "macos")]
30    pub(crate) inner: Retained<MLModel>,
31    #[cfg(not(target_os = "macos"))]
32    _phantom: std::marker::PhantomData<()>,
33    pub(crate) config: Config,
34    pub(crate) function_name: Option<String>,
35}
36
37impl std::fmt::Debug for CoreMLModel {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("CoreMLModel")
40            .field("config", &self.config)
41            .field("function_name", &self.function_name)
42            .finish_non_exhaustive()
43    }
44}
45
46impl CoreMLModel {
47    /// Load a CoreML model from a .mlmodelc directory with default configuration
48    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, CandleError> {
49        let config = Config::default();
50        Self::load_from_file(path, &config)
51    }
52
53    /// Load a CoreML model with a specific function name
54    pub fn load_with_function<P: AsRef<Path>>(
55        path: P,
56        config: &Config,
57        function_name: &str,
58    ) -> Result<Self, CandleError> {
59        Self::load_from_file_with_function(path, config, Some(function_name))
60    }
61
62    /// Load a CoreML model from a .mlmodelc directory following standard Candle patterns
63    ///
64    /// Note: Unlike other Candle models, CoreML models are pre-compiled and don't use VarBuilder.
65    /// This method provides a Candle-compatible interface while loading from CoreML files.
66    pub fn load_from_file<P: AsRef<Path>>(path: P, config: &Config) -> Result<Self, CandleError> {
67        Self::load_from_file_with_function(path, config, None)
68    }
69
70    /// Load a CoreML model with optional function name specification
71    pub fn load_from_file_with_function<P: AsRef<Path>>(
72        path: P,
73        config: &Config,
74        function_name: Option<&str>,
75    ) -> Result<Self, CandleError> {
76        #[cfg(target_os = "macos")]
77        {
78            let path = path.as_ref();
79            if !path.exists() {
80                return Err(CandleError::Msg(format!(
81                    "Model file not found: {}",
82                    path.display()
83                )));
84            }
85
86            autoreleasepool(|_| {
87                let url =
88                    unsafe { NSURL::fileURLWithPath(&NSString::from_str(&path.to_string_lossy())) };
89
90                // Helper: load from URL with or without configuration (preserve function_name)
91                unsafe fn load_with_config(
92                    url: &NSURL,
93                    function_name: Option<&str>,
94                ) -> Result<Retained<MLModel>, CandleError> {
95                    if let Some(func) = function_name {
96                        let ml_cfg = MLModelConfiguration::new();
97                        let ns_name = NSString::from_str(func);
98                        ml_cfg.setFunctionName(Some(&ns_name));
99                        MLModel::modelWithContentsOfURL_configuration_error(url, &ml_cfg).map_err(
100                            |e| {
101                                CandleError::Msg(format!(
102                                    "Failed to load CoreML model with configuration: {e:?}"
103                                ))
104                            },
105                        )
106                    } else {
107                        MLModel::modelWithContentsOfURL_error(url).map_err(|e| {
108                            CandleError::Msg(format!("Failed to load CoreML model: {e:?}"))
109                        })
110                    }
111                }
112
113                // Determine the artifact type to avoid compiling compiled bundles
114                let is_dir = path.is_dir();
115                let ext = path
116                    .extension()
117                    .and_then(|s| s.to_str())
118                    .unwrap_or_default()
119                    .to_ascii_lowercase();
120                let looks_like_modelc =
121                    ext == "mlmodelc" || (is_dir && path.to_string_lossy().ends_with(".mlmodelc"));
122                let looks_like_package = ext == "mlpackage"
123                    || (is_dir && path.to_string_lossy().ends_with(".mlpackage"));
124                // Special-case: some packages are folders with Data/com.apple.CoreML/model.mlmodel
125                // and no Manifest.json ("typo-fixer style"). For those, compile the inner .mlmodel.
126                let manifest_json_exists = path.join("Manifest.json").exists();
127                let inner_mlmodel_path = path.join("Data/com.apple.CoreML/model.mlmodel");
128                let has_inner_mlmodel = inner_mlmodel_path.exists();
129
130                // Show loading progress for large models
131                info!("Loading CoreML model at {}", path.display());
132                let load_start = std::time::Instant::now();
133
134                // If it's a compiled .mlmodelc bundle, never attempt compilation. Just load.
135                if looks_like_modelc {
136                    match unsafe { load_with_config(&url, function_name) } {
137                        Ok(model) => {
138                            info!("Model loaded in {:.1}s", load_start.elapsed().as_secs_f32());
139                            return Ok(CoreMLModel {
140                                inner: model,
141                                config: config.clone(),
142                                function_name: function_name.map(|s| s.to_string()),
143                            });
144                        }
145                        Err(err) => {
146                            // Common CoreML version mismatch message handling
147                            let msg = format!("{err}");
148                            if msg.contains("compiler major version")
149                                && msg.contains("more recent than this framework")
150                            {
151                                return Err(CandleError::Msg(format!(
152                                    "CoreML version compatibility issue: {msg}\n\
153                                     Update macOS or use a model compiled for this framework version."
154                                )));
155                            }
156                            return Err(err);
157                        }
158                    }
159                }
160
161                // Otherwise, attempt direct load first. If it fails and the artifact is a
162                // source (.mlmodel/.mlpackage), try compiling then load the compiled URL.
163                match unsafe { load_with_config(&url, function_name) } {
164                    Ok(model) => {
165                        info!("Model loaded in {:.1}s", load_start.elapsed().as_secs_f32());
166                        Ok(CoreMLModel {
167                            inner: model,
168                            config: config.clone(),
169                            function_name: function_name.map(|s| s.to_string()),
170                        })
171                    }
172                    Err(load_err) => {
173                        // Only try to compile for non-compiled artifacts
174                        if looks_like_package || ext == "mlmodel" || !is_dir {
175                            debug!("Direct load failed, attempting compilation: {load_err}");
176
177                            // Try to use cached compiled model first
178                            if let Ok(cached_model) = Self::try_load_cached_compiled_model(
179                                path,
180                                &load_start,
181                                config,
182                                function_name,
183                            ) {
184                                return Ok(cached_model);
185                            }
186
187                            #[allow(deprecated)]
188                            // Choose compile target: for 'typo-fixer style' packages, compile the inner model.mlmodel
189                            let compile_result = unsafe {
190                                if looks_like_package && !manifest_json_exists && has_inner_mlmodel
191                                {
192                                    let inner_url = NSURL::fileURLWithPath(&NSString::from_str(
193                                        &inner_mlmodel_path.to_string_lossy(),
194                                    ));
195                                    MLModel::compileModelAtURL_error(&inner_url)
196                                } else {
197                                    MLModel::compileModelAtURL_error(&url)
198                                }
199                            };
200
201                            match compile_result {
202                                Ok(compiled_url) => {
203                                    debug!("Compilation completed, caching and loading compiled model");
204
205                                    // Cache the compiled model for future use
206                                    if let Err(e) = Self::cache_compiled_model(path, &compiled_url) {
207                                        debug!("Failed to cache compiled model: {e}");
208                                    }
209                                    match unsafe { load_with_config(&compiled_url, function_name) } {
210                                        Ok(model) => {
211                                            info!(
212                                                "Compiled model loaded in {:.1}s total",
213                                                load_start.elapsed().as_secs_f32()
214                                            );
215                                            Ok(CoreMLModel {
216                                                inner: model,
217                                                config: config.clone(),
218                                                function_name: function_name.map(|s| s.to_string()),
219                                            })
220                                        }
221                                        Err(err) => Err(CandleError::Msg(format!(
222                                            "Failed to load compiled CoreML model: {err}"
223                                        ))),
224                                    }
225                                }
226                                Err(compile_err) => Err(CandleError::Msg(format!(
227                                    "Failed to compile CoreML model: {compile_err}. Original load error: {load_err}"
228                                ))),
229                            }
230                        } else {
231                            // Not a compilable artifact and load failed
232                            Err(load_err)
233                        }
234                    }
235                }
236            })
237        }
238
239        #[cfg(not(target_os = "macos"))]
240        {
241            let _ = (path, config, function_name);
242            Err(CandleError::Msg(
243                "CoreML is only available on macOS".to_string(),
244            ))
245        }
246    }
247
248    /// Run forward pass through the model with multiple inputs
249    ///
250    /// Accepts tensors from CPU or Metal devices, rejects CUDA tensors.
251    /// Returns output tensor on the same device as the input tensors.
252    ///
253    /// # Arguments
254    /// * `inputs` - Slice of tensors corresponding to the input_names in config order
255    ///
256    /// Convenience method for single-input models (backward compatibility)
257    pub fn forward_single(&self, input: &Tensor) -> Result<Tensor, CandleError> {
258        self.forward(&[input])
259    }
260
261    pub fn forward(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError> {
262        // Validate we have the expected number of inputs
263        if inputs.len() != self.config.input_names.len() {
264            return Err(CandleError::Msg(format!(
265                "Expected {} inputs, got {}. Input names: {:?}",
266                self.config.input_names.len(),
267                inputs.len(),
268                self.config.input_names
269            )));
270        }
271
272        // Validate all input devices are compatible - accept CPU/Metal, reject CUDA
273        for (i, input) in inputs.iter().enumerate() {
274            match input.device() {
275                Device::Cpu | Device::Metal(_) => {
276                    // Valid devices for CoreML
277                }
278                Device::Cuda(_) => {
279                    return Err(CandleError::Msg(format!(
280                            "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
281                            i, self.config.input_names[i]
282                        )));
283                }
284            }
285        }
286
287        #[cfg(target_os = "macos")]
288        {
289            self.forward_impl(inputs)
290        }
291
292        #[cfg(not(target_os = "macos"))]
293        {
294            let _ = inputs;
295            Err(CandleError::Msg(
296                "CoreML is only available on macOS".to_string(),
297            ))
298        }
299    }
300
301    /// Forward pass returning all outputs as a HashMap
302    ///
303    /// This is useful for models that have multiple outputs, such as the Qwen LM head
304    /// which produces 16 different logits chunks that need to be concatenated.
305    pub fn forward_all(
306        &self,
307        inputs: &[&Tensor],
308    ) -> Result<std::collections::HashMap<String, Tensor>, CandleError> {
309        // Validate we have the expected number of inputs
310        if inputs.len() != self.config.input_names.len() {
311            return Err(CandleError::Msg(format!(
312                "Expected {} inputs, got {}. Input names: {:?}",
313                self.config.input_names.len(),
314                inputs.len(),
315                self.config.input_names
316            )));
317        }
318
319        // Validate all input devices are compatible - accept CPU/Metal, reject CUDA
320        for (i, input) in inputs.iter().enumerate() {
321            match input.device() {
322                Device::Cpu | Device::Metal(_) => {
323                    // Valid devices for CoreML
324                }
325                Device::Cuda(_) => {
326                    return Err(CandleError::Msg(format!(
327                            "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
328                            i, self.config.input_names[i]
329                        )));
330                }
331            }
332        }
333
334        #[cfg(target_os = "macos")]
335        {
336            self.forward_all_impl(inputs)
337        }
338
339        #[cfg(not(target_os = "macos"))]
340        {
341            let _ = inputs;
342            Err(CandleError::Msg(
343                "CoreML is only available on macOS".to_string(),
344            ))
345        }
346    }
347
348    /// Get the model configuration
349    pub fn config(&self) -> &Config {
350        &self.config
351    }
352
353    /// Get access to the inner MLModel for advanced usage (testing only)
354    #[cfg(target_os = "macos")]
355    pub fn inner_model(&self) -> &Retained<MLModel> {
356        &self.inner
357    }
358
359    /// Create a CoreMLModel from an existing MLModel (for testing)
360    #[cfg(target_os = "macos")]
361    pub fn from_mlmodel(inner: Retained<MLModel>, config: Config) -> Self {
362        CoreMLModel {
363            inner,
364            config,
365            function_name: None,
366        }
367    }
368
369    /// Create a fresh state object for this model.
370    ///
371    /// This enables efficient autoregressive generation by maintaining
372    /// persistent KV-cache across multiple prediction calls.
373    ///
374    /// # Returns
375    ///
376    /// A new `CoreMLState` instance that can be used with `predict_with_state()`.
377    /// For stateless models, this returns an empty state object that can still
378    /// be used with stateful prediction methods (resulting in stateless behavior).
379    ///
380    /// # Example
381    ///
382    /// ```rust,no_run
383    /// use candle_core::{Device, Tensor};
384    /// use candle_coreml::{CoreMLModel, Config};
385    ///
386    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
387    /// let model = CoreMLModel::load("model.mlmodelc")?;
388    ///
389    /// // Create state for efficient token generation
390    /// let mut state = model.make_state()?;
391    ///
392    /// // Use state with predict_with_state() for streaming inference
393    /// # Ok(())
394    /// # }
395    /// ```
396    pub fn make_state(&self) -> Result<CoreMLState, CandleError> {
397        #[cfg(target_os = "macos")]
398        {
399            CoreMLState::new(&self.inner)
400        }
401
402        #[cfg(not(target_os = "macos"))]
403        {
404            CoreMLState::new(&())
405        }
406    }
407
408    /// Run forward pass through the model with persistent state.
409    ///
410    /// This method enables efficient autoregressive generation by maintaining
411    /// KV-cache state across multiple prediction calls. Unlike the stateless
412    /// `forward()` method, this preserves computation state between calls.
413    ///
414    /// # Arguments
415    ///
416    /// * `inputs` - Slice of tensors corresponding to input_names in config order
417    /// * `state` - Mutable reference to the model state (will be updated)
418    ///
419    /// # Returns
420    ///
421    /// Output tensor on the same device as the input tensors.
422    ///
423    /// # Device Compatibility
424    ///
425    /// Accepts tensors from CPU or Metal devices, rejects CUDA tensors.
426    ///
427    /// # Example
428    ///
429    /// ```rust,no_run
430    /// use candle_core::{Device, Tensor};
431    /// use candle_coreml::{CoreMLModel, Config};
432    ///
433    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
434    /// let model = CoreMLModel::load("model.mlmodelc")?;
435    /// let device = Device::Cpu;
436    ///
437    /// let mut state = model.make_state()?;
438    ///
439    /// // Generate tokens with persistent KV-cache
440    /// for i in 0..10 {
441    ///     let input = Tensor::ones((1, 1), candle_core::DType::I64, &device)?;
442    ///     let output = model.predict_with_state(&[&input], &mut state)?;
443    ///     println!("Token {}: {:?}", i, output);
444    /// }
445    /// # Ok(())
446    /// # }
447    /// ```
448    pub fn predict_with_state(
449        &self,
450        inputs: &[&Tensor],
451        state: &mut CoreMLState,
452    ) -> Result<Tensor, CandleError> {
453        // Validate we have the expected number of inputs
454        if inputs.len() != self.config.input_names.len() {
455            return Err(CandleError::Msg(format!(
456                "Expected {} inputs, got {}. Input names: {:?}",
457                self.config.input_names.len(),
458                inputs.len(),
459                self.config.input_names
460            )));
461        }
462
463        // Validate all input devices are compatible - accept CPU/Metal, reject CUDA
464        for (i, input) in inputs.iter().enumerate() {
465            match input.device() {
466                Device::Cpu | Device::Metal(_) => {
467                    // Valid devices for CoreML
468                }
469                Device::Cuda(_) => {
470                    return Err(CandleError::Msg(format!(
471                            "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
472                            i, self.config.input_names[i]
473                        )));
474                }
475            }
476        }
477
478        #[cfg(target_os = "macos")]
479        {
480            // Verbose print of input shapes and names moved to trace level
481            tracing::trace!("predict_with_state function={:?}", self.function_name);
482            for (i, t) in inputs.iter().enumerate() {
483                tracing::trace!(
484                    "predict_with_state input {} '{}' shape={:?}",
485                    i,
486                    self.config.input_names[i],
487                    t.dims()
488                );
489            }
490            self.predict_with_state_impl(inputs, state)
491        }
492
493        #[cfg(not(target_os = "macos"))]
494        {
495            let _ = (inputs, state);
496            Err(CandleError::Msg(
497                "CoreML is only available on macOS".to_string(),
498            ))
499        }
500    }
501
502    #[cfg(target_os = "macos")]
503    fn forward_impl(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError> {
504        autoreleasepool(|_| {
505            // Convert all Candle tensors to MLMultiArrays
506            let mut ml_arrays = Vec::with_capacity(inputs.len());
507            for input in inputs {
508                let ml_array = tensor_to_mlmultiarray(input)?;
509                ml_arrays.push(ml_array);
510            }
511
512            // Create feature provider with all named inputs
513            let provider = create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
514
515            // Run prediction
516            let prediction = self.run_prediction(&provider)?;
517
518            // Extract output with configured output name (use first input device for output)
519            let output_tensor =
520                extract_output(&prediction, &self.config.output_name, inputs[0].device())?;
521
522            Ok(output_tensor)
523        })
524    }
525
526    #[cfg(target_os = "macos")]
527    fn forward_all_impl(
528        &self,
529        inputs: &[&Tensor],
530    ) -> Result<std::collections::HashMap<String, Tensor>, CandleError> {
531        autoreleasepool(|_| {
532            // Convert all Candle tensors to MLMultiArrays
533            let mut ml_arrays = Vec::with_capacity(inputs.len());
534            for input in inputs {
535                let ml_array = tensor_to_mlmultiarray(input)?;
536                ml_arrays.push(ml_array);
537            }
538
539            // Create feature provider with all named inputs
540            let provider = create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
541
542            // Run prediction
543            let prediction = self.run_prediction(&provider)?;
544
545            // Extract all outputs
546            extract_all_outputs(&prediction, inputs[0].device())
547        })
548    }
549
550    #[cfg(target_os = "macos")]
551    fn run_prediction(
552        &self,
553        provider: &MLDictionaryFeatureProvider,
554    ) -> Result<Retained<ProtocolObject<dyn MLFeatureProvider>>, CandleError> {
555        autoreleasepool(|_| unsafe {
556            let protocol_provider = ProtocolObject::from_ref(provider);
557
558            // Function name is now handled during model loading via MLModelConfiguration
559            self.inner
560                .predictionFromFeatures_error(protocol_provider)
561                .map_err(|e| CandleError::Msg(format!("CoreML prediction error: {e:?}")))
562        })
563    }
564
565    #[cfg(target_os = "macos")]
566    fn predict_with_state_impl(
567        &self,
568        inputs: &[&Tensor],
569        state: &mut CoreMLState,
570    ) -> Result<Tensor, CandleError> {
571        autoreleasepool(|_| {
572            // Convert all Candle tensors to MLMultiArrays (reuse existing logic)
573            let mut ml_arrays = Vec::with_capacity(inputs.len());
574            for input in inputs {
575                let ml_array = tensor_to_mlmultiarray(input)?;
576                ml_arrays.push(ml_array);
577            }
578
579            // Create feature provider with all named inputs (reuse existing logic)
580            let provider = create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
581
582            // Run stateful prediction
583            let prediction = self.run_prediction_with_state(&provider, state)?;
584
585            // Extract output with configured output name (use first input device for output)
586            let output_tensor =
587                extract_output(&prediction, &self.config.output_name, inputs[0].device())?;
588
589            Ok(output_tensor)
590        })
591    }
592
593    #[cfg(target_os = "macos")]
594    fn run_prediction_with_state(
595        &self,
596        provider: &MLDictionaryFeatureProvider,
597        state: &mut CoreMLState,
598    ) -> Result<Retained<ProtocolObject<dyn MLFeatureProvider>>, CandleError> {
599        autoreleasepool(|_| unsafe {
600            let protocol_provider = ProtocolObject::from_ref(provider);
601
602            self.inner
603                .predictionFromFeatures_usingState_error(protocol_provider, state.inner())
604                .map_err(|e| CandleError::Msg(format!("CoreML stateful prediction error: {e:?}")))
605        })
606    }
607
608    /// Try to load a cached compiled model if it exists
609    #[cfg(target_os = "macos")]
610    fn try_load_cached_compiled_model(
611        source_path: &Path,
612        load_start: &std::time::Instant,
613        config: &Config,
614        function_name: Option<&str>,
615    ) -> Result<CoreMLModel, CandleError> {
616        let cache_path = Self::get_compiled_cache_path(source_path)?;
617
618        if cache_path.exists() {
619            debug!("Found cached compiled model at: {}", cache_path.display());
620
621            // Check if cached version is newer than source
622            if let (Ok(cache_meta), Ok(source_meta)) =
623                (cache_path.metadata(), source_path.metadata())
624            {
625                if let (Ok(cache_modified), Ok(source_modified)) =
626                    (cache_meta.modified(), source_meta.modified())
627                {
628                    if cache_modified >= source_modified {
629                        let url = unsafe {
630                            NSURL::fileURLWithPath(&NSString::from_str(
631                                &cache_path.to_string_lossy(),
632                            ))
633                        };
634
635                        match unsafe {
636                            if let Some(func) = function_name {
637                                let ml_cfg = MLModelConfiguration::new();
638                                let ns_name = NSString::from_str(func);
639                                ml_cfg.setFunctionName(Some(&ns_name));
640                                MLModel::modelWithContentsOfURL_configuration_error(&url, &ml_cfg)
641                            } else {
642                                MLModel::modelWithContentsOfURL_error(&url)
643                            }
644                        } {
645                            Ok(model) => {
646                                info!(
647                                    "Cached compiled model loaded in {:.1}s",
648                                    load_start.elapsed().as_secs_f32()
649                                );
650                                return Ok(CoreMLModel {
651                                    inner: model,
652                                    config: config.clone(),
653                                    function_name: function_name.map(|s| s.to_string()),
654                                });
655                            }
656                            Err(e) => {
657                                debug!("Failed to load cached compiled model: {e}");
658                                // Continue to recompilation
659                            }
660                        }
661                    } else {
662                        debug!("Cached compiled model is older than source, will recompile");
663                    }
664                }
665            }
666        }
667
668        Err(CandleError::Msg(
669            "No valid cached compiled model found".to_string(),
670        ))
671    }
672
673    /// Cache a compiled model for future use
674    #[cfg(target_os = "macos")]
675    fn cache_compiled_model(source_path: &Path, compiled_url: &NSURL) -> Result<(), CandleError> {
676        let cache_path = Self::get_compiled_cache_path(source_path)?;
677
678        // Create cache directory if it doesn't exist
679        if let Some(parent) = cache_path.parent() {
680            std::fs::create_dir_all(parent)
681                .map_err(|e| CandleError::Msg(format!("Failed to create cache directory: {e}")))?;
682        }
683
684        // Get the path from the compiled URL
685        let compiled_path_str = unsafe { compiled_url.path() };
686        if compiled_path_str.is_none() {
687            return Err(CandleError::Msg("Invalid compiled model URL".to_string()));
688        }
689
690        let compiled_path = std::path::PathBuf::from(compiled_path_str.unwrap().to_string());
691
692        // Copy the compiled model to the cache location
693        if compiled_path.exists() {
694            if cache_path.exists() {
695                std::fs::remove_dir_all(&cache_path).map_err(|e| {
696                    CandleError::Msg(format!("Failed to remove old cached model: {e}"))
697                })?;
698            }
699
700            Self::copy_recursive(&compiled_path, &cache_path)
701                .map_err(|e| CandleError::Msg(format!("Failed to cache compiled model: {e}")))?;
702
703            debug!("Cached compiled model at: {}", cache_path.display());
704        } else {
705            return Err(CandleError::Msg(
706                "Compiled model path does not exist".to_string(),
707            ));
708        }
709
710        Ok(())
711    }
712
713    /// Get the cache path for a compiled model
714    fn get_compiled_cache_path(source_path: &Path) -> Result<std::path::PathBuf, CandleError> {
715        // Use the CacheManager to get a consistent cache directory
716        use crate::CacheManager;
717        let cache_manager = CacheManager::new()
718            .map_err(|e| CandleError::Msg(format!("Failed to initialize cache manager: {e}")))?;
719
720        let cache_dir = cache_manager.models_dir().parent().unwrap().to_path_buf();
721
722        // Create a unique cache key based on the source path
723        let source_hash = {
724            use std::collections::hash_map::DefaultHasher;
725            use std::hash::{Hash, Hasher};
726            let mut hasher = DefaultHasher::new();
727            source_path.hash(&mut hasher);
728            hasher.finish()
729        };
730
731        let cache_name = format!("compiled_{source_hash:x}.mlmodelc");
732        Ok(cache_dir.join("compiled_models").join(cache_name))
733    }
734
735    /// Recursively copy a directory
736    fn copy_recursive(from: &Path, to: &Path) -> std::io::Result<()> {
737        if from.is_dir() {
738            std::fs::create_dir_all(to)?;
739            for entry in std::fs::read_dir(from)? {
740                let entry = entry?;
741                let from_path = entry.path();
742                let to_path = to.join(entry.file_name());
743                Self::copy_recursive(&from_path, &to_path)?;
744            }
745        } else {
746            if let Some(parent) = to.parent() {
747                std::fs::create_dir_all(parent)?;
748            }
749            std::fs::copy(from, to)?;
750        }
751        Ok(())
752    }
753}
754
755#[cfg(test)]
756mod tests {
757    #[cfg(target_os = "macos")]
758    use super::*;
759
760    #[test]
761    #[cfg(target_os = "macos")]
762    fn test_model_creation() {
763        // This test requires an actual .mlmodelc file
764        // Skip if file doesn't exist
765        let model_path = "models/test.mlmodelc";
766        if !std::path::Path::new(model_path).exists() {
767            return;
768        }
769
770        let config = Config::default();
771        let device = Device::Cpu;
772
773        let model = CoreMLModel::load_from_file(model_path, &config).expect("Failed to load model");
774
775        // Test config access
776        assert_eq!(model.config().input_names[0], "input_ids");
777
778        // Test with dummy input tensor on CPU device
779        let input = Tensor::ones((1, 10), candle_core::DType::F32, &device)
780            .expect("Failed to create input tensor");
781
782        // This will fail without a real model but tests the interface
783        let _result = model.forward_single(&input);
784    }
785}