1use crate::config::Config;
4use crate::state::CoreMLState;
5
6#[cfg(target_os = "macos")]
7use crate::conversion::{
8 create_multi_feature_provider, extract_all_outputs, extract_output, tensor_to_mlmultiarray,
9};
10use candle_core::{Device, Error as CandleError, Tensor};
11use std::path::Path;
12
13#[cfg(target_os = "macos")]
14use tracing::{debug, info};
15
16#[cfg(target_os = "macos")]
17use objc2::rc::{autoreleasepool, Retained};
18#[cfg(target_os = "macos")]
19use objc2::runtime::ProtocolObject;
20#[cfg(target_os = "macos")]
21use objc2_core_ml::{
22 MLDictionaryFeatureProvider, MLFeatureProvider, MLModel, MLModelConfiguration,
23};
24#[cfg(target_os = "macos")]
25use objc2_foundation::{NSString, NSURL};
26
27pub struct CoreMLModel {
29 #[cfg(target_os = "macos")]
30 pub(crate) inner: Retained<MLModel>,
31 #[cfg(not(target_os = "macos"))]
32 _phantom: std::marker::PhantomData<()>,
33 pub(crate) config: Config,
34 pub(crate) function_name: Option<String>,
35}
36
37impl std::fmt::Debug for CoreMLModel {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.debug_struct("CoreMLModel")
40 .field("config", &self.config)
41 .field("function_name", &self.function_name)
42 .finish_non_exhaustive()
43 }
44}
45
46impl CoreMLModel {
47 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, CandleError> {
49 let config = Config::default();
50 Self::load_from_file(path, &config)
51 }
52
53 pub fn load_with_function<P: AsRef<Path>>(
55 path: P,
56 config: &Config,
57 function_name: &str,
58 ) -> Result<Self, CandleError> {
59 Self::load_from_file_with_function(path, config, Some(function_name))
60 }
61
62 pub fn load_from_file<P: AsRef<Path>>(path: P, config: &Config) -> Result<Self, CandleError> {
67 Self::load_from_file_with_function(path, config, None)
68 }
69
70 pub fn load_from_file_with_function<P: AsRef<Path>>(
72 path: P,
73 config: &Config,
74 function_name: Option<&str>,
75 ) -> Result<Self, CandleError> {
76 #[cfg(target_os = "macos")]
77 {
78 let path = path.as_ref();
79 if !path.exists() {
80 return Err(CandleError::Msg(format!(
81 "Model file not found: {}",
82 path.display()
83 )));
84 }
85
86 autoreleasepool(|_| {
87 let url =
88 unsafe { NSURL::fileURLWithPath(&NSString::from_str(&path.to_string_lossy())) };
89
90 info!("Loading and compiling CoreML model at {}", path.display());
92 let load_start = std::time::Instant::now();
93
94 let model_result = if let Some(func_name) = function_name {
96 let config = unsafe { MLModelConfiguration::new() };
97 let ns_func_name = NSString::from_str(func_name);
98 unsafe { config.setFunctionName(Some(&ns_func_name)) };
99 unsafe { MLModel::modelWithContentsOfURL_configuration_error(&url, &config) }
100 } else {
101 unsafe { MLModel::modelWithContentsOfURL_error(&url) }
102 };
103
104 let load_time = load_start.elapsed();
105
106 match model_result {
108 Ok(model) => {
109 info!(
110 "Model loaded and compiled in {:.1}s",
111 load_time.as_secs_f32()
112 );
113 Ok(CoreMLModel {
114 inner: model,
115 config: config.clone(),
116 function_name: function_name.map(|s| s.to_string()),
117 })
118 }
119 Err(err) => {
120 let err_msg = format!("{err:?}");
122 if err_msg.contains("Compile the model") {
123 debug!("Model requires compilation, compiling now");
124 #[allow(deprecated)]
125 match unsafe { MLModel::compileModelAtURL_error(&url) } {
126 Ok(compiled_url) => {
127 debug!("Compilation completed, loading compiled model");
128 match unsafe {
130 MLModel::modelWithContentsOfURL_error(&compiled_url)
131 } {
132 Ok(model) => {
133 info!(
134 "Compiled model loaded in {:.1}s total",
135 load_time.as_secs_f32()
136 );
137 Ok(CoreMLModel {
138 inner: model,
139 config: config.clone(),
140 function_name: function_name.map(|s| s.to_string()),
141 })
142 }
143 Err(compile_err) => Err(CandleError::Msg(format!(
144 "Failed to load compiled CoreML model: {compile_err:?}"
145 ))),
146 }
147 }
148 Err(compile_err) => Err(CandleError::Msg(format!(
149 "Failed to compile CoreML model: {compile_err:?}. Original error: {err:?}"
150 ))),
151 }
152 } else {
153 let err_msg = format!("{err:?}");
155 if err_msg.contains("compiler major version")
156 && err_msg.contains("more recent than this framework")
157 {
158 Err(CandleError::Msg(format!(
159 "CoreML version compatibility issue: {err_msg}\n\
160 This model was compiled with a newer CoreML compiler than this system supports.\n\
161 Solutions:\n\
162 • Update to a newer macOS version\n\
163 • Use models compiled for your CoreML framework version\n\
164 • Set RUST_LOG=debug for more details"
165 )))
166 } else {
167 Err(CandleError::Msg(format!(
168 "Failed to load CoreML model: {err:?}"
169 )))
170 }
171 }
172 }
173 }
174 })
175 }
176
177 #[cfg(not(target_os = "macos"))]
178 {
179 let _ = (path, config, function_name);
180 Err(CandleError::Msg(
181 "CoreML is only available on macOS".to_string(),
182 ))
183 }
184 }
185
186 pub fn forward_single(&self, input: &Tensor) -> Result<Tensor, CandleError> {
196 self.forward(&[input])
197 }
198
199 pub fn forward(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError> {
200 if inputs.len() != self.config.input_names.len() {
202 return Err(CandleError::Msg(format!(
203 "Expected {} inputs, got {}. Input names: {:?}",
204 self.config.input_names.len(),
205 inputs.len(),
206 self.config.input_names
207 )));
208 }
209
210 for (i, input) in inputs.iter().enumerate() {
212 match input.device() {
213 Device::Cpu | Device::Metal(_) => {
214 }
216 Device::Cuda(_) => {
217 return Err(CandleError::Msg(format!(
218 "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
219 i, self.config.input_names[i]
220 )));
221 }
222 }
223 }
224
225 #[cfg(target_os = "macos")]
226 {
227 self.forward_impl(inputs)
228 }
229
230 #[cfg(not(target_os = "macos"))]
231 {
232 let _ = inputs;
233 Err(CandleError::Msg(
234 "CoreML is only available on macOS".to_string(),
235 ))
236 }
237 }
238
239 pub fn forward_all(
244 &self,
245 inputs: &[&Tensor],
246 ) -> Result<std::collections::HashMap<String, Tensor>, CandleError> {
247 if inputs.len() != self.config.input_names.len() {
249 return Err(CandleError::Msg(format!(
250 "Expected {} inputs, got {}. Input names: {:?}",
251 self.config.input_names.len(),
252 inputs.len(),
253 self.config.input_names
254 )));
255 }
256
257 for (i, input) in inputs.iter().enumerate() {
259 match input.device() {
260 Device::Cpu | Device::Metal(_) => {
261 }
263 Device::Cuda(_) => {
264 return Err(CandleError::Msg(format!(
265 "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
266 i, self.config.input_names[i]
267 )));
268 }
269 }
270 }
271
272 #[cfg(target_os = "macos")]
273 {
274 self.forward_all_impl(inputs)
275 }
276
277 #[cfg(not(target_os = "macos"))]
278 {
279 let _ = inputs;
280 Err(CandleError::Msg(
281 "CoreML is only available on macOS".to_string(),
282 ))
283 }
284 }
285
286 pub fn config(&self) -> &Config {
288 &self.config
289 }
290
291 #[cfg(target_os = "macos")]
293 pub fn inner_model(&self) -> &Retained<MLModel> {
294 &self.inner
295 }
296
297 #[cfg(target_os = "macos")]
299 pub fn from_mlmodel(inner: Retained<MLModel>, config: Config) -> Self {
300 CoreMLModel {
301 inner,
302 config,
303 function_name: None,
304 }
305 }
306
307 pub fn make_state(&self) -> Result<CoreMLState, CandleError> {
335 #[cfg(target_os = "macos")]
336 {
337 CoreMLState::new(&self.inner)
338 }
339
340 #[cfg(not(target_os = "macos"))]
341 {
342 CoreMLState::new(&())
343 }
344 }
345
346 pub fn predict_with_state(
387 &self,
388 inputs: &[&Tensor],
389 state: &mut CoreMLState,
390 ) -> Result<Tensor, CandleError> {
391 if inputs.len() != self.config.input_names.len() {
393 return Err(CandleError::Msg(format!(
394 "Expected {} inputs, got {}. Input names: {:?}",
395 self.config.input_names.len(),
396 inputs.len(),
397 self.config.input_names
398 )));
399 }
400
401 for (i, input) in inputs.iter().enumerate() {
403 match input.device() {
404 Device::Cpu | Device::Metal(_) => {
405 }
407 Device::Cuda(_) => {
408 return Err(CandleError::Msg(format!(
409 "CoreML models do not support CUDA tensors. Input {} '{}' is on CUDA device. Please move tensor to CPU or Metal device first.",
410 i, self.config.input_names[i]
411 )));
412 }
413 }
414 }
415
416 #[cfg(target_os = "macos")]
417 {
418 tracing::debug!("predict_with_state function={:?}", self.function_name);
420 for (i, t) in inputs.iter().enumerate() {
421 tracing::debug!(
422 "predict_with_state input {} '{}' shape={:?}",
423 i,
424 self.config.input_names[i],
425 t.dims()
426 );
427 }
428 self.predict_with_state_impl(inputs, state)
429 }
430
431 #[cfg(not(target_os = "macos"))]
432 {
433 let _ = (inputs, state);
434 Err(CandleError::Msg(
435 "CoreML is only available on macOS".to_string(),
436 ))
437 }
438 }
439
440 #[cfg(target_os = "macos")]
441 fn forward_impl(&self, inputs: &[&Tensor]) -> Result<Tensor, CandleError> {
442 autoreleasepool(|_| {
443 let mut ml_arrays = Vec::with_capacity(inputs.len());
445 for input in inputs {
446 let ml_array = tensor_to_mlmultiarray(input)?;
447 ml_arrays.push(ml_array);
448 }
449
450 let provider = create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
452
453 let prediction = self.run_prediction(&provider)?;
455
456 let output_tensor =
458 extract_output(&prediction, &self.config.output_name, inputs[0].device())?;
459
460 Ok(output_tensor)
461 })
462 }
463
464 #[cfg(target_os = "macos")]
465 fn forward_all_impl(
466 &self,
467 inputs: &[&Tensor],
468 ) -> Result<std::collections::HashMap<String, Tensor>, CandleError> {
469 autoreleasepool(|_| {
470 let mut ml_arrays = Vec::with_capacity(inputs.len());
472 for input in inputs {
473 let ml_array = tensor_to_mlmultiarray(input)?;
474 ml_arrays.push(ml_array);
475 }
476
477 let provider = create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
479
480 let prediction = self.run_prediction(&provider)?;
482
483 extract_all_outputs(&prediction, inputs[0].device())
485 })
486 }
487
488 #[cfg(target_os = "macos")]
489 fn run_prediction(
490 &self,
491 provider: &MLDictionaryFeatureProvider,
492 ) -> Result<Retained<ProtocolObject<dyn MLFeatureProvider>>, CandleError> {
493 autoreleasepool(|_| unsafe {
494 let protocol_provider = ProtocolObject::from_ref(provider);
495
496 self.inner
498 .predictionFromFeatures_error(protocol_provider)
499 .map_err(|e| CandleError::Msg(format!("CoreML prediction error: {e:?}")))
500 })
501 }
502
503 #[cfg(target_os = "macos")]
504 fn predict_with_state_impl(
505 &self,
506 inputs: &[&Tensor],
507 state: &mut CoreMLState,
508 ) -> Result<Tensor, CandleError> {
509 autoreleasepool(|_| {
510 let mut ml_arrays = Vec::with_capacity(inputs.len());
512 for input in inputs {
513 let ml_array = tensor_to_mlmultiarray(input)?;
514 ml_arrays.push(ml_array);
515 }
516
517 let provider = create_multi_feature_provider(&self.config.input_names, &ml_arrays)?;
519
520 let prediction = self.run_prediction_with_state(&provider, state)?;
522
523 let output_tensor =
525 extract_output(&prediction, &self.config.output_name, inputs[0].device())?;
526
527 Ok(output_tensor)
528 })
529 }
530
531 #[cfg(target_os = "macos")]
532 fn run_prediction_with_state(
533 &self,
534 provider: &MLDictionaryFeatureProvider,
535 state: &mut CoreMLState,
536 ) -> Result<Retained<ProtocolObject<dyn MLFeatureProvider>>, CandleError> {
537 autoreleasepool(|_| unsafe {
538 let protocol_provider = ProtocolObject::from_ref(provider);
539
540 self.inner
541 .predictionFromFeatures_usingState_error(protocol_provider, state.inner())
542 .map_err(|e| CandleError::Msg(format!("CoreML stateful prediction error: {e:?}")))
543 })
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 #[cfg(target_os = "macos")]
550 use super::*;
551
552 #[test]
553 #[cfg(target_os = "macos")]
554 fn test_model_creation() {
555 let model_path = "models/test.mlmodelc";
558 if !std::path::Path::new(model_path).exists() {
559 return;
560 }
561
562 let config = Config::default();
563 let device = Device::Cpu;
564
565 let model = CoreMLModel::load_from_file(model_path, &config).expect("Failed to load model");
566
567 assert_eq!(model.config().input_names[0], "input_ids");
569
570 let input = Tensor::ones((1, 10), candle_core::DType::F32, &device)
572 .expect("Failed to create input tensor");
573
574 let _result = model.forward_single(&input);
576 }
577}