Skip to main content

mistralrs_core/pipeline/
auto.rs

1use super::hf::{hf_access_error, remote_issue_from_api_error, RemoteAccessIssue};
2use super::{
3    DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoaderBuilder, EmbeddingLoaderType,
4    EmbeddingSpecificConfig, Loader, ModelKind, ModelPaths, NormalLoaderBuilder, NormalLoaderType,
5    NormalSpecificConfig, SpeechLoader, TokenSource, VisionLoaderBuilder, VisionLoaderType,
6    VisionSpecificConfig,
7};
8use crate::utils::{progress::ProgressScopeGuard, tokens::get_token};
9use crate::Ordering;
10use crate::{DeviceMapSetting, IsqType, PagedAttentionConfig, Pipeline, TryIntoDType};
11use anyhow::Result;
12use candle_core::Device;
13use hf_hub::{
14    api::sync::{ApiBuilder, ApiError, ApiRepo},
15    Cache, Repo, RepoType,
16};
17use serde::Deserialize;
18use std::io;
19use std::path::Path;
20use std::path::PathBuf;
21use std::sync::Arc;
22use std::sync::Mutex;
23use tracing::{debug, info, warn};
24
25/// Automatically selects the appropriate loader based on repository/config metadata.
26pub struct AutoLoader {
27    model_id: String,
28    normal_builder: Mutex<Option<NormalLoaderBuilder>>,
29    vision_builder: Mutex<Option<VisionLoaderBuilder>>,
30    embedding_builder: Mutex<Option<EmbeddingLoaderBuilder>>,
31    loader: Mutex<Option<Box<dyn Loader>>>,
32    hf_cache_path: Option<PathBuf>,
33}
34
35pub struct AutoLoaderBuilder {
36    normal_cfg: NormalSpecificConfig,
37    vision_cfg: VisionSpecificConfig,
38    embedding_cfg: EmbeddingSpecificConfig,
39    chat_template: Option<String>,
40    tokenizer_json: Option<String>,
41    model_id: String,
42    jinja_explicit: Option<String>,
43    no_kv_cache: bool,
44    xlora_model_id: Option<String>,
45    xlora_order: Option<Ordering>,
46    tgt_non_granular_index: Option<usize>,
47    lora_adapter_ids: Option<Vec<String>>,
48    hf_cache_path: Option<PathBuf>,
49}
50
51impl AutoLoaderBuilder {
52    #[allow(clippy::too_many_arguments)]
53    pub fn new(
54        normal_cfg: NormalSpecificConfig,
55        vision_cfg: VisionSpecificConfig,
56        embedding_cfg: EmbeddingSpecificConfig,
57        chat_template: Option<String>,
58        tokenizer_json: Option<String>,
59        model_id: String,
60        no_kv_cache: bool,
61        jinja_explicit: Option<String>,
62    ) -> Self {
63        Self {
64            normal_cfg,
65            vision_cfg,
66            embedding_cfg,
67            chat_template,
68            tokenizer_json,
69            model_id,
70            jinja_explicit,
71            no_kv_cache,
72            xlora_model_id: None,
73            xlora_order: None,
74            tgt_non_granular_index: None,
75            lora_adapter_ids: None,
76            hf_cache_path: None,
77        }
78    }
79
80    pub fn with_xlora(
81        mut self,
82        model_id: String,
83        order: Ordering,
84        no_kv_cache: bool,
85        tgt_non_granular_index: Option<usize>,
86    ) -> Self {
87        self.xlora_model_id = Some(model_id);
88        self.xlora_order = Some(order);
89        self.no_kv_cache = no_kv_cache;
90        self.tgt_non_granular_index = tgt_non_granular_index;
91        self
92    }
93
94    pub fn with_lora(mut self, adapters: Vec<String>) -> Self {
95        self.lora_adapter_ids = Some(adapters);
96        self
97    }
98
99    pub fn hf_cache_path(mut self, path: PathBuf) -> Self {
100        self.hf_cache_path = Some(path);
101        self
102    }
103
104    pub fn build(self) -> Box<dyn Loader> {
105        let Self {
106            normal_cfg,
107            vision_cfg,
108            embedding_cfg,
109            chat_template,
110            tokenizer_json,
111            model_id,
112            jinja_explicit,
113            no_kv_cache,
114            xlora_model_id,
115            xlora_order,
116            tgt_non_granular_index,
117            lora_adapter_ids,
118            hf_cache_path,
119        } = self;
120
121        let mut normal_builder = NormalLoaderBuilder::new(
122            normal_cfg,
123            chat_template.clone(),
124            tokenizer_json.clone(),
125            Some(model_id.clone()),
126            no_kv_cache,
127            jinja_explicit.clone(),
128        );
129        if let (Some(id), Some(ord)) = (xlora_model_id.clone(), xlora_order.clone()) {
130            normal_builder =
131                normal_builder.with_xlora(id, ord, no_kv_cache, tgt_non_granular_index);
132        }
133        if let Some(ref adapters) = lora_adapter_ids {
134            normal_builder = normal_builder.with_lora(adapters.clone());
135        }
136        if let Some(ref path) = hf_cache_path {
137            normal_builder = normal_builder.hf_cache_path(path.clone());
138        }
139
140        let mut vision_builder = VisionLoaderBuilder::new(
141            vision_cfg,
142            chat_template,
143            tokenizer_json.clone(),
144            Some(model_id.clone()),
145            jinja_explicit,
146        );
147        if let Some(ref adapters) = lora_adapter_ids {
148            vision_builder = vision_builder.with_lora(adapters.clone());
149        }
150        if let Some(ref path) = hf_cache_path {
151            vision_builder = vision_builder.hf_cache_path(path.clone());
152        }
153
154        let mut embedding_builder =
155            EmbeddingLoaderBuilder::new(embedding_cfg, tokenizer_json, Some(model_id.clone()));
156        if let Some(ref adapters) = lora_adapter_ids {
157            embedding_builder = embedding_builder.with_lora(adapters.clone());
158        }
159        if let Some(ref path) = hf_cache_path {
160            embedding_builder = embedding_builder.hf_cache_path(path.clone());
161        }
162
163        Box::new(AutoLoader {
164            model_id,
165            normal_builder: Mutex::new(Some(normal_builder)),
166            vision_builder: Mutex::new(Some(vision_builder)),
167            embedding_builder: Mutex::new(Some(embedding_builder)),
168            loader: Mutex::new(None),
169            hf_cache_path,
170        })
171    }
172}
173
174#[derive(Deserialize)]
175struct AutoConfig {
176    #[serde(default)]
177    architectures: Vec<String>,
178}
179
180struct ConfigArtifacts {
181    contents: Option<String>,
182    sentence_transformers_present: bool,
183    repo_files: Vec<String>,
184    remote_access_issue: Option<RemoteAccessIssue>,
185}
186
187enum Detected {
188    Normal(NormalLoaderType),
189    Vision(VisionLoaderType),
190    Embedding(Option<EmbeddingLoaderType>),
191    Diffusion(DiffusionLoaderType),
192    Speech(crate::speech_models::SpeechLoaderType),
193}
194
195impl AutoLoader {
196    fn try_get_file(
197        api: &ApiRepo,
198        model_id: &Path,
199        file: &str,
200    ) -> std::result::Result<Option<PathBuf>, ApiError> {
201        if model_id.exists() {
202            let path = model_id.join(file);
203            if path.exists() {
204                info!("Loading `{}` locally at `{}`", file, path.display());
205                Ok(Some(path))
206            } else {
207                Ok(None)
208            }
209        } else {
210            api.get(file).map(Some)
211        }
212    }
213
214    fn list_local_repo_files(model_root: &Path) -> Vec<String> {
215        fn collect_files(root: &Path, dir: &Path, out: &mut Vec<String>) -> io::Result<()> {
216            for entry in std::fs::read_dir(dir)? {
217                let entry = entry?;
218                let path = entry.path();
219                if path.is_dir() {
220                    collect_files(root, &path, out)?;
221                } else if let Ok(rel) = path.strip_prefix(root) {
222                    out.push(rel.to_string_lossy().replace('\\', "/"));
223                }
224            }
225            Ok(())
226        }
227
228        if !model_root.is_dir() {
229            return Vec::new();
230        }
231
232        let mut files = Vec::new();
233        if collect_files(model_root, model_root, &mut files).is_err() {
234            return Vec::new();
235        }
236        files
237    }
238
239    fn read_config_from_path(&self, paths: &dyn ModelPaths) -> Result<ConfigArtifacts> {
240        let config_path = paths.get_config_filename();
241        let contents = match std::fs::read_to_string(config_path) {
242            Ok(contents) => Some(contents),
243            Err(err) if err.kind() == io::ErrorKind::NotFound => None,
244            Err(err) => return Err(err.into()),
245        };
246        let model_root = Path::new(&self.model_id);
247        let repo_files = if model_root.exists() {
248            Self::list_local_repo_files(model_root)
249        } else {
250            Vec::new()
251        };
252        let sentence_transformers_present = Self::has_sentence_transformers_sibling(config_path)
253            || repo_files
254                .iter()
255                .any(|f| f == "config_sentence_transformers.json");
256        Ok(ConfigArtifacts {
257            contents,
258            sentence_transformers_present,
259            repo_files,
260            remote_access_issue: None,
261        })
262    }
263
264    fn read_config_from_hf(
265        &self,
266        revision: Option<String>,
267        token_source: &TokenSource,
268        silent: bool,
269    ) -> Result<ConfigArtifacts> {
270        let cache = self
271            .hf_cache_path
272            .clone()
273            .map(Cache::new)
274            .unwrap_or_default();
275        let mut api = ApiBuilder::from_cache(cache)
276            .with_progress(!silent)
277            .with_token(get_token(token_source)?);
278        if let Some(cache_dir) = crate::hf_hub_cache_dir() {
279            api = api.with_cache_dir(cache_dir);
280        }
281        let api = api.build()?;
282        let revision = revision.unwrap_or_else(|| "main".to_string());
283        let api = api.repo(Repo::with_revision(
284            self.model_id.clone(),
285            RepoType::Model,
286            revision,
287        ));
288        let model_id = Path::new(&self.model_id);
289        let mut remote_access_issue = None;
290        let contents = match Self::try_get_file(&api, model_id, "config.json") {
291            Ok(Some(path)) => Some(std::fs::read_to_string(&path)?),
292            Ok(None) => None,
293            Err(err) => {
294                let issue = remote_issue_from_api_error(model_id, Some("config.json"), &err);
295                warn!(
296                    "Auto loader could not fetch `config.json` for `{}`: {}",
297                    self.model_id, issue.message
298                );
299                remote_access_issue = Some(issue);
300                None
301            }
302        };
303        let sentence_transformers_present =
304            model_id.join("config_sentence_transformers.json").exists()
305                || Self::fetch_sentence_transformers_config(&api, model_id);
306        let repo_files = if model_id.exists() {
307            Self::list_local_repo_files(model_id)
308        } else {
309            crate::api_dir_list!(api, model_id, false).collect::<Vec<_>>()
310        };
311        Ok(ConfigArtifacts {
312            contents,
313            sentence_transformers_present,
314            repo_files,
315            remote_access_issue,
316        })
317    }
318
319    fn has_sentence_transformers_sibling(config_path: &Path) -> bool {
320        config_path
321            .parent()
322            .map(|parent| parent.join("config_sentence_transformers.json").exists())
323            .unwrap_or(false)
324    }
325
326    fn fetch_sentence_transformers_config(api: &ApiRepo, model_id: &Path) -> bool {
327        if model_id.exists() {
328            return false;
329        }
330        match api.get("config_sentence_transformers.json") {
331            Ok(_) => true,
332            Err(err) => {
333                debug!(
334                    "No `config_sentence_transformers.json` found for `{}`: {err}",
335                    model_id.display()
336                );
337                false
338            }
339        }
340    }
341
342    fn detect(&self, artifacts: &ConfigArtifacts) -> Result<Detected> {
343        if let Some(tp) = DiffusionLoaderType::auto_detect_from_files(&artifacts.repo_files) {
344            return Ok(Detected::Diffusion(tp));
345        }
346
347        if let Some(ref config) = artifacts.contents {
348            if let Some(tp) =
349                crate::speech_models::SpeechLoaderType::auto_detect_from_config(config)
350            {
351                return Ok(Detected::Speech(tp));
352            }
353        }
354
355        if artifacts.sentence_transformers_present {
356            if let Some(ref config) = artifacts.contents {
357                let cfg: AutoConfig = serde_json::from_str(config)?;
358                if let Some(name) = cfg.architectures.first() {
359                    if let Ok(tp) = EmbeddingLoaderType::from_causal_lm_name(name) {
360                        info!(
361                            "Detected `config_sentence_transformers.json`; using embedding loader `{tp}`."
362                        );
363                        return Ok(Detected::Embedding(Some(tp)));
364                    }
365                }
366            }
367            if artifacts.contents.is_none() {
368                if let Some(issue) = artifacts.remote_access_issue.as_ref() {
369                    return Err(hf_access_error(Path::new(&self.model_id), issue));
370                }
371            }
372            info!(
373                "Detected `config_sentence_transformers.json`; routing via auto embedding loader."
374            );
375            return Ok(Detected::Embedding(None));
376        }
377
378        let config = artifacts.contents.as_ref().ok_or_else(|| {
379            if let Some(issue) = artifacts.remote_access_issue.as_ref() {
380                hf_access_error(Path::new(&self.model_id), issue)
381            } else {
382                anyhow::anyhow!(
383                    "Auto loader could not determine model type: missing `config.json` and no diffusion/speech markers found."
384                )
385            }
386        })?;
387        let cfg: AutoConfig = serde_json::from_str(config)?;
388        if cfg.architectures.len() != 1 {
389            anyhow::bail!("Expected exactly one architecture in config");
390        }
391        let name = &cfg.architectures[0];
392        if let Ok(tp) = VisionLoaderType::from_causal_lm_name(name) {
393            return Ok(Detected::Vision(tp));
394        }
395        let tp = NormalLoaderType::from_causal_lm_name(name)?;
396        Ok(Detected::Normal(tp))
397    }
398
399    fn ensure_loader(&self, artifacts: &ConfigArtifacts) -> Result<()> {
400        let mut guard = self.loader.lock().unwrap();
401        if guard.is_some() {
402            return Ok(());
403        }
404        match self.detect(artifacts)? {
405            Detected::Normal(tp) => {
406                let builder = self
407                    .normal_builder
408                    .lock()
409                    .unwrap()
410                    .take()
411                    .expect("builder taken");
412                let loader = builder.build(Some(tp)).expect("build normal");
413                *guard = Some(loader);
414            }
415            Detected::Vision(tp) => {
416                let builder = self
417                    .vision_builder
418                    .lock()
419                    .unwrap()
420                    .take()
421                    .expect("builder taken");
422                let loader = builder.build(Some(tp));
423                *guard = Some(loader);
424            }
425            Detected::Embedding(tp) => {
426                let builder = self
427                    .embedding_builder
428                    .lock()
429                    .unwrap()
430                    .take()
431                    .expect("builder taken");
432                let loader = builder.build(tp);
433                *guard = Some(loader);
434            }
435            Detected::Diffusion(tp) => {
436                let loader = DiffusionLoaderBuilder::new(Some(self.model_id.clone())).build(tp);
437                *guard = Some(loader);
438            }
439            Detected::Speech(tp) => {
440                let loader: Box<dyn Loader> = Box::new(SpeechLoader {
441                    model_id: self.model_id.clone(),
442                    dac_model_id: None,
443                    arch: tp,
444                    cfg: None,
445                });
446                *guard = Some(loader);
447            }
448        }
449        Ok(())
450    }
451}
452
453impl Loader for AutoLoader {
454    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
455    fn load_model_from_hf(
456        &self,
457        revision: Option<String>,
458        token_source: TokenSource,
459        dtype: &dyn TryIntoDType,
460        device: &Device,
461        silent: bool,
462        mapper: DeviceMapSetting,
463        in_situ_quant: Option<IsqType>,
464        paged_attn_config: Option<PagedAttentionConfig>,
465    ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
466        let _progress_guard = ProgressScopeGuard::new(silent);
467        let config = self.read_config_from_hf(revision.clone(), &token_source, silent)?;
468        self.ensure_loader(&config)?;
469        self.loader
470            .lock()
471            .unwrap()
472            .as_ref()
473            .unwrap()
474            .load_model_from_hf(
475                revision,
476                token_source,
477                dtype,
478                device,
479                silent,
480                mapper,
481                in_situ_quant,
482                paged_attn_config,
483            )
484    }
485
486    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
487    fn load_model_from_path(
488        &self,
489        paths: &Box<dyn ModelPaths>,
490        dtype: &dyn TryIntoDType,
491        device: &Device,
492        silent: bool,
493        mapper: DeviceMapSetting,
494        in_situ_quant: Option<IsqType>,
495        paged_attn_config: Option<PagedAttentionConfig>,
496    ) -> Result<Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync>>> {
497        let _progress_guard = ProgressScopeGuard::new(silent);
498        let config = self.read_config_from_path(paths.as_ref())?;
499        self.ensure_loader(&config)?;
500        self.loader
501            .lock()
502            .unwrap()
503            .as_ref()
504            .unwrap()
505            .load_model_from_path(
506                paths,
507                dtype,
508                device,
509                silent,
510                mapper,
511                in_situ_quant,
512                paged_attn_config,
513            )
514    }
515
516    fn get_id(&self) -> String {
517        self.model_id.clone()
518    }
519
520    fn get_kind(&self) -> ModelKind {
521        self.loader
522            .lock()
523            .unwrap()
524            .as_ref()
525            .map(|l| l.get_kind())
526            .unwrap_or(ModelKind::Normal)
527    }
528}