Skip to main content

hanzo_engine/pipeline/
auto.rs

1use super::hf::{hf_access_error, remote_issue_from_api_error, RemoteAccessIssue};
2use super::{
3    DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoaderBuilder, EmbeddingLoaderType,
4    EmbeddingSpecificConfig, Loader, ModelKind, ModelPaths, MultimodalLoaderBuilder,
5    MultimodalLoaderType, MultimodalSpecificConfig, NormalLoaderBuilder, NormalLoaderType,
6    NormalSpecificConfig, SpeechLoader, TokenSource,
7};
8use crate::utils::{progress::ProgressScopeGuard, tokens::get_token};
9use crate::Ordering;
10use crate::{DeviceMapSetting, IsqType, PagedAttentionConfig, Pipeline, TryIntoDType};
11use anyhow::Result;
12use hanzo_ml::Device;
13use hf_hub::{
14    api::sync::{ApiBuilder, ApiError, ApiRepo},
15    Cache, Repo, RepoType,
16};
17use serde::Deserialize;
18use std::io;
19use std::path::Path;
20use std::path::PathBuf;
21use std::sync::Arc;
22use std::sync::Mutex;
23use tracing::{debug, info, warn};
24
25/// Automatically selects the appropriate loader based on repository/config metadata.
26pub struct AutoLoader {
27    model_id: String,
28    normal_builder: Mutex<Option<NormalLoaderBuilder>>,
29    multimodal_builder: Mutex<Option<MultimodalLoaderBuilder>>,
30    embedding_builder: Mutex<Option<EmbeddingLoaderBuilder>>,
31    loader: Mutex<Option<Box<dyn Loader>>>,
32    hf_cache_path: Option<PathBuf>,
33}
34
35pub struct AutoLoaderBuilder {
36    normal_cfg: NormalSpecificConfig,
37    multimodal_cfg: MultimodalSpecificConfig,
38    embedding_cfg: EmbeddingSpecificConfig,
39    chat_template: Option<String>,
40    tokenizer_json: Option<String>,
41    model_id: String,
42    jinja_explicit: Option<String>,
43    no_kv_cache: bool,
44    xlora_model_id: Option<String>,
45    xlora_order: Option<Ordering>,
46    tgt_non_granular_index: Option<usize>,
47    lora_adapter_ids: Option<Vec<String>>,
48    hf_cache_path: Option<PathBuf>,
49}
50
51impl AutoLoaderBuilder {
52    #[allow(clippy::too_many_arguments)]
53    pub fn new(
54        normal_cfg: NormalSpecificConfig,
55        multimodal_cfg: MultimodalSpecificConfig,
56        embedding_cfg: EmbeddingSpecificConfig,
57        chat_template: Option<String>,
58        tokenizer_json: Option<String>,
59        model_id: String,
60        no_kv_cache: bool,
61        jinja_explicit: Option<String>,
62    ) -> Self {
63        Self {
64            normal_cfg,
65            multimodal_cfg,
66            embedding_cfg,
67            chat_template,
68            tokenizer_json,
69            model_id,
70            jinja_explicit,
71            no_kv_cache,
72            xlora_model_id: None,
73            xlora_order: None,
74            tgt_non_granular_index: None,
75            lora_adapter_ids: None,
76            hf_cache_path: None,
77        }
78    }
79
80    pub fn with_xlora(
81        mut self,
82        model_id: String,
83        order: Ordering,
84        no_kv_cache: bool,
85        tgt_non_granular_index: Option<usize>,
86    ) -> Self {
87        self.xlora_model_id = Some(model_id);
88        self.xlora_order = Some(order);
89        self.no_kv_cache = no_kv_cache;
90        self.tgt_non_granular_index = tgt_non_granular_index;
91        self
92    }
93
94    pub fn with_lora(mut self, adapters: Vec<String>) -> Self {
95        self.lora_adapter_ids = Some(adapters);
96        self
97    }
98
99    pub fn hf_cache_path(mut self, path: PathBuf) -> Self {
100        self.hf_cache_path = Some(path);
101        self
102    }
103
104    pub fn build(self) -> Box<dyn Loader> {
105        let Self {
106            normal_cfg,
107            multimodal_cfg,
108            embedding_cfg,
109            chat_template,
110            tokenizer_json,
111            model_id,
112            jinja_explicit,
113            no_kv_cache,
114            xlora_model_id,
115            xlora_order,
116            tgt_non_granular_index,
117            lora_adapter_ids,
118            hf_cache_path,
119        } = self;
120
121        let mut normal_builder = NormalLoaderBuilder::new(
122            normal_cfg,
123            chat_template.clone(),
124            tokenizer_json.clone(),
125            Some(model_id.clone()),
126            no_kv_cache,
127            jinja_explicit.clone(),
128        );
129        if let (Some(id), Some(ord)) = (xlora_model_id.clone(), xlora_order.clone()) {
130            normal_builder =
131                normal_builder.with_xlora(id, ord, no_kv_cache, tgt_non_granular_index);
132        }
133        if let Some(ref adapters) = lora_adapter_ids {
134            normal_builder = normal_builder.with_lora(adapters.clone());
135        }
136        if let Some(ref path) = hf_cache_path {
137            normal_builder = normal_builder.hf_cache_path(path.clone());
138        }
139
140        let mut multimodal_builder = MultimodalLoaderBuilder::new(
141            multimodal_cfg,
142            chat_template,
143            tokenizer_json.clone(),
144            Some(model_id.clone()),
145            jinja_explicit,
146        );
147        if let Some(ref adapters) = lora_adapter_ids {
148            multimodal_builder = multimodal_builder.with_lora(adapters.clone());
149        }
150        if let Some(ref path) = hf_cache_path {
151            multimodal_builder = multimodal_builder.hf_cache_path(path.clone());
152        }
153
154        let mut embedding_builder =
155            EmbeddingLoaderBuilder::new(embedding_cfg, tokenizer_json, Some(model_id.clone()));
156        if let Some(ref adapters) = lora_adapter_ids {
157            embedding_builder = embedding_builder.with_lora(adapters.clone());
158        }
159        if let Some(ref path) = hf_cache_path {
160            embedding_builder = embedding_builder.hf_cache_path(path.clone());
161        }
162
163        Box::new(AutoLoader {
164            model_id,
165            normal_builder: Mutex::new(Some(normal_builder)),
166            multimodal_builder: Mutex::new(Some(multimodal_builder)),
167            embedding_builder: Mutex::new(Some(embedding_builder)),
168            loader: Mutex::new(None),
169            hf_cache_path,
170        })
171    }
172}
173
174#[derive(Deserialize)]
175struct AutoConfig {
176    #[serde(default)]
177    architectures: Vec<String>,
178}
179
180struct ConfigArtifacts {
181    contents: Option<String>,
182    sentence_transformers_present: bool,
183    repo_files: Vec<String>,
184    remote_access_issue: Option<RemoteAccessIssue>,
185}
186
187enum Detected {
188    Normal(NormalLoaderType),
189    Multimodal(MultimodalLoaderType),
190    Embedding(Option<EmbeddingLoaderType>),
191    Diffusion(DiffusionLoaderType),
192    Speech(crate::speech_models::SpeechLoaderType),
193}
194
195impl AutoLoader {
196    fn try_get_file(
197        api: &ApiRepo,
198        model_id: &Path,
199        file: &str,
200        revision: &str,
201    ) -> std::result::Result<Option<PathBuf>, ApiError> {
202        crate::pipeline::hf::try_get_file(api, model_id, file, revision)
203    }
204
205    fn list_local_repo_files(model_root: &Path) -> Vec<String> {
206        fn collect_files(root: &Path, dir: &Path, out: &mut Vec<String>) -> io::Result<()> {
207            for entry in std::fs::read_dir(dir)? {
208                let entry = entry?;
209                let path = entry.path();
210                if path.is_dir() {
211                    collect_files(root, &path, out)?;
212                } else if let Ok(rel) = path.strip_prefix(root) {
213                    out.push(rel.to_string_lossy().replace('\\', "/"));
214                }
215            }
216            Ok(())
217        }
218
219        if !model_root.is_dir() {
220            return Vec::new();
221        }
222
223        let mut files = Vec::new();
224        if collect_files(model_root, model_root, &mut files).is_err() {
225            return Vec::new();
226        }
227        files
228    }
229
230    fn read_config_from_path(&self, paths: &dyn ModelPaths) -> Result<ConfigArtifacts> {
231        let config_path = paths.get_config_filename();
232        let contents = match std::fs::read_to_string(config_path) {
233            Ok(contents) => Some(contents),
234            Err(err) if err.kind() == io::ErrorKind::NotFound => None,
235            Err(err) => return Err(err.into()),
236        };
237        let model_root = Path::new(&self.model_id);
238        let repo_files = if model_root.exists() {
239            Self::list_local_repo_files(model_root)
240        } else {
241            Vec::new()
242        };
243        let sentence_transformers_present = Self::has_sentence_transformers_sibling(config_path)
244            || repo_files
245                .iter()
246                .any(|f| f == "config_sentence_transformers.json");
247        Ok(ConfigArtifacts {
248            contents,
249            sentence_transformers_present,
250            repo_files,
251            remote_access_issue: None,
252        })
253    }
254
255    fn read_config_from_hf(
256        &self,
257        revision: Option<String>,
258        token_source: &TokenSource,
259        silent: bool,
260    ) -> Result<ConfigArtifacts> {
261        let cache = self
262            .hf_cache_path
263            .clone()
264            .map(Cache::new)
265            .unwrap_or_default();
266        let mut api = ApiBuilder::from_cache(cache)
267            .with_progress(!silent)
268            .with_token(get_token(token_source)?);
269        if let Some(cache_dir) = crate::hf_hub_cache_dir() {
270            api = api.with_cache_dir(cache_dir);
271        }
272        let api = api.build()?;
273        let revision = revision.unwrap_or_else(|| "main".to_string());
274        let api = api.repo(Repo::with_revision(
275            self.model_id.clone(),
276            RepoType::Model,
277            revision.clone(),
278        ));
279        let model_id = Path::new(&self.model_id);
280        let mut remote_access_issue = None;
281        let contents = match Self::try_get_file(&api, model_id, "config.json", &revision) {
282            Ok(Some(path)) => Some(std::fs::read_to_string(&path)?),
283            Ok(None) => None,
284            Err(err) => {
285                let issue = remote_issue_from_api_error(model_id, Some("config.json"), &err);
286                warn!(
287                    "Auto loader could not fetch `config.json` for `{}`: {}",
288                    self.model_id, issue.message
289                );
290                remote_access_issue = Some(issue);
291                None
292            }
293        };
294        let sentence_transformers_present =
295            model_id.join("config_sentence_transformers.json").exists()
296                || Self::fetch_sentence_transformers_config(&api, model_id, &revision);
297        let repo_files = if model_id.exists() {
298            Self::list_local_repo_files(model_id)
299        } else {
300            crate::api_dir_list!(api, model_id, false, &revision).collect::<Vec<_>>()
301        };
302        Ok(ConfigArtifacts {
303            contents,
304            sentence_transformers_present,
305            repo_files,
306            remote_access_issue,
307        })
308    }
309
310    fn has_sentence_transformers_sibling(config_path: &Path) -> bool {
311        config_path
312            .parent()
313            .map(|parent| parent.join("config_sentence_transformers.json").exists())
314            .unwrap_or(false)
315    }
316
317    fn fetch_sentence_transformers_config(api: &ApiRepo, model_id: &Path, revision: &str) -> bool {
318        match crate::pipeline::hf::try_get_file(
319            api,
320            model_id,
321            "config_sentence_transformers.json",
322            revision,
323        ) {
324            Ok(Some(_)) => true,
325            Ok(None) => false,
326            Err(err) => {
327                debug!(
328                    "No `config_sentence_transformers.json` found for `{}`: {err}",
329                    model_id.display()
330                );
331                false
332            }
333        }
334    }
335
336    fn detect(&self, artifacts: &ConfigArtifacts) -> Result<Detected> {
337        if let Some(tp) = DiffusionLoaderType::auto_detect_from_files(&artifacts.repo_files) {
338            return Ok(Detected::Diffusion(tp));
339        }
340
341        if let Some(ref config) = artifacts.contents {
342            if let Some(tp) =
343                crate::speech_models::SpeechLoaderType::auto_detect_from_config(config)
344            {
345                return Ok(Detected::Speech(tp));
346            }
347        }
348
349        if artifacts.sentence_transformers_present {
350            if let Some(ref config) = artifacts.contents {
351                let cfg: AutoConfig = serde_json::from_str(config)?;
352                if let Some(name) = cfg.architectures.first() {
353                    if let Ok(tp) = EmbeddingLoaderType::from_causal_lm_name(name) {
354                        info!(
355                            "Detected `config_sentence_transformers.json`; using embedding loader `{tp}`."
356                        );
357                        return Ok(Detected::Embedding(Some(tp)));
358                    }
359                }
360            }
361            if artifacts.contents.is_none() {
362                if let Some(issue) = artifacts.remote_access_issue.as_ref() {
363                    return Err(hf_access_error(Path::new(&self.model_id), issue));
364                }
365            }
366            info!(
367                "Detected `config_sentence_transformers.json`; routing via auto embedding loader."
368            );
369            return Ok(Detected::Embedding(None));
370        }
371
372        // Detect Mistral-native models that use params.json instead of config.json
373        if artifacts.contents.is_none() && artifacts.repo_files.iter().any(|f| f == "params.json") {
374            // Voxtral uses params.json with a "multimodal" key containing "whisper_model_args"
375            info!("Detected `params.json` in repo; routing as Voxtral.");
376            return Ok(Detected::Multimodal(MultimodalLoaderType::Voxtral));
377        }
378
379        let config = artifacts.contents.as_ref().ok_or_else(|| {
380            if let Some(issue) = artifacts.remote_access_issue.as_ref() {
381                hf_access_error(Path::new(&self.model_id), issue)
382            } else {
383                anyhow::anyhow!(
384                    "Auto loader could not determine model type: missing `config.json` and no diffusion/speech markers found."
385                )
386            }
387        })?;
388        let cfg: AutoConfig = serde_json::from_str(config)?;
389        if cfg.architectures.len() != 1 {
390            anyhow::bail!("Expected exactly one architecture in config");
391        }
392        let name = &cfg.architectures[0];
393        if let Ok(tp) = MultimodalLoaderType::from_causal_lm_name(name) {
394            return Ok(Detected::Multimodal(tp));
395        }
396        let tp = NormalLoaderType::from_causal_lm_name(name)?;
397        Ok(Detected::Normal(tp))
398    }
399
400    fn ensure_loader(&self, artifacts: &ConfigArtifacts) -> Result<()> {
401        let mut guard = self.loader.lock().unwrap();
402        if guard.is_some() {
403            return Ok(());
404        }
405        match self.detect(artifacts)? {
406            Detected::Normal(tp) => {
407                let builder = self
408                    .normal_builder
409                    .lock()
410                    .unwrap()
411                    .take()
412                    .expect("builder taken");
413                let loader = builder.build(Some(tp)).expect("build normal");
414                *guard = Some(loader);
415            }
416            Detected::Multimodal(tp) => {
417                let builder = self
418                    .multimodal_builder
419                    .lock()
420                    .unwrap()
421                    .take()
422                    .expect("builder taken");
423                let loader = builder.build(Some(tp));
424                *guard = Some(loader);
425            }
426            Detected::Embedding(tp) => {
427                let builder = self
428                    .embedding_builder
429                    .lock()
430                    .unwrap()
431                    .take()
432                    .expect("builder taken");
433                let loader = builder.build(tp);
434                *guard = Some(loader);
435            }
436            Detected::Diffusion(tp) => {
437                let loader = DiffusionLoaderBuilder::new(Some(self.model_id.clone())).build(tp);
438                *guard = Some(loader);
439            }
440            Detected::Speech(tp) => {
441                let loader: Box<dyn Loader> = Box::new(SpeechLoader {
442                    model_id: self.model_id.clone(),
443                    dac_model_id: None,
444                    arch: tp,
445                    cfg: None,
446                });
447                *guard = Some(loader);
448            }
449        }
450        Ok(())
451    }
452}
453
454impl Loader for AutoLoader {
455    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
456    fn load_model_from_hf(
457        &self,
458        revision: Option<String>,
459        token_source: TokenSource,
460        dtype: &dyn TryIntoDType,
461        device: &Device,
462        silent: bool,
463        mapper: DeviceMapSetting,
464        in_situ_quant: Option<IsqType>,
465        paged_attn_config: Option<PagedAttentionConfig>,
466    ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
467        let _progress_guard = ProgressScopeGuard::new(silent);
468        let config = self.read_config_from_hf(revision.clone(), &token_source, silent)?;
469        self.ensure_loader(&config)?;
470        self.loader
471            .lock()
472            .unwrap()
473            .as_ref()
474            .unwrap()
475            .load_model_from_hf(
476                revision,
477                token_source,
478                dtype,
479                device,
480                silent,
481                mapper,
482                in_situ_quant,
483                paged_attn_config,
484            )
485    }
486
487    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
488    fn load_model_from_path(
489        &self,
490        paths: &Box<dyn ModelPaths>,
491        dtype: &dyn TryIntoDType,
492        device: &Device,
493        silent: bool,
494        mapper: DeviceMapSetting,
495        in_situ_quant: Option<IsqType>,
496        paged_attn_config: Option<PagedAttentionConfig>,
497    ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
498        let _progress_guard = ProgressScopeGuard::new(silent);
499        let config = self.read_config_from_path(paths.as_ref())?;
500        self.ensure_loader(&config)?;
501        self.loader
502            .lock()
503            .unwrap()
504            .as_ref()
505            .unwrap()
506            .load_model_from_path(
507                paths,
508                dtype,
509                device,
510                silent,
511                mapper,
512                in_situ_quant,
513                paged_attn_config,
514            )
515    }
516
517    fn get_id(&self) -> String {
518        self.model_id.clone()
519    }
520
521    fn get_kind(&self) -> ModelKind {
522        self.loader
523            .lock()
524            .unwrap()
525            .as_ref()
526            .map(|l| l.get_kind())
527            .unwrap_or(ModelKind::Normal)
528    }
529}