1use crate::config::basic::Config;
4use crate::CoreMLModel;
5use candle_core::Error as CandleError;
6use std::path::{Path, PathBuf};
7
8pub struct CoreMLModelBuilder {
13 config: Config,
14 model_filename: PathBuf,
15}
16
17impl CoreMLModelBuilder {
18 pub fn new<P: AsRef<Path>>(model_path: P, config: Config) -> Self {
20 Self {
21 config,
22 model_filename: model_path.as_ref().to_path_buf(),
23 }
24 }
25
26 pub fn load_from_hub(
28 model_id: &str,
29 model_filename: Option<&str>,
30 config_filename: Option<&str>,
31 ) -> Result<Self, CandleError> {
32 use crate::get_local_or_remote_file;
33 use hf_hub::{api::sync::Api, Repo, RepoType};
34
35 let api =
36 Api::new().map_err(|e| CandleError::Msg(format!("Failed to create HF API: {e}")))?;
37 let repo = api.repo(Repo::with_revision(
38 model_id.to_string(),
39 RepoType::Model,
40 "main".to_string(),
41 ));
42
43 let config_path = match config_filename {
45 Some(filename) => get_local_or_remote_file(filename, &repo)
46 .map_err(|e| CandleError::Msg(format!("Failed to get config file: {e}")))?,
47 None => get_local_or_remote_file("config.json", &repo)
48 .map_err(|e| CandleError::Msg(format!("Failed to get config.json: {e}")))?,
49 };
50
51 let config_str = std::fs::read_to_string(config_path)
52 .map_err(|e| CandleError::Msg(format!("Failed to read config file: {e}")))?;
53 let config: Config = serde_json::from_str(&config_str)
54 .map_err(|e| CandleError::Msg(format!("Failed to parse config: {e}")))?;
55
56 let model_path = match model_filename {
58 Some(filename) => get_local_or_remote_file(filename, &repo)
59 .map_err(|e| CandleError::Msg(format!("Failed to get model file: {e}")))?,
60 None => {
61 for filename in &["model.mlmodelc", "model.mlpackage"] {
63 if let Ok(path) = get_local_or_remote_file(filename, &repo) {
64 return Ok(Self::new(path, config));
65 }
66 }
67 return Err(CandleError::Msg("No CoreML model file found".to_string()));
68 }
69 };
70
71 Ok(Self::new(model_path, config))
72 }
73
74 pub fn build_model(&self) -> Result<CoreMLModel, CandleError> {
76 CoreMLModel::load_from_file(&self.model_filename, &self.config)
77 }
78
79 pub fn config(&self) -> &Config {
81 &self.config
82 }
83}