1use super::hf::{hf_access_error, remote_issue_from_api_error, RemoteAccessIssue};
2use super::{
3 DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoaderBuilder, EmbeddingLoaderType,
4 EmbeddingSpecificConfig, Loader, ModelKind, ModelPaths, NormalLoaderBuilder, NormalLoaderType,
5 NormalSpecificConfig, SpeechLoader, TokenSource, VisionLoaderBuilder, VisionLoaderType,
6 VisionSpecificConfig,
7};
8use crate::utils::{progress::ProgressScopeGuard, tokens::get_token};
9use crate::Ordering;
10use crate::{DeviceMapSetting, IsqType, PagedAttentionConfig, Pipeline, TryIntoDType};
11use anyhow::Result;
12use candle_core::Device;
13use hf_hub::{
14 api::sync::{ApiBuilder, ApiError, ApiRepo},
15 Cache, Repo, RepoType,
16};
17use serde::Deserialize;
18use std::io;
19use std::path::Path;
20use std::path::PathBuf;
21use std::sync::Arc;
22use std::sync::Mutex;
23use tracing::{debug, info, warn};
24
25pub struct AutoLoader {
27 model_id: String,
28 normal_builder: Mutex<Option<NormalLoaderBuilder>>,
29 vision_builder: Mutex<Option<VisionLoaderBuilder>>,
30 embedding_builder: Mutex<Option<EmbeddingLoaderBuilder>>,
31 loader: Mutex<Option<Box<dyn Loader>>>,
32 hf_cache_path: Option<PathBuf>,
33}
34
35pub struct AutoLoaderBuilder {
36 normal_cfg: NormalSpecificConfig,
37 vision_cfg: VisionSpecificConfig,
38 embedding_cfg: EmbeddingSpecificConfig,
39 chat_template: Option<String>,
40 tokenizer_json: Option<String>,
41 model_id: String,
42 jinja_explicit: Option<String>,
43 no_kv_cache: bool,
44 xlora_model_id: Option<String>,
45 xlora_order: Option<Ordering>,
46 tgt_non_granular_index: Option<usize>,
47 lora_adapter_ids: Option<Vec<String>>,
48 hf_cache_path: Option<PathBuf>,
49}
50
51impl AutoLoaderBuilder {
52 #[allow(clippy::too_many_arguments)]
53 pub fn new(
54 normal_cfg: NormalSpecificConfig,
55 vision_cfg: VisionSpecificConfig,
56 embedding_cfg: EmbeddingSpecificConfig,
57 chat_template: Option<String>,
58 tokenizer_json: Option<String>,
59 model_id: String,
60 no_kv_cache: bool,
61 jinja_explicit: Option<String>,
62 ) -> Self {
63 Self {
64 normal_cfg,
65 vision_cfg,
66 embedding_cfg,
67 chat_template,
68 tokenizer_json,
69 model_id,
70 jinja_explicit,
71 no_kv_cache,
72 xlora_model_id: None,
73 xlora_order: None,
74 tgt_non_granular_index: None,
75 lora_adapter_ids: None,
76 hf_cache_path: None,
77 }
78 }
79
80 pub fn with_xlora(
81 mut self,
82 model_id: String,
83 order: Ordering,
84 no_kv_cache: bool,
85 tgt_non_granular_index: Option<usize>,
86 ) -> Self {
87 self.xlora_model_id = Some(model_id);
88 self.xlora_order = Some(order);
89 self.no_kv_cache = no_kv_cache;
90 self.tgt_non_granular_index = tgt_non_granular_index;
91 self
92 }
93
94 pub fn with_lora(mut self, adapters: Vec<String>) -> Self {
95 self.lora_adapter_ids = Some(adapters);
96 self
97 }
98
99 pub fn hf_cache_path(mut self, path: PathBuf) -> Self {
100 self.hf_cache_path = Some(path);
101 self
102 }
103
104 pub fn build(self) -> Box<dyn Loader> {
105 let Self {
106 normal_cfg,
107 vision_cfg,
108 embedding_cfg,
109 chat_template,
110 tokenizer_json,
111 model_id,
112 jinja_explicit,
113 no_kv_cache,
114 xlora_model_id,
115 xlora_order,
116 tgt_non_granular_index,
117 lora_adapter_ids,
118 hf_cache_path,
119 } = self;
120
121 let mut normal_builder = NormalLoaderBuilder::new(
122 normal_cfg,
123 chat_template.clone(),
124 tokenizer_json.clone(),
125 Some(model_id.clone()),
126 no_kv_cache,
127 jinja_explicit.clone(),
128 );
129 if let (Some(id), Some(ord)) = (xlora_model_id.clone(), xlora_order.clone()) {
130 normal_builder =
131 normal_builder.with_xlora(id, ord, no_kv_cache, tgt_non_granular_index);
132 }
133 if let Some(ref adapters) = lora_adapter_ids {
134 normal_builder = normal_builder.with_lora(adapters.clone());
135 }
136 if let Some(ref path) = hf_cache_path {
137 normal_builder = normal_builder.hf_cache_path(path.clone());
138 }
139
140 let mut vision_builder = VisionLoaderBuilder::new(
141 vision_cfg,
142 chat_template,
143 tokenizer_json.clone(),
144 Some(model_id.clone()),
145 jinja_explicit,
146 );
147 if let Some(ref adapters) = lora_adapter_ids {
148 vision_builder = vision_builder.with_lora(adapters.clone());
149 }
150 if let Some(ref path) = hf_cache_path {
151 vision_builder = vision_builder.hf_cache_path(path.clone());
152 }
153
154 let mut embedding_builder =
155 EmbeddingLoaderBuilder::new(embedding_cfg, tokenizer_json, Some(model_id.clone()));
156 if let Some(ref adapters) = lora_adapter_ids {
157 embedding_builder = embedding_builder.with_lora(adapters.clone());
158 }
159 if let Some(ref path) = hf_cache_path {
160 embedding_builder = embedding_builder.hf_cache_path(path.clone());
161 }
162
163 Box::new(AutoLoader {
164 model_id,
165 normal_builder: Mutex::new(Some(normal_builder)),
166 vision_builder: Mutex::new(Some(vision_builder)),
167 embedding_builder: Mutex::new(Some(embedding_builder)),
168 loader: Mutex::new(None),
169 hf_cache_path,
170 })
171 }
172}
173
174#[derive(Deserialize)]
175struct AutoConfig {
176 #[serde(default)]
177 architectures: Vec<String>,
178}
179
180struct ConfigArtifacts {
181 contents: Option<String>,
182 sentence_transformers_present: bool,
183 repo_files: Vec<String>,
184 remote_access_issue: Option<RemoteAccessIssue>,
185}
186
187enum Detected {
188 Normal(NormalLoaderType),
189 Vision(VisionLoaderType),
190 Embedding(Option<EmbeddingLoaderType>),
191 Diffusion(DiffusionLoaderType),
192 Speech(crate::speech_models::SpeechLoaderType),
193}
194
195impl AutoLoader {
196 fn try_get_file(
197 api: &ApiRepo,
198 model_id: &Path,
199 file: &str,
200 ) -> std::result::Result<Option<PathBuf>, ApiError> {
201 if model_id.exists() {
202 let path = model_id.join(file);
203 if path.exists() {
204 info!("Loading `{}` locally at `{}`", file, path.display());
205 Ok(Some(path))
206 } else {
207 Ok(None)
208 }
209 } else {
210 api.get(file).map(Some)
211 }
212 }
213
214 fn list_local_repo_files(model_root: &Path) -> Vec<String> {
215 fn collect_files(root: &Path, dir: &Path, out: &mut Vec<String>) -> io::Result<()> {
216 for entry in std::fs::read_dir(dir)? {
217 let entry = entry?;
218 let path = entry.path();
219 if path.is_dir() {
220 collect_files(root, &path, out)?;
221 } else if let Ok(rel) = path.strip_prefix(root) {
222 out.push(rel.to_string_lossy().replace('\\', "/"));
223 }
224 }
225 Ok(())
226 }
227
228 if !model_root.is_dir() {
229 return Vec::new();
230 }
231
232 let mut files = Vec::new();
233 if collect_files(model_root, model_root, &mut files).is_err() {
234 return Vec::new();
235 }
236 files
237 }
238
239 fn read_config_from_path(&self, paths: &dyn ModelPaths) -> Result<ConfigArtifacts> {
240 let config_path = paths.get_config_filename();
241 let contents = match std::fs::read_to_string(config_path) {
242 Ok(contents) => Some(contents),
243 Err(err) if err.kind() == io::ErrorKind::NotFound => None,
244 Err(err) => return Err(err.into()),
245 };
246 let model_root = Path::new(&self.model_id);
247 let repo_files = if model_root.exists() {
248 Self::list_local_repo_files(model_root)
249 } else {
250 Vec::new()
251 };
252 let sentence_transformers_present = Self::has_sentence_transformers_sibling(config_path)
253 || repo_files
254 .iter()
255 .any(|f| f == "config_sentence_transformers.json");
256 Ok(ConfigArtifacts {
257 contents,
258 sentence_transformers_present,
259 repo_files,
260 remote_access_issue: None,
261 })
262 }
263
264 fn read_config_from_hf(
265 &self,
266 revision: Option<String>,
267 token_source: &TokenSource,
268 silent: bool,
269 ) -> Result<ConfigArtifacts> {
270 let cache = self
271 .hf_cache_path
272 .clone()
273 .map(Cache::new)
274 .unwrap_or_default();
275 let mut api = ApiBuilder::from_cache(cache)
276 .with_progress(!silent)
277 .with_token(get_token(token_source)?);
278 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
279 api = api.with_cache_dir(cache_dir);
280 }
281 let api = api.build()?;
282 let revision = revision.unwrap_or_else(|| "main".to_string());
283 let api = api.repo(Repo::with_revision(
284 self.model_id.clone(),
285 RepoType::Model,
286 revision,
287 ));
288 let model_id = Path::new(&self.model_id);
289 let mut remote_access_issue = None;
290 let contents = match Self::try_get_file(&api, model_id, "config.json") {
291 Ok(Some(path)) => Some(std::fs::read_to_string(&path)?),
292 Ok(None) => None,
293 Err(err) => {
294 let issue = remote_issue_from_api_error(model_id, Some("config.json"), &err);
295 warn!(
296 "Auto loader could not fetch `config.json` for `{}`: {}",
297 self.model_id, issue.message
298 );
299 remote_access_issue = Some(issue);
300 None
301 }
302 };
303 let sentence_transformers_present =
304 model_id.join("config_sentence_transformers.json").exists()
305 || Self::fetch_sentence_transformers_config(&api, model_id);
306 let repo_files = if model_id.exists() {
307 Self::list_local_repo_files(model_id)
308 } else {
309 crate::api_dir_list!(api, model_id, false).collect::<Vec<_>>()
310 };
311 Ok(ConfigArtifacts {
312 contents,
313 sentence_transformers_present,
314 repo_files,
315 remote_access_issue,
316 })
317 }
318
319 fn has_sentence_transformers_sibling(config_path: &Path) -> bool {
320 config_path
321 .parent()
322 .map(|parent| parent.join("config_sentence_transformers.json").exists())
323 .unwrap_or(false)
324 }
325
326 fn fetch_sentence_transformers_config(api: &ApiRepo, model_id: &Path) -> bool {
327 if model_id.exists() {
328 return false;
329 }
330 match api.get("config_sentence_transformers.json") {
331 Ok(_) => true,
332 Err(err) => {
333 debug!(
334 "No `config_sentence_transformers.json` found for `{}`: {err}",
335 model_id.display()
336 );
337 false
338 }
339 }
340 }
341
342 fn detect(&self, artifacts: &ConfigArtifacts) -> Result<Detected> {
343 if let Some(tp) = DiffusionLoaderType::auto_detect_from_files(&artifacts.repo_files) {
344 return Ok(Detected::Diffusion(tp));
345 }
346
347 if let Some(ref config) = artifacts.contents {
348 if let Some(tp) =
349 crate::speech_models::SpeechLoaderType::auto_detect_from_config(config)
350 {
351 return Ok(Detected::Speech(tp));
352 }
353 }
354
355 if artifacts.sentence_transformers_present {
356 if let Some(ref config) = artifacts.contents {
357 let cfg: AutoConfig = serde_json::from_str(config)?;
358 if let Some(name) = cfg.architectures.first() {
359 if let Ok(tp) = EmbeddingLoaderType::from_causal_lm_name(name) {
360 info!(
361 "Detected `config_sentence_transformers.json`; using embedding loader `{tp}`."
362 );
363 return Ok(Detected::Embedding(Some(tp)));
364 }
365 }
366 }
367 if artifacts.contents.is_none() {
368 if let Some(issue) = artifacts.remote_access_issue.as_ref() {
369 return Err(hf_access_error(Path::new(&self.model_id), issue));
370 }
371 }
372 info!(
373 "Detected `config_sentence_transformers.json`; routing via auto embedding loader."
374 );
375 return Ok(Detected::Embedding(None));
376 }
377
378 let config = artifacts.contents.as_ref().ok_or_else(|| {
379 if let Some(issue) = artifacts.remote_access_issue.as_ref() {
380 hf_access_error(Path::new(&self.model_id), issue)
381 } else {
382 anyhow::anyhow!(
383 "Auto loader could not determine model type: missing `config.json` and no diffusion/speech markers found."
384 )
385 }
386 })?;
387 let cfg: AutoConfig = serde_json::from_str(config)?;
388 if cfg.architectures.len() != 1 {
389 anyhow::bail!("Expected exactly one architecture in config");
390 }
391 let name = &cfg.architectures[0];
392 if let Ok(tp) = VisionLoaderType::from_causal_lm_name(name) {
393 return Ok(Detected::Vision(tp));
394 }
395 let tp = NormalLoaderType::from_causal_lm_name(name)?;
396 Ok(Detected::Normal(tp))
397 }
398
399 fn ensure_loader(&self, artifacts: &ConfigArtifacts) -> Result<()> {
400 let mut guard = self.loader.lock().unwrap();
401 if guard.is_some() {
402 return Ok(());
403 }
404 match self.detect(artifacts)? {
405 Detected::Normal(tp) => {
406 let builder = self
407 .normal_builder
408 .lock()
409 .unwrap()
410 .take()
411 .expect("builder taken");
412 let loader = builder.build(Some(tp)).expect("build normal");
413 *guard = Some(loader);
414 }
415 Detected::Vision(tp) => {
416 let builder = self
417 .vision_builder
418 .lock()
419 .unwrap()
420 .take()
421 .expect("builder taken");
422 let loader = builder.build(Some(tp));
423 *guard = Some(loader);
424 }
425 Detected::Embedding(tp) => {
426 let builder = self
427 .embedding_builder
428 .lock()
429 .unwrap()
430 .take()
431 .expect("builder taken");
432 let loader = builder.build(tp);
433 *guard = Some(loader);
434 }
435 Detected::Diffusion(tp) => {
436 let loader = DiffusionLoaderBuilder::new(Some(self.model_id.clone())).build(tp);
437 *guard = Some(loader);
438 }
439 Detected::Speech(tp) => {
440 let loader: Box<dyn Loader> = Box::new(SpeechLoader {
441 model_id: self.model_id.clone(),
442 dac_model_id: None,
443 arch: tp,
444 cfg: None,
445 });
446 *guard = Some(loader);
447 }
448 }
449 Ok(())
450 }
451}
452
453impl Loader for AutoLoader {
454 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
455 fn load_model_from_hf(
456 &self,
457 revision: Option<String>,
458 token_source: TokenSource,
459 dtype: &dyn TryIntoDType,
460 device: &Device,
461 silent: bool,
462 mapper: DeviceMapSetting,
463 in_situ_quant: Option<IsqType>,
464 paged_attn_config: Option<PagedAttentionConfig>,
465 ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
466 let _progress_guard = ProgressScopeGuard::new(silent);
467 let config = self.read_config_from_hf(revision.clone(), &token_source, silent)?;
468 self.ensure_loader(&config)?;
469 self.loader
470 .lock()
471 .unwrap()
472 .as_ref()
473 .unwrap()
474 .load_model_from_hf(
475 revision,
476 token_source,
477 dtype,
478 device,
479 silent,
480 mapper,
481 in_situ_quant,
482 paged_attn_config,
483 )
484 }
485
486 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
487 fn load_model_from_path(
488 &self,
489 paths: &Box<dyn ModelPaths>,
490 dtype: &dyn TryIntoDType,
491 device: &Device,
492 silent: bool,
493 mapper: DeviceMapSetting,
494 in_situ_quant: Option<IsqType>,
495 paged_attn_config: Option<PagedAttentionConfig>,
496 ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
497 let _progress_guard = ProgressScopeGuard::new(silent);
498 let config = self.read_config_from_path(paths.as_ref())?;
499 self.ensure_loader(&config)?;
500 self.loader
501 .lock()
502 .unwrap()
503 .as_ref()
504 .unwrap()
505 .load_model_from_path(
506 paths,
507 dtype,
508 device,
509 silent,
510 mapper,
511 in_situ_quant,
512 paged_attn_config,
513 )
514 }
515
516 fn get_id(&self) -> String {
517 self.model_id.clone()
518 }
519
520 fn get_kind(&self) -> ModelKind {
521 self.loader
522 .lock()
523 .unwrap()
524 .as_ref()
525 .map(|l| l.get_kind())
526 .unwrap_or(ModelKind::Normal)
527 }
528}