1use crate::config::basic::Config;
4use crate::state::CoreMLState;
5
6#[cfg(target_os = "macos")]
7use crate::conversion::{
8 create_multi_feature_provider, extract_all_outputs, extract_output, tensor_to_mlmultiarray,
9};
10use candle_core::{Device, Error as CandleError, Tensor};
11use std::path::Path;
12
13#[cfg(target_os = "macos")]
14use tracing::{debug, info};
15
16#[cfg(target_os = "macos")]
17use objc2::rc::{autoreleasepool, Retained};
18#[cfg(target_os = "macos")]
19use objc2::runtime::ProtocolObject;
20#[cfg(target_os = "macos")]
21use objc2_core_ml::{
22 MLDictionaryFeatureProvider, MLFeatureProvider, MLModel, MLModelConfiguration,
23};
24#[cfg(target_os = "macos")]
25use objc2_foundation::{NSString, NSURL};
26
27pub struct CoreMLModel {
29 #[cfg(target_os = "macos")]
30 pub(crate) inner: Retained<MLModel>,
31 #[cfg(not(target_os = "macos"))]
32 _phantom: std::marker::PhantomData<()>,
33 pub(crate) config: Config,
34 pub(crate) function_name: Option<String>,
35}
36
37impl std::fmt::Debug for CoreMLModel {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.debug_struct("CoreMLModel")
40 .field("config", &self.config)
41 .field("function_name", &self.function_name)
42 .finish_non_exhaustive()
43 }
44}
45
46impl CoreMLModel {
47 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, CandleError> {
49 let config = Config::default();
50 Self::load_from_file(path, &config)
51 }
52
53 pub fn load_with_function<P: AsRef<Path>>(
55 path: P,
56 config: &Config,
57 function_name: &str,
58 ) -> Result<Self, CandleError> {
59 Self::load_from_file_with_function(path, config, Some(function_name))
60 }
61
62 pub fn load_from_file<P: AsRef<Path>>(path: P, config: &Config) -> Result<Self, CandleError> {
67 Self::load_from_file_with_function(path, config, None)
68 }
69
70 pub fn load_from_file_with_function<P: AsRef<Path>>(
72 path: P,
73 config: &Config,
74 function_name: Option<&str>,
75 ) -> Result<Self, CandleError> {
76 #[cfg(target_os = "macos")]
77 {
78 let path = path.as_ref();
79 if !path.exists() {
80 return Err(CandleError::Msg(format!(
81 "Model file not found: {}",
82 path.display()
83 )));
84 }
85
86 autoreleasepool(|_| {
87 let url =
88 unsafe { NSURL::fileURLWithPath(&NSString::from_str(&path.to_string_lossy())) };
89
90 unsafe fn load_with_config(
92 url: &NSURL,
93 function_name: Option<&str>,
94 ) -> Result<Retained<MLModel>, CandleError> {
95 if let Some(func) = function_name {
96 let ml_cfg = MLModelConfiguration::new();
97 let ns_name = NSString::from_str(func);
98 ml_cfg.setFunctionName(Some(&ns_name));
99 MLModel::modelWithContentsOfURL_configuration_error(url, &ml_cfg).map_err(
100 |e| {
101 CandleError::Msg(format!(
102 "Failed to load CoreML model with configuration: {e:?}"
103 ))
104 },
105 )
106 } else {
107 MLModel::modelWithContentsOfURL_error(url).map_err(|e| {
108 CandleError::Msg(format!("Failed to load CoreML model: {e:?}"))
109 })
110 }
111 }
112
113 let is_dir = path.is_dir();
115 let ext = path
116 .extension()
117 .and_then(|s| s.to_str())
118 .unwrap_or_default()
119 .to_ascii_lowercase();
120 let looks_like_modelc =
121 ext == "mlmodelc" || (is_dir && path.to_string_lossy().ends_with(".mlmodelc"));
122 let looks_like_package = ext == "mlpackage"
123 || (is_dir && path.to_string_lossy().ends_with(".mlpackage"));
124 let manifest_json_exists = path.join("Manifest.json").exists();
127 let inner_mlmodel_path = path.join("Data/com.apple.CoreML/model.mlmodel");
128 let has_inner_mlmodel = inner_mlmodel_path.exists();
129
130 info!("Loading CoreML model at {}", path.display());
132 let load_start = std::time::Instant::now();
133
134 if looks_like_modelc {
136 match unsafe { load_with_config(&url, function_name) } {
137 Ok(model) => {
138 info!("Model loaded in {:.1}s", load_start.elapsed().as_secs_f32());
139 return Ok(CoreMLModel {
140 inner: model,
141 config: config.clone(),
142 function_name: function_name.map(|s| s.to_string()),
143 });
144 }
145 Err(err) => {
146 let msg = format!("{err}");
148 if msg.contains("compiler major version")
149 && msg.contains("more recent than this framework")
150 {
151 return Err(CandleError::Msg(format!(
152 "CoreML version compatibility issue: {msg}\n\
153 Update macOS or use a model compiled for this framework version."
154 )));
155 }
156 return Err(err);
157 }
158 }
159 }
160
161 match unsafe { load_with_config(&url, function_name) } {
164 Ok(model) => {
165 info!("Model loaded in {:.1}s", load_start.elapsed().as_secs_f32());
166 Ok(CoreMLModel {
167 inner: model,
168 config: config.clone(),
169 function_name: function_name.map(|s| s.to_string()),
170 })
171 }
172 Err(load_err) => {
173 if looks_like_package || ext == "mlmodel" || !is_dir {
175 debug!("Direct load failed, attempting compilation: {load_err}");
176
177 if let Ok(cached_model) = Self::try_load_cached_compiled_model(
179 path,
180 &load_start,
181 config,
182 function_name,
183 ) {
184 return Ok(cached_model);
185 }
186
187 #[allow(deprecated)]
188 let compile_result = unsafe {
190 if looks_like_package && !manifest_json_exists && has_inner_mlmodel
191 {
192 let inner_url = NSURL::fileURLWithPath(&NSString::from_str(
193 &inner_mlmodel_path.to_string_lossy(),
194 ));
195 MLModel::compileModelAtURL_error(&inner_url)
196 } else {
197 MLModel::compileModelAtURL_error(&url)
198 }
199 };
200
201 match compile_result {
202 Ok(compiled_url) => {
203 debug!("Compilation completed, caching and loading compiled model");
204
205 if let Err(e) = Self::cache_compiled_model(path, &compiled_url) {
207 debug!("Failed to cache compiled model: {e}");
208 }
209 match unsafe { load_with_config(&compiled_url, function_name) } {
210 Ok(model) => {
211 info!(
212 "Compiled model loaded in {:.1}s total",
213 load_start.elapsed().as_secs_f32()
214 );
215 Ok(CoreMLModel {
216 inner: model,
217 config: config.clone(),
218 function_name: function_name.map(|s| s.to_string()),
219 })
220 }
221 Err(err) => Err(CandleError::Msg(format!(
222 "Failed to load compiled CoreML model: {err}"
223 ))),
224 }
225 }
226 Err(compile_err) => Err(CandleError::Msg(format!(
227 "Failed to compile CoreML model: {compile_err}. Original load error: {load_err}"
228 ))),
229 }
230 } else {
231 Err(load_err)
233 }
234 }
235 }
236 })
237 }
238
239 #[cfg(not(target_os = "macos"))]
240 {
241 let _ = (path, config, function_name);
242 Err(CandleError::Msg(
243 "CoreML is only available on macOS".to_string(),
244 ))
245 }
246 }
247
248 pub fn forward_single(&self, input: &Tensor) -> Result<Tensor, CandleError> {
258 self.forward(&[input])
259 }
260
261 pub fn forward(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError> {
262 if inputs.len() != self.config.input_names.len() {
264 return Err(CandleError::Msg(format!(
265 "Expected {} inputs, got {}. Input names: {:?}",
266 self.config.input_names.len(),
267 inputs.len(),
268 self.config.input_names
269 )));
270 }
271
272 for (i, input) in inputs.iter().enumerate() {
274 match input.device() {
275 Device::Cpu | Device::Metal(_) => {
276 }
278 Device::Cuda(_) => {
279 return Err(CandleError::Msg(format!(
280 "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
281 i, self.config.input_names[i]
282 )));
283 }
284 }
285 }
286
287 #[cfg(target_os = "macos")]
288 {
289 self.forward_impl(inputs)
290 }
291
292 #[cfg(not(target_os = "macos"))]
293 {
294 let _ = inputs;
295 Err(CandleError::Msg(
296 "CoreML is only available on macOS".to_string(),
297 ))
298 }
299 }
300
301 pub fn forward_all(
306 &self,
307 inputs: &[&Tensor],
308 ) -> Result<std::collections::HashMap<String, Tensor>, CandleError> {
309 if inputs.len() != self.config.input_names.len() {
311 return Err(CandleError::Msg(format!(
312 "Expected {} inputs, got {}. Input names: {:?}",
313 self.config.input_names.len(),
314 inputs.len(),
315 self.config.input_names
316 )));
317 }
318
319 for (i, input) in inputs.iter().enumerate() {
321 match input.device() {
322 Device::Cpu | Device::Metal(_) => {
323 }
325 Device::Cuda(_) => {
326 return Err(CandleError::Msg(format!(
327 "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
328 i, self.config.input_names[i]
329 )));
330 }
331 }
332 }
333
334 #[cfg(target_os = "macos")]
335 {
336 self.forward_all_impl(inputs)
337 }
338
339 #[cfg(not(target_os = "macos"))]
340 {
341 let _ = inputs;
342 Err(CandleError::Msg(
343 "CoreML is only available on macOS".to_string(),
344 ))
345 }
346 }
347
348 pub fn config(&self) -> &Config {
350 &self.config
351 }
352
353 #[cfg(target_os = "macos")]
355 pub fn inner_model(&self) -> &Retained<MLModel> {
356 &self.inner
357 }
358
359 #[cfg(target_os = "macos")]
361 pub fn from_mlmodel(inner: Retained<MLModel>, config: Config) -> Self {
362 CoreMLModel {
363 inner,
364 config,
365 function_name: None,
366 }
367 }
368
369 pub fn make_state(&self) -> Result<CoreMLState, CandleError> {
397 #[cfg(target_os = "macos")]
398 {
399 CoreMLState::new(&self.inner)
400 }
401
402 #[cfg(not(target_os = "macos"))]
403 {
404 CoreMLState::new(&())
405 }
406 }
407
408 pub fn predict_with_state(
449 &self,
450 inputs: &[&Tensor],
451 state: &mut CoreMLState,
452 ) -> Result<Tensor, CandleError> {
453 if inputs.len() != self.config.input_names.len() {
455 return Err(CandleError::Msg(format!(
456 "Expected {} inputs, got {}. Input names: {:?}",
457 self.config.input_names.len(),
458 inputs.len(),
459 self.config.input_names
460 )));
461 }
462
463 for (i, input) in inputs.iter().enumerate() {
465 match input.device() {
466 Device::Cpu | Device::Metal(_) => {
467 }
469 Device::Cuda(_) => {
470 return Err(CandleError::Msg(format!(
471 "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
472 i, self.config.input_names[i]
473 )));
474 }
475 }
476 }
477
478 #[cfg(target_os = "macos")]
479 {
480 tracing::trace!("predict_with_state function={:?}", self.function_name);
482 for (i, t) in inputs.iter().enumerate() {
483 tracing::trace!(
484 "predict_with_state input {} '{}' shape={:?}",
485 i,
486 self.config.input_names[i],
487 t.dims()
488 );
489 }
490 self.predict_with_state_impl(inputs, state)
491 }
492
493 #[cfg(not(target_os = "macos"))]
494 {
495 let _ = (inputs, state);
496 Err(CandleError::Msg(
497 "CoreML is only available on macOS".to_string(),
498 ))
499 }
500 }
501
502 #[cfg(target_os = "macos")]
503 fn forward_impl(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError> {
504 autoreleasepool(|_| {
505 let mut ml_arrays = Vec::with_capacity(inputs.len());
507 for input in inputs {
508 let ml_array = tensor_to_mlmultiarray(input)?;
509 ml_arrays.push(ml_array);
510 }
511
512 let provider = create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
514
515 let prediction = self.run_prediction(&provider)?;
517
518 let output_tensor =
520 extract_output(&prediction, &self.config.output_name, inputs[0].device())?;
521
522 Ok(output_tensor)
523 })
524 }
525
526 #[cfg(target_os = "macos")]
527 fn forward_all_impl(
528 &self,
529 inputs: &[&Tensor],
530 ) -> Result<std::collections::HashMap<String, Tensor>, CandleError> {
531 autoreleasepool(|_| {
532 let mut ml_arrays = Vec::with_capacity(inputs.len());
534 for input in inputs {
535 let ml_array = tensor_to_mlmultiarray(input)?;
536 ml_arrays.push(ml_array);
537 }
538
539 let provider = create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
541
542 let prediction = self.run_prediction(&provider)?;
544
545 extract_all_outputs(&prediction, inputs[0].device())
547 })
548 }
549
550 #[cfg(target_os = "macos")]
551 fn run_prediction(
552 &self,
553 provider: &MLDictionaryFeatureProvider,
554 ) -> Result<Retained<ProtocolObject<dyn MLFeatureProvider>>, CandleError> {
555 autoreleasepool(|_| unsafe {
556 let protocol_provider = ProtocolObject::from_ref(provider);
557
558 self.inner
560 .predictionFromFeatures_error(protocol_provider)
561 .map_err(|e| CandleError::Msg(format!("CoreML prediction error: {e:?}")))
562 })
563 }
564
565 #[cfg(target_os = "macos")]
566 fn predict_with_state_impl(
567 &self,
568 inputs: &[&Tensor],
569 state: &mut CoreMLState,
570 ) -> Result<Tensor, CandleError> {
571 autoreleasepool(|_| {
572 let mut ml_arrays = Vec::with_capacity(inputs.len());
574 for input in inputs {
575 let ml_array = tensor_to_mlmultiarray(input)?;
576 ml_arrays.push(ml_array);
577 }
578
579 let provider = create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
581
582 let prediction = self.run_prediction_with_state(&provider, state)?;
584
585 let output_tensor =
587 extract_output(&prediction, &self.config.output_name, inputs[0].device())?;
588
589 Ok(output_tensor)
590 })
591 }
592
593 #[cfg(target_os = "macos")]
594 fn run_prediction_with_state(
595 &self,
596 provider: &MLDictionaryFeatureProvider,
597 state: &mut CoreMLState,
598 ) -> Result<Retained<ProtocolObject<dyn MLFeatureProvider>>, CandleError> {
599 autoreleasepool(|_| unsafe {
600 let protocol_provider = ProtocolObject::from_ref(provider);
601
602 self.inner
603 .predictionFromFeatures_usingState_error(protocol_provider, state.inner())
604 .map_err(|e| CandleError::Msg(format!("CoreML stateful prediction error: {e:?}")))
605 })
606 }
607
608 #[cfg(target_os = "macos")]
610 fn try_load_cached_compiled_model(
611 source_path: &Path,
612 load_start: &std::time::Instant,
613 config: &Config,
614 function_name: Option<&str>,
615 ) -> Result<CoreMLModel, CandleError> {
616 let cache_path = Self::get_compiled_cache_path(source_path)?;
617
618 if cache_path.exists() {
619 debug!("Found cached compiled model at: {}", cache_path.display());
620
621 if let (Ok(cache_meta), Ok(source_meta)) =
623 (cache_path.metadata(), source_path.metadata())
624 {
625 if let (Ok(cache_modified), Ok(source_modified)) =
626 (cache_meta.modified(), source_meta.modified())
627 {
628 if cache_modified >= source_modified {
629 let url = unsafe {
630 NSURL::fileURLWithPath(&NSString::from_str(
631 &cache_path.to_string_lossy(),
632 ))
633 };
634
635 match unsafe {
636 if let Some(func) = function_name {
637 let ml_cfg = MLModelConfiguration::new();
638 let ns_name = NSString::from_str(func);
639 ml_cfg.setFunctionName(Some(&ns_name));
640 MLModel::modelWithContentsOfURL_configuration_error(&url, &ml_cfg)
641 } else {
642 MLModel::modelWithContentsOfURL_error(&url)
643 }
644 } {
645 Ok(model) => {
646 info!(
647 "Cached compiled model loaded in {:.1}s",
648 load_start.elapsed().as_secs_f32()
649 );
650 return Ok(CoreMLModel {
651 inner: model,
652 config: config.clone(),
653 function_name: function_name.map(|s| s.to_string()),
654 });
655 }
656 Err(e) => {
657 debug!("Failed to load cached compiled model: {e}");
658 }
660 }
661 } else {
662 debug!("Cached compiled model is older than source, will recompile");
663 }
664 }
665 }
666 }
667
668 Err(CandleError::Msg(
669 "No valid cached compiled model found".to_string(),
670 ))
671 }
672
673 #[cfg(target_os = "macos")]
675 fn cache_compiled_model(source_path: &Path, compiled_url: &NSURL) -> Result<(), CandleError> {
676 let cache_path = Self::get_compiled_cache_path(source_path)?;
677
678 if let Some(parent) = cache_path.parent() {
680 std::fs::create_dir_all(parent)
681 .map_err(|e| CandleError::Msg(format!("Failed to create cache directory: {e}")))?;
682 }
683
684 let compiled_path_str = unsafe { compiled_url.path() };
686 if compiled_path_str.is_none() {
687 return Err(CandleError::Msg("Invalid compiled model URL".to_string()));
688 }
689
690 let compiled_path = std::path::PathBuf::from(compiled_path_str.unwrap().to_string());
691
692 if compiled_path.exists() {
694 if cache_path.exists() {
695 std::fs::remove_dir_all(&cache_path).map_err(|e| {
696 CandleError::Msg(format!("Failed to remove old cached model: {e}"))
697 })?;
698 }
699
700 Self::copy_recursive(&compiled_path, &cache_path)
701 .map_err(|e| CandleError::Msg(format!("Failed to cache compiled model: {e}")))?;
702
703 debug!("Cached compiled model at: {}", cache_path.display());
704 } else {
705 return Err(CandleError::Msg(
706 "Compiled model path does not exist".to_string(),
707 ));
708 }
709
710 Ok(())
711 }
712
713 fn get_compiled_cache_path(source_path: &Path) -> Result<std::path::PathBuf, CandleError> {
715 use crate::CacheManager;
717 let cache_manager = CacheManager::new()
718 .map_err(|e| CandleError::Msg(format!("Failed to initialize cache manager: {e}")))?;
719
720 let cache_dir = cache_manager.models_dir().parent().unwrap().to_path_buf();
721
722 let source_hash = {
724 use std::collections::hash_map::DefaultHasher;
725 use std::hash::{Hash, Hasher};
726 let mut hasher = DefaultHasher::new();
727 source_path.hash(&mut hasher);
728 hasher.finish()
729 };
730
731 let cache_name = format!("compiled_{source_hash:x}.mlmodelc");
732 Ok(cache_dir.join("compiled_models").join(cache_name))
733 }
734
735 fn copy_recursive(from: &Path, to: &Path) -> std::io::Result<()> {
737 if from.is_dir() {
738 std::fs::create_dir_all(to)?;
739 for entry in std::fs::read_dir(from)? {
740 let entry = entry?;
741 let from_path = entry.path();
742 let to_path = to.join(entry.file_name());
743 Self::copy_recursive(&from_path, &to_path)?;
744 }
745 } else {
746 if let Some(parent) = to.parent() {
747 std::fs::create_dir_all(parent)?;
748 }
749 std::fs::copy(from, to)?;
750 }
751 Ok(())
752 }
753}
754
755#[cfg(test)]
756mod tests {
757 #[cfg(target_os = "macos")]
758 use super::*;
759
760 #[test]
761 #[cfg(target_os = "macos")]
762 fn test_model_creation() {
763 let model_path = "models/test.mlmodelc";
766 if !std::path::Path::new(model_path).exists() {
767 return;
768 }
769
770 let config = Config::default();
771 let device = Device::Cpu;
772
773 let model = CoreMLModel::load_from_file(model_path, &config).expect("Failed to load model");
774
775 assert_eq!(model.config().input_names[0], "input_ids");
777
778 let input = Tensor::ones((1, 10), candle_core::DType::F32, &device)
780 .expect("Failed to create input tensor");
781
782 let _result = model.forward_single(&input);
784 }
785}