candle_coreml/
builder.rs

1//! `CoreML` model builder for convenient model loading
2
3use crate::config::basic::Config;
4use crate::CoreMLModel;
5use candle_core::Error as CandleError;
6use std::path::{Path, PathBuf};
7
8/// Builder for `CoreML` models
9///
10/// This provides an interface for loading `CoreML` models with configuration
11/// management and device selection.
12pub struct CoreMLModelBuilder {
13    config: Config,
14    model_filename: PathBuf,
15}
16
17impl CoreMLModelBuilder {
18    /// Create a new builder with the specified model path and config
19    pub fn new<P: AsRef<Path>>(model_path: P, config: Config) -> Self {
20        Self {
21            config,
22            model_filename: model_path.as_ref().to_path_buf(),
23        }
24    }
25
26    /// Load a `CoreML` model from `HuggingFace` or local files
27    pub fn load_from_hub(
28        model_id: &str,
29        model_filename: Option<&str>,
30        config_filename: Option<&str>,
31    ) -> Result<Self, CandleError> {
32        use crate::get_local_or_remote_file;
33        use hf_hub::{api::sync::Api, Repo, RepoType};
34
35        let api =
36            Api::new().map_err(|e| CandleError::Msg(format!("Failed to create HF API: {e}")))?;
37        let repo = api.repo(Repo::with_revision(
38            model_id.to_string(),
39            RepoType::Model,
40            "main".to_string(),
41        ));
42
43        // Load config
44        let config_path = match config_filename {
45            Some(filename) => get_local_or_remote_file(filename, &repo)
46                .map_err(|e| CandleError::Msg(format!("Failed to get config file: {e}")))?,
47            None => get_local_or_remote_file("config.json", &repo)
48                .map_err(|e| CandleError::Msg(format!("Failed to get config.json: {e}")))?,
49        };
50
51        let config_str = std::fs::read_to_string(config_path)
52            .map_err(|e| CandleError::Msg(format!("Failed to read config file: {e}")))?;
53        let config: Config = serde_json::from_str(&config_str)
54            .map_err(|e| CandleError::Msg(format!("Failed to parse config: {e}")))?;
55
56        // Get model file
57        let model_path = match model_filename {
58            Some(filename) => get_local_or_remote_file(filename, &repo)
59                .map_err(|e| CandleError::Msg(format!("Failed to get model file: {e}")))?,
60            None => {
61                // Try common CoreML model filenames
62                for filename in &["model.mlmodelc", "model.mlpackage"] {
63                    if let Ok(path) = get_local_or_remote_file(filename, &repo) {
64                        return Ok(Self::new(path, config));
65                    }
66                }
67                return Err(CandleError::Msg("No CoreML model file found".to_string()));
68            }
69        };
70
71        Ok(Self::new(model_path, config))
72    }
73
74    /// Build the CoreML model
75    pub fn build_model(&self) -> Result<CoreMLModel, CandleError> {
76        CoreMLModel::load_from_file(&self.model_filename, &self.config)
77    }
78
79    /// Get the config
80    pub fn config(&self) -> &Config {
81        &self.config
82    }
83}