Skip to main content

mistralrs_core/pipeline/
paths.rs

1use std::{
2    collections::HashMap,
3    fs,
4    path::{Path, PathBuf},
5};
6
7use anyhow::Result;
8use either::Either;
9use hf_hub::{
10    api::sync::{ApiBuilder, ApiRepo},
11    Repo, RepoType,
12};
13use regex_automata::meta::Regex;
14use serde_json::Value;
15use tracing::{info, warn};
16
17use crate::{
18    api_dir_list, api_get_file,
19    lora::LoraConfig,
20    pipeline::{
21        chat_template::{ChatTemplate, ChatTemplateValue},
22        isq::UQFF_RESIDUAL_SAFETENSORS,
23    },
24    utils::tokens::get_token,
25    xlora_models::XLoraConfig,
26    ModelPaths, Ordering, TokenSource, GLOBAL_HF_CACHE,
27};
28
29// Match files against these
30const SAFETENSOR_MATCH: &str = r"model-\d+-of-\d+\.safetensors\b";
31const QUANT_SAFETENSOR_MATCH: &str = r"model\.safetensors\b";
32const CONSOLIDATED_SAFETENSOR_MATCH: &str = r"consolidated\.safetensors\b";
33const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b";
34
35#[derive(Clone, Debug)]
36pub struct LoraAdapterPaths {
37    pub lora_config: mistralrs_quant::LoraConfig,
38    pub adapter_path: PathBuf,
39}
40
41#[allow(clippy::large_enum_variant)]
42#[derive(Clone, Debug)]
43pub enum AdapterPaths {
44    XLora {
45        adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
46        adapter_safetensors: Option<Vec<(String, PathBuf)>>,
47        classifier_path: Option<PathBuf>,
48        xlora_order: Option<Ordering>,
49        xlora_config: Option<XLoraConfig>,
50        lora_preload_adapter_info: Option<HashMap<String, (PathBuf, LoraConfig)>>,
51    },
52    Lora(Vec<LoraAdapterPaths>),
53    None,
54}
55
56pub fn get_xlora_paths(
57    base_model_id: String,
58    xlora_model_id: Option<&String>,
59    lora_adapter_ids: Option<&Vec<String>>,
60    token_source: &TokenSource,
61    revision: String,
62    xlora_order: Option<&Ordering>,
63) -> Result<AdapterPaths> {
64    match (lora_adapter_ids, xlora_model_id, xlora_order) {
65        (None, Some(xlora_id), Some(xlora_order)) => {
66            let api = {
67                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
68                let mut api = ApiBuilder::from_cache(cache)
69                    .with_progress(true)
70                    .with_token(get_token(token_source)?);
71                if let Some(cache_dir) = crate::hf_hub_cache_dir() {
72                    api = api.with_cache_dir(cache_dir);
73                }
74                api.build().map_err(candle_core::Error::msg)?
75            };
76            let api = api.repo(Repo::with_revision(
77                xlora_id.clone(),
78                RepoType::Model,
79                revision,
80            ));
81            let model_id = Path::new(&xlora_id);
82            let dir_list = api_dir_list!(api, model_id, true).collect::<Vec<_>>();
83            // Get the path for the xlora classifier
84            let xlora_classifier = &dir_list
85                .clone()
86                .into_iter()
87                .filter(|x| x.contains("xlora_classifier.safetensors"))
88                .collect::<Vec<_>>();
89            if xlora_classifier.len() > 1 {
90                warn!("Detected multiple X-LoRA classifiers: {xlora_classifier:?}");
91                warn!("Selected classifier: `{}`", &xlora_classifier[0]);
92            }
93            let xlora_classifier = xlora_classifier.first();
94
95            let classifier_path = xlora_classifier
96                .map(|xlora_classifier| -> candle_core::Result<_> {
97                    Ok(api_get_file!(api, xlora_classifier, model_id))
98                })
99                .transpose()?;
100
101            // Get the path for the xlora config by checking all for valid versions.
102            // NOTE(EricLBuehler): Remove this functionality because all configs should be deserializable
103            let xlora_configs = &dir_list
104                .clone()
105                .into_iter()
106                .filter(|x| x.contains("xlora_config.json"))
107                .collect::<Vec<_>>();
108            if xlora_configs.len() > 1 {
109                warn!("Detected multiple X-LoRA configs: {xlora_configs:?}");
110            }
111
112            let mut xlora_config: Option<XLoraConfig> = None;
113            let mut last_err: Option<serde_json::Error> = None;
114            for (i, config_path) in xlora_configs.iter().enumerate() {
115                if xlora_configs.len() != 1 {
116                    warn!("Selecting config: `{}`", config_path);
117                }
118                let config_path = api_get_file!(api, config_path, model_id);
119                let conf = fs::read_to_string(config_path)?;
120                let deser: Result<XLoraConfig, serde_json::Error> = serde_json::from_str(&conf);
121                match deser {
122                    Ok(conf) => {
123                        xlora_config = Some(conf);
124                        break;
125                    }
126                    Err(e) => {
127                        if i != xlora_configs.len() - 1 {
128                            warn!("Config is broken with error `{e}`");
129                        }
130                        last_err = Some(e);
131                    }
132                }
133            }
134            let xlora_config = xlora_config.map(Some).unwrap_or_else(|| {
135                if let Some(last_err) = last_err {
136                    panic!("Unable to derserialize any configs. Last error: {last_err}")
137                } else {
138                    None
139                }
140            });
141
142            // If there are adapters in the ordering file, get their names and remote paths
143            let adapter_files = dir_list
144                .into_iter()
145                .filter_map(|name| {
146                    if let Some(ref adapters) = xlora_order.adapters {
147                        for adapter_name in adapters {
148                            if name.contains(adapter_name) {
149                                return Some((name, adapter_name.clone()));
150                            }
151                        }
152                    }
153                    None
154                })
155                .collect::<Vec<_>>();
156            if adapter_files.is_empty() && xlora_order.adapters.is_some() {
157                anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
158            }
159
160            // Get the local paths for each adapter
161            let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
162            for (file, name) in adapter_files {
163                if let Some(paths) = adapters_paths.get_mut(&name) {
164                    paths.push(api_get_file!(api, &file, model_id));
165                } else {
166                    adapters_paths.insert(name, vec![api_get_file!(api, &file, model_id)]);
167                }
168            }
169
170            // Sort local paths for the adapter configs and safetensors files
171            let mut adapters_configs = Vec::new();
172            let mut adapters_safetensors = Vec::new();
173            if let Some(ref adapters) = xlora_order.adapters {
174                for (i, name) in adapters.iter().enumerate() {
175                    let paths = adapters_paths
176                        .get(name)
177                        .unwrap_or_else(|| panic!("Adapter {name} not found."));
178                    for path in paths {
179                        if path.extension().unwrap() == "safetensors" {
180                            adapters_safetensors.push((name.clone(), path.to_owned()));
181                        } else {
182                            let conf = fs::read_to_string(path)?;
183                            let lora_config: LoraConfig = serde_json::from_str(&conf)?;
184                            adapters_configs
185                                .push((((i + 1).to_string(), name.clone()), lora_config));
186                        }
187                    }
188                }
189            }
190
191            // Make sure they all match
192            if xlora_order.base_model_id
193                != *xlora_config
194                    .as_ref()
195                    .map(|cfg| &cfg.base_model_id)
196                    .unwrap_or(&base_model_id)
197                || xlora_config
198                    .as_ref()
199                    .map(|cfg| &cfg.base_model_id)
200                    .unwrap_or(&base_model_id)
201                    != &base_model_id
202            {
203                anyhow::bail!(
204                    "Adapter ordering file, adapter model config, and base model ID do not match: {}, {}, and {} respectively.",
205                    xlora_order.base_model_id,
206                    xlora_config.map(|cfg| cfg.base_model_id).unwrap_or(base_model_id.clone()),
207                    base_model_id
208                );
209            }
210
211            let lora_preload_adapter_info =
212                // If preload adapters are specified, get their metadata like above
213                if let Some(preload_adapters) = &xlora_order.preload_adapters {
214                    let mut output = HashMap::new();
215                    for adapter in preload_adapters {
216                        // Get the names and remote paths of the files associated with this adapter
217                        let adapter_files = api_dir_list!(api, &adapter.adapter_model_id, true)
218                            .filter_map(|f| {
219                                if f.contains(&adapter.name) {
220                                    Some((f, adapter.name.clone()))
221                                } else {
222                                    None
223                                }
224                            })
225                            .collect::<Vec<_>>();
226                        if adapter_files.is_empty() {
227                            anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
228                        }
229                        // Get local paths for this adapter
230                        let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
231                        for (file, name) in adapter_files {
232                            if let Some(paths) = adapters_paths.get_mut(&name) {
233                                paths.push(api_get_file!(api, &file, model_id));
234                            } else {
235                                adapters_paths
236                                    .insert(name, vec![api_get_file!(api, &file, model_id)]);
237                            }
238                        }
239
240                        let mut config = None;
241                        let mut safetensor = None;
242
243                        // Sort local paths for the adapter configs and safetensors files
244                        let paths = adapters_paths
245                            .get(&adapter.name)
246                            .unwrap_or_else(|| panic!("Adapter {} not found.", adapter.name));
247                        for path in paths {
248                            if path.extension().unwrap() == "safetensors" {
249                                safetensor = Some(path.to_owned());
250                            } else {
251                                let conf = fs::read_to_string(path)?;
252                                let lora_config: LoraConfig = serde_json::from_str(&conf)?;
253                                config = Some(lora_config);
254                            }
255                        }
256
257                        let (config, safetensor) = (config.unwrap(), safetensor.unwrap());
258                        output.insert(adapter.name.clone(), (safetensor, config));
259                    }
260                    Some(output)
261                } else {
262                    None
263                };
264
265            Ok(AdapterPaths::XLora {
266                adapter_configs: Some(adapters_configs),
267                adapter_safetensors: Some(adapters_safetensors),
268                classifier_path,
269                xlora_order: Some(xlora_order.clone()),
270                xlora_config,
271                lora_preload_adapter_info,
272            })
273        }
274        (Some(adapter_ids), None, None) => {
275            let mut lora_adapter_paths = Vec::new();
276            for adapter_id in adapter_ids {
277                info!("Loading adapter at `{adapter_id}`");
278
279                let api = {
280                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
281                    let mut api = ApiBuilder::from_cache(cache)
282                        .with_progress(true)
283                        .with_token(get_token(token_source)?);
284                    if let Some(cache_dir) = crate::hf_hub_cache_dir() {
285                        api = api.with_cache_dir(cache_dir);
286                    }
287                    api.build().map_err(candle_core::Error::msg)?
288                };
289                let api = api.repo(Repo::with_revision(
290                    adapter_id.clone(),
291                    RepoType::Model,
292                    revision.clone(),
293                ));
294
295                let config_path = api.get("adapter_config.json")?;
296                let adapter_path = api.get("adapter_model.safetensors")?;
297                let lora_config: mistralrs_quant::LoraConfig =
298                    serde_json::from_str(&fs::read_to_string(config_path)?)?;
299
300                lora_adapter_paths.push(LoraAdapterPaths {
301                    lora_config,
302                    adapter_path,
303                });
304            }
305
306            Ok(AdapterPaths::Lora(lora_adapter_paths))
307        }
308        (None, None, None) => Ok(AdapterPaths::None),
309        _ => anyhow::bail!(
310            "Incorrect configuration for an adapter model. Lora and XLora are mutually exclusive."
311        ),
312    }
313}
314
315pub fn get_model_paths(
316    revision: String,
317    token_source: &TokenSource,
318    quantized_model_id: Option<&String>,
319    quantized_filename: Option<&Vec<String>>,
320    api: &ApiRepo,
321    model_id: &Path,
322    loading_from_uqff: bool,
323) -> Result<Vec<PathBuf>> {
324    match quantized_filename {
325        Some(names) => {
326            let id = quantized_model_id.unwrap();
327            let mut files = Vec::new();
328
329            for name in names {
330                let qapi = {
331                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
332                    let mut api = ApiBuilder::from_cache(cache)
333                        .with_progress(true)
334                        .with_token(get_token(token_source)?);
335                    if let Some(cache_dir) = crate::hf_hub_cache_dir() {
336                        api = api.with_cache_dir(cache_dir);
337                    }
338                    api.build().map_err(candle_core::Error::msg)?
339                };
340                let qapi = qapi.repo(Repo::with_revision(
341                    id.to_string(),
342                    RepoType::Model,
343                    revision.clone(),
344                ));
345                let model_id = Path::new(&id);
346                files.push(api_get_file!(qapi, name, model_id));
347            }
348            Ok(files)
349        }
350        None => {
351            // We only match these patterns for model names
352            let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
353            let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?;
354            let consolidated_safetensor_match = Regex::new(CONSOLIDATED_SAFETENSOR_MATCH)?;
355            let pickle_match = Regex::new(PICKLE_MATCH)?;
356
357            let mut filenames = vec![];
358            let listing = api_dir_list!(api, model_id, true).filter(|x| {
359                safetensor_match.is_match(x)
360                    || pickle_match.is_match(x)
361                    || quant_safetensor_match.is_match(x)
362                    || consolidated_safetensor_match.is_match(x)
363                    || x == UQFF_RESIDUAL_SAFETENSORS
364            });
365            let safetensors = listing
366                .clone()
367                .filter(|x| x.ends_with(".safetensors"))
368                .collect::<Vec<_>>();
369            let pickles = listing
370                .clone()
371                .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
372                .collect::<Vec<_>>();
373            let uqff_residual = listing
374                .clone()
375                .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS)
376                .collect::<Vec<_>>();
377            let files = if !safetensors.is_empty() {
378                // Always prefer safetensors
379                safetensors
380            } else if !pickles.is_empty() {
381                // Fall back to pickle
382                pickles
383            } else if !uqff_residual.is_empty() && loading_from_uqff {
384                uqff_residual
385            } else {
386                anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
387            };
388            info!(
389                "Found model weight filenames {:?}",
390                files
391                    .iter()
392                    .map(|x| x.split('/').next_back().unwrap())
393                    .collect::<Vec<_>>()
394            );
395            for rfilename in files {
396                filenames.push(api_get_file!(api, &rfilename, model_id));
397            }
398            Ok(filenames)
399        }
400    }
401}
402
403/// Find and parse the appropriate [`ChatTemplate`], and ensure is has a valid [`ChatTemplate.chat_template`].
404/// If the provided `tokenizer_config.json` from [`ModelPaths.get_template_filename`] does not
405/// have a `chat_template`, use the provided one.
406///
407/// - Uses `chat_template_fallback` if `paths` does not contain a chat template file. This may be a literal or .json file.
408/// - `chat_template_ovrd` (GGUF chat template content) causes the usage of that string chat template initially.
409///   Falls back to `chat_template_file` if it is invalid. *The user must add the bos/unk/eos tokens manually if this
410///   is used.*
411///
412/// THE FOLLOWING IS IGNORED:
413/// After this, if the `chat_template_explicit` filename is specified (a json with one field: "chat_template" OR a jinja file),
414///  the chat template is overwritten with this chat template.
415#[allow(clippy::borrowed_box)]
416pub(crate) fn get_chat_template(
417    paths: &Box<dyn ModelPaths>,
418    jinja_explicit: Option<&String>,
419    chat_template_explicit: Option<&String>,
420    chat_template_fallback: Option<&String>,
421    chat_template_ovrd: Option<String>,
422) -> ChatTemplate {
423    // Get template content, this may be overridden.
424    let template_content = if let Some(template_filename) = paths.get_template_filename() {
425        if !["jinja", "json"].contains(
426            &template_filename
427                .extension()
428                .expect("Template filename must be a file")
429                .to_string_lossy()
430                .to_string()
431                .as_str(),
432        ) {
433            panic!("Template filename {template_filename:?} must end with `.json` or `.jinja`.");
434        }
435        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
436    } else if chat_template_fallback.is_some_and(|f| f.ends_with(".json")) {
437        // User specified a file
438        let template_filename = chat_template_fallback
439            .expect("A tokenizer config or chat template file path must be specified.");
440        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
441    } else if chat_template_ovrd.is_some() {
442        None
443    } else {
444        info!("No chat template file found. Chat template may be set via `chat_template.json` or processor config.");
445        None
446    };
447    let mut template: ChatTemplate = match chat_template_ovrd {
448        Some(chat_template) => {
449            // In this case the override chat template is being used. The user must add the bos/eos/unk toks themselves.
450            info!("Using literal chat template.");
451            let mut template = ChatTemplate::default();
452            template.chat_template = Some(ChatTemplateValue(Either::Left(chat_template)));
453            template
454        }
455        None => {
456            if let Some(ref content) = template_content {
457                // Check if template_filename is a .jinja file
458                if let Some(template_filename) = paths.get_template_filename() {
459                    if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
460                        info!("Using chat template from .jinja file.");
461                        // Load special tokens (bos/eos/unk) from tokenizer_config.json
462                        // in the same directory, matching HF's behavior where
463                        // apply_chat_template passes self.special_tokens_map to the template.
464                        let mut template = template_filename
465                            .parent()
466                            .map(|dir| dir.join("tokenizer_config.json"))
467                            .filter(|p| p.exists())
468                            .and_then(|p| fs::read_to_string(p).ok())
469                            .and_then(|s| serde_json::from_str::<ChatTemplate>(&s).ok())
470                            .unwrap_or_default();
471                        template.chat_template =
472                            Some(ChatTemplateValue(Either::Left(content.clone())));
473                        template
474                    } else {
475                        serde_json::from_str(content).unwrap()
476                    }
477                } else {
478                    serde_json::from_str(content).unwrap()
479                }
480            } else {
481                // No template content available; downstream code may fill in from
482                // chat_template.json, processor_config, or jinja_explicit.
483                ChatTemplate::default()
484            }
485        }
486    };
487    // Overwrite to use any present `chat_template.json`, only if there is not one present already.
488    if template.chat_template.is_none() {
489        if let Some(chat_template_explicit) = chat_template_explicit {
490            let ct =
491                fs::read_to_string(chat_template_explicit).expect("Loading chat template failed.");
492
493            let new_chat_template = if chat_template_explicit.ends_with(".jinja") {
494                ct
495            } else {
496                #[derive(Debug, serde::Deserialize)]
497                struct AutomaticTemplate {
498                    chat_template: String,
499                }
500                let deser: AutomaticTemplate = serde_json::from_str(&ct).unwrap();
501                deser.chat_template
502            };
503
504            template.chat_template = Some(ChatTemplateValue(Either::Left(new_chat_template)));
505        }
506    }
507
508    // JINJA explicit
509    if let Some(jinja_explicit) = jinja_explicit {
510        if !jinja_explicit.ends_with(".jinja") {
511            panic!("jinja_explicit must end with .jinja!");
512        }
513
514        let ct = fs::read_to_string(jinja_explicit).expect("Loading chat template failed.");
515
516        template.chat_template = Some(ChatTemplateValue(Either::Left(ct)));
517    }
518
519    let processor_conf: Option<crate::vision_models::processor_config::ProcessorConfig> = paths
520        .get_processor_config()
521        .as_ref()
522        .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
523    if let Some(processor_conf) = processor_conf {
524        if processor_conf.chat_template.is_some() {
525            template.chat_template = processor_conf
526                .chat_template
527                .map(|x| ChatTemplateValue(Either::Left(x)));
528        }
529    }
530
531    #[derive(Debug, serde::Deserialize)]
532    struct SpecifiedTemplate {
533        chat_template: String,
534        bos_token: Option<String>,
535        eos_token: Option<String>,
536        unk_token: Option<String>,
537    }
538
539    if template.chat_template.is_some() {
540        return template;
541    };
542
543    match &template.chat_template {
544        Some(_) => template,
545        None => {
546            if let Some(template_content) = template_content {
547                info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
548                let mut deser: HashMap<String, Value> =
549                    serde_json::from_str(&template_content).unwrap();
550
551                match chat_template_fallback.cloned() {
552                    Some(t) => {
553                        info!("Loading specified loading chat template file at `{t}`.");
554                        let templ: SpecifiedTemplate =
555                            serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
556                        deser.insert(
557                            "chat_template".to_string(),
558                            Value::String(templ.chat_template),
559                        );
560                        if let Some(bos_token) = templ.bos_token {
561                            deser.insert("bos_token".to_string(), Value::String(bos_token));
562                        }
563                        if let Some(eos_token) = templ.eos_token {
564                            deser.insert("eos_token".to_string(), Value::String(eos_token));
565                        }
566                        if let Some(unk_token) = templ.unk_token {
567                            deser.insert("unk_token".to_string(), Value::String(unk_token));
568                        }
569                    }
570                    None => {
571                        warn!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
572                        deser.insert("chat_template".to_string(), Value::Null);
573                    }
574                }
575
576                let ser = serde_json::to_string_pretty(&deser)
577                    .expect("Serialization of modified chat template failed.");
578                serde_json::from_str(&ser).unwrap()
579            } else {
580                warn!("No chat template source found. No chat template will be used. Only prompts will be accepted, not messages.");
581                template
582            }
583        }
584    }
585}
586
587mod tests {
588    #[test]
589    fn match_safetensors() -> anyhow::Result<()> {
590        use regex_automata::meta::Regex;
591
592        use super::SAFETENSOR_MATCH;
593        let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
594
595        let positive_ids = [
596            "model-00001-of-00001.safetensors",
597            "model-00002-of-00002.safetensors",
598            "model-00003-of-00003.safetensors",
599            "model-00004-of-00004.safetensors",
600            "model-00005-of-00005.safetensors",
601            "model-00006-of-00006.safetensors",
602        ];
603        let negative_ids = [
604            "model-0000a-of-00002.safetensors",
605            "consolidated.safetensors",
606        ];
607        for id in positive_ids {
608            assert!(safetensor_match.is_match(id));
609        }
610        for id in negative_ids {
611            assert!(!safetensor_match.is_match(id));
612        }
613        Ok(())
614    }
615
616    #[test]
617    fn match_pickle() -> anyhow::Result<()> {
618        use regex_automata::meta::Regex;
619
620        use super::PICKLE_MATCH;
621        let pickle_match = Regex::new(PICKLE_MATCH)?;
622
623        let positive_ids = [
624            "pytorch_model-00001-of-00002.bin",
625            "pytorch_model-00002-of-00002.bin",
626        ];
627        let negative_ids = [
628            "pytorch_model-000001-of-00001.bin",
629            "pytorch_model-0000a-of-00002.bin",
630            "pytorch_model-000-of-00003.bin",
631            "pytorch_consolidated.bin",
632        ];
633        for id in positive_ids {
634            assert!(pickle_match.is_match(id));
635        }
636        for id in negative_ids {
637            assert!(!pickle_match.is_match(id));
638        }
639        Ok(())
640    }
641}