1use super::hf::{hf_access_error, remote_issue_from_api_error, RemoteAccessIssue};
2use super::{
3 DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoaderBuilder, EmbeddingLoaderType,
4 EmbeddingSpecificConfig, Loader, ModelKind, ModelPaths, MultimodalLoaderBuilder,
5 MultimodalLoaderType, MultimodalSpecificConfig, NormalLoaderBuilder, NormalLoaderType,
6 NormalSpecificConfig, SpeechLoader, TokenSource,
7};
8use crate::utils::{progress::ProgressScopeGuard, tokens::get_token};
9use crate::Ordering;
10use crate::{DeviceMapSetting, IsqType, PagedAttentionConfig, Pipeline, TryIntoDType};
11use anyhow::Result;
12use hanzo_ml::Device;
13use hf_hub::{
14 api::sync::{ApiBuilder, ApiError, ApiRepo},
15 Cache, Repo, RepoType,
16};
17use serde::Deserialize;
18use std::io;
19use std::path::Path;
20use std::path::PathBuf;
21use std::sync::Arc;
22use std::sync::Mutex;
23use tracing::{debug, info, warn};
24
25pub struct AutoLoader {
27 model_id: String,
28 normal_builder: Mutex<Option<NormalLoaderBuilder>>,
29 multimodal_builder: Mutex<Option<MultimodalLoaderBuilder>>,
30 embedding_builder: Mutex<Option<EmbeddingLoaderBuilder>>,
31 loader: Mutex<Option<Box<dyn Loader>>>,
32 hf_cache_path: Option<PathBuf>,
33}
34
35pub struct AutoLoaderBuilder {
36 normal_cfg: NormalSpecificConfig,
37 multimodal_cfg: MultimodalSpecificConfig,
38 embedding_cfg: EmbeddingSpecificConfig,
39 chat_template: Option<String>,
40 tokenizer_json: Option<String>,
41 model_id: String,
42 jinja_explicit: Option<String>,
43 no_kv_cache: bool,
44 xlora_model_id: Option<String>,
45 xlora_order: Option<Ordering>,
46 tgt_non_granular_index: Option<usize>,
47 lora_adapter_ids: Option<Vec<String>>,
48 hf_cache_path: Option<PathBuf>,
49}
50
51impl AutoLoaderBuilder {
52 #[allow(clippy::too_many_arguments)]
53 pub fn new(
54 normal_cfg: NormalSpecificConfig,
55 multimodal_cfg: MultimodalSpecificConfig,
56 embedding_cfg: EmbeddingSpecificConfig,
57 chat_template: Option<String>,
58 tokenizer_json: Option<String>,
59 model_id: String,
60 no_kv_cache: bool,
61 jinja_explicit: Option<String>,
62 ) -> Self {
63 Self {
64 normal_cfg,
65 multimodal_cfg,
66 embedding_cfg,
67 chat_template,
68 tokenizer_json,
69 model_id,
70 jinja_explicit,
71 no_kv_cache,
72 xlora_model_id: None,
73 xlora_order: None,
74 tgt_non_granular_index: None,
75 lora_adapter_ids: None,
76 hf_cache_path: None,
77 }
78 }
79
80 pub fn with_xlora(
81 mut self,
82 model_id: String,
83 order: Ordering,
84 no_kv_cache: bool,
85 tgt_non_granular_index: Option<usize>,
86 ) -> Self {
87 self.xlora_model_id = Some(model_id);
88 self.xlora_order = Some(order);
89 self.no_kv_cache = no_kv_cache;
90 self.tgt_non_granular_index = tgt_non_granular_index;
91 self
92 }
93
94 pub fn with_lora(mut self, adapters: Vec<String>) -> Self {
95 self.lora_adapter_ids = Some(adapters);
96 self
97 }
98
99 pub fn hf_cache_path(mut self, path: PathBuf) -> Self {
100 self.hf_cache_path = Some(path);
101 self
102 }
103
104 pub fn build(self) -> Box<dyn Loader> {
105 let Self {
106 normal_cfg,
107 multimodal_cfg,
108 embedding_cfg,
109 chat_template,
110 tokenizer_json,
111 model_id,
112 jinja_explicit,
113 no_kv_cache,
114 xlora_model_id,
115 xlora_order,
116 tgt_non_granular_index,
117 lora_adapter_ids,
118 hf_cache_path,
119 } = self;
120
121 let mut normal_builder = NormalLoaderBuilder::new(
122 normal_cfg,
123 chat_template.clone(),
124 tokenizer_json.clone(),
125 Some(model_id.clone()),
126 no_kv_cache,
127 jinja_explicit.clone(),
128 );
129 if let (Some(id), Some(ord)) = (xlora_model_id.clone(), xlora_order.clone()) {
130 normal_builder =
131 normal_builder.with_xlora(id, ord, no_kv_cache, tgt_non_granular_index);
132 }
133 if let Some(ref adapters) = lora_adapter_ids {
134 normal_builder = normal_builder.with_lora(adapters.clone());
135 }
136 if let Some(ref path) = hf_cache_path {
137 normal_builder = normal_builder.hf_cache_path(path.clone());
138 }
139
140 let mut multimodal_builder = MultimodalLoaderBuilder::new(
141 multimodal_cfg,
142 chat_template,
143 tokenizer_json.clone(),
144 Some(model_id.clone()),
145 jinja_explicit,
146 );
147 if let Some(ref adapters) = lora_adapter_ids {
148 multimodal_builder = multimodal_builder.with_lora(adapters.clone());
149 }
150 if let Some(ref path) = hf_cache_path {
151 multimodal_builder = multimodal_builder.hf_cache_path(path.clone());
152 }
153
154 let mut embedding_builder =
155 EmbeddingLoaderBuilder::new(embedding_cfg, tokenizer_json, Some(model_id.clone()));
156 if let Some(ref adapters) = lora_adapter_ids {
157 embedding_builder = embedding_builder.with_lora(adapters.clone());
158 }
159 if let Some(ref path) = hf_cache_path {
160 embedding_builder = embedding_builder.hf_cache_path(path.clone());
161 }
162
163 Box::new(AutoLoader {
164 model_id,
165 normal_builder: Mutex::new(Some(normal_builder)),
166 multimodal_builder: Mutex::new(Some(multimodal_builder)),
167 embedding_builder: Mutex::new(Some(embedding_builder)),
168 loader: Mutex::new(None),
169 hf_cache_path,
170 })
171 }
172}
173
174#[derive(Deserialize)]
175struct AutoConfig {
176 #[serde(default)]
177 architectures: Vec<String>,
178}
179
180struct ConfigArtifacts {
181 contents: Option<String>,
182 sentence_transformers_present: bool,
183 repo_files: Vec<String>,
184 remote_access_issue: Option<RemoteAccessIssue>,
185}
186
187enum Detected {
188 Normal(NormalLoaderType),
189 Multimodal(MultimodalLoaderType),
190 Embedding(Option<EmbeddingLoaderType>),
191 Diffusion(DiffusionLoaderType),
192 Speech(crate::speech_models::SpeechLoaderType),
193}
194
195impl AutoLoader {
196 fn try_get_file(
197 api: &ApiRepo,
198 model_id: &Path,
199 file: &str,
200 revision: &str,
201 ) -> std::result::Result<Option<PathBuf>, ApiError> {
202 crate::pipeline::hf::try_get_file(api, model_id, file, revision)
203 }
204
205 fn list_local_repo_files(model_root: &Path) -> Vec<String> {
206 fn collect_files(root: &Path, dir: &Path, out: &mut Vec<String>) -> io::Result<()> {
207 for entry in std::fs::read_dir(dir)? {
208 let entry = entry?;
209 let path = entry.path();
210 if path.is_dir() {
211 collect_files(root, &path, out)?;
212 } else if let Ok(rel) = path.strip_prefix(root) {
213 out.push(rel.to_string_lossy().replace('\\', "/"));
214 }
215 }
216 Ok(())
217 }
218
219 if !model_root.is_dir() {
220 return Vec::new();
221 }
222
223 let mut files = Vec::new();
224 if collect_files(model_root, model_root, &mut files).is_err() {
225 return Vec::new();
226 }
227 files
228 }
229
230 fn read_config_from_path(&self, paths: &dyn ModelPaths) -> Result<ConfigArtifacts> {
231 let config_path = paths.get_config_filename();
232 let contents = match std::fs::read_to_string(config_path) {
233 Ok(contents) => Some(contents),
234 Err(err) if err.kind() == io::ErrorKind::NotFound => None,
235 Err(err) => return Err(err.into()),
236 };
237 let model_root = Path::new(&self.model_id);
238 let repo_files = if model_root.exists() {
239 Self::list_local_repo_files(model_root)
240 } else {
241 Vec::new()
242 };
243 let sentence_transformers_present = Self::has_sentence_transformers_sibling(config_path)
244 || repo_files
245 .iter()
246 .any(|f| f == "config_sentence_transformers.json");
247 Ok(ConfigArtifacts {
248 contents,
249 sentence_transformers_present,
250 repo_files,
251 remote_access_issue: None,
252 })
253 }
254
255 fn read_config_from_hf(
256 &self,
257 revision: Option<String>,
258 token_source: &TokenSource,
259 silent: bool,
260 ) -> Result<ConfigArtifacts> {
261 let cache = self
262 .hf_cache_path
263 .clone()
264 .map(Cache::new)
265 .unwrap_or_default();
266 let mut api = ApiBuilder::from_cache(cache)
267 .with_progress(!silent)
268 .with_token(get_token(token_source)?);
269 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
270 api = api.with_cache_dir(cache_dir);
271 }
272 let api = api.build()?;
273 let revision = revision.unwrap_or_else(|| "main".to_string());
274 let api = api.repo(Repo::with_revision(
275 self.model_id.clone(),
276 RepoType::Model,
277 revision.clone(),
278 ));
279 let model_id = Path::new(&self.model_id);
280 let mut remote_access_issue = None;
281 let contents = match Self::try_get_file(&api, model_id, "config.json", &revision) {
282 Ok(Some(path)) => Some(std::fs::read_to_string(&path)?),
283 Ok(None) => None,
284 Err(err) => {
285 let issue = remote_issue_from_api_error(model_id, Some("config.json"), &err);
286 warn!(
287 "Auto loader could not fetch `config.json` for `{}`: {}",
288 self.model_id, issue.message
289 );
290 remote_access_issue = Some(issue);
291 None
292 }
293 };
294 let sentence_transformers_present =
295 model_id.join("config_sentence_transformers.json").exists()
296 || Self::fetch_sentence_transformers_config(&api, model_id, &revision);
297 let repo_files = if model_id.exists() {
298 Self::list_local_repo_files(model_id)
299 } else {
300 crate::api_dir_list!(api, model_id, false, &revision).collect::<Vec<_>>()
301 };
302 Ok(ConfigArtifacts {
303 contents,
304 sentence_transformers_present,
305 repo_files,
306 remote_access_issue,
307 })
308 }
309
310 fn has_sentence_transformers_sibling(config_path: &Path) -> bool {
311 config_path
312 .parent()
313 .map(|parent| parent.join("config_sentence_transformers.json").exists())
314 .unwrap_or(false)
315 }
316
317 fn fetch_sentence_transformers_config(api: &ApiRepo, model_id: &Path, revision: &str) -> bool {
318 match crate::pipeline::hf::try_get_file(
319 api,
320 model_id,
321 "config_sentence_transformers.json",
322 revision,
323 ) {
324 Ok(Some(_)) => true,
325 Ok(None) => false,
326 Err(err) => {
327 debug!(
328 "No `config_sentence_transformers.json` found for `{}`: {err}",
329 model_id.display()
330 );
331 false
332 }
333 }
334 }
335
336 fn detect(&self, artifacts: &ConfigArtifacts) -> Result<Detected> {
337 if let Some(tp) = DiffusionLoaderType::auto_detect_from_files(&artifacts.repo_files) {
338 return Ok(Detected::Diffusion(tp));
339 }
340
341 if let Some(ref config) = artifacts.contents {
342 if let Some(tp) =
343 crate::speech_models::SpeechLoaderType::auto_detect_from_config(config)
344 {
345 return Ok(Detected::Speech(tp));
346 }
347 }
348
349 if artifacts.sentence_transformers_present {
350 if let Some(ref config) = artifacts.contents {
351 let cfg: AutoConfig = serde_json::from_str(config)?;
352 if let Some(name) = cfg.architectures.first() {
353 if let Ok(tp) = EmbeddingLoaderType::from_causal_lm_name(name) {
354 info!(
355 "Detected `config_sentence_transformers.json`; using embedding loader `{tp}`."
356 );
357 return Ok(Detected::Embedding(Some(tp)));
358 }
359 }
360 }
361 if artifacts.contents.is_none() {
362 if let Some(issue) = artifacts.remote_access_issue.as_ref() {
363 return Err(hf_access_error(Path::new(&self.model_id), issue));
364 }
365 }
366 info!(
367 "Detected `config_sentence_transformers.json`; routing via auto embedding loader."
368 );
369 return Ok(Detected::Embedding(None));
370 }
371
372 if artifacts.contents.is_none() && artifacts.repo_files.iter().any(|f| f == "params.json") {
374 info!("Detected `params.json` in repo; routing as Voxtral.");
376 return Ok(Detected::Multimodal(MultimodalLoaderType::Voxtral));
377 }
378
379 let config = artifacts.contents.as_ref().ok_or_else(|| {
380 if let Some(issue) = artifacts.remote_access_issue.as_ref() {
381 hf_access_error(Path::new(&self.model_id), issue)
382 } else {
383 anyhow::anyhow!(
384 "Auto loader could not determine model type: missing `config.json` and no diffusion/speech markers found."
385 )
386 }
387 })?;
388 let cfg: AutoConfig = serde_json::from_str(config)?;
389 if cfg.architectures.len() != 1 {
390 anyhow::bail!("Expected exactly one architecture in config");
391 }
392 let name = &cfg.architectures[0];
393 if let Ok(tp) = MultimodalLoaderType::from_causal_lm_name(name) {
394 return Ok(Detected::Multimodal(tp));
395 }
396 let tp = NormalLoaderType::from_causal_lm_name(name)?;
397 Ok(Detected::Normal(tp))
398 }
399
400 fn ensure_loader(&self, artifacts: &ConfigArtifacts) -> Result<()> {
401 let mut guard = self.loader.lock().unwrap();
402 if guard.is_some() {
403 return Ok(());
404 }
405 match self.detect(artifacts)? {
406 Detected::Normal(tp) => {
407 let builder = self
408 .normal_builder
409 .lock()
410 .unwrap()
411 .take()
412 .expect("builder taken");
413 let loader = builder.build(Some(tp)).expect("build normal");
414 *guard = Some(loader);
415 }
416 Detected::Multimodal(tp) => {
417 let builder = self
418 .multimodal_builder
419 .lock()
420 .unwrap()
421 .take()
422 .expect("builder taken");
423 let loader = builder.build(Some(tp));
424 *guard = Some(loader);
425 }
426 Detected::Embedding(tp) => {
427 let builder = self
428 .embedding_builder
429 .lock()
430 .unwrap()
431 .take()
432 .expect("builder taken");
433 let loader = builder.build(tp);
434 *guard = Some(loader);
435 }
436 Detected::Diffusion(tp) => {
437 let loader = DiffusionLoaderBuilder::new(Some(self.model_id.clone())).build(tp);
438 *guard = Some(loader);
439 }
440 Detected::Speech(tp) => {
441 let loader: Box<dyn Loader> = Box::new(SpeechLoader {
442 model_id: self.model_id.clone(),
443 dac_model_id: None,
444 arch: tp,
445 cfg: None,
446 });
447 *guard = Some(loader);
448 }
449 }
450 Ok(())
451 }
452}
453
454impl Loader for AutoLoader {
455 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
456 fn load_model_from_hf(
457 &self,
458 revision: Option<String>,
459 token_source: TokenSource,
460 dtype: &dyn TryIntoDType,
461 device: &Device,
462 silent: bool,
463 mapper: DeviceMapSetting,
464 in_situ_quant: Option<IsqType>,
465 paged_attn_config: Option<PagedAttentionConfig>,
466 ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
467 let _progress_guard = ProgressScopeGuard::new(silent);
468 let config = self.read_config_from_hf(revision.clone(), &token_source, silent)?;
469 self.ensure_loader(&config)?;
470 self.loader
471 .lock()
472 .unwrap()
473 .as_ref()
474 .unwrap()
475 .load_model_from_hf(
476 revision,
477 token_source,
478 dtype,
479 device,
480 silent,
481 mapper,
482 in_situ_quant,
483 paged_attn_config,
484 )
485 }
486
487 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
488 fn load_model_from_path(
489 &self,
490 paths: &Box<dyn ModelPaths>,
491 dtype: &dyn TryIntoDType,
492 device: &Device,
493 silent: bool,
494 mapper: DeviceMapSetting,
495 in_situ_quant: Option<IsqType>,
496 paged_attn_config: Option<PagedAttentionConfig>,
497 ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
498 let _progress_guard = ProgressScopeGuard::new(silent);
499 let config = self.read_config_from_path(paths.as_ref())?;
500 self.ensure_loader(&config)?;
501 self.loader
502 .lock()
503 .unwrap()
504 .as_ref()
505 .unwrap()
506 .load_model_from_path(
507 paths,
508 dtype,
509 device,
510 silent,
511 mapper,
512 in_situ_quant,
513 paged_attn_config,
514 )
515 }
516
517 fn get_id(&self) -> String {
518 self.model_id.clone()
519 }
520
521 fn get_kind(&self) -> ModelKind {
522 self.loader
523 .lock()
524 .unwrap()
525 .as_ref()
526 .map(|l| l.get_kind())
527 .unwrap_or(ModelKind::Normal)
528 }
529}